Skip to content

Loss masking for distillation #250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 74 commits into from
May 7, 2025
Merged

Loss masking for distillation #250

merged 74 commits into from
May 7, 2025

Conversation

jlamypoirier
Copy link
Collaborator

✨ Description

Pass a loss mask to kwargs so we it can be used for distillation loss, aka cross-entropy from logits

πŸ” Type of change

Select all that apply:

  • πŸ› Bug fix (non-breaking change that addresses a specific issue)
  • πŸš€ New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • πŸ“ˆ Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • πŸ› οΈ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • πŸ“¦ Dependency bump (updates dependencies, including Dockerfile or package changes)
  • πŸ“ Documentation change (updates documentation, including new content or typo fixes)
  • πŸ”§ Infrastructure/Build change (affects build process, CI/CD, or dependencies)

@oleksost
Copy link
Contributor

oleksost commented May 2, 2025

IN the documentation, here, the parameter use_loss_masking_spans should be under batch and not under sampling, right?

@jlamypoirier
Copy link
Collaborator Author

@oleksost Yes, looks like we forgot to update

if group:
Assert.eq(implementation, CrossEntropyImpl.fused)
return fused_cross_entropy_forward_backward(
return _fused_cross_entropy_forward_backward(
logits, target, grad_output, logits_scale_factor, target_format, group
)
else:
Copy link
Contributor

@oleksost oleksost May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation] we also need to pass loss mask?

if group:
Assert.eq(implementation, CrossEntropyImpl.fused)
return fused_cross_entropy_forward_backward(
return _fused_cross_entropy_forward_backward(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to pass loss_mask?

Base automatically changed from reference_model_preprocessing to main May 2, 2025 17:42
@jlamypoirier jlamypoirier marked this pull request as ready for review May 2, 2025 22:00
@jlamypoirier jlamypoirier requested review from oleksost and tscholak May 2, 2025 22:00
@oleksost
Copy link
Contributor

oleksost commented May 5, 2025

THis assert should be removed, no?

Copy link
Contributor

@oleksost oleksost left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works well for me.

@jlamypoirier jlamypoirier merged commit f08ac90 into main May 7, 2025
2 checks passed
@jlamypoirier jlamypoirier deleted the distillation_loss_mask branch May 7, 2025 20:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants