@@ -50,7 +50,9 @@ def __init__(
50
50
self ._group_size = tensor_space .distributed_config .tensor_parallel
51
51
self ._sequence_parallel = tensor_space .distributed_config .sequence_tensor_parallel
52
52
self ._parallel_embeddings = tensor_space .distributed_config .tensor_parallel > 1 and config .parallel_embeddings
53
- self ._sequence_parallel_logits = self ._sequence_parallel and not self ._parallel_embeddings
53
+ self ._sequence_parallel_logits = (
54
+ tensor_space .distributed_config .sequence_tensor_parallel and not config .parallel_embeddings
55
+ )
54
56
self ._cross_entropy_splits = config .cross_entropy_splits
55
57
if self ._cross_entropy_splits is not None and self ._sequence_parallel :
56
58
assert not self ._parallel_embeddings
@@ -67,7 +69,7 @@ def __init__(
67
69
# >0: multi-token prediction (MTP)
68
70
Assert .geq (prediction_distance , 0 )
69
71
self ._prediction_distance = prediction_distance
70
- self .is_last_head = self ._prediction_distance == config .prediction_heads - 1
72
+ self ._is_last_head = self ._prediction_distance == config .prediction_heads - 1
71
73
72
74
self ._init_output_weights (hidden_dim , config )
73
75
@@ -114,7 +116,7 @@ def forward(
114
116
tensor_name = "Loss" ,
115
117
reductions = ((DistributedDimNames .data , ReduceOp .AVG ),), # noqa
116
118
)
117
- if not self .is_last_head :
119
+ if not self ._is_last_head :
118
120
# MTP: split the stacked input
119
121
shared_hidden , input_ = torch .unbind (input_ , dim = 0 )
120
122
# TODO: Pytorch copies the grads in backward for no reason (not sure if still the case)
@@ -123,10 +125,10 @@ def forward(
123
125
# TODO: Drop autograd entirely.
124
126
# TODO: Skip cross-entropy backward if not needed.
125
127
language_model_loss = self ._forward (input_ , kwargs , losses )
126
- if language_model_loss is not None :
128
+ if losses is not None and language_model_loss is not None :
127
129
losses [self ._loss_name ].append (language_model_loss )
128
130
# TODO: Return the model output when needed.
129
- if self .is_last_head :
131
+ if self ._is_last_head :
130
132
# Last head should return the loss for backward.
131
133
return language_model_loss
132
134
else :
@@ -147,14 +149,13 @@ def _forward_backward(
147
149
if target is not None :
148
150
if self ._config .distillation_model is None :
149
151
# MTP: Shift the labels
150
- target = (
151
- target [self ._prediction_distance : self ._prediction_distance + input_ .size (0 ),]
152
- if kwargs [TransformerKwargs .sequence_first ]
153
- else target [
154
- :,
155
- self ._prediction_distance : self ._prediction_distance + input_ .size (1 ),
156
- ]
152
+ target_sequence_length = (
153
+ target .size (1 - kwargs [TransformerKwargs .sequence_first ]) + 1 - self ._config .prediction_heads
157
154
)
155
+ if TransformerKwargs .sequence_q_dim in kwargs :
156
+ Assert .eq (target_sequence_length , kwargs [TransformerKwargs .sequence_q_dim ].size )
157
+ target_slice = slice (self ._prediction_distance , self ._prediction_distance + target_sequence_length )
158
+ target = target [target_slice ] if kwargs [TransformerKwargs .sequence_first ] else target [:, target_slice ]
158
159
target = target .flatten ()
159
160
else :
160
161
# Target is reference model logits.
0 commit comments