-
Notifications
You must be signed in to change notification settings - Fork 1
Fix the deprecated torch.cuda.amp module #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix the deprecated torch.cuda.amp module #21
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR modernizes the GradScaler import by migrating from the deprecated torch.cuda.amp.GradScaler to the newer torch.amp.GradScaler API, which is device-agnostic. The changes ensure backward compatibility by wrapping the new API with a partial function that defaults to CUDA device when neither NPU nor MLU is available.
- Updated GradScaler import from
torch.cuda.amptotorch.ampin production code and tests - Added a device-specific wrapper using
functools.partialto maintain CUDA as default device
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| mmengine/optim/optimizer/amp_optimizer_wrapper.py | Updated GradScaler import to use torch.amp and added partial wrapper for CUDA device specification |
| tests/test_optim/test_optimizer/test_optimizer_wrapper.py | Updated test imports to use the new torch.amp.GradScaler API |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Documentation build overview
Show files changed (2 files in total): 📝 2 modified | ➕ 0 added | ➖ 0 deleted
|
|
If you fix the linting error I will merge. |
|
@lauriebax Lint fixed. And, as described in open-mmlab#1676:
I suggest to update |
This sub-PR is related to open-mmlab#1665
Brief
According to PyTorch:
This includes two related replacement:
amp_optimizer_wrappertest_optimizer_wrapperPyTest Result After this PR
pytest tests/test_optim/test_optimizer/test_optimizer_wrapper.py