Skip to content

Commit 7bc055e

Browse files
committed
fix test: disable nvtx by default
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 89238c9 commit 7bc055e

File tree

4 files changed

+64
-59
lines changed

4 files changed

+64
-59
lines changed

modelopt/torch/speculative/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,8 @@ class EagleConfig(ModeloptBaseConfig):
115115
default=True,
116116
description="Whether to use torch.compile on eagle forward/loss methods for faster training.",
117117
)
118+
119+
eagle_enable_nvtx: bool = ModeloptField(
120+
default=False,
121+
description="Whether to enable NVTX ranges for profiling eagle forward/loss methods.",
122+
)

modelopt/torch/speculative/eagle/eagle_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,4 @@ def modify(
4040
self.eagle_ttt_steps = config.eagle_ttt_steps
4141
self.eagle_mix_hidden_states = config.eagle_mix_hidden_states
4242
self.eagle_use_torch_compile = config.eagle_use_torch_compile
43+
self.eagle_enable_nvtx = config.eagle_enable_nvtx

modelopt/torch/speculative/eagle/utils.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@
3535

3636
"""Eagle model utils."""
3737

38-
from contextlib import nullcontext
39-
4038
import torch
4139

4240

@@ -72,17 +70,3 @@ def expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = No
7270
inverted_mask = 1.0 - expanded_mask
7371

7472
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
75-
76-
77-
def maybe_nvtx_range(*args, **kwargs):
78-
"""Helper function to create NVTX ranges if NVTX is available."""
79-
try:
80-
import torch.cuda.nvtx as nvtx
81-
82-
nvtx.range_push("nvtx init")
83-
nvtx.range_pop()
84-
85-
return nvtx.range(*args, **kwargs)
86-
except Exception:
87-
# If NVTX is not available, return a no-op context manager
88-
return nullcontext()

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 58 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
)
5858
from ..eagle.conversion import EagleDMRegistry
5959
from ..eagle.eagle_model import EagleModel
60-
from ..eagle.utils import expand_mask, make_causal_mask, maybe_nvtx_range
60+
from ..eagle.utils import expand_mask, make_causal_mask
6161
from ..medusa.conversion import MedusaDMRegistry
6262
from ..medusa.medusa_model import MedusaModel
6363
from ..utils import (
@@ -453,6 +453,23 @@ def _draft_model_config(self):
453453
"""Return the llm config for the draft model."""
454454
return self.eagle_config
455455

456+
def _enable_cp_ttt(self):
457+
if self.training and not self.eagle_mix_hidden_states:
458+
return enable_cp_ttt_patch()
459+
return contextlib.nullcontext()
460+
461+
def _nvtx_range(self, name):
462+
"""Optionally create an NVTX range for the given name when config.eagle_enable_nvtx is set."""
463+
if not self.eagle_enable_nvtx:
464+
return contextlib.nullcontext()
465+
try:
466+
import torch.cuda.nvtx as nvtx
467+
468+
return nvtx.range(name)
469+
except Exception as e:
470+
print(f"Failed to create NVTX range {name}: {e}")
471+
return contextlib.nullcontext()
472+
456473
def get_exporter(self) -> SpeculativeDecodingExporter:
457474
"""Get the exporter for the draft model."""
458475
exporter_cls = (
@@ -682,7 +699,6 @@ def _prepare_decoder_attention_mask(
682699

683700
return combined_attention_mask
684701

685-
@maybe_nvtx_range("prepare_eagle_inputs")
686702
def _prepare_eagle_inputs(
687703
self,
688704
input_ids,
@@ -785,7 +801,6 @@ def _compute_ttt_attention_mask(
785801
tensor_mask = tensor_mask.repeat(batch_size, 1, 1, 1)
786802
return tensor_mask
787803

788-
@maybe_nvtx_range("base_model_forward")
789804
def _base_model_forward(
790805
self,
791806
input_ids,
@@ -834,7 +849,6 @@ def _map_logits_to_draft_vocab(self, full_logits):
834849
)
835850
return full_logits[:, :, reverse_mapping]
836851

837-
@maybe_nvtx_range("eagle_forward")
838852
def _eagle_forward(
839853
self,
840854
eagle_input_hidden_states,
@@ -913,15 +927,16 @@ def forward(
913927
base_outputs.logits = self.lm_head(base_outputs.out_hiddens)
914928
past_key_values = None
915929
else:
916-
base_outputs, past_key_values = self._base_model_forward(
917-
input_ids,
918-
attention_mask,
919-
position_ids,
920-
past_key_values,
921-
self.eagle_freeze_base_model,
922-
labels,
923-
**kwargs,
924-
)
930+
with self._nvtx_range("base_model_forward"):
931+
base_outputs, past_key_values = self._base_model_forward(
932+
input_ids,
933+
attention_mask,
934+
position_ids,
935+
past_key_values,
936+
self.eagle_freeze_base_model,
937+
labels,
938+
**kwargs,
939+
)
925940

926941
if not isinstance(past_key_values, Cache):
927942
past_key_values = _get_empty_cache(self._base_llm_config)
@@ -935,20 +950,21 @@ def forward(
935950
num_ttt = self.eagle_ttt_steps
936951
train_accs = torch.zeros(num_parallel, num_ttt, device=input_ids.device)
937952
b, seq_length, _ = base_outputs.out_hiddens.shape
938-
(
939-
eagle_input_embeds,
940-
eagle_input_hiddens,
941-
eagle_attn_mask_0,
942-
eagle_position_ids,
943-
base_output_predict_tok,
944-
base_output_softmax_logits,
945-
) = self._prepare_eagle_inputs(
946-
input_ids,
947-
attention_mask,
948-
position_ids,
949-
eagle_cache,
950-
base_outputs,
951-
)
953+
with self._nvtx_range("prepare_eagle_inputs"):
954+
(
955+
eagle_input_embeds,
956+
eagle_input_hiddens,
957+
eagle_attn_mask_0,
958+
eagle_position_ids,
959+
base_output_predict_tok,
960+
base_output_softmax_logits,
961+
) = self._prepare_eagle_inputs(
962+
input_ids,
963+
attention_mask,
964+
position_ids,
965+
eagle_cache,
966+
base_outputs,
967+
)
952968

953969
self.eagle_module._maybe_init_rope()
954970

@@ -960,11 +976,7 @@ def forward(
960976
if self.eagle_mix_hidden_states or ttt_step == 0
961977
else self._get_ttt_attention_mask(b, seq_length, ttt_step)
962978
)
963-
with (
964-
enable_cp_ttt_patch()
965-
if self.training and not self.eagle_mix_hidden_states
966-
else contextlib.nullcontext()
967-
):
979+
with self._enable_cp_ttt(), self._nvtx_range("eagle_forward"):
968980
_, eagle_output_hiddens, eagle_logits, eagle_cache = self._eagle_forward(
969981
eagle_input_hiddens,
970982
eagle_input_embeds,
@@ -992,15 +1004,16 @@ def forward(
9921004

9931005
for i in range(self.eagle_config.parallel_draft_step):
9941006
eagle_logit = eagle_logits[i]
995-
classification_loss, acc = self._eagle_loss(
996-
# base model predict +1 tok, while eagle predict +2
997-
# so we shift base model outputs compared to eagle outputs
998-
# additionally, we mask the first n tok of eagle outputs at nth TTT step
999-
base_output_softmax_logits[:, 1 + i + ttt_step :],
1000-
base_output_predict_tok[:, 1 + i + ttt_step :],
1001-
eagle_logit[:, ttt_step : -(1 + i)],
1002-
loss_mask[:, 1 + ttt_step :] if i == 0 else loss_mask[:, 1 + ttt_step : -i],
1003-
)
1007+
with self._nvtx_range("eagle_loss"):
1008+
classification_loss, acc = self._eagle_loss(
1009+
# base model predict +1 tok, while eagle predict +2
1010+
# so we shift base model outputs compared to eagle outputs
1011+
# additionally, we mask the first n tok of eagle outputs at nth TTT step
1012+
base_output_softmax_logits[:, 1 + i + ttt_step :],
1013+
base_output_predict_tok[:, 1 + i + ttt_step :],
1014+
eagle_logit[:, ttt_step : -(1 + i)],
1015+
loss_mask[:, 1 + ttt_step :] if i == 0 else loss_mask[:, 1 + ttt_step : -i],
1016+
)
10041017
# Apply loss decay factor to focus on early steps
10051018
classification_loss *= self.eagle_loss_decay_factor ** (ttt_step + i)
10061019
eagle_loss = (
@@ -1028,7 +1041,6 @@ def forward(
10281041
train_acc=train_accs,
10291042
)
10301043

1031-
@maybe_nvtx_range("eagle_loss")
10321044
def _eagle_loss(
10331045
self,
10341046
base_output_softmax_logits,
@@ -1100,7 +1112,10 @@ def pseudo_speculative_generate(
11001112
)
11011113

11021114
# Use SDPA attention during generation for both stability and performance
1103-
with temporary_set_config_value(self.eagle_config, "_attn_implementation", "sdpa"):
1115+
with (
1116+
temporary_set_config_value(self.eagle_config, "_attn_implementation", "sdpa"),
1117+
self._nvtx_range("eagle_forward"),
1118+
):
11041119
_, eagle_prenorm_h, eagle_logits, _ = self._eagle_forward(
11051120
eagle_input_hidden_states,
11061121
self._base_model_embeddings(eagle_ids),

0 commit comments

Comments
 (0)