Skip to content

Commit 5dfd676

Browse files
committed
disable no packing for legacy sampling and code cleaning
1 parent 3c0199f commit 5dfd676

File tree

5 files changed

+12
-22
lines changed

5 files changed

+12
-22
lines changed

fast_llm/data/dataset/gpt/memmap.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -270,13 +270,8 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP
270270
spans = np.vstack(spans, dtype=np.int32)
271271
else:
272272
spans = np.array(spans, dtype=np.int32)
273-
# if len(chosen_spans) > 0:
274-
# chosen_spans = np.vstack(chosen_spans, dtype=np.int32)
275-
# else:
273+
276274
chosen_spans = np.array(chosen_spans, dtype=np.int32).reshape(-1, 2)
277-
# if len(rejected_spans) > 0:
278-
# rejected_spans = np.vstack(rejected_spans, dtype=np.int32)
279-
# else:
280275
rejected_spans = np.array(rejected_spans, dtype=np.int32).reshape(-1, 2)
281276

282277
# Write the index file (.idx)

fast_llm/data/dataset/gpt/sampled.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def __init__(
117117
)
118118
# TODO: Names are confusing
119119

120-
# contains document indexes/pointers in order of traversal (shuffled)
120+
# contains shuffled document indicies
121121
self._document_shuffling = MemmapArray(base_path.with_name(base_path.name + "_shuffling.npy"))
122122

123123
# contains cumulative sum of document sizes grouped by TOKEN_CUMSUM_RATE in shuffled order
@@ -521,6 +521,10 @@ def __init__(
521521
self._indexed_dataset = indexed_dataset
522522
self._num_samples = sampling.num_samples
523523
self._sequence_length = sampling.sequence_length
524+
if not sampling.config.enable_packing:
525+
raise NotImplementedError(
526+
"Legacy sampling only supports document packing. Please use the latest dataset format."
527+
)
524528
if not sampling.truncate_documents:
525529
raise NotImplementedError(
526530
"Legacy sampling only supports document truncation. Please use the latest dataset format."

fast_llm/engine/schedule/schedule.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def __init__(
135135
if self._batch_config.num_inputs < self._distributed.pipeline_parallel:
136136
warnings.warn("Not enough input to achieve true pipeline parallelism.")
137137

138-
# Setup the activation metas. (metadata for sequence parallel)
138+
# Setup the activation metas.
139139
self._preprocessed_meta = self._multi_stage.base_model.preprocess_meta(
140140
self._batch_config,
141141
phase=self._phase,
@@ -191,8 +191,8 @@ def get_step(
191191
return self._step_map[(type_, stage, data_index)]
192192

193193
def _create_index(self) -> None:
194-
self._device_steps: list[list[Step]] = [[] for _ in range(self._distributed.pipeline_parallel)] # steps for each device
195-
self._step_map = {} # map index (type, stage, data index) => step
194+
self._device_steps: list[list[Step]] = [[] for _ in range(self._distributed.pipeline_parallel)]
195+
self._step_map = {}
196196
for i, step in enumerate(self._steps):
197197
Assert.in_range(step.stage, 0, self._num_stages)
198198
Assert.in_range(
@@ -204,7 +204,6 @@ def _create_index(self) -> None:
204204
step.global_index = i
205205
# TODO: More configurable placement?
206206

207-
# perform looping here
208207
step.pipeline_rank = step.stage % self._distributed.pipeline_parallel
209208
step.local_index = len(self._device_steps[step.pipeline_rank])
210209
self._device_steps[step.pipeline_rank].append(step)

fast_llm/functional/dpo.py

-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ def compute_logps_for_spans(
1212
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
1313

1414
# gather log probabilities corresponding to the target tokens
15-
# selected_log_probs = log_probs[torch.arange(logits.shape[0] - 1), targets]
1615
selected_log_probs = log_probs[:-1].gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1)
1716

1817
# apply chosen mask
@@ -25,9 +24,6 @@ def compute_logps_for_spans(
2524
rejected_mask[rejected_span[:, 0]: rejected_span[:, 1] + 1] = 1
2625
rejected_logp = (selected_log_probs * rejected_mask).sum()
2726

28-
# chosen_logp = selected_log_probs[chosen_span[:, 0]: chosen_span[:, 1] + 1].sum()
29-
# rejected_logp = selected_log_probs[rejected_span[:, 0]: rejected_span[:, 1] + 1].sum()
30-
3127
return chosen_logp, rejected_logp
3228

3329
def compute_simplified_dpo_loss(

fast_llm/models/gpt/model.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ def setup(self, distributed: Distributed) -> None:
120120
self._is_setup = True
121121

122122

123-
# perform preprocessing for sequence parallel
124123
def preprocess_meta(
125124
self, batch_meta: BatchConfig | torch.Tensor, phase: PhaseType
126125
) -> list[tuple[TensorMeta, dict]]:
@@ -166,15 +165,13 @@ def preprocess_meta(
166165
else sequence_q_dim
167166
)
168167

169-
# determins if batch dim or sequence dim is first
170168
need_sequence_first = hidden_sequence_q_dim.size != sequence_length
171169
if self._config.sequence_first is None:
172170
sequence_first = need_sequence_first
173171
else:
174172
sequence_first = self._config.sequence_first
175173
assert not (need_sequence_first and not sequence_first)
176174

177-
# hidden dim is model hidden size
178175
hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden)
179176
hidden_dims = (
180177
(hidden_sequence_q_dim, batch_dim, hidden_dim)
@@ -199,7 +196,6 @@ def preprocess_meta(
199196
sequence_k = sequence_k_past + sequence_q_dim.size
200197
sequence_k_dim = TensorDim(TransformerDimNames.sequence_k, sequence_k)
201198

202-
# sequence_k_past is start and sequence_k is end of sequence
203199
tokens = TensorMeta.from_dims(
204200
hidden_dims[:2], tensor_name=f"tokens_{sequence_k_past}_to_{sequence_k-1}", dtype=torch.int64
205201
)
@@ -294,7 +290,7 @@ def preprocess(
294290
for i, spans in enumerate(batch.loss_masking_spans):
295291
if not spans.numel():
296292
continue
297-
# filter spans within the sequence or partially within the sequence
293+
# only keep spans within the sequence or partially within the sequence
298294
valid_spans = spans[(spans[:, 0] <= sequence_k) & (spans[:, 1] >= sequence_offset)]
299295
if valid_spans.numel():
300296
# if span is partially within the sequence, truncate parts of spans that are outside of the sequence
@@ -310,7 +306,7 @@ def preprocess(
310306
for i, spans in enumerate(batch.chosen_loss_masking_spans):
311307
if not spans.numel():
312308
continue
313-
# filter spans within the sequence or partially within the sequence
309+
# only keep spans within the sequence or partially within the sequence
314310
valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)]
315311
if valid_spans.numel():
316312
# if span is partially within the sequence, truncate parts of spans that are outside of the sequence
@@ -322,7 +318,7 @@ def preprocess(
322318
for i, spans in enumerate(batch.rejected_loss_masking_spans):
323319
if not spans.numel():
324320
continue
325-
# filter spans within the sequence or partially within the sequence
321+
# only keep spans within the sequence or partially within the sequence
326322
valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)]
327323
if valid_spans.numel():
328324
# if span is partially within the sequence, truncate parts of spans that are outside of the sequence

0 commit comments

Comments
 (0)