feat(profiler): drive torch.profiler around the training loop#1750
feat(profiler): drive torch.profiler around the training loop#1750dyurk-lila wants to merge 1 commit into
Conversation
SkyRL constructed a Profiler object on the Megatron policy worker but never
drove it -- start()/step()/stop() were called nowhere, so torch_profiler_config
was dead code. This wires torch.profiler end to end for both Megatron and FSDP,
both RL and SFT, with the full torch.profiler surface exposed as config.
- start_profile / profile_step / stop_profile RPCs on the shared Worker base,
dispatched via pass_through thin wrappers in WorkerDispatch.
- Trainers bracket the loop: start before, one profile_step per global step,
stop after -- all gated on torch_profiler_config.enable so non-profiling runs
dispatch zero extra RPCs.
- TorchProfilerConfig hoisted to PolicyConfig (also wired through the SFT config
bridge), exposing schedule (skip_first/wait/warmup/active/repeat) and capture
(activities/record_shapes/profile_memory/with_stack/with_flops/with_modules/
export_type) knobs. Defaults reproduce prior effective behavior; enable=false
by default. Traces written by tensorboard_trace_handler as HTA/Kineto-friendly
*.pt.trace.json under {ckpt_path}/profiler_traces by default.
- Profiles only the policy model's training step (fwd/bwd + optimizer), not the
critic/ref models and not generation/inference.
- All profiler paths are exception-isolated: a fault disables profiling for the
rest of the run rather than crashing it.
- Removes the redundant Megatron-only torch_profiler_config.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request introduces a unified, backend-agnostic TorchProfilerConfig and Profiler class to replace the previous Megatron-specific profiler configuration, enabling profiling across both FSDP and Megatron backends. The profiler is now driven by the trainer via start, step, and stop RPCs dispatched to workers, supported by comprehensive validation and unit tests. The review feedback recommends several robustness enhancements: tracking the profiler's running state to prevent PyTorch RuntimeErrors during stop operations, using getattr with fallbacks for backward compatibility with older PyTorch versions, adding null-checks to configuration validation to prevent TypeErrors from YAML inputs, and adding a unit test to verify safe stop behavior.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| self._last_pairs: list = [] | ||
| self._window_count: int = 0 | ||
| if not config.enable: | ||
| return |
There was a problem hiding this comment.
Initialize a _running state variable to track whether the profiler has been successfully started. This ensures we can safely guard calls to stop() and prevent PyTorch from raising a RuntimeError if stop() is called on a non-running profiler.
| self._last_pairs: list = [] | |
| self._window_count: int = 0 | |
| if not config.enable: | |
| return | |
| self._last_pairs: list = [] | |
| self._window_count: int = 0 | |
| self._running = False | |
| if not config.enable: | |
| return |
| try: | ||
| # ``self_device_time_total`` is torch 2.11's field (the older | ||
| # ``self_cuda_time_total`` was removed). Microseconds, self time. | ||
| self._last_pairs = [(str(e.key), float(e.self_device_time_total)) for e in prof.key_averages()] |
There was a problem hiding this comment.
Using e.self_device_time_total directly can raise an AttributeError on older PyTorch versions (where it was named self_cuda_time_total or self_cpu_time_total). Using getattr with fallbacks makes the kernel summary extraction robust across different PyTorch versions and profiling activities.
self._last_pairs = [
(
str(e.key),
float(
getattr(
e,
"self_device_time_total",
getattr(e, "self_cuda_time_total", getattr(e, "self_cpu_time_total", 0.0)),
)
),
)
for e in prof.key_averages()
]| def start(self) -> None: | ||
| if self.check(): | ||
| logger.info(f"[Profiler] stopped for rank {self.rank}") | ||
| self.prof.stop() | ||
|
|
||
| def save(self): | ||
| if self.prof is not None and not self.saved: | ||
| if not os.path.exists(self.save_path): | ||
| os.makedirs(self.save_path) | ||
| save_file_name = f"/prof_rank_{self.rank}.json" | ||
| logger.info(f"[Profiler] Saving trace to {self.save_path + save_file_name}") | ||
| self.prof.export_chrome_trace(self.save_path + save_file_name) | ||
| self.enable = False | ||
| self.saved = True | ||
| try: | ||
| logger.info(f"[Profiler] started for rank {self.rank}") | ||
| self.prof.start() | ||
| except Exception as e: | ||
| self._disable("start", e) |
There was a problem hiding this comment.
Set self._running = True upon successful start of the profiler.
| def start(self) -> None: | |
| if self.check(): | |
| logger.info(f"[Profiler] stopped for rank {self.rank}") | |
| self.prof.stop() | |
| def save(self): | |
| if self.prof is not None and not self.saved: | |
| if not os.path.exists(self.save_path): | |
| os.makedirs(self.save_path) | |
| save_file_name = f"/prof_rank_{self.rank}.json" | |
| logger.info(f"[Profiler] Saving trace to {self.save_path + save_file_name}") | |
| self.prof.export_chrome_trace(self.save_path + save_file_name) | |
| self.enable = False | |
| self.saved = True | |
| try: | |
| logger.info(f"[Profiler] started for rank {self.rank}") | |
| self.prof.start() | |
| except Exception as e: | |
| self._disable("start", e) | |
| def start(self) -> None: | |
| if self.check(): | |
| try: | |
| logger.info(f"[Profiler] started for rank {self.rank}") | |
| self.prof.start() | |
| self._running = True | |
| except Exception as e: | |
| self._disable("start", e) |
| def stop(self) -> None: | ||
| if self.check(): | ||
| logger.info(f"[Profiler] Trace stopped for rank {self.rank}") | ||
| self.enable = False | ||
| try: | ||
| logger.info(f"[Profiler] stopped for rank {self.rank}") | ||
| self.prof.stop() | ||
| except Exception as e: | ||
| self._disable("stop", e) |
There was a problem hiding this comment.
Only call self.prof.stop() if the profiler is currently running (self._running is True), and reset the running state in a finally block. This prevents PyTorch from raising a RuntimeError if stop() is called on a non-running profiler (e.g., if start() failed or was never called).
| def stop(self) -> None: | |
| if self.check(): | |
| logger.info(f"[Profiler] Trace stopped for rank {self.rank}") | |
| self.enable = False | |
| try: | |
| logger.info(f"[Profiler] stopped for rank {self.rank}") | |
| self.prof.stop() | |
| except Exception as e: | |
| self._disable("stop", e) | |
| def stop(self) -> None: | |
| if self.check() and getattr(self, "_running", False): | |
| try: | |
| logger.info(f"[Profiler] stopped for rank {self.rank}") | |
| self.prof.stop() | |
| except Exception as e: | |
| self._disable("stop", e) | |
| finally: | |
| self._running = False |
| bad_activities = [a for a in self.activities if a.lower() not in TORCH_PROFILER_ACTIVITIES] | ||
| if bad_activities: | ||
| raise ValueError( | ||
| f"invalid `torch_profiler_config.activities` entries {bad_activities}. " | ||
| f"Each must be one of {list(TORCH_PROFILER_ACTIVITIES)}." | ||
| ) |
There was a problem hiding this comment.
Add a check to ensure self.activities is not None before iterating over it. If a user sets activities: null in their YAML configuration, this prevents a TypeError during validation.
if self.activities is None:
raise ValueError("`torch_profiler_config.activities` cannot be None.")
bad_activities = [a for a in self.activities if a.lower() not in TORCH_PROFILER_ACTIVITIES]
if bad_activities:
raise ValueError(
f"invalid `torch_profiler_config.activities` entries {bad_activities}. "
f"Each must be one of {list(TORCH_PROFILER_ACTIVITIES)}."
)| for name in ("skip_first", "wait", "warmup", "repeat"): | ||
| value = getattr(self, name) | ||
| if value < 0: | ||
| raise ValueError(f"`torch_profiler_config.{name}` must be >= 0, got {value}.") | ||
| if self.active < 1: | ||
| raise ValueError(f"`torch_profiler_config.active` must be >= 1, got {self.active}.") |
There was a problem hiding this comment.
Add checks to ensure that the numeric schedule fields are not None before performing comparison operations. This prevents a TypeError if any of these fields are set to null in the YAML configuration.
| for name in ("skip_first", "wait", "warmup", "repeat"): | |
| value = getattr(self, name) | |
| if value < 0: | |
| raise ValueError(f"`torch_profiler_config.{name}` must be >= 0, got {value}.") | |
| if self.active < 1: | |
| raise ValueError(f"`torch_profiler_config.active` must be >= 1, got {self.active}.") | |
| for name in ("skip_first", "wait", "warmup", "repeat"): | |
| value = getattr(self, name) | |
| if value is None: | |
| raise ValueError(f"`torch_profiler_config.{name}` cannot be None.") | |
| if value < 0: | |
| raise ValueError(f"`torch_profiler_config.{name}` must be >= 0, got {value}.") | |
| if self.active is None: | |
| raise ValueError("`torch_profiler_config.active` cannot be None.") | |
| if self.active < 1: | |
| raise ValueError(f"`torch_profiler_config.active` must be >= 1, got {self.active}.") |
| def test_dispatch_to_policy_when_enabled(self): | ||
| trainer, calls = self._trainer(enable=True) | ||
| trainer._profiler_start() | ||
| trainer._profiler_step() | ||
| trainer._profiler_stop() | ||
| assert calls == [("start", "policy"), ("step", "policy"), ("stop", "policy")] |
There was a problem hiding this comment.
Add a unit test to verify that calling stop() on a profiler that was never started is a safe no-op and does not raise any exceptions.
| def test_dispatch_to_policy_when_enabled(self): | |
| trainer, calls = self._trainer(enable=True) | |
| trainer._profiler_start() | |
| trainer._profiler_step() | |
| trainer._profiler_stop() | |
| assert calls == [("start", "policy"), ("step", "policy"), ("stop", "policy")] | |
| def test_dispatch_to_policy_when_enabled(self): | |
| trainer, calls = self._trainer(enable=True) | |
| trainer._profiler_start() | |
| trainer._profiler_step() | |
| trainer._profiler_stop() | |
| assert calls == [("start", "policy"), ("step", "policy"), ("stop", "policy")] | |
| def test_profiler_stop_without_start_is_noop(tmp_path): | |
| prof = Profiler(_ProfCfg(), default_save_path=str(tmp_path)) | |
| # Calling stop() before start() should not raise any error and should be a safe no-op. | |
| prof.stop() | |
| assert getattr(prof, "_running", False) is False |
What
SkyRL constructs a
Profilerobject on the Megatron policy worker but never drives it —.start()/.step()/.stop()are called nowhere in the repo, sotorch_profiler_configwas effectively dead code. This PR wirestorch.profilerup end to end for both Megatron and FSDP, both RL and SFT, with the fulltorch.profilersurface exposed as config (no hardcodedactive=1/ single step).After this, profiling reduces to setting a couple of config flags and reading the SkyRL-written traces — no worker subclass, no trainer overrides.
How it's driven
start_profile/profile_step/stop_profileRPCs on the sharedWorkerbase (worker.py) → on the Ray actor method table of bothPolicyWorkers automatically (same pattern asoptim_step,set_lr,save_memory_snapshot). No subclass, noray.remotere-wrap. Dispatched viapass_throughthin wrappers inWorkerDispatch.start_profilebefore, oneprofile_stepper global step,stop_profileafter (in afinally, so an open trace window is never leaked) — all gated ontorch_profiler_config.enableso non-profiling runs dispatch zero extra RPCs.sft_trainer.pytrain()/ dummy-train loop + oneprofile_stepintrain_step.trainer.pytrain()loop (and the async / fully-async trainers) + oneprofile_stepper global step, so a torchactivewindow spans the whole step (not a single minibatch).Config — full
torch.profilersurface, sane defaultsTorchProfilerConfig(hoisted toPolicyConfig, also wired through the SFT config bridge):skip_first, wait, warmup, active, repeat→torch.profiler.schedule. This is the "profile N steps, at an interval, repeating M times" knob (repeat=0= whole run).activities,record_shapes,profile_memory,with_stack,with_flops,with_modules,export_type.enable=false(default) = unchanged from before.tensorboard_trace_handleras HTA/Kineto-friendly*.pt.trace.json(one per active window per rank).save_pathdefaults to{ckpt_path}/profiler_traces.TorchProfilerConfig.validate()rejects unusable settings up front (called from both the RL and SFT entrypoint validators) so an enabled run fails fast instead of silently degrading.Scope
Profiles only the policy model's training step (forward/backward + optimizer). In an RL run it does not profile the critic or reference models, and it does not profile generation/inference — only policy training compute on the configured
ranks.Per-kernel summary data path
For downstream attribution tooling that wants per-kernel self-time without re-parsing the trace:
on_trace_readyalso stashes a pickle-safe per-kernel self-device-time summary for the just-closed window (exact — no cross-stream overlap double-counting).Profiler.get_kernel_summary()→Worker.dump_profiler_summary()RPC →WorkerDispatch.dump_profiler_summary(model)(returns a per-rank list). SkyRL's own trainers do not call this; the trace files remain the primary deliverable.Safety
All profiler paths are exception-isolated: a profiler fault disables profiling for the rest of the run rather than crashing it.
Cleanup
Removes the now-redundant Megatron-only
torch_profiler_config(it was dead; replaced by the backend-agnostic policy-level one).Testing
tests/backends/skyrl_train/utils/test_profiler.py): schedule-driven trace-file counts (single window,repeat,skip_firstdeferral), disabled/rank-not-selected no-ops,save_pathresolution, activities threading, exception isolation, the kernel-summary path, and the Worker / WorkerDispatch / trainer RPC plumbing.tests/train/test_config.pyandtests/train/test_sft_config.py:TorchProfilerConfig.validate()rejects bad configs on both the RL and SFT paths, andtorch_profiler_configbridges throughbuild_skyrl_config_for_sft.🤖 Generated with Claude Code