diff --git a/skyrl-train/skyrl_train/distributed/fsdp_strategy.py b/skyrl-train/skyrl_train/distributed/fsdp_strategy.py index f2e12a3b9..6f87507a5 100644 --- a/skyrl-train/skyrl_train/distributed/fsdp_strategy.py +++ b/skyrl-train/skyrl_train/distributed/fsdp_strategy.py @@ -39,7 +39,6 @@ fsdp_version, fsdp2_load_full_state_dict, ) -from transformers.trainer import get_scheduler from packaging import version @@ -186,8 +185,6 @@ def optimizer_step( return grad_norm optimizer.step() - if scheduler is not None: - scheduler.step() optimizer.zero_grad() return grad_norm @@ -288,23 +285,15 @@ def _fsdp_init_train_model(self, model, optimizer, scheduler): betas=optim_config.adam_betas, weight_decay=optim_config.weight_decay, ) - - lr_scheduler = get_scheduler( - optim_config.scheduler, - new_optimizer, - num_warmup_steps=optim_config.num_warmup_steps, - num_training_steps=self.total_training_steps, - ) else: new_optimizer = None - lr_scheduler = None if is_wrapped: model.model = fsdp_module else: model = fsdp_module - - return model, new_optimizer, lr_scheduler + # backwards compatibility, return None for lr_scheduler (now controlled by Tinker) + return model, new_optimizer, None def _fsdp_init_eval_model(self, model): """Initialize a model for evaluation with FSDP""" @@ -453,14 +442,9 @@ def save_checkpoint( with io.open_file(optim_path, "wb") as f: torch.save(optimizer_state_dict, f) - # Get scheduler state dict if scheduler is provided - lr_scheduler_state_dict = {} - if scheduler is not None: - lr_scheduler_state_dict = scheduler.state_dict() - # Create extra state dict with client state and any additional info extra_state_dict = { - "lr_scheduler": lr_scheduler_state_dict, + "lr_scheduler": {}, "client_state": client_state, "tag": tag, "fsdp_strategy": self.fsdp_strategy, @@ -556,7 +540,7 @@ def load_checkpoint( optimizer_state_dict = torch.load(f, map_location="cpu", weights_only=False) # Extract scheduler state from extra state - lr_scheduler_state_dict = extra_state_dict.get("lr_scheduler", {}) + # lr_scheduler_state_dict = extra_state_dict.get("lr_scheduler", {}) # Set up state dict configurations for sharded loading state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True) @@ -575,11 +559,6 @@ def load_checkpoint( optimizer.load_state_dict(optimizer_state_dict) self.print(f"[rank-{rank}]: Successfully loaded optimizer state") - # Load scheduler state dict if scheduler object is provided and loading is requested - if scheduler is not None and load_lr_scheduler_states: - scheduler.load_state_dict(lr_scheduler_state_dict) - self.print(f"[rank-{rank}]: Successfully loaded scheduler state") - # Load RNG state for reproducibility if "rng" in extra_state_dict: self.load_rng_state(extra_state_dict["rng"]) diff --git a/skyrl-train/skyrl_train/distributed/megatron/megatron_strategy.py b/skyrl-train/skyrl_train/distributed/megatron/megatron_strategy.py index 6884feb2c..3749d656c 100644 --- a/skyrl-train/skyrl_train/distributed/megatron/megatron_strategy.py +++ b/skyrl-train/skyrl_train/distributed/megatron/megatron_strategy.py @@ -120,13 +120,13 @@ def optimizer_step( self, optimizer: optim.Optimizer, model, - scheduler, + scheduler=None, name="model", **kwargs, ) -> Optional[Float[torch.Tensor, "1"]]: """Perform optimizer step""" _, grad_norm, _ = optimizer.step() - scheduler.step(1) + # Do not call scheduler.step() - Tinker controls LR optimizer.zero_grad() return grad_norm @@ -158,15 +158,13 @@ def save_checkpoint( # All ranks wait for the checkpoint directory to be created before saving. dist.barrier() - # Collect the sharded state dicts for model and optimizer, and full state dict for the scheduler. + # Collect the sharded state dicts for model and optimizer. sharded_state_dict = {} model_sharded_state_dict = unwrapped_model.sharded_state_dict() if not self.is_lora: sharded_state_dict["model"] = model_sharded_state_dict if optimizer: sharded_state_dict["optimizer"] = optimizer.sharded_state_dict(model_sharded_state_dict) - if scheduler: - sharded_state_dict["lr_scheduler"] = scheduler.state_dict() # Save RNG state. sharded_state_dict["rng"] = self.get_rng_state() @@ -257,8 +255,6 @@ def load_checkpoint( sharded_state_dict["model"] = model_sharded_state_dict if optimizer and load_optimizer_states: sharded_state_dict["optimizer"] = optimizer.sharded_state_dict(model_sharded_state_dict) - if scheduler and load_lr_scheduler_states: - sharded_state_dict["lr_scheduler"] = scheduler.state_dict() with io.local_read_dir(ckpt_dir) as read_dir: # Load the checkpoint in parallel. @@ -286,13 +282,6 @@ def load_checkpoint( optimizer.load_state_dict(state_dict["optimizer"]) self.print("Loaded optimizer state dict.") - if scheduler and load_lr_scheduler_states: - assert ( - "lr_scheduler" in state_dict - ), f"LR scheduler state dict not found in checkpoint loaded from {ckpt_dir}. Available keys: {state_dict.keys()}" - scheduler.load_state_dict(state_dict["lr_scheduler"]) - self.print("Loaded LR scheduler state dict.") - # Load RNG state, if present. if "rng" in state_dict: self.load_rng_state(state_dict["rng"]) diff --git a/skyrl-train/skyrl_train/distributed/strategy.py b/skyrl-train/skyrl_train/distributed/strategy.py index acceccb45..ebf1463db 100644 --- a/skyrl-train/skyrl_train/distributed/strategy.py +++ b/skyrl-train/skyrl_train/distributed/strategy.py @@ -30,7 +30,7 @@ def optimizer_step( self, optimizer: optim.Optimizer, model, - scheduler, + scheduler=None, name="model", **kwargs, ) -> Optional[Float[torch.Tensor, "1"]]: @@ -38,13 +38,20 @@ def optimizer_step( pass @abstractmethod - def save_checkpoint(self, model, ckpt_dir, node_local_rank, optimizer, scheduler, tokenizer): + def save_checkpoint(self, model, ckpt_dir, node_local_rank, optimizer, scheduler=None, tokenizer=None): """Save checkpoint""" pass @abstractmethod def load_checkpoint( - self, model, ckpt_dir, optimizer, scheduler, load_module_strict, load_optimizer_states, load_lr_scheduler_states + self, + model, + ckpt_dir, + optimizer, + scheduler=None, + load_module_strict=True, + load_optimizer_states=True, + load_lr_scheduler_states=False, ): """Load checkpoint""" pass diff --git a/skyrl-train/skyrl_train/trainer.py b/skyrl-train/skyrl_train/trainer.py index 4882007e1..d0ab1d7ab 100644 --- a/skyrl-train/skyrl_train/trainer.py +++ b/skyrl-train/skyrl_train/trainer.py @@ -1327,7 +1327,6 @@ def load_checkpoints(self) -> Tuple[int, str]: "policy", policy_ckpt_dir, load_optimizer_states=True, - load_lr_scheduler_states=True, ) logger.info("Successfully loaded policy checkpoint") @@ -1338,7 +1337,6 @@ def load_checkpoints(self) -> Tuple[int, str]: "critic", critic_ckpt_dir, load_optimizer_states=True, - load_lr_scheduler_states=True, ) logger.info("Successfully loaded critic checkpoint") diff --git a/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py b/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py index a3ade9879..6e927fd24 100644 --- a/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py +++ b/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py @@ -148,12 +148,10 @@ def init_model(self, model_path, num_training_steps: int = None): } ) - self.model, self.optimizer, self.scheduler = strategy.prepare( + self.model, self.optimizer, _ = strategy.prepare( (wrapped_model, None, None), ) - assert ( - self.optimizer is not None and self.scheduler is not None - ), "FSDP preparation should create optimizer and scheduler" + assert self.optimizer is not None, "FSDP preparation should create optimizer" # Initialize weight extractor # TODO(haochen): Now module grouping (in order to support FlashRL) is only enabled for the CUDA IPC @@ -306,7 +304,7 @@ def init_model(self, model_path, num_training_steps: int = None): ) # prepare models/optimizers... - self.model, self.optimizer, self.scheduler = strategy.prepare( + self.model, self.optimizer, _ = strategy.prepare( (critic, None, None), ) assert self.optimizer is not None diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py index 60c591d34..d6b8abe66 100644 --- a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py @@ -16,12 +16,10 @@ from megatron.bridge.peft.canonical_lora import CanonicalLoRA import megatron.core.parallel_state as mpu from megatron.core.optimizer import DistributedOptimizer, ChainedOptimizer -from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler from skyrl_train.distributed.megatron.optimizer import ( init_megatron_optim_config, get_megatron_optimizer, - get_megatron_optimizer_param_scheduler, ) from skyrl_train.distributed.dispatch import MeshRank from skyrl_train.distributed.megatron.megatron_strategy import MegatronStrategy @@ -372,7 +370,6 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self.model: MegatronModelWrapper = None self.actor_module: List[nn.Module] = None - self.scheduler: OptimizerParamScheduler = None self.optimizer: DistributedOptimizer = None self.profiler: Profiler = None self._is_lora = self.cfg.trainer.policy.model.lora.rank > 0 @@ -432,7 +429,8 @@ def _broadcast_no_grad(*args, **kwargs): def init_model(self, model_path, num_training_steps: int = 1e9): """ - Initialize the model, optimizer, and scheduler for the policy worker. + Initialize the model and optimizer for the policy worker. + Tinker controls the LR so no need to init that. """ # initialize the bridge and provider objects self.init_configs( @@ -472,13 +470,6 @@ def init_model(self, model_path, num_training_steps: int = 1e9): ) self.optimizer = get_megatron_optimizer(self.actor_module, optim_config) - # create scheduler - self.scheduler = get_megatron_optimizer_param_scheduler( - optimizer=self.optimizer, - config=self.cfg.trainer.policy.optimizer_config, - num_training_steps=num_training_steps, - ) - # create worker model self.model = MegatronModelWrapper( config=self.cfg, @@ -608,7 +599,7 @@ def optim_step(self) -> Optional[float]: Returns: The gradient norm (before scaling, after clipping), or None if unavailable. """ - grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler, name="actor") + grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, None, name="actor") # Reset counter for next accumulation cycle self._micro_batches_accumulated = 0 @@ -635,9 +626,7 @@ def set_lr(self, learning_rate: float) -> None: distributed optimizer). Updates all param_groups across all underlying optimizers. - Note: This bypasses the scheduler. The next scheduler.step() call - will override this value unless the scheduler is configured for - constant LR. + Tinker uses this function to set the LR """ if isinstance(self.optimizer, ChainedOptimizer): # ChainedOptimizer wraps multiple optimizers (e.g., for different param groups) diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 10883b94a..af4e54f0e 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -22,7 +22,6 @@ placement_group_table, ) from torch.optim import Optimizer -from torch.optim.lr_scheduler import LRScheduler from transformers import PreTrainedModel from skyrl_train.dataset.replay_buffer import Experience @@ -633,7 +632,6 @@ class PolicyWorkerBase(Worker): def __init__(self, **kwargs): super().__init__(**kwargs) self.model: nn.Module = None - self.scheduler: LRScheduler = None self.optimizer: Optimizer = None self.strategy: DistributedStrategy = None self.record_memory: bool = False @@ -824,7 +822,7 @@ def _forward_backward_micro( status = { "loss": loss.item(), "response_length": num_actions, - "lr": self.scheduler.get_last_lr()[0], + "lr": self.get_lr(), "loss_fn_outputs": loss_fn_outputs, } else: @@ -863,7 +861,7 @@ def _forward_backward_micro( "ppo_clip_ratio": clip_ratio, "policy_entropy": entropy.item(), "response_length": num_actions, - "policy_lr": self.scheduler.get_last_lr()[0], + "policy_lr": self.get_lr(), } if self.cfg.trainer.algorithm.use_kl_loss: status["policy_kl"] = kl_loss.item() @@ -895,7 +893,7 @@ def optim_step(self) -> float: param.grad.mul_(scale) # Perform optimizer step (includes gradient clipping) - grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler, name="actor") + grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, None, name="actor") # Reset counter for next accumulation cycle self._micro_batches_accumulated = 0 @@ -936,22 +934,20 @@ def save_checkpoint(self, ckpt_dir: Path, tokenizer=None): self.strategy.save_checkpoint( model=self.model, optimizer=self.optimizer, - scheduler=self.scheduler, + scheduler=None, ckpt_dir=ckpt_dir, node_local_rank=self.get_node_local_rank(), tokenizer=tokenizer, ) - def load_checkpoint( - self, ckpt_dir: Path, load_optimizer_states: bool = True, load_lr_scheduler_states: bool = True - ): + def load_checkpoint(self, ckpt_dir: Path, load_optimizer_states: bool = True): _, states = self.strategy.load_checkpoint( model=self.model, optimizer=self.optimizer if load_optimizer_states else None, - scheduler=self.scheduler if load_lr_scheduler_states else None, + scheduler=None, ckpt_dir=ckpt_dir, load_optimizer_states=load_optimizer_states, - load_lr_scheduler_states=load_lr_scheduler_states, + load_lr_scheduler_states=False, ) return states @@ -994,7 +990,6 @@ class CriticWorkerBase(Worker): def __init__(self, **kwargs): super().__init__(**kwargs) self.model: nn.Module = None - self.scheduler: LRScheduler = None self.optimizer: Optimizer = None self.strategy: DistributedStrategy = None self.record_memory: bool = False @@ -1092,7 +1087,7 @@ def _forward_backward_micro(self, experience: Experience) -> Dict[str, float]: "critic_loss": loss.item(), "values_mean": masked_mean(values, loss_mask).item(), "values_clipfrac": clipfrac, - "critic_lr": self.scheduler.get_last_lr()[0], + "critic_lr": self.get_lr(), } # All-reduce metrics across DP workers @@ -1115,7 +1110,7 @@ def optim_step(self) -> float: param.grad.mul_(scale) # Perform optimizer step (includes gradient clipping) - grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler, name="critic") + grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, None, name="critic") # Reset counter for next accumulation cycle self._micro_batches_accumulated = 0 @@ -1189,20 +1184,20 @@ def save_checkpoint(self, ckpt_dir: str, tokenizer=None): self.strategy.save_checkpoint( model=self.model, optimizer=self.optimizer, - scheduler=self.scheduler, + scheduler=None, ckpt_dir=ckpt_dir, node_local_rank=self.get_node_local_rank(), tokenizer=tokenizer, ) - def load_checkpoint(self, ckpt_dir=None, load_optimizer_states=True, load_lr_scheduler_states=True): + def load_checkpoint(self, ckpt_dir=None, load_optimizer_states=True): _, states = self.strategy.load_checkpoint( model=self.model, optimizer=self.optimizer if load_optimizer_states else None, - scheduler=self.scheduler if load_lr_scheduler_states else None, + scheduler=None, ckpt_dir=ckpt_dir, load_optimizer_states=load_optimizer_states, - load_lr_scheduler_states=load_lr_scheduler_states, + load_lr_scheduler_states=False, ) return states diff --git a/skyrl-train/skyrl_train/workers/worker_dispatch.py b/skyrl-train/skyrl_train/workers/worker_dispatch.py index 42070985a..6242b7d96 100644 --- a/skyrl-train/skyrl_train/workers/worker_dispatch.py +++ b/skyrl-train/skyrl_train/workers/worker_dispatch.py @@ -302,7 +302,6 @@ def load_checkpoint( model: str, ckpt_dir: str, load_optimizer_states: bool = True, - load_lr_scheduler_states: bool = True, ) -> None: """Load checkpoint for model.""" self._ensure_on_gpu(model, need_optimizer=load_optimizer_states, need_model=True) @@ -313,7 +312,6 @@ def load_checkpoint( "load_checkpoint", ckpt_dir=ckpt_dir, load_optimizer_states=load_optimizer_states, - load_lr_scheduler_states=load_lr_scheduler_states, ) )