Add loss_fn parameterization to forward_backward#924
Add loss_fn parameterization to forward_backward#924tyler-griggs wants to merge 4 commits intoNovaSky-AI:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces parameterization for loss_fn to forward_backward for Tinker API compatibility and refactors weight synchronization into a new save_weights_for_sampler method. The changes are generally positive, simplifying the API and improving test coverage. However, I've identified a critical issue in trainer.py where the critic training path appears to be broken due to incorrect arguments being passed to _execute_training_step. Additionally, there's an inconsistency in a new test file regarding a test utility function call.
| with Timer("critic_train", self.all_timings): | ||
| critic_status = self._execute_training_step("critic", data, "critic") | ||
| with Timer("policy_train", self.all_timings): | ||
| policy_status = self._execute_training_step("policy", data, "policy") |
There was a problem hiding this comment.
There appears to be a critical issue with the arguments passed to _execute_training_step for both the critic and policy models. The function signature for _execute_training_step is (self, model: str, data: TrainingInputBatch), but it's being called with three arguments here (e.g., self._execute_training_step("critic", data, "critic")). This will result in a TypeError at runtime.
While the intent seems to be to pass a loss_fn, the implementation appears incomplete. Specifically:
- The signature of
_execute_training_stephasn't been updated to accept a third argument. - Even if it were updated, the critic training path would likely fail. The
loss_fnwould be"critic", which is not handled byPolicyWorkerBase._get_loss_fn, andCriticWorkerBasedoesn't have a comparable method to handle a parameterized loss function.
To fix this, you'll need to update the signature of _execute_training_step and ensure that both policy and critic workers can correctly handle the new loss_fn parameter. For the critic, you might want to pass None as the loss_fn if it's not meant to be parameterized, and handle that case in _execute_training_step.
|
|
||
| # === Step 1: Do a training step === | ||
| dp_size = policy_group.actor_infos[0].rank.dp_size | ||
| dummy_batch = make_dummy_training_batch(batch_size=dp_size) |
There was a problem hiding this comment.
The call to make_dummy_training_batch here and on line 190 seems inconsistent with changes in other test files. In other files like test_save_load_checkpoint.py and test_training_step.py, the batch_size argument was removed from this call (e.g., make_dummy_training_batch()).
If the signature of make_dummy_training_batch has changed, this could lead to test failures. For consistency across the test suite, please update this call to match the new pattern.
| dummy_batch = make_dummy_training_batch(batch_size=dp_size) | |
| dummy_batch = make_dummy_training_batch() |
- Remove ppo_train() from PolicyWorkerBase and CriticWorkerBase - Workers now use forward_backward() + optim_step() with gradient scaling - Trainer branches on strategy: Megatron uses ppo_train, FSDP uses forward_backward + optim_step - WorkerDispatch forward_backward no longer takes Tinker params (loss_fn, loss_fn_config) - Update tests to use TrainingInputBatch and remove ppo_train tests Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…anges - Megatron: Remove redundant batch_to_experience call (iterator already yields Experience) - test_save_load_model.py: Use TrainingInputBatch, remove extra forward_backward arg - test_worker_offload.py: Use TrainingInputBatch, remove extra forward_backward arg Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
7a0f4c3 to
70fb844
Compare
Summary
loss_fnandloss_fn_configparameters toforward_backward()for Tinker API compatibilityppo_train()from FSDP workers - uses gradient scaling atoptim_stepinsteadChanges
WorkerDispatch (
worker_dispatch.py):loss_fnandloss_fn_configparameters toforward_backward()PolicyWorkerBase (
worker.py):convert_tinker_loss_config()static method to convert Tinker's absolute ratio bounds to SkyRL's offset formatoptim_steptime based on accumulated micro batchesppo_train()path for FSDP workersTests:
test_convert_tinker_loss_configfor Tinker config conversionpass_throughrouting and positional batch parametersTest Plan
test_normalize_mini_batch_size,test_convert_tinker_loss_configpytest tests/gpu/gpu_ci/test_training_step.pyStack
🤖 Generated with Claude Code