Skip to content

Commit 068138f

Browse files
committed
minor
1 parent 0b469fb commit 068138f

File tree

1 file changed

+61
-60
lines changed
  • fast_llm/layers/language_model

1 file changed

+61
-60
lines changed

fast_llm/layers/language_model/head.py

Lines changed: 61 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -402,67 +402,68 @@ def _logits_cross_entropy_forward_backward(
402402
sequence_parallel=self._sequence_parallel and self._parallel_embeddings,
403403
)
404404

405-
if self.config.transformer.diffusion == DiffusionStyle.masked:
406-
masked_indices = kwargs[LanguageModelKwargs.mask_indexes]
407-
p_mask = kwargs[LanguageModelKwargs.mask_probabilities]
408-
# index [0, 1, 2, 3, 4, 5] ->
409-
# The labels are already left shifted x = [A, B, C, D, E, F] ->
410-
# embd = [A, B, C, D, E]
411-
# label = [B, C, D, E, F]
412-
413-
# Question Pier: if 2 is the masked token, settling needs to settled 3; but 3 is already seen by the model,
414-
# can it just learn to copy 3? i.e copy the next token to the masked?
415-
# Yes. We need to drop those position from loss if the next token is not masked
416-
# We can include curruption to further enhance, but it seems not to big looking at other CPT (diffuLlama)
417-
418-
last_weight = 0
419-
B = logits.shape[0]
420-
421-
loss_weight = torch.cat(
422-
(
423-
# ar_weight * in_context[:, 1:] + # not implement yet
424-
masked_indices[:, 1:] / p_mask[:, None],
425-
# + un_weight * ((1-epsilon) * in_shuffled[:, 1:] + epsilon * in_clean[:, 1:]) / (1 - p_mask[:, None]) # not implement yet
426-
(last_weight * torch.ones(B, device=logits.device)).unsqueeze(1),
427-
# This may need some weighting in terms of masking. Let's do last_weight=0 TODO: Decide later
428-
),
429-
dim=1,
430-
).to(logits.dtype)
431-
432-
# print(f"Loss weight: {loss_weight}")
433-
434-
loss, grad = cross_entropy_forward_backward(
435-
logits=logits.flatten(0, -2),
436-
target=target,
437-
loss_mask=None,
438-
grad_output=grad_output,
439-
group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None,
440-
implementation=self._cross_entropy_impl,
441-
logits_scale_factor=self._logits_scale_factor,
442-
loss_weight=loss_weight,
443-
)
444-
445-
elif self.confing.transformer.diffusion == DiffusionStyle.ar_masked:
446-
447-
loss_weights = kwargs[LanguageModelKwargs.loss_weights]
448-
context_index = kwargs[LanguageModelKwargs.in_context]
449-
masked_index = kwargs[LanguageModelKwargs.mask_indexes]
450-
B = loss_weights.shape[0]
451-
masked_index = torch.cat([masked_index[:, 1:], torch.zeros(B, 1, device=loss_weights.device)], dim=1)
452-
context_index = torch.cat([context_index[:, 1:], torch.zeros(B, 1, device=loss_weights.device)], dim=1)
453-
454-
loss, grad, per_token_loss_b4_weight = cross_entropy_forward_backward(
455-
logits.flatten(0, -2),
456-
target=target,
457-
group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None,
458-
grad_output=grad_output,
459-
implementation=self._cross_entropy_impl,
460-
logits_scale_factor=self._logits_scale_factor,
461-
loss_weight=loss_weights,
462-
)
405+
if self.config.transformer.diffusion is not None:
406+
if self.config.transformer.diffusion == DiffusionStyle.masked:
407+
masked_indices = kwargs[LanguageModelKwargs.mask_indexes]
408+
p_mask = kwargs[LanguageModelKwargs.mask_probabilities]
409+
# index [0, 1, 2, 3, 4, 5] ->
410+
# The labels are already left shifted x = [A, B, C, D, E, F] ->
411+
# embd = [A, B, C, D, E]
412+
# label = [B, C, D, E, F]
413+
414+
# Question Pier: if 2 is the masked token, settling needs to settled 3; but 3 is already seen by the model,
415+
# can it just learn to copy 3? i.e copy the next token to the masked?
416+
# Yes. We need to drop those position from loss if the next token is not masked
417+
# We can include curruption to further enhance, but it seems not to big looking at other CPT (diffuLlama)
418+
419+
last_weight = 0
420+
B = logits.shape[0]
421+
422+
loss_weight = torch.cat(
423+
(
424+
# ar_weight * in_context[:, 1:] + # not implement yet
425+
masked_indices[:, 1:] / p_mask[:, None],
426+
# + un_weight * ((1-epsilon) * in_shuffled[:, 1:] + epsilon * in_clean[:, 1:]) / (1 - p_mask[:, None]) # not implement yet
427+
(last_weight * torch.ones(B, device=logits.device)).unsqueeze(1),
428+
# This may need some weighting in terms of masking. Let's do last_weight=0 TODO: Decide later
429+
),
430+
dim=1,
431+
).to(logits.dtype)
432+
433+
# print(f"Loss weight: {loss_weight}")
434+
435+
loss, grad = cross_entropy_forward_backward(
436+
logits=logits.flatten(0, -2),
437+
target=target,
438+
loss_mask=None,
439+
grad_output=grad_output,
440+
group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None,
441+
implementation=self._cross_entropy_impl,
442+
logits_scale_factor=self._logits_scale_factor,
443+
loss_weight=loss_weight,
444+
)
463445

464-
losses["loss_mask_tokens"].append((per_token_loss_b4_weight * masked_index).mean())
465-
losses["loss_in_context_tokens"].append((per_token_loss_b4_weight * context_index).mean())
446+
elif self.confing.transformer.diffusion == DiffusionStyle.ar_masked:
447+
448+
loss_weights = kwargs[LanguageModelKwargs.loss_weights]
449+
context_index = kwargs[LanguageModelKwargs.in_context]
450+
masked_index = kwargs[LanguageModelKwargs.mask_indexes]
451+
B = loss_weights.shape[0]
452+
masked_index = torch.cat([masked_index[:, 1:], torch.zeros(B, 1, device=loss_weights.device)], dim=1)
453+
context_index = torch.cat([context_index[:, 1:], torch.zeros(B, 1, device=loss_weights.device)], dim=1)
454+
455+
loss, grad, per_token_loss_b4_weight = cross_entropy_forward_backward(
456+
logits.flatten(0, -2),
457+
target=target,
458+
group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None,
459+
grad_output=grad_output,
460+
implementation=self._cross_entropy_impl,
461+
logits_scale_factor=self._logits_scale_factor,
462+
loss_weight=loss_weights,
463+
)
464+
# Add these before weighting to display them separately
465+
losses["loss_mask_tokens"].append((per_token_loss_b4_weight * masked_index).mean())
466+
losses["loss_in_context_tokens"].append((per_token_loss_b4_weight * context_index).mean())
466467

467468
# This happens with the loss_weight.
468469
# MDM https://github.com/ML-GSAI/SMDM/blob/583aa4716d17728dbb825aec6c24a121164d616a/pretrain/train_mdm.py#L274

0 commit comments

Comments
 (0)