Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 4 additions & 25 deletions skyrl-train/skyrl_train/distributed/fsdp_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
fsdp_version,
fsdp2_load_full_state_dict,
)
from transformers.trainer import get_scheduler

from packaging import version

Expand Down Expand Up @@ -186,8 +185,6 @@ def optimizer_step(
return grad_norm

optimizer.step()
if scheduler is not None:
scheduler.step()
optimizer.zero_grad()
return grad_norm

Expand Down Expand Up @@ -288,23 +285,15 @@ def _fsdp_init_train_model(self, model, optimizer, scheduler):
betas=optim_config.adam_betas,
weight_decay=optim_config.weight_decay,
)

lr_scheduler = get_scheduler(
optim_config.scheduler,
new_optimizer,
num_warmup_steps=optim_config.num_warmup_steps,
num_training_steps=self.total_training_steps,
)
else:
new_optimizer = None
lr_scheduler = None

if is_wrapped:
model.model = fsdp_module
else:
model = fsdp_module

return model, new_optimizer, lr_scheduler
# backwards compatibility, return None for lr_scheduler (now controlled by Tinker)
return model, new_optimizer, None

def _fsdp_init_eval_model(self, model):
"""Initialize a model for evaluation with FSDP"""
Expand Down Expand Up @@ -453,14 +442,9 @@ def save_checkpoint(
with io.open_file(optim_path, "wb") as f:
torch.save(optimizer_state_dict, f)

# Get scheduler state dict if scheduler is provided
lr_scheduler_state_dict = {}
if scheduler is not None:
lr_scheduler_state_dict = scheduler.state_dict()

# Create extra state dict with client state and any additional info
extra_state_dict = {
"lr_scheduler": lr_scheduler_state_dict,
"lr_scheduler": {},
"client_state": client_state,
"tag": tag,
"fsdp_strategy": self.fsdp_strategy,
Expand Down Expand Up @@ -556,7 +540,7 @@ def load_checkpoint(
optimizer_state_dict = torch.load(f, map_location="cpu", weights_only=False)

# Extract scheduler state from extra state
lr_scheduler_state_dict = extra_state_dict.get("lr_scheduler", {})
# lr_scheduler_state_dict = extra_state_dict.get("lr_scheduler", {})

# Set up state dict configurations for sharded loading
state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True)
Expand All @@ -575,11 +559,6 @@ def load_checkpoint(
optimizer.load_state_dict(optimizer_state_dict)
self.print(f"[rank-{rank}]: Successfully loaded optimizer state")

# Load scheduler state dict if scheduler object is provided and loading is requested
if scheduler is not None and load_lr_scheduler_states:
scheduler.load_state_dict(lr_scheduler_state_dict)
self.print(f"[rank-{rank}]: Successfully loaded scheduler state")

# Load RNG state for reproducibility
if "rng" in extra_state_dict:
self.load_rng_state(extra_state_dict["rng"])
Expand Down
17 changes: 3 additions & 14 deletions skyrl-train/skyrl_train/distributed/megatron/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ def optimizer_step(
self,
optimizer: optim.Optimizer,
model,
scheduler,
scheduler=None,
name="model",
**kwargs,
) -> Optional[Float[torch.Tensor, "1"]]:
"""Perform optimizer step"""
_, grad_norm, _ = optimizer.step()
scheduler.step(1)
# Do not call scheduler.step() - Tinker controls LR
optimizer.zero_grad()
return grad_norm

Expand Down Expand Up @@ -158,15 +158,13 @@ def save_checkpoint(
# All ranks wait for the checkpoint directory to be created before saving.
dist.barrier()

# Collect the sharded state dicts for model and optimizer, and full state dict for the scheduler.
# Collect the sharded state dicts for model and optimizer.
sharded_state_dict = {}
model_sharded_state_dict = unwrapped_model.sharded_state_dict()
if not self.is_lora:
sharded_state_dict["model"] = model_sharded_state_dict
if optimizer:
sharded_state_dict["optimizer"] = optimizer.sharded_state_dict(model_sharded_state_dict)
if scheduler:
sharded_state_dict["lr_scheduler"] = scheduler.state_dict()

# Save RNG state.
sharded_state_dict["rng"] = self.get_rng_state()
Expand Down Expand Up @@ -257,8 +255,6 @@ def load_checkpoint(
sharded_state_dict["model"] = model_sharded_state_dict
if optimizer and load_optimizer_states:
sharded_state_dict["optimizer"] = optimizer.sharded_state_dict(model_sharded_state_dict)
if scheduler and load_lr_scheduler_states:
sharded_state_dict["lr_scheduler"] = scheduler.state_dict()

with io.local_read_dir(ckpt_dir) as read_dir:
# Load the checkpoint in parallel.
Expand Down Expand Up @@ -286,13 +282,6 @@ def load_checkpoint(
optimizer.load_state_dict(state_dict["optimizer"])
self.print("Loaded optimizer state dict.")

if scheduler and load_lr_scheduler_states:
assert (
"lr_scheduler" in state_dict
), f"LR scheduler state dict not found in checkpoint loaded from {ckpt_dir}. Available keys: {state_dict.keys()}"
scheduler.load_state_dict(state_dict["lr_scheduler"])
self.print("Loaded LR scheduler state dict.")

# Load RNG state, if present.
if "rng" in state_dict:
self.load_rng_state(state_dict["rng"])
Expand Down
13 changes: 10 additions & 3 deletions skyrl-train/skyrl_train/distributed/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,28 @@ def optimizer_step(
self,
optimizer: optim.Optimizer,
model,
scheduler,
scheduler=None,
name="model",
**kwargs,
) -> Optional[Float[torch.Tensor, "1"]]:
"""Perform optimizer step"""
pass

@abstractmethod
def save_checkpoint(self, model, ckpt_dir, node_local_rank, optimizer, scheduler, tokenizer):
def save_checkpoint(self, model, ckpt_dir, node_local_rank, optimizer, scheduler=None, tokenizer=None):
"""Save checkpoint"""
pass

@abstractmethod
def load_checkpoint(
self, model, ckpt_dir, optimizer, scheduler, load_module_strict, load_optimizer_states, load_lr_scheduler_states
self,
model,
ckpt_dir,
optimizer,
scheduler=None,
load_module_strict=True,
load_optimizer_states=True,
load_lr_scheduler_states=False,
):
"""Load checkpoint"""
pass
Expand Down
2 changes: 0 additions & 2 deletions skyrl-train/skyrl_train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,7 +1327,6 @@ def load_checkpoints(self) -> Tuple[int, str]:
"policy",
policy_ckpt_dir,
load_optimizer_states=True,
load_lr_scheduler_states=True,
)
logger.info("Successfully loaded policy checkpoint")

Expand All @@ -1338,7 +1337,6 @@ def load_checkpoints(self) -> Tuple[int, str]:
"critic",
critic_ckpt_dir,
load_optimizer_states=True,
load_lr_scheduler_states=True,
)
logger.info("Successfully loaded critic checkpoint")

Expand Down
8 changes: 3 additions & 5 deletions skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,10 @@ def init_model(self, model_path, num_training_steps: int = None):
}
)

self.model, self.optimizer, self.scheduler = strategy.prepare(
self.model, self.optimizer, _ = strategy.prepare(
(wrapped_model, None, None),
)
assert (
self.optimizer is not None and self.scheduler is not None
), "FSDP preparation should create optimizer and scheduler"
assert self.optimizer is not None, "FSDP preparation should create optimizer"

# Initialize weight extractor
# TODO(haochen): Now module grouping (in order to support FlashRL) is only enabled for the CUDA IPC
Expand Down Expand Up @@ -306,7 +304,7 @@ def init_model(self, model_path, num_training_steps: int = None):
)

# prepare models/optimizers...
self.model, self.optimizer, self.scheduler = strategy.prepare(
self.model, self.optimizer, _ = strategy.prepare(
(critic, None, None),
)
assert self.optimizer is not None
Expand Down
19 changes: 4 additions & 15 deletions skyrl-train/skyrl_train/workers/megatron/megatron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@
from megatron.bridge.peft.canonical_lora import CanonicalLoRA
import megatron.core.parallel_state as mpu
from megatron.core.optimizer import DistributedOptimizer, ChainedOptimizer
from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler

from skyrl_train.distributed.megatron.optimizer import (
init_megatron_optim_config,
get_megatron_optimizer,
get_megatron_optimizer_param_scheduler,
)
Comment on lines 20 to 23
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

With the removal of get_megatron_optimizer_param_scheduler from the imports, the function itself in skyrl_train/skyrl_train/distributed/megatron/optimizer.py appears to be dead code. Consider removing it in a follow-up to improve code maintainability.

from skyrl_train.distributed.dispatch import MeshRank
from skyrl_train.distributed.megatron.megatron_strategy import MegatronStrategy
Expand Down Expand Up @@ -372,7 +370,6 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)
self.model: MegatronModelWrapper = None
self.actor_module: List[nn.Module] = None
self.scheduler: OptimizerParamScheduler = None
self.optimizer: DistributedOptimizer = None
self.profiler: Profiler = None
self._is_lora = self.cfg.trainer.policy.model.lora.rank > 0
Expand Down Expand Up @@ -432,7 +429,8 @@ def _broadcast_no_grad(*args, **kwargs):

def init_model(self, model_path, num_training_steps: int = 1e9):
"""
Initialize the model, optimizer, and scheduler for the policy worker.
Initialize the model and optimizer for the policy worker.
Tinker controls the LR so no need to init that.
"""
# initialize the bridge and provider objects
self.init_configs(
Expand Down Expand Up @@ -472,13 +470,6 @@ def init_model(self, model_path, num_training_steps: int = 1e9):
)
self.optimizer = get_megatron_optimizer(self.actor_module, optim_config)

# create scheduler
self.scheduler = get_megatron_optimizer_param_scheduler(
optimizer=self.optimizer,
config=self.cfg.trainer.policy.optimizer_config,
num_training_steps=num_training_steps,
)

# create worker model
self.model = MegatronModelWrapper(
config=self.cfg,
Expand Down Expand Up @@ -608,7 +599,7 @@ def optim_step(self) -> Optional[float]:
Returns:
The gradient norm (before scaling, after clipping), or None if unavailable.
"""
grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler, name="actor")
grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, None, name="actor")

# Reset counter for next accumulation cycle
self._micro_batches_accumulated = 0
Expand All @@ -635,9 +626,7 @@ def set_lr(self, learning_rate: float) -> None:
distributed optimizer). Updates all param_groups across all
underlying optimizers.

Note: This bypasses the scheduler. The next scheduler.step() call
will override this value unless the scheduler is configured for
constant LR.
Tinker uses this function to set the LR
"""
if isinstance(self.optimizer, ChainedOptimizer):
# ChainedOptimizer wraps multiple optimizers (e.g., for different param groups)
Expand Down
31 changes: 13 additions & 18 deletions skyrl-train/skyrl_train/workers/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
placement_group_table,
)
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from transformers import PreTrainedModel

from skyrl_train.dataset.replay_buffer import Experience
Expand Down Expand Up @@ -633,7 +632,6 @@ class PolicyWorkerBase(Worker):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.model: nn.Module = None
self.scheduler: LRScheduler = None
self.optimizer: Optimizer = None
self.strategy: DistributedStrategy = None
self.record_memory: bool = False
Expand Down Expand Up @@ -824,7 +822,7 @@ def _forward_backward_micro(
status = {
"loss": loss.item(),
"response_length": num_actions,
"lr": self.scheduler.get_last_lr()[0],
"lr": self.get_lr(),
"loss_fn_outputs": loss_fn_outputs,
}
else:
Expand Down Expand Up @@ -863,7 +861,7 @@ def _forward_backward_micro(
"ppo_clip_ratio": clip_ratio,
"policy_entropy": entropy.item(),
"response_length": num_actions,
"policy_lr": self.scheduler.get_last_lr()[0],
"policy_lr": self.get_lr(),
}
if self.cfg.trainer.algorithm.use_kl_loss:
status["policy_kl"] = kl_loss.item()
Expand Down Expand Up @@ -895,7 +893,7 @@ def optim_step(self) -> float:
param.grad.mul_(scale)

# Perform optimizer step (includes gradient clipping)
grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler, name="actor")
grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, None, name="actor")

# Reset counter for next accumulation cycle
self._micro_batches_accumulated = 0
Expand Down Expand Up @@ -936,22 +934,20 @@ def save_checkpoint(self, ckpt_dir: Path, tokenizer=None):
self.strategy.save_checkpoint(
model=self.model,
optimizer=self.optimizer,
scheduler=self.scheduler,
scheduler=None,
ckpt_dir=ckpt_dir,
node_local_rank=self.get_node_local_rank(),
tokenizer=tokenizer,
)

def load_checkpoint(
self, ckpt_dir: Path, load_optimizer_states: bool = True, load_lr_scheduler_states: bool = True
):
def load_checkpoint(self, ckpt_dir: Path, load_optimizer_states: bool = True):
_, states = self.strategy.load_checkpoint(
model=self.model,
optimizer=self.optimizer if load_optimizer_states else None,
scheduler=self.scheduler if load_lr_scheduler_states else None,
scheduler=None,
ckpt_dir=ckpt_dir,
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states,
load_lr_scheduler_states=False,
)
return states

Expand Down Expand Up @@ -994,7 +990,6 @@ class CriticWorkerBase(Worker):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.model: nn.Module = None
self.scheduler: LRScheduler = None
self.optimizer: Optimizer = None
self.strategy: DistributedStrategy = None
self.record_memory: bool = False
Expand Down Expand Up @@ -1092,7 +1087,7 @@ def _forward_backward_micro(self, experience: Experience) -> Dict[str, float]:
"critic_loss": loss.item(),
"values_mean": masked_mean(values, loss_mask).item(),
"values_clipfrac": clipfrac,
"critic_lr": self.scheduler.get_last_lr()[0],
"critic_lr": self.get_lr(),
}

