-
Notifications
You must be signed in to change notification settings - Fork 89
fix(transformers/diffusers): accelerate model loading by optimizing the checkpoint-loading strategy with mindspore 2.7.1 #1422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @Cui-yshoho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a series of optimizations aimed at drastically reducing the time it takes to load large models within the MindSpore framework. By refining how model checkpoints are handled, implementing direct memory manipulation for parameter data types, and utilizing a dynamic patching mechanism for default data types during initialization, the PR achieves substantial performance gains, making the model loading process much more efficient. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant optimizations to accelerate model loading in MindSpore, achieving impressive performance gains. The core strategies involve monkey-patching MindSpore internals and leveraging newer framework features like no_init_parameters and optimized checkpoint loading. While the performance improvements are valuable, the current implementation introduces some critical risks. Specifically, the monkey-patching is duplicated and applied globally, which can be a maintenance and stability concern. Furthermore, the temporary patching of nn.Cell.__init__ is not exception-safe, potentially leaving the application in an inconsistent state if an error occurs. My review includes critical feedback to address these safety and maintainability issues, along with other suggestions for code simplification and consistency.
| ms.Parameter._data = ms.Tensor.data | ||
| ms.Parameter.data_ptr = ms.Tensor.data_ptr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These monkey patches for ms.Parameter are duplicated across multiple files (mindone/diffusers/models/model_loading_utils.py, mindone/diffusers/models/modeling_utils.py, and mindone/transformers/modeling_utils.py). This is risky and makes maintenance difficult. Please centralize this logic in a single location, for example, within the new mindone/utils/modeling_patch.py file. It would be even safer to wrap this patching logic in a context manager to limit its scope, similar to how patch_nn_default_dtype is used.
| ms.Parameter._data = ms.Tensor.data | ||
| ms.Parameter.data_ptr = ms.Tensor.data_ptr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These monkey patches for ms.Parameter are duplicated across multiple files. This is a maintenance risk and can lead to inconsistent behavior. Please centralize this logic in a single location, such as mindone/utils/modeling_patch.py, and consider using a context manager to control its scope for better safety.
| if mindspore_dtype is not None: | ||
| patch_nn_default_dtype(dtype=mindspore_dtype, force=True) | ||
| with no_init_parameters(): | ||
| model = cls.from_config(config, **unused_kwargs) | ||
| if mindspore_dtype is not None: | ||
| restore_nn_default_dtype() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The patching of nn.Cell.__init__ methods is not exception-safe. If cls.from_config raises an exception, restore_nn_default_dtype() will not be called, leaving many mindspore.nn classes in a patched state. This can cause unpredictable behavior in other parts of the application. Please use a try...finally block to ensure the restoration logic is always executed.
| if mindspore_dtype is not None: | |
| patch_nn_default_dtype(dtype=mindspore_dtype, force=True) | |
| with no_init_parameters(): | |
| model = cls.from_config(config, **unused_kwargs) | |
| if mindspore_dtype is not None: | |
| restore_nn_default_dtype() | |
| if mindspore_dtype is not None: | |
| patch_nn_default_dtype(dtype=mindspore_dtype, force=True) | |
| try: | |
| with no_init_parameters(): | |
| model = cls.from_config(config, **unused_kwargs) | |
| finally: | |
| if mindspore_dtype is not None: | |
| restore_nn_default_dtype() |
| ms.Parameter._data = ms.Tensor.data | ||
| ms.Parameter.data_ptr = ms.Tensor.data_ptr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These monkey patches for ms.Parameter are duplicated across multiple files. This is a maintenance risk and can lead to inconsistent behavior. Please centralize this logic in a single location, such as mindone/utils/modeling_patch.py, and consider using a context manager to control its scope for better safety.
| if mindspore_dtype is not None: | ||
| patch_nn_default_dtype(dtype=mindspore_dtype, force=True) | ||
| with no_init_parameters(): | ||
| model = cls(config, **kwargs) | ||
| if mindspore_dtype is not None: | ||
| restore_nn_default_dtype() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The patching of nn.Cell.__init__ methods is not exception-safe. If cls(config, **kwargs) raises an exception inside the with no_init_parameters() block, restore_nn_default_dtype() will not be called. This leaves mindspore.nn classes in a patched state, which can cause unpredictable behavior. Please use a try...finally block to ensure the restoration logic is always executed.
| if mindspore_dtype is not None: | |
| patch_nn_default_dtype(dtype=mindspore_dtype, force=True) | |
| with no_init_parameters(): | |
| model = cls(config, **kwargs) | |
| if mindspore_dtype is not None: | |
| restore_nn_default_dtype() | |
| if mindspore_dtype is not None: | |
| patch_nn_default_dtype(dtype=mindspore_dtype, force=True) | |
| try: | |
| with no_init_parameters(): | |
| model = cls(config, **kwargs) | |
| finally: | |
| if mindspore_dtype is not None: | |
| restore_nn_default_dtype() |
| if mindspore_dtype is not None: | ||
| patch_nn_default_dtype(dtype=mindspore_dtype, force=True) | ||
| with no_init_parameters(): | ||
| model = cls(config, *model_args, **model_kwargs) | ||
| if mindspore_dtype is not None: | ||
| restore_nn_default_dtype() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The patching of nn.Cell.__init__ methods is not exception-safe. If cls(config, *model_args, **model_kwargs) raises an exception, restore_nn_default_dtype() will not be called, leaving many mindspore.nn classes in a patched state. This can cause unpredictable behavior in other parts of the application. Please use a try...finally block to ensure the restoration logic is always executed.
| if mindspore_dtype is not None: | |
| patch_nn_default_dtype(dtype=mindspore_dtype, force=True) | |
| with no_init_parameters(): | |
| model = cls(config, *model_args, **model_kwargs) | |
| if mindspore_dtype is not None: | |
| restore_nn_default_dtype() | |
| if mindspore_dtype is not None: | |
| patch_nn_default_dtype(dtype=mindspore_dtype, force=True) | |
| try: | |
| with no_init_parameters(): | |
| model = cls(config, *model_args, **model_kwargs) | |
| finally: | |
| if mindspore_dtype is not None: | |
| restore_nn_default_dtype() |
| def half(self): | ||
| for p in self.get_parameters(): | ||
| p.set_dtype(ms.float16) | ||
| p._data = p.to(device="CPU", dtype=ms.float16) | ||
| return self | ||
|
|
||
| def float(self): | ||
| for p in self.get_parameters(): | ||
| p.set_dtype(ms.float32) | ||
| p._data = p.to(device="CPU", dtype=ms.float32) | ||
| return self |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The half() and float() methods are missing a check to see if the parameter's dtype is already the target dtype. This check is present in the to() method and helps avoid unnecessary tensor conversions and data copies. Please add this check for consistency and performance.
| def half(self): | |
| for p in self.get_parameters(): | |
| p.set_dtype(ms.float16) | |
| p._data = p.to(device="CPU", dtype=ms.float16) | |
| return self | |
| def float(self): | |
| for p in self.get_parameters(): | |
| p.set_dtype(ms.float32) | |
| p._data = p.to(device="CPU", dtype=ms.float32) | |
| return self | |
| def half(self): | |
| for p in self.get_parameters(): | |
| if p.dtype != ms.float16: | |
| p._data = p.to(device="CPU", dtype=ms.float16) | |
| return self | |
| def float(self): | |
| for p in self.get_parameters(): | |
| if p.dtype != ms.float32: | |
| p._data = p.to(device="CPU", dtype=ms.float32) | |
| return self |
| def float(self): | ||
| for p in self.get_parameters(): | ||
| p.set_dtype(ms.float32) | ||
| p._data = p.to(device="CPU", dtype=ms.float32) | ||
| return self | ||
|
|
||
| def half(self): | ||
| for p in self.get_parameters(): | ||
| p.set_dtype(ms.float16) | ||
| p._data = p.to(device="CPU", dtype=ms.float16) | ||
| return self |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The float() and half() methods are missing a check to see if the parameter's dtype is already the target dtype. This check is present in the to() method and helps avoid unnecessary tensor conversions and data copies. Please add this check for consistency and performance.
| def float(self): | |
| for p in self.get_parameters(): | |
| p.set_dtype(ms.float32) | |
| p._data = p.to(device="CPU", dtype=ms.float32) | |
| return self | |
| def half(self): | |
| for p in self.get_parameters(): | |
| p.set_dtype(ms.float16) | |
| p._data = p.to(device="CPU", dtype=ms.float16) | |
| return self | |
| def float(self): | |
| for p in self.get_parameters(): | |
| if p.dtype != ms.float32: | |
| p._data = p.to(device="CPU", dtype=ms.float32) | |
| return self | |
| def half(self): | |
| for p in self.get_parameters(): | |
| if p.dtype != ms.float16: | |
| p._data = p.to(device="CPU", dtype=ms.float16) | |
| return self |
|
|
||
| import mindspore as ms | ||
| from mindspore import nn, ops | ||
| from mindspore.ops import Cast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| from mindspore import Parameter, Tensor, mint, nn, ops | ||
| from mindspore.nn import CrossEntropyLoss, Identity | ||
| from mindspore.nn.utils import no_init_parameters | ||
| from mindspore.ops import Cast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3718397 to
3b08431
Compare
| p.set_dtype(ms.float32) | ||
| p._data = p.to(device="CPU", dtype=ms.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 这个是把参数移到了
cpu吗 后面是移回来的操作在什么位置? - 另外为啥对
p._data做操作 而不是p.data
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 这里主要目的是改变网络本身的dtype,一般在网络执行前操作,主要是因为使用了延迟初始化,这里如果不设置固定device会出现问题,二是zero3在cpu上操作,权重转换为网络的dtype那部分也统一使用了cpu,这里就相同操作了。三是网络在没有执行前,除了我们使用mint或者ops赋值的parameter,剩下的都还在CPU上。2.p._data是ms2.7.1新增的一个特性,只有使用这个方法和set_dtype才可以对parameter本身进行原地操作,但是p._data可以修改内存地址,速度非常快。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
p._data不是公共接口,建议写一个函数包一下,确保2.7.0/2.7.1或以上调用不同方法
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,我修改一下。
| ms.Parameter._data = ms.Tensor.data | ||
| ms.Parameter.data_ptr = ms.Tensor.data_ptr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ms.Tensor.data 属性表示的是什么 好像没有找到定义的位置
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里就是ms2.7.1的新增特性,通过这样定义+后面的param._data修改,可以直接对param进行原地dtype转换操作。
| v.set_dtype(local_state[k].dtype) | ||
| v._data = v.to(device="CPU", dtype=local_state[k].dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么移动到cpu操作,后面移动回来的地方是哪里?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是因为zero3切分权重的时候,目前在CPU上进行,所以这里必须在CPU上操作,原来使用的是cpu_cast,但是速度没有to本身快,ms2.7.1 tensor.to新增device,所以这里用了to(device="cpu"),这部分得到的param_dict会载入网络,之后正常前向等
mindone/utils/modeling_patch.py
Outdated
| setattr(attr, "__init__", _new_init) | ||
|
|
||
|
|
||
| def restore_nn_default_dtype(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是否叫unpatch_xx会好一些?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,我修改一下~
3b08431 to
9cadec7
Compare
9cadec7 to
f5ea810
Compare
What does this PR do?
Fixes # (issue)
_datacan directly replace the memory address.!!! Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.7.1.
Recommend to use mindspore 2.7.1 to avoid many known issues.
Reduces the loading time of THUDM/CogVideoX-5b from 408 s → 43 s (≈ 90 % faster) with this change.
Adds # (feature)
Before submitting
What's New. Here are thedocumentation guidelines
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@xxx