-
Notifications
You must be signed in to change notification settings - Fork 30
Loss masking for distillation #250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
IN the documentation, here, the parameter |
@oleksost Yes, looks like we forgot to update |
if group: | ||
Assert.eq(implementation, CrossEntropyImpl.fused) | ||
return fused_cross_entropy_forward_backward( | ||
return _fused_cross_entropy_forward_backward( | ||
logits, target, grad_output, logits_scale_factor, target_format, group | ||
) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation] we also need to pass loss mask?
if group: | ||
Assert.eq(implementation, CrossEntropyImpl.fused) | ||
return fused_cross_entropy_forward_backward( | ||
return _fused_cross_entropy_forward_backward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to pass loss_mask?
THis assert should be removed, no? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Works well for me.
β¨ Description
Pass a loss mask to kwargs so we it can be used for distillation loss, aka cross-entropy from logits
π Type of change
Select all that apply: