@@ -120,7 +120,6 @@ def setup(self, distributed: Distributed) -> None:
120
120
self ._is_setup = True
121
121
122
122
123
- # perform preprocessing for sequence parallel
124
123
def preprocess_meta (
125
124
self , batch_meta : BatchConfig | torch .Tensor , phase : PhaseType
126
125
) -> list [tuple [TensorMeta , dict ]]:
@@ -166,15 +165,13 @@ def preprocess_meta(
166
165
else sequence_q_dim
167
166
)
168
167
169
- # determins if batch dim or sequence dim is first
170
168
need_sequence_first = hidden_sequence_q_dim .size != sequence_length
171
169
if self ._config .sequence_first is None :
172
170
sequence_first = need_sequence_first
173
171
else :
174
172
sequence_first = self ._config .sequence_first
175
173
assert not (need_sequence_first and not sequence_first )
176
174
177
- # hidden dim is model hidden size
178
175
hidden_dim = self ._tensor_space .get_tensor_dim (TransformerDimNames .hidden )
179
176
hidden_dims = (
180
177
(hidden_sequence_q_dim , batch_dim , hidden_dim )
@@ -199,7 +196,6 @@ def preprocess_meta(
199
196
sequence_k = sequence_k_past + sequence_q_dim .size
200
197
sequence_k_dim = TensorDim (TransformerDimNames .sequence_k , sequence_k )
201
198
202
- # sequence_k_past is start and sequence_k is end of sequence
203
199
tokens = TensorMeta .from_dims (
204
200
hidden_dims [:2 ], tensor_name = f"tokens_{ sequence_k_past } _to_{ sequence_k - 1 } " , dtype = torch .int64
205
201
)
@@ -294,7 +290,7 @@ def preprocess(
294
290
for i , spans in enumerate (batch .loss_masking_spans ):
295
291
if not spans .numel ():
296
292
continue
297
- # filter spans within the sequence or partially within the sequence
293
+ # only keep spans within the sequence or partially within the sequence
298
294
valid_spans = spans [(spans [:, 0 ] <= sequence_k ) & (spans [:, 1 ] >= sequence_offset )]
299
295
if valid_spans .numel ():
300
296
# if span is partially within the sequence, truncate parts of spans that are outside of the sequence
@@ -310,7 +306,7 @@ def preprocess(
310
306
for i , spans in enumerate (batch .chosen_loss_masking_spans ):
311
307
if not spans .numel ():
312
308
continue
313
- # filter spans within the sequence or partially within the sequence
309
+ # only keep spans within the sequence or partially within the sequence
314
310
valid_spans = spans [(spans [0 ] <= sequence_k ) & (spans [1 ] >= sequence_offset )]
315
311
if valid_spans .numel ():
316
312
# if span is partially within the sequence, truncate parts of spans that are outside of the sequence
@@ -322,7 +318,7 @@ def preprocess(
322
318
for i , spans in enumerate (batch .rejected_loss_masking_spans ):
323
319
if not spans .numel ():
324
320
continue
325
- # filter spans within the sequence or partially within the sequence
321
+ # only keep spans within the sequence or partially within the sequence
326
322
valid_spans = spans [(spans [0 ] <= sequence_k ) & (spans [1 ] >= sequence_offset )]
327
323
if valid_spans .numel ():
328
324
# if span is partially within the sequence, truncate parts of spans that are outside of the sequence
0 commit comments