@@ -402,67 +402,68 @@ def _logits_cross_entropy_forward_backward(
402402 sequence_parallel = self ._sequence_parallel and self ._parallel_embeddings ,
403403 )
404404
405- if self .config .transformer .diffusion == DiffusionStyle .masked :
406- masked_indices = kwargs [LanguageModelKwargs .mask_indexes ]
407- p_mask = kwargs [LanguageModelKwargs .mask_probabilities ]
408- # index [0, 1, 2, 3, 4, 5] ->
409- # The labels are already left shifted x = [A, B, C, D, E, F] ->
410- # embd = [A, B, C, D, E]
411- # label = [B, C, D, E, F]
412-
413- # Question Pier: if 2 is the masked token, settling needs to settled 3; but 3 is already seen by the model,
414- # can it just learn to copy 3? i.e copy the next token to the masked?
415- # Yes. We need to drop those position from loss if the next token is not masked
416- # We can include curruption to further enhance, but it seems not to big looking at other CPT (diffuLlama)
417-
418- last_weight = 0
419- B = logits .shape [0 ]
420-
421- loss_weight = torch .cat (
422- (
423- # ar_weight * in_context[:, 1:] + # not implement yet
424- masked_indices [:, 1 :] / p_mask [:, None ],
425- # + un_weight * ((1-epsilon) * in_shuffled[:, 1:] + epsilon * in_clean[:, 1:]) / (1 - p_mask[:, None]) # not implement yet
426- (last_weight * torch .ones (B , device = logits .device )).unsqueeze (1 ),
427- # This may need some weighting in terms of masking. Let's do last_weight=0 TODO: Decide later
428- ),
429- dim = 1 ,
430- ).to (logits .dtype )
431-
432- # print(f"Loss weight: {loss_weight}")
433-
434- loss , grad = cross_entropy_forward_backward (
435- logits = logits .flatten (0 , - 2 ),
436- target = target ,
437- loss_mask = None ,
438- grad_output = grad_output ,
439- group = self ._tensor_space .distributed .tensor_group if self ._parallel_embeddings else None ,
440- implementation = self ._cross_entropy_impl ,
441- logits_scale_factor = self ._logits_scale_factor ,
442- loss_weight = loss_weight ,
443- )
444-
445- elif self .confing .transformer .diffusion == DiffusionStyle .ar_masked :
446-
447- loss_weights = kwargs [LanguageModelKwargs .loss_weights ]
448- context_index = kwargs [LanguageModelKwargs .in_context ]
449- masked_index = kwargs [LanguageModelKwargs .mask_indexes ]
450- B = loss_weights .shape [0 ]
451- masked_index = torch .cat ([masked_index [:, 1 :], torch .zeros (B , 1 , device = loss_weights .device )], dim = 1 )
452- context_index = torch .cat ([context_index [:, 1 :], torch .zeros (B , 1 , device = loss_weights .device )], dim = 1 )
453-
454- loss , grad , per_token_loss_b4_weight = cross_entropy_forward_backward (
455- logits .flatten (0 , - 2 ),
456- target = target ,
457- group = self ._tensor_space .distributed .tensor_group if self ._parallel_embeddings else None ,
458- grad_output = grad_output ,
459- implementation = self ._cross_entropy_impl ,
460- logits_scale_factor = self ._logits_scale_factor ,
461- loss_weight = loss_weights ,
462- )
405+ if self .config .transformer .diffusion is not None :
406+ if self .config .transformer .diffusion == DiffusionStyle .masked :
407+ masked_indices = kwargs [LanguageModelKwargs .mask_indexes ]
408+ p_mask = kwargs [LanguageModelKwargs .mask_probabilities ]
409+ # index [0, 1, 2, 3, 4, 5] ->
410+ # The labels are already left shifted x = [A, B, C, D, E, F] ->
411+ # embd = [A, B, C, D, E]
412+ # label = [B, C, D, E, F]
413+
414+ # Question Pier: if 2 is the masked token, settling needs to settled 3; but 3 is already seen by the model,
415+ # can it just learn to copy 3? i.e copy the next token to the masked?
416+ # Yes. We need to drop those position from loss if the next token is not masked
417+ # We can include curruption to further enhance, but it seems not to big looking at other CPT (diffuLlama)
418+
419+ last_weight = 0
420+ B = logits .shape [0 ]
421+
422+ loss_weight = torch .cat (
423+ (
424+ # ar_weight * in_context[:, 1:] + # not implement yet
425+ masked_indices [:, 1 :] / p_mask [:, None ],
426+ # + un_weight * ((1-epsilon) * in_shuffled[:, 1:] + epsilon * in_clean[:, 1:]) / (1 - p_mask[:, None]) # not implement yet
427+ (last_weight * torch .ones (B , device = logits .device )).unsqueeze (1 ),
428+ # This may need some weighting in terms of masking. Let's do last_weight=0 TODO: Decide later
429+ ),
430+ dim = 1 ,
431+ ).to (logits .dtype )
432+
433+ # print(f"Loss weight: {loss_weight}")
434+
435+ loss , grad = cross_entropy_forward_backward (
436+ logits = logits .flatten (0 , - 2 ),
437+ target = target ,
438+ loss_mask = None ,
439+ grad_output = grad_output ,
440+ group = self ._tensor_space .distributed .tensor_group if self ._parallel_embeddings else None ,
441+ implementation = self ._cross_entropy_impl ,
442+ logits_scale_factor = self ._logits_scale_factor ,
443+ loss_weight = loss_weight ,
444+ )
463445
464- losses ["loss_mask_tokens" ].append ((per_token_loss_b4_weight * masked_index ).mean ())
465- losses ["loss_in_context_tokens" ].append ((per_token_loss_b4_weight * context_index ).mean ())
446+ elif self .confing .transformer .diffusion == DiffusionStyle .ar_masked :
447+
448+ loss_weights = kwargs [LanguageModelKwargs .loss_weights ]
449+ context_index = kwargs [LanguageModelKwargs .in_context ]
450+ masked_index = kwargs [LanguageModelKwargs .mask_indexes ]
451+ B = loss_weights .shape [0 ]
452+ masked_index = torch .cat ([masked_index [:, 1 :], torch .zeros (B , 1 , device = loss_weights .device )], dim = 1 )
453+ context_index = torch .cat ([context_index [:, 1 :], torch .zeros (B , 1 , device = loss_weights .device )], dim = 1 )
454+
455+ loss , grad , per_token_loss_b4_weight = cross_entropy_forward_backward (
456+ logits .flatten (0 , - 2 ),
457+ target = target ,
458+ group = self ._tensor_space .distributed .tensor_group if self ._parallel_embeddings else None ,
459+ grad_output = grad_output ,
460+ implementation = self ._cross_entropy_impl ,
461+ logits_scale_factor = self ._logits_scale_factor ,
462+ loss_weight = loss_weights ,
463+ )
464+ # Add these before weighting to display them separately
465+ losses ["loss_mask_tokens" ].append ((per_token_loss_b4_weight * masked_index ).mean ())
466+ losses ["loss_in_context_tokens" ].append ((per_token_loss_b4_weight * context_index ).mean ())
466467
467468 # This happens with the loss_weight.
468469 # MDM https://github.com/ML-GSAI/SMDM/blob/583aa4716d17728dbb825aec6c24a121164d616a/pretrain/train_mdm.py#L274
0 commit comments