Skip to content

feat(profiler): drive torch.profiler around the training loop#1750

Open
dyurk-lila wants to merge 1 commit into
NovaSky-AI:mainfrom
dyurk-lila:upstream-profiler-driving
Open

feat(profiler): drive torch.profiler around the training loop#1750
dyurk-lila wants to merge 1 commit into
NovaSky-AI:mainfrom
dyurk-lila:upstream-profiler-driving

Conversation

@dyurk-lila
Copy link
Copy Markdown

What

SkyRL constructs a Profiler object on the Megatron policy worker but never drives it.start()/.step()/.stop() are called nowhere in the repo, so torch_profiler_config was effectively dead code. This PR wires torch.profiler up end to end for both Megatron and FSDP, both RL and SFT, with the full torch.profiler surface exposed as config (no hardcoded active=1 / single step).

After this, profiling reduces to setting a couple of config flags and reading the SkyRL-written traces — no worker subclass, no trainer overrides.

How it's driven

  • start_profile / profile_step / stop_profile RPCs on the shared Worker base (worker.py) → on the Ray actor method table of both PolicyWorkers automatically (same pattern as optim_step, set_lr, save_memory_snapshot). No subclass, no ray.remote re-wrap. Dispatched via pass_through thin wrappers in WorkerDispatch.
  • Trainers bracket the loop: start_profile before, one profile_step per global step, stop_profile after (in a finally, so an open trace window is never leaked) — all gated on torch_profiler_config.enable so non-profiling runs dispatch zero extra RPCs.
    • SFT: sft_trainer.py train() / dummy-train loop + one profile_step in train_step.
    • RL: trainer.py train() loop (and the async / fully-async trainers) + one profile_step per global step, so a torch active window spans the whole step (not a single minibatch).

Config — full torch.profiler surface, sane defaults

TorchProfilerConfig (hoisted to PolicyConfig, also wired through the SFT config bridge):

  • Schedule: skip_first, wait, warmup, active, repeattorch.profiler.schedule. This is the "profile N steps, at an interval, repeating M times" knob (repeat=0 = whole run).
  • Capture: activities, record_shapes, profile_memory, with_stack, with_flops, with_modules, export_type.
  • Defaults reproduce the prior effective behavior (CPU+CUDA, shapes+stack). enable=false (default) = unchanged from before.
  • Traces written by tensorboard_trace_handler as HTA/Kineto-friendly *.pt.trace.json (one per active window per rank). save_path defaults to {ckpt_path}/profiler_traces.
  • TorchProfilerConfig.validate() rejects unusable settings up front (called from both the RL and SFT entrypoint validators) so an enabled run fails fast instead of silently degrading.

Scope

Profiles only the policy model's training step (forward/backward + optimizer). In an RL run it does not profile the critic or reference models, and it does not profile generation/inference — only policy training compute on the configured ranks.

Per-kernel summary data path

For downstream attribution tooling that wants per-kernel self-time without re-parsing the trace:

  • on_trace_ready also stashes a pickle-safe per-kernel self-device-time summary for the just-closed window (exact — no cross-stream overlap double-counting).
  • Exposed via Profiler.get_kernel_summary()Worker.dump_profiler_summary() RPC → WorkerDispatch.dump_profiler_summary(model) (returns a per-rank list). SkyRL's own trainers do not call this; the trace files remain the primary deliverable.

Safety

All profiler paths are exception-isolated: a profiler fault disables profiling for the rest of the run rather than crashing it.

Cleanup

Removes the now-redundant Megatron-only torch_profiler_config (it was dead; replaced by the backend-agnostic policy-level one).

Testing

  • New CPU unit tests (tests/backends/skyrl_train/utils/test_profiler.py): schedule-driven trace-file counts (single window, repeat, skip_first deferral), disabled/rank-not-selected no-ops, save_path resolution, activities threading, exception isolation, the kernel-summary path, and the Worker / WorkerDispatch / trainer RPC plumbing.
  • tests/train/test_config.py and tests/train/test_sft_config.py: TorchProfilerConfig.validate() rejects bad configs on both the RL and SFT paths, and torch_profiler_config bridges through build_skyrl_config_for_sft.
  • Locally: ruff/black clean on all touched files; the new profiler tests pass on CPU. GPU/e2e paths left to CI.

🤖 Generated with Claude Code

SkyRL constructed a Profiler object on the Megatron policy worker but never
drove it -- start()/step()/stop() were called nowhere, so torch_profiler_config
was dead code. This wires torch.profiler end to end for both Megatron and FSDP,
both RL and SFT, with the full torch.profiler surface exposed as config.