# All-reduce metrics across DP workers
Expand All @@ -1115,7 +1110,7 @@ def optim_step(self) -> float:
param.grad.mul_(scale)

# Perform optimizer step (includes gradient clipping)
grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler, name="critic")
grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, None, name="critic")

# Reset counter for next accumulation cycle
self._micro_batches_accumulated = 0
Expand Down Expand Up @@ -1189,20 +1184,20 @@ def save_checkpoint(self, ckpt_dir: str, tokenizer=None):
self.strategy.save_checkpoint(
model=self.model,
optimizer=self.optimizer,
scheduler=self.scheduler,
scheduler=None,
ckpt_dir=ckpt_dir,
node_local_rank=self.get_node_local_rank(),
tokenizer=tokenizer,
)

def load_checkpoint(self, ckpt_dir=None, load_optimizer_states=True, load_lr_scheduler_states=True):
def load_checkpoint(self, ckpt_dir=None, load_optimizer_states=True):
_, states = self.strategy.load_checkpoint(
model=self.model,
optimizer=self.optimizer if load_optimizer_states else None,
scheduler=self.scheduler if load_lr_scheduler_states else None,
scheduler=None,
ckpt_dir=ckpt_dir,
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states,
load_lr_scheduler_states=False,
)
return states

Expand Down
Loading