5757)
5858from ..eagle .conversion import EagleDMRegistry
5959from ..eagle .eagle_model import EagleModel
60- from ..eagle .utils import expand_mask , make_causal_mask , maybe_nvtx_range
60+ from ..eagle .utils import expand_mask , make_causal_mask
6161from ..medusa .conversion import MedusaDMRegistry
6262from ..medusa .medusa_model import MedusaModel
6363from ..utils import (
@@ -453,6 +453,23 @@ def _draft_model_config(self):
453453 """Return the llm config for the draft model."""
454454 return self .eagle_config
455455
456+ def _enable_cp_ttt (self ):
457+ if self .training and not self .eagle_mix_hidden_states :
458+ return enable_cp_ttt_patch ()
459+ return contextlib .nullcontext ()
460+
461+ def _nvtx_range (self , name ):
462+ """Optionally create an NVTX range for the given name when config.eagle_enable_nvtx is set."""
463+ if not self .eagle_enable_nvtx :
464+ return contextlib .nullcontext ()
465+ try :
466+ import torch .cuda .nvtx as nvtx
467+
468+ return nvtx .range (name )
469+ except Exception as e :
470+ print (f"Failed to create NVTX range { name } : { e } " )
471+ return contextlib .nullcontext ()
472+
456473 def get_exporter (self ) -> SpeculativeDecodingExporter :
457474 """Get the exporter for the draft model."""
458475 exporter_cls = (
@@ -682,7 +699,6 @@ def _prepare_decoder_attention_mask(
682699
683700 return combined_attention_mask
684701
685- @maybe_nvtx_range ("prepare_eagle_inputs" )
686702 def _prepare_eagle_inputs (
687703 self ,
688704 input_ids ,
@@ -785,7 +801,6 @@ def _compute_ttt_attention_mask(
785801 tensor_mask = tensor_mask .repeat (batch_size , 1 , 1 , 1 )
786802 return tensor_mask
787803
788- @maybe_nvtx_range ("base_model_forward" )
789804 def _base_model_forward (
790805 self ,
791806 input_ids ,
@@ -834,7 +849,6 @@ def _map_logits_to_draft_vocab(self, full_logits):
834849 )
835850 return full_logits [:, :, reverse_mapping ]
836851
837- @maybe_nvtx_range ("eagle_forward" )
838852 def _eagle_forward (
839853 self ,
840854 eagle_input_hidden_states ,
@@ -913,15 +927,16 @@ def forward(
913927 base_outputs .logits = self .lm_head (base_outputs .out_hiddens )
914928 past_key_values = None
915929 else :
916- base_outputs , past_key_values = self ._base_model_forward (
917- input_ids ,
918- attention_mask ,
919- position_ids ,
920- past_key_values ,
921- self .eagle_freeze_base_model ,
922- labels ,
923- ** kwargs ,
924- )
930+ with self ._nvtx_range ("base_model_forward" ):
931+ base_outputs , past_key_values = self ._base_model_forward (
932+ input_ids ,
933+ attention_mask ,
934+ position_ids ,
935+ past_key_values ,
936+ self .eagle_freeze_base_model ,
937+ labels ,
938+ ** kwargs ,
939+ )
925940
926941 if not isinstance (past_key_values , Cache ):
927942 past_key_values = _get_empty_cache (self ._base_llm_config )
@@ -935,20 +950,21 @@ def forward(
935950 num_ttt = self .eagle_ttt_steps
936951 train_accs = torch .zeros (num_parallel , num_ttt , device = input_ids .device )
937952 b , seq_length , _ = base_outputs .out_hiddens .shape
938- (
939- eagle_input_embeds ,
940- eagle_input_hiddens ,
941- eagle_attn_mask_0 ,
942- eagle_position_ids ,
943- base_output_predict_tok ,
944- base_output_softmax_logits ,
945- ) = self ._prepare_eagle_inputs (
946- input_ids ,
947- attention_mask ,
948- position_ids ,
949- eagle_cache ,
950- base_outputs ,
951- )
953+ with self ._nvtx_range ("prepare_eagle_inputs" ):
954+ (
955+ eagle_input_embeds ,
956+ eagle_input_hiddens ,
957+ eagle_attn_mask_0 ,
958+ eagle_position_ids ,
959+ base_output_predict_tok ,
960+ base_output_softmax_logits ,
961+ ) = self ._prepare_eagle_inputs (
962+ input_ids ,
963+ attention_mask ,
964+ position_ids ,
965+ eagle_cache ,
966+ base_outputs ,
967+ )
952968
953969 self .eagle_module ._maybe_init_rope ()
954970
@@ -960,11 +976,7 @@ def forward(
960976 if self .eagle_mix_hidden_states or ttt_step == 0
961977 else self ._get_ttt_attention_mask (b , seq_length , ttt_step )
962978 )
963- with (
964- enable_cp_ttt_patch ()
965- if self .training and not self .eagle_mix_hidden_states
966- else contextlib .nullcontext ()
967- ):
979+ with self ._enable_cp_ttt (), self ._nvtx_range ("eagle_forward" ):
968980 _ , eagle_output_hiddens , eagle_logits , eagle_cache = self ._eagle_forward (
969981 eagle_input_hiddens ,
970982 eagle_input_embeds ,
@@ -992,15 +1004,16 @@ def forward(
9921004
9931005 for i in range (self .eagle_config .parallel_draft_step ):
9941006 eagle_logit = eagle_logits [i ]
995- classification_loss , acc = self ._eagle_loss (
996- # base model predict +1 tok, while eagle predict +2
997- # so we shift base model outputs compared to eagle outputs
998- # additionally, we mask the first n tok of eagle outputs at nth TTT step
999- base_output_softmax_logits [:, 1 + i + ttt_step :],
1000- base_output_predict_tok [:, 1 + i + ttt_step :],
1001- eagle_logit [:, ttt_step : - (1 + i )],
1002- loss_mask [:, 1 + ttt_step :] if i == 0 else loss_mask [:, 1 + ttt_step : - i ],
1003- )
1007+ with self ._nvtx_range ("eagle_loss" ):
1008+ classification_loss , acc = self ._eagle_loss (
1009+ # base model predict +1 tok, while eagle predict +2
1010+ # so we shift base model outputs compared to eagle outputs
1011+ # additionally, we mask the first n tok of eagle outputs at nth TTT step
1012+ base_output_softmax_logits [:, 1 + i + ttt_step :],
1013+ base_output_predict_tok [:, 1 + i + ttt_step :],
1014+ eagle_logit [:, ttt_step : - (1 + i )],
1015+ loss_mask [:, 1 + ttt_step :] if i == 0 else loss_mask [:, 1 + ttt_step : - i ],
1016+ )
10041017 # Apply loss decay factor to focus on early steps
10051018 classification_loss *= self .eagle_loss_decay_factor ** (ttt_step + i )
10061019 eagle_loss = (
@@ -1028,7 +1041,6 @@ def forward(
10281041 train_acc = train_accs ,
10291042 )
10301043
1031- @maybe_nvtx_range ("eagle_loss" )
10321044 def _eagle_loss (
10331045 self ,
10341046 base_output_softmax_logits ,
@@ -1100,7 +1112,10 @@ def pseudo_speculative_generate(
11001112 )
11011113
11021114 # Use SDPA attention during generation for both stability and performance
1103- with temporary_set_config_value (self .eagle_config , "_attn_implementation" , "sdpa" ):
1115+ with (
1116+ temporary_set_config_value (self .eagle_config , "_attn_implementation" , "sdpa" ),
1117+ self ._nvtx_range ("eagle_forward" ),
1118+ ):
11041119 _ , eagle_prenorm_h , eagle_logits , _ = self ._eagle_forward (
11051120 eagle_input_hidden_states ,
11061121 self ._base_model_embeddings (eagle_ids ),
0 commit comments