- start_profile / profile_step / stop_profile RPCs on the shared Worker base,
  dispatched via pass_through thin wrappers in WorkerDispatch.
- Trainers bracket the loop: start before, one profile_step per global step,
  stop after -- all gated on torch_profiler_config.enable so non-profiling runs
  dispatch zero extra RPCs.
- TorchProfilerConfig hoisted to PolicyConfig (also wired through the SFT config
  bridge), exposing schedule (skip_first/wait/warmup/active/repeat) and capture
  (activities/record_shapes/profile_memory/with_stack/with_flops/with_modules/
  export_type) knobs. Defaults reproduce prior effective behavior; enable=false
  by default. Traces written by tensorboard_trace_handler as HTA/Kineto-friendly
  *.pt.trace.json under {ckpt_path}/profiler_traces by default.
- Profiles only the policy model's training step (fwd/bwd + optimizer), not the
  critic/ref models and not generation/inference.
- All profiler paths are exception-isolated: a fault disables profiling for the
  rest of the run rather than crashing it.
- Removes the redundant Megatron-only torch_profiler_config.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a unified, backend-agnostic TorchProfilerConfig and Profiler class to replace the previous Megatron-specific profiler configuration, enabling profiling across both FSDP and Megatron backends. The profiler is now driven by the trainer via start, step, and stop RPCs dispatched to workers, supported by comprehensive validation and unit tests. The review feedback recommends several robustness enhancements: tracking the profiler's running state to prevent PyTorch RuntimeErrors during stop operations, using getattr with fallbacks for backward compatibility with older PyTorch versions, adding null-checks to configuration validation to prevent TypeErrors from YAML inputs, and adding a unit test to verify safe stop behavior.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +58 to 61
self._last_pairs: list = []
self._window_count: int = 0
if not config.enable:
return
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Initialize a _running state variable to track whether the profiler has been successfully started. This ensures we can safely guard calls to stop() and prevent PyTorch from raising a RuntimeError if stop() is called on a non-running profiler.

Suggested change
self._last_pairs: list = []
self._window_count: int = 0
if not config.enable:
return
self._last_pairs: list = []
self._window_count: int = 0
self._running = False
if not config.enable:
return

try:
# ``self_device_time_total`` is torch 2.11's field (the older
# ``self_cuda_time_total`` was removed). Microseconds, self time.
self._last_pairs = [(str(e.key), float(e.self_device_time_total)) for e in prof.key_averages()]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using e.self_device_time_total directly can raise an AttributeError on older PyTorch versions (where it was named self_cuda_time_total or self_cpu_time_total). Using getattr with fallbacks makes the kernel summary extraction robust across different PyTorch versions and profiling activities.

            self._last_pairs = [
                (
                    str(e.key),
                    float(
                        getattr(
                            e,
                            "self_device_time_total",
                            getattr(e, "self_cuda_time_total", getattr(e, "self_cpu_time_total", 0.0)),
                        )
                    ),
                )
                for e in prof.key_averages()
            ]

Comment on lines +155 to +161
def start(self) -> None:
if self.check():
logger.info(f"[Profiler] stopped for rank {self.rank}")
self.prof.stop()

def save(self):
if self.prof is not None and not self.saved:
if not os.path.exists(self.save_path):
os.makedirs(self.save_path)
save_file_name = f"/prof_rank_{self.rank}.json"
logger.info(f"[Profiler] Saving trace to {self.save_path + save_file_name}")
self.prof.export_chrome_trace(self.save_path + save_file_name)
self.enable = False
self.saved = True
try:
logger.info(f"[Profiler] started for rank {self.rank}")
self.prof.start()
except Exception as e:
self._disable("start", e)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Set self._running = True upon successful start of the profiler.

Suggested change
def start(self) -> None:
if self.check():
logger.info(f"[Profiler] stopped for rank {self.rank}")
self.prof.stop()
def save(self):
if self.prof is not None and not self.saved:
if not os.path.exists(self.save_path):
os.makedirs(self.save_path)
save_file_name = f"/prof_rank_{self.rank}.json"
logger.info(f"[Profiler] Saving trace to {self.save_path + save_file_name}")
self.prof.export_chrome_trace(self.save_path + save_file_name)
self.enable = False
self.saved = True
try:
logger.info(f"[Profiler] started for rank {self.rank}")
self.prof.start()
except Exception as e:
self._disable("start", e)
def start(self) -> None:
if self.check():
try:
logger.info(f"[Profiler] started for rank {self.rank}")
self.prof.start()
self._running = True
except Exception as e:
self._disable("start", e)

