You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
To implement callbacks in this training library, we have to worry about two major pieces:
Taking callbacks that were provided to the run_training API and passing them to the torchrun subprocesses which run the actual training loop
Receiving the callbacks inside of the training loop (child process of torchrun) and executing them at the appropriate points without interfering with the training process
We will eventually also want to implement callbacks which can modify what actions are performed by the training loop itself, however this will require some additional effort to ensure proper synchronization of the global process group. These callbacks will not be covered in this issue though.
The callback API
We can define a general callback as a generic function which accepts some general context as well as a list of args + kwargs:
defmy_callback(context, **kwargs) ->None:
# do stuffreturn
Callbacks may also have arguments for certain events, e.g. on_save may have the following interface:
defon_save(context, checkpoint_path: str, **kwargs) ->None:
# do stuffreturn
While the on_evaluate callback would look like this:
defon_evaluate(context, validation_loss: float, **kwargs) ->None:
# do stuffreturn
We expect these callbacks to be entirely self-contained, so that any and all packages are imported inside of the callback function and must be present within whatever venv the training loop runs in.
Additionally, since we expect the initial callbacks to be read-only and inconsequential to the training loop, the training loop will wrap them as asynchronous tasks which run on the global event loop. For this reason, callbacks should be careful with their usage of async functionality.
Warning
To prevent any unexpected behavior or complex testing scenarios, the callbacks will not be allowed to have their exceptions propagate into the training loop, therefore any exceptions must be handled by the callback or they will fail silently.
Adding callbacks to TrainingArgs
To maintain a simple interface, the TrainingArgs Pydantic model (in src/instructlab/training/config.py) would expose registration for each callback with a flat interface like this:
frominstructlab.training.configimportTrainingArgsargs=TrainingArgs(
on_save=evaluate_gsm8k,
on_evaluate=save_if_improved,
...
)
# then we call trainingrun_training(args, torch_args)
The API for these event hooks should also allow providing multiple callbacks. You would see a definition like this:
In order to execute callbacks during training, they must first be passed between the main process where run_training is called and the child processes that are spawned to run the training loop in parallel. When run_training is first called, it uses the subprocess.Popen API (via StreamablePopen from utils.py) in order to invoke torchrun which then creates a worker for each GPU. The training loop itself expects to be invoked as a CLI, so run_training maps all of the fields in the TrainingArgs to the CLI flags expected by the actual training script. You can see this in action here: https://github.com/instructlab/training/blob/main/src/instructlab/training/main_ds.py#L624-L801
Therefore we can provide these callbacks to our training script by first serializing them with base64, passing them as string arguments to the CLI, then unserializing and eval-ing each callback to register it in the child's memory space. As long as the callback is self-contained, it should execute on the child process just as it would in the main process.
Registering and executing callbacks
Once they're unserialized, each callback will be registered on the global main/rank-0 process only by a callback manager.
This is necessary to avoid race conditions or a flood of evaluation requests from each process after every checkpoint save. We can add this functionality later, but we should avoid it to start so the initial implementation can be straightforward.
In order to manage callbacks, we should have a class like CallbacksManager which registers the callbacks and is responsible for providing them with the relevant context when they're called.
The callback manager would have an on_<event> method for each of the following events:
on_train_begin: Called after model.train() and state initialization, before the epoch loop (~line 200 in main_ds.py:train())
on_epoch_begin: Called at the top of the epoch for loop, after sampler.set_epoch(epoch) (~line 208)
on_step_begin: Called at the beginning of each training step, after start = time.time() (~line 221)
on_before_forward: Called inside BatchLossManager.process_batch(), before model.compute_loss() (in batch_loss_manager.py)
on_after_forward: Called inside BatchLossManager.process_batch(), after model.compute_loss() returns (scaled_loss, raw_losses), before backward
on_before_backward: Called inside BatchLossManager.process_batch(), before accelerator.backward(scaled_loss)
on_after_backward: Called inside BatchLossManager.process_batch(), after accelerator.backward(scaled_loss)
on_before_optimizer_step: Called before accelerator.take_optimizer_step() (~line 236), which internally does clip_grad_norm_, optimizer.step(), lr_scheduler.step(), optimizer.zero_grad()
on_after_optimizer_step: Called after accelerator.take_optimizer_step() (~line 236)
on_log: Called around metric_logger.info() (~lines 258-279), after metrics dict is built with elapsed_time, overall_throughput, current_lr, cuda_mem_allocated, global_grad_norm, etc.
on_evaluate: Called after compute_validation_loss() (~line 295), which returns {"val_loss": float, "val_num_tokens": int}
on_save: Called after save_checkpoint() and dist.barrier() (~lines 297-309)
on_step_end: Called after global_step += 1 and torch.cuda.empty_cache() (~line 314)
on_epoch_end: Called after the inner step loop exits, before/after epoch checkpoint (~line 315)
on_train_end: Called after the epoch loop exits, before final save (~line 331)
Internally, the callbacks should be treated as asynchronous coroutines. When the manager's event hook is triggered, it would create a task for each of the registered callbacks and schedule them for execution on the global event loop. There should be no expectation of reading the return values of the callbacks or that they pass values from one to another.
Context available to callbacks
The train() function has several key variables that should be exposed to callbacks via a context object:
The train() function is the core loop in main_ds.py (~lines 170-340). It receives model, accelerator, train/val loaders, and args. The CallbacksManager would be passed as an additional parameter.
The run_training() API launches a subprocess (main_ds.py, lines 571-843) via StreamablePopen. TrainingArgs fields are mapped 1:1 to CLI flags and parsed back via argparse in the subprocess. Callbacks must be serialized (e.g., base64-encoded source via inspect.getsource or dill/cloudpickle) to cross this process boundary.
Forward/backward passes are inside BatchLossManager (batch_loss_manager.py), not inline in train(). The process_batch() method loops over minibatches, calls model.compute_loss(), then accelerator.backward(). Hooks around forward/backward need to be injected into this class.
Optimizer step is inside Accelerator (accelerator.py). The take_optimizer_step() method bundles clip_grad_norm_, optimizer.step(), lr_scheduler.step(), and optimizer.zero_grad().
Validation is a separate functioncompute_validation_loss() (~lines 87-167 in main_ds.py). It switches to model.eval(), runs a forward-only loop over the val loader, does all_reduce for loss/tokens across ranks, then restores model.train().
Logging uses Python's logging module with custom handlers (JSONL async writer, TensorBoard, W&B, MLflow) configured at setup time. The callback system should coexist with this, not replace it.
The loop is distributed (FSDP or DeepSpeed via HF Accelerate). Callbacks should fire on rank-0 only by default to avoid race conditions. The context object should expose is_main_process for callbacks that need to check.
Overview
Goal: Add Callbacks to Instructlab-Training
Why:
We want to facilitate use-cases such as:
In order to implement this, we would add features in the following order:
Related Issues:
Expected usage
Users calling the
run_trainingAPI would directly provide their callbacks to theTrainingArgsobject through a flat API.As an example:
Implementation Details
To implement callbacks in this training library, we have to worry about two major pieces:
run_trainingAPI and passing them to thetorchrunsubprocesses which run the actual training looptorchrun) and executing them at the appropriate points without interfering with the training processWe will eventually also want to implement callbacks which can modify what actions are performed by the training loop itself, however this will require some additional effort to ensure proper synchronization of the global process group. These callbacks will not be covered in this issue though.
The callback API
We can define a general callback as a generic function which accepts some general context as well as a list of args + kwargs:
Callbacks may also have arguments for certain events, e.g.
on_savemay have the following interface:While the
on_evaluatecallback would look like this:We expect these callbacks to be entirely self-contained, so that any and all packages are imported inside of the callback function and must be present within whatever venv the training loop runs in.
Additionally, since we expect the initial callbacks to be read-only and inconsequential to the training loop, the training loop will wrap them as asynchronous tasks which run on the global event loop. For this reason, callbacks should be careful with their usage of async functionality.
Warning
To prevent any unexpected behavior or complex testing scenarios, the callbacks will not be allowed to have their exceptions propagate into the training loop, therefore any exceptions must be handled by the callback or they will fail silently.
Adding callbacks to
TrainingArgsTo maintain a simple interface, the
TrainingArgsPydantic model (insrc/instructlab/training/config.py) would expose registration for each callback with a flat interface like this:The API for these event hooks should also allow providing multiple callbacks. You would see a definition like this:
Passing callbacks to
torchrunprocsIn order to execute callbacks during training, they must first be passed between the main process where
run_trainingis called and the child processes that are spawned to run the training loop in parallel. Whenrun_trainingis first called, it uses thesubprocess.PopenAPI (viaStreamablePopenfromutils.py) in order to invoketorchrunwhich then creates a worker for each GPU. The training loop itself expects to be invoked as a CLI, sorun_trainingmaps all of the fields in theTrainingArgsto the CLI flags expected by the actual training script. You can see this in action here: https://github.com/instructlab/training/blob/main/src/instructlab/training/main_ds.py#L624-L801Therefore we can provide these callbacks to our training script by first serializing them with base64, passing them as string arguments to the CLI, then unserializing and
eval-ing each callback to register it in the child's memory space. As long as the callback is self-contained, it should execute on the child process just as it would in the main process.Registering and executing callbacks
Once they're unserialized, each callback will be registered on the global main/rank-0 process only by a callback manager.
This is necessary to avoid race conditions or a flood of evaluation requests from each process after every checkpoint save. We can add this functionality later, but we should avoid it to start so the initial implementation can be straightforward.
In order to manage callbacks, we should have a class like
CallbacksManagerwhich registers the callbacks and is responsible for providing them with the relevant context when they're called.The callback manager would have an
on_<event>method for each of the following events:on_train_begin: Called aftermodel.train()and state initialization, before the epoch loop (~line 200 inmain_ds.py:train())on_epoch_begin: Called at the top of the epochforloop, aftersampler.set_epoch(epoch)(~line 208)on_step_begin: Called at the beginning of each training step, afterstart = time.time()(~line 221)on_before_forward: Called insideBatchLossManager.process_batch(), beforemodel.compute_loss()(inbatch_loss_manager.py)on_after_forward: Called insideBatchLossManager.process_batch(), aftermodel.compute_loss()returns(scaled_loss, raw_losses), before backwardon_before_backward: Called insideBatchLossManager.process_batch(), beforeaccelerator.backward(scaled_loss)on_after_backward: Called insideBatchLossManager.process_batch(), afteraccelerator.backward(scaled_loss)on_before_optimizer_step: Called beforeaccelerator.take_optimizer_step()(~line 236), which internally doesclip_grad_norm_,optimizer.step(),lr_scheduler.step(),optimizer.zero_grad()on_after_optimizer_step: Called afteraccelerator.take_optimizer_step()(~line 236)on_log: Called aroundmetric_logger.info()(~lines 258-279), after metrics dict is built withelapsed_time,overall_throughput,current_lr,cuda_mem_allocated,global_grad_norm, etc.on_evaluate: Called aftercompute_validation_loss()(~line 295), which returns{"val_loss": float, "val_num_tokens": int}on_save: Called aftersave_checkpoint()anddist.barrier()(~lines 297-309)on_step_end: Called afterglobal_step += 1andtorch.cuda.empty_cache()(~line 314)on_epoch_end: Called after the inner step loop exits, before/after epoch checkpoint (~line 315)on_train_end: Called after the epoch loop exits, before final save (~line 331)Internally, the callbacks should be treated as asynchronous coroutines. When the manager's event hook is triggered, it would create a task for each of the registered callbacks and schedule them for execution on the global event loop. There should be no expectation of reading the return values of the callbacks or that they pass values from one to another.
Context available to callbacks
The
train()function has several key variables that should be exposed to callbacks via a context object:global_stepepochsamples_seenavg_loss_across_ranksbatch_metricsBatchMetricsdataclass:total_samples,total_length,num_loss_counted_tokens,accumulated_loss,accumulated_aux_loss,grad_accum_steps,num_minibatchesval_metricsval_lossandval_num_tokens(when validation is enabled)elapsed_timeoverall_throughputcurrent_lrglobal_grad_normcuda_mem_allocatedargsKey files involved
src/instructlab/training/config.pyTrainingArgs,TorchrunArgs, etc.)src/instructlab/training/main_ds.pyrun_training(),main()init,train()loopsrc/instructlab/training/batch_loss_manager.pyBatchLossManager— forward/backward loop over minibatches, loss reduction. Forward/backward hooks go here.src/instructlab/training/model.pyModelwrapper,CausalLMModel,LigerModel,compute_loss(),setup_optimizer()src/instructlab/training/accelerator.pyAcceleratorwrapper around HF Accelerate, FSDP/DS config,take_optimizer_step()src/instructlab/training/utils.pysave_checkpoint(),save_hf_format_accelerate(),load_latest_full_state(),StreamablePopenArchitectural notes
The
train()function is the core loop inmain_ds.py(~lines 170-340). It receives model, accelerator, train/val loaders, and args. TheCallbacksManagerwould be passed as an additional parameter.The
run_training()API launches a subprocess (main_ds.py, lines 571-843) viaStreamablePopen.TrainingArgsfields are mapped 1:1 to CLI flags and parsed back viaargparsein the subprocess. Callbacks must be serialized (e.g., base64-encoded source viainspect.getsourceordill/cloudpickle) to cross this process boundary.Forward/backward passes are inside
BatchLossManager(batch_loss_manager.py), not inline intrain(). Theprocess_batch()method loops over minibatches, callsmodel.compute_loss(), thenaccelerator.backward(). Hooks around forward/backward need to be injected into this class.Optimizer step is inside
Accelerator(accelerator.py). Thetake_optimizer_step()method bundlesclip_grad_norm_,optimizer.step(),lr_scheduler.step(), andoptimizer.zero_grad().Validation is a separate function
compute_validation_loss()(~lines 87-167 inmain_ds.py). It switches tomodel.eval(), runs a forward-only loop over the val loader, doesall_reducefor loss/tokens across ranks, then restoresmodel.train().Logging uses Python's
loggingmodule with custom handlers (JSONL async writer, TensorBoard, W&B, MLflow) configured at setup time. The callback system should coexist with this, not replace it.The loop is distributed (FSDP or DeepSpeed via HF Accelerate). Callbacks should fire on rank-0 only by default to avoid race conditions. The context object should expose
is_main_processfor callbacks that need to check.