-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Fix FSDP mixed precision semantics and add user warning #21361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Fix FSDP mixed precision semantics and add user warning #21361
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #21361 +/- ##
=======================================
- Coverage 87% 87% -0%
=======================================
Files 269 269
Lines 23744 23736 -8
=======================================
- Hits 20572 20561 -11
- Misses 3172 3175 +3 |
| if precision != "32-true": | ||
| rank_zero_warn( | ||
| f"FSDPPrecision `{precision}` runs computations in reduced precision " | ||
| "(e.g., float16/bfloat16) while keeping model weights stored in full precision. " | ||
| "These modes are still experimental and may produce slightly different accuracy or stability " | ||
| "compared to full precision (`precision='32-true'`)." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's not do this :D
| param_dtype = torch.float32 | ||
| reduce_dtype = buffer_dtype = torch.bfloat16 | ||
| elif self.precision == "16-true": | ||
| if self.precision in ("16-true", "16-mixed"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's instead add a warning here for 16-true and bf16-true that fsdp always keeps a full precision reference of the weights
What does this PR do?
Fixes #<issue_number>
Align FSDP "-mixed" precision modes with their "-true" counterparts, since parameters remain in full precision. Adds a warning for all reduced-precision modes indicating experimental status.
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist
📚 Documentation preview 📚: https://pytorch-lightning--21361.org.readthedocs.build/en/21361/