Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/api-pages.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@
- skyrl.train.config.config.MegatronConfig
- skyrl.train.config.config.MegatronDDPConfig
- skyrl.train.config.config.MegatronLoraConfig
- skyrl.train.config.config.MegatronTorchProfilerConfig
- skyrl.train.config.config.TorchProfilerConfig
- heading: Placement
description: ""
objects:
Expand Down
33 changes: 33 additions & 0 deletions docs/content/docs/configuration/config.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,23 @@ policy:
use_torch_compile: false # Enable torch compile for the entropy calculation
record_memory: false # Dump memory snapshot for debugging

torch_profiler_config: # torch.profiler-based training-loop profiler (see below)
enable: false
ranks: [0]
save_path: null # defaults to {ckpt_path}/profiler_traces
skip_first: 10 # torch.profiler.schedule
wait: 0
warmup: 1
active: 1
repeat: 1 # 0 = profile for the whole run
activities: ["cpu", "cuda"]
record_shapes: true
profile_memory: false
with_stack: true
with_flops: false
with_modules: false
export_type: "chrome_trace" # chrome_trace | stacks

model_config_kwargs: {} # pass through kwargs to the HuggingFace model config for FSDP training backends (i.e. for overriding vocab size, etc) - for megatron, use policy.megatron_config.transformer_config_kwargs instead

```
Expand All @@ -281,6 +298,22 @@ policy:
- `policy.use_torch_compile`: Whether to enable torch compile for entropy calculation
- `policy.record_memory`: Whether to record memory usage. If `True`, this will use PyTorch's [memory snapshotting utility](https://docs.pytorch.org/docs/stable/torch_cuda_memory.html) to record memory usage and dump memory snapshots after each policy model training step.

### Torch Profiler Configuration

`policy.torch_profiler_config` enables a [`torch.profiler`](https://docs.pytorch.org/docs/stable/profiler.html)-based profiler that the trainer drives around the training loop (both FSDP and Megatron backends, and both RL and SFT). When enabled, it writes one Kineto/[HTA](https://github.com/facebookresearch/HolisticTraceAnalysis)-friendly `*.pt.trace.json` per active window per profiled rank into `save_path`.

**Scope:** the profiler captures **only the policy model's training step** (forward/backward + optimizer). In an RL run it does **not** profile the critic or reference models, and it does **not** profile generation/inference — only policy training compute on the configured `ranks`.

Which steps are recorded is controlled entirely by [`torch.profiler.schedule`](https://docs.pytorch.org/docs/stable/profiler.html#torch.profiler.schedule): the profiler skips the first `skip_first` steps, then repeats a cycle of `wait` (idle) + `warmup` (tracing discarded) + `active` (tracing recorded) steps `repeat` times (`repeat: 0` profiles every cycle for the whole run). This is how you profile multiple steps, at an interval, repeating.

- `policy.torch_profiler_config.enable`: Master switch. When `false` (default), no profiler RPCs are dispatched and there is zero overhead.
- `policy.torch_profiler_config.ranks`: List of global ranks to profile (e.g. `[0]`). Add a mid-pipeline-stage rank to diagnose pipeline bubbles.
- `policy.torch_profiler_config.save_path`: Output directory for traces. Defaults to `{ckpt_path}/profiler_traces`.
- `policy.torch_profiler_config.{skip_first,wait,warmup,active,repeat}`: Passed directly to `torch.profiler.schedule`.
- `policy.torch_profiler_config.activities`: Subset of `["cpu", "cuda"]` to record.
- `policy.torch_profiler_config.{record_shapes,profile_memory,with_stack,with_flops,with_modules}`: Passed directly to `torch.profiler.profile`. `with_stack` and `record_shapes` add overhead; `profile_memory` is heavier still and off by default.
- `policy.torch_profiler_config.export_type`: `chrome_trace` writes `*.pt.trace.json` (Kineto/HTA-friendly); `stacks` writes flamegraph-style self-CUDA-time stacks (requires `with_stack: true`).

### LoRA Configuration

LoRA (Low-Rank Adaptation) enables parameter-efficient fine-tuning by training only a small number of additional low-rank matrices instead of the full model weights:
Expand Down
127 changes: 70 additions & 57 deletions examples/train/async/async_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,63 +47,76 @@ async def train(self):
start_epoch = self.global_step // len(self.train_dataloader)
# Start from step 1
self.global_step += 1
for epoch in range(start_epoch, self.cfg.trainer.epochs):
# while this is just off by one, you can image a more general queue based approach
# where the generation buffer holds a list of objects that the trainer can read from
# bit by bit.
generation_buffer = asyncio.Queue(maxsize=1)
self.sync_finished = asyncio.Event()
self.generation_ack = asyncio.Event()

# start generator task
generator_task = asyncio.create_task(self._run_generate_loop(generation_buffer))

for idx in range(len(self.train_dataloader)):
with Timer("step", self.all_timings):
status = await self._run_training(generation_buffer)

# request the generation loop that we should sync sometime soon.
if idx != len(self.train_dataloader) - 1:
await self.generation_ack.wait()

# sync weights
async with Timer("sync_weights", self.all_timings):
await self.dispatch.save_weights_for_sampler()

self.sync_finished.set()
self.generation_ack.clear()

# 5. set logs
logger.info(status)
# log epoch info
self.all_metrics.update({"trainer/epoch": epoch, "trainer/global_step": self.global_step})
self.tracker.log(self.all_metrics, step=self.global_step)
self.all_metrics = {}
pbar.update(1)

if self.cfg.trainer.eval_interval > 0 and (
self.global_step % self.cfg.trainer.eval_interval == 0
or self.global_step == self.total_training_steps
):
with Timer("eval", self.all_timings):
eval_metrics = await self.eval()
self.all_metrics.update(eval_metrics)
if self.cfg.trainer.ckpt_interval > 0 and self.global_step % self.cfg.trainer.ckpt_interval == 0:
with Timer("save_checkpoints", self.all_timings):
self.save_checkpoints()
if self.cfg.trainer.hf_save_interval > 0 and self.global_step % self.cfg.trainer.hf_save_interval == 0:
with Timer("save_hf_model", self.all_timings):
self.save_models()
self.tracker.log({"timing/" + k: v for k, v in self.all_timings.items()}, step=self.global_step)
self.all_timings = {}
self.global_step += 1

if self.cfg.trainer.update_ref_every_epoch and self.ref_model is not None:
with Timer("update_ref_with_policy", self.all_timings):
await asyncio.to_thread(self.update_ref_with_policy)

# cancel generation task for this epoch
generator_task.cancel()
self._profiler_start()
try:
for epoch in range(start_epoch, self.cfg.trainer.epochs):
# while this is just off by one, you can image a more general queue based approach
# where the generation buffer holds a list of objects that the trainer can read from
# bit by bit.
generation_buffer = asyncio.Queue(maxsize=1)
self.sync_finished = asyncio.Event()
self.generation_ack = asyncio.Event()

# start generator task
generator_task = asyncio.create_task(self._run_generate_loop(generation_buffer))

for idx in range(len(self.train_dataloader)):
with Timer("step", self.all_timings):
status = await self._run_training(generation_buffer)

# request the generation loop that we should sync sometime soon.
if idx != len(self.train_dataloader) - 1:
await self.generation_ack.wait()

# sync weights
async with Timer("sync_weights", self.all_timings):
await self.dispatch.save_weights_for_sampler()

self.sync_finished.set()
self.generation_ack.clear()

# Advance the torch profiler schedule once per global step
# (no-op unless profiling is enabled).
self._profiler_step()

# 5. set logs
logger.info(status)
# log epoch info
self.all_metrics.update({"trainer/epoch": epoch, "trainer/global_step": self.global_step})
self.tracker.log(self.all_metrics, step=self.global_step)
self.all_metrics = {}
pbar.update(1)

if self.cfg.trainer.eval_interval > 0 and (
self.global_step % self.cfg.trainer.eval_interval == 0
or self.global_step == self.total_training_steps
):
with Timer("eval", self.all_timings):
eval_metrics = await self.eval()
self.all_metrics.update(eval_metrics)
if self.cfg.trainer.ckpt_interval > 0 and self.global_step % self.cfg.trainer.ckpt_interval == 0:
with Timer("save_checkpoints", self.all_timings):
self.save_checkpoints()
if (
self.cfg.trainer.hf_save_interval > 0
and self.global_step % self.cfg.trainer.hf_save_interval == 0
):
with Timer("save_hf_model", self.all_timings):
self.save_models()
self.tracker.log({"timing/" + k: v for k, v in self.all_timings.items()}, step=self.global_step)
self.all_timings = {}
self.global_step += 1

if self.cfg.trainer.update_ref_every_epoch and self.ref_model is not None:
with Timer("update_ref_with_policy", self.all_timings):
await asyncio.to_thread(self.update_ref_with_policy)

# cancel generation task for this epoch
generator_task.cancel()
finally:
# Always stop/flush the profiler when the loop exits (incl. on error)
# so the open kineto trace window isn't leaked. No-op when disabled.
self._profiler_stop()

pbar.close()
if self.cfg.trainer.ckpt_interval > 0:
Expand Down
10 changes: 6 additions & 4 deletions examples/train/megatron/run_megatron.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ MEGATRON_TP=2
MEGATRON_PP=2
MEGATRON_CP=1

# torch profiler config
# torch profiler config. Profiles ONLY the policy model's training step
# (forward/backward + optim) -- not the critic or ref models, and not generation/
# inference. See trainer.policy.torch_profiler_config for schedule/capture knobs.
ENABLE_TORCH_PROFILER=false
RANKS_TO_PROFILE="[0]"
SAVE_PATH="$HOME/megatron_prof/tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_${MODEL_NAME}"
Expand All @@ -33,9 +35,9 @@ uv run --isolated --extra megatron -m skyrl.train.entrypoints.main_base \
trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \
generator.inference_engine.num_engines=$NUM_GPUS \
generator.inference_engine.tensor_parallel_size=1 \
trainer.policy.megatron_config.torch_profiler_config.enable=$ENABLE_TORCH_PROFILER \
trainer.policy.megatron_config.torch_profiler_config.ranks=$RANKS_TO_PROFILE \
trainer.policy.megatron_config.torch_profiler_config.save_path=$SAVE_PATH \
trainer.policy.torch_profiler_config.enable=$ENABLE_TORCH_PROFILER \
trainer.policy.torch_profiler_config.ranks=$RANKS_TO_PROFILE \
trainer.policy.torch_profiler_config.save_path=$SAVE_PATH \
trainer.policy.megatron_config.tensor_model_parallel_size=$MEGATRON_TP \
trainer.policy.megatron_config.pipeline_model_parallel_size=$MEGATRON_PP \
trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \
Expand Down
6 changes: 3 additions & 3 deletions examples/train/megatron/run_megatron_nemotron_mini_4b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ uv run --isolated --extra megatron -m skyrl.train.entrypoints.main_base \
trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \
generator.inference_engine.num_engines=$NUM_INFERENCE_ENGINES \
generator.inference_engine.tensor_parallel_size=$INFERENCE_TP \
trainer.policy.megatron_config.torch_profiler_config.enable=$ENABLE_TORCH_PROFILER \
trainer.policy.megatron_config.torch_profiler_config.ranks=$RANKS_TO_PROFILE \
trainer.policy.megatron_config.torch_profiler_config.save_path=$SAVE_PATH \
trainer.policy.torch_profiler_config.enable=$ENABLE_TORCH_PROFILER \
trainer.policy.torch_profiler_config.ranks=$RANKS_TO_PROFILE \
trainer.policy.torch_profiler_config.save_path=$SAVE_PATH \
trainer.policy.megatron_config.tensor_model_parallel_size=$MEGATRON_TP \
trainer.policy.megatron_config.pipeline_model_parallel_size=$MEGATRON_PP \
trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \
Expand Down
Loading