Comment on lines +170 to +176
def stop(self) -> None:
if self.check():
logger.info(f"[Profiler] Trace stopped for rank {self.rank}")
self.enable = False
try:
logger.info(f"[Profiler] stopped for rank {self.rank}")
self.prof.stop()
except Exception as e:
self._disable("stop", e)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Only call self.prof.stop() if the profiler is currently running (self._running is True), and reset the running state in a finally block. This prevents PyTorch from raising a RuntimeError if stop() is called on a non-running profiler (e.g., if start() failed or was never called).

Suggested change
def stop(self) -> None:
if self.check():
logger.info(f"[Profiler] Trace stopped for rank {self.rank}")
self.enable = False
try:
logger.info(f"[Profiler] stopped for rank {self.rank}")
self.prof.stop()
except Exception as e:
self._disable("stop", e)
def stop(self) -> None:
if self.check() and getattr(self, "_running", False):
try:
logger.info(f"[Profiler] stopped for rank {self.rank}")
self.prof.stop()
except Exception as e:
self._disable("stop", e)
finally:
self._running = False

Comment on lines +191 to +196
bad_activities = [a for a in self.activities if a.lower() not in TORCH_PROFILER_ACTIVITIES]
if bad_activities:
raise ValueError(
f"invalid `torch_profiler_config.activities` entries {bad_activities}. "
f"Each must be one of {list(TORCH_PROFILER_ACTIVITIES)}."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Add a check to ensure self.activities is not None before iterating over it. If a user sets activities: null in their YAML configuration, this prevents a TypeError during validation.

        if self.activities is None:
            raise ValueError("`torch_profiler_config.activities` cannot be None.")
        bad_activities = [a for a in self.activities if a.lower() not in TORCH_PROFILER_ACTIVITIES]
        if bad_activities:
            raise ValueError(
                f"invalid `torch_profiler_config.activities` entries {bad_activities}. "
                f"Each must be one of {list(TORCH_PROFILER_ACTIVITIES)}."
            )

Comment on lines +207 to +212
for name in ("skip_first", "wait", "warmup", "repeat"):
value = getattr(self, name)
if value < 0:
raise ValueError(f"`torch_profiler_config.{name}` must be >= 0, got {value}.")
if self.active < 1:
raise ValueError(f"`torch_profiler_config.active` must be >= 1, got {self.active}.")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Add checks to ensure that the numeric schedule fields are not None before performing comparison operations. This prevents a TypeError if any of these fields are set to null in the YAML configuration.

Suggested change
for name in ("skip_first", "wait", "warmup", "repeat"):
value = getattr(self, name)
if value < 0:
raise ValueError(f"`torch_profiler_config.{name}` must be >= 0, got {value}.")
if self.active < 1:
raise ValueError(f"`torch_profiler_config.active` must be >= 1, got {self.active}.")
for name in ("skip_first", "wait", "warmup", "repeat"):
value = getattr(self, name)
if value is None:
raise ValueError(f"`torch_profiler_config.{name}` cannot be None.")
if value < 0:
raise ValueError(f"`torch_profiler_config.{name}` must be >= 0, got {value}.")
if self.active is None:
raise ValueError("`torch_profiler_config.active` cannot be None.")
if self.active < 1:
raise ValueError(f"`torch_profiler_config.active` must be >= 1, got {self.active}.")

Comment on lines +389 to +394
def test_dispatch_to_policy_when_enabled(self):
trainer, calls = self._trainer(enable=True)
trainer._profiler_start()
trainer._profiler_step()
trainer._profiler_stop()
assert calls == [("start", "policy"), ("step", "policy"), ("stop", "policy")]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Add a unit test to verify that calling stop() on a profiler that was never started is a safe no-op and does not raise any exceptions.

Suggested change
def test_dispatch_to_policy_when_enabled(self):
trainer, calls = self._trainer(enable=True)
trainer._profiler_start()
trainer._profiler_step()
trainer._profiler_stop()
assert calls == [("start", "policy"), ("step", "policy"), ("stop", "policy")]
def test_dispatch_to_policy_when_enabled(self):
trainer, calls = self._trainer(enable=True)
trainer._profiler_start()
trainer._profiler_step()
trainer._profiler_stop()
assert calls == [("start", "policy"), ("step", "policy"), ("stop", "policy")]
def test_profiler_stop_without_start_is_noop(tmp_path):
prof = Profiler(_ProfCfg(), default_save_path=str(tmp_path))
# Calling stop() before start() should not raise any error and should be a safe no-op.
prof.stop()
assert getattr(prof, "_running", False) is False

@SumanthRH SumanthRH self-assigned this Jun 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants