Skip to content

Conversation

@Cui-yshoho
Copy link
Contributor

What does this PR do?

Fixes # (issue)

  • Monkey patch speeds up setting the default dtype of the network layers.
  • no_init_parameters is used to skip parameter initialization.
  • _data can directly replace the memory address.
    !!! Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.7.1.
    Recommend to use mindspore 2.7.1 to avoid many known issues.

Reduces the loading time of THUDM/CogVideoX-5b from 408 s → 43 s (≈ 90 % faster) with this change.

Adds # (feature)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline?
  • Did you make sure to update the documentation with your changes? E.g. record bug fixes or new features in What's New. Here are the
    documentation guidelines
  • Did you build and run the code without any errors?
  • Did you report the running environment (NPU type/MS version) and performance in the doc? (better record it for data loading, model inference, or training tasks)
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@xxx

@Cui-yshoho Cui-yshoho requested a review from vigo999 as a code owner November 7, 2025 07:23
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Cui-yshoho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a series of optimizations aimed at drastically reducing the time it takes to load large models within the MindSpore framework. By refining how model checkpoints are handled, implementing direct memory manipulation for parameter data types, and utilizing a dynamic patching mechanism for default data types during initialization, the PR achieves substantial performance gains, making the model loading process much more efficient.

Highlights

  • Accelerated Model Loading: The primary goal of this PR is to significantly accelerate model loading times, particularly for large models like THUDM/CogVideoX-5b, achieving approximately a 90% speedup (from 408s to 43s).
  • Optimized Checkpoint Strategy: The checkpoint loading strategy has been optimized by leveraging MindSpore's native ms.load_checkpoint for safetensors and directly manipulating parameter data (_data) for dtype conversion, bypassing slower set_dtype methods.
  • Dynamic Default Dtype Patching: A new utility (modeling_patch.py) has been introduced to dynamically patch nn.Cell constructors, allowing for the temporary enforcement of a default dtype during model initialization. This is used in conjunction with no_init_parameters to prevent unnecessary parameter initialization.
  • MindSpore 2.7.1 Compatibility: The changes are specifically tested and recommended for MindSpore 2.7.1, addressing known issues and optimizing performance within this version.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant optimizations to accelerate model loading in MindSpore, achieving impressive performance gains. The core strategies involve monkey-patching MindSpore internals and leveraging newer framework features like no_init_parameters and optimized checkpoint loading. While the performance improvements are valuable, the current implementation introduces some critical risks. Specifically, the monkey-patching is duplicated and applied globally, which can be a maintenance and stability concern. Furthermore, the temporary patching of nn.Cell.__init__ is not exception-safe, potentially leaving the application in an inconsistent state if an error occurs. My review includes critical feedback to address these safety and maintainability issues, along with other suggestions for code simplification and consistency.

ms.Parameter._data = ms.Tensor.data
ms.Parameter.data_ptr = ms.Tensor.data_ptr
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

These monkey patches for ms.Parameter are duplicated across multiple files (mindone/diffusers/models/model_loading_utils.py, mindone/diffusers/models/modeling_utils.py, and mindone/transformers/modeling_utils.py). This is risky and makes maintenance difficult. Please centralize this logic in a single location, for example, within the new mindone/utils/modeling_patch.py file. It would be even safer to wrap this patching logic in a context manager to limit its scope, similar to how patch_nn_default_dtype is used.

Comment on lines +66 to +67
ms.Parameter._data = ms.Tensor.data
ms.Parameter.data_ptr = ms.Tensor.data_ptr
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

These monkey patches for ms.Parameter are duplicated across multiple files. This is a maintenance risk and can lead to inconsistent behavior. Please centralize this logic in a single location, such as mindone/utils/modeling_patch.py, and consider using a context manager to control its scope for better safety.

Comment on lines 859 to 864
if mindspore_dtype is not None:
patch_nn_default_dtype(dtype=mindspore_dtype, force=True)
with no_init_parameters():
model = cls.from_config(config, **unused_kwargs)
if mindspore_dtype is not None:
restore_nn_default_dtype()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The patching of nn.Cell.__init__ methods is not exception-safe. If cls.from_config raises an exception, restore_nn_default_dtype() will not be called, leaving many mindspore.nn classes in a patched state. This can cause unpredictable behavior in other parts of the application. Please use a try...finally block to ensure the restoration logic is always executed.

Suggested change
if mindspore_dtype is not None:
patch_nn_default_dtype(dtype=mindspore_dtype, force=True)
with no_init_parameters():
model = cls.from_config(config, **unused_kwargs)
if mindspore_dtype is not None:
restore_nn_default_dtype()
if mindspore_dtype is not None:
patch_nn_default_dtype(dtype=mindspore_dtype, force=True)
try:
with no_init_parameters():
model = cls.from_config(config, **unused_kwargs)
finally:
if mindspore_dtype is not None:
restore_nn_default_dtype()

ms.Parameter._data = ms.Tensor.data
ms.Parameter.data_ptr = ms.Tensor.data_ptr
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

These monkey patches for ms.Parameter are duplicated across multiple files. This is a maintenance risk and can lead to inconsistent behavior. Please centralize this logic in a single location, such as mindone/utils/modeling_patch.py, and consider using a context manager to control its scope for better safety.

Comment on lines 1167 to 1171
if mindspore_dtype is not None:
patch_nn_default_dtype(dtype=mindspore_dtype, force=True)
with no_init_parameters():
model = cls(config, **kwargs)
if mindspore_dtype is not None:
restore_nn_default_dtype()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The patching of nn.Cell.__init__ methods is not exception-safe. If cls(config, **kwargs) raises an exception inside the with no_init_parameters() block, restore_nn_default_dtype() will not be called. This leaves mindspore.nn classes in a patched state, which can cause unpredictable behavior. Please use a try...finally block to ensure the restoration logic is always executed.

Suggested change
if mindspore_dtype is not None:
patch_nn_default_dtype(dtype=mindspore_dtype, force=True)
with no_init_parameters():
model = cls(config, **kwargs)
if mindspore_dtype is not None:
restore_nn_default_dtype()
if mindspore_dtype is not None:
patch_nn_default_dtype(dtype=mindspore_dtype, force=True)
try:
with no_init_parameters():
model = cls(config, **kwargs)
finally:
if mindspore_dtype is not None:
restore_nn_default_dtype()

Comment on lines 2768 to 2772
if mindspore_dtype is not None:
patch_nn_default_dtype(dtype=mindspore_dtype, force=True)
with no_init_parameters():
model = cls(config, *model_args, **model_kwargs)
if mindspore_dtype is not None:
restore_nn_default_dtype()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The patching of nn.Cell.__init__ methods is not exception-safe. If cls(config, *model_args, **model_kwargs) raises an exception, restore_nn_default_dtype() will not be called, leaving many mindspore.nn classes in a patched state. This can cause unpredictable behavior in other parts of the application. Please use a try...finally block to ensure the restoration logic is always executed.

Suggested change
if mindspore_dtype is not None:
patch_nn_default_dtype(dtype=mindspore_dtype, force=True)
with no_init_parameters():
model = cls(config, *model_args, **model_kwargs)
if mindspore_dtype is not None:
restore_nn_default_dtype()
if mindspore_dtype is not None:
patch_nn_default_dtype(dtype=mindspore_dtype, force=True)
try:
with no_init_parameters():
model = cls(config, *model_args, **model_kwargs)
finally:
if mindspore_dtype is not None:
restore_nn_default_dtype()

Comment on lines 924 to 932
def half(self):
for p in self.get_parameters():
p.set_dtype(ms.float16)
p._data = p.to(device="CPU", dtype=ms.float16)
return self

def float(self):
for p in self.get_parameters():
p.set_dtype(ms.float32)
p._data = p.to(device="CPU", dtype=ms.float32)
return self
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The half() and float() methods are missing a check to see if the parameter's dtype is already the target dtype. This check is present in the to() method and helps avoid unnecessary tensor conversions and data copies. Please add this check for consistency and performance.

Suggested change
def half(self):
for p in self.get_parameters():
p.set_dtype(ms.float16)
p._data = p.to(device="CPU", dtype=ms.float16)
return self
def float(self):
for p in self.get_parameters():
p.set_dtype(ms.float32)
p._data = p.to(device="CPU", dtype=ms.float32)
return self
def half(self):
for p in self.get_parameters():
if p.dtype != ms.float16:
p._data = p.to(device="CPU", dtype=ms.float16)
return self
def float(self):
for p in self.get_parameters():
if p.dtype != ms.float32:
p._data = p.to(device="CPU", dtype=ms.float32)
return self

Comment on lines 522 to 529
def float(self):
for p in self.get_parameters():
p.set_dtype(ms.float32)
p._data = p.to(device="CPU", dtype=ms.float32)
return self

def half(self):
for p in self.get_parameters():
p.set_dtype(ms.float16)
p._data = p.to(device="CPU", dtype=ms.float16)
return self
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The float() and half() methods are missing a check to see if the parameter's dtype is already the target dtype. This check is present in the to() method and helps avoid unnecessary tensor conversions and data copies. Please add this check for consistency and performance.

Suggested change
def float(self):
for p in self.get_parameters():
p.set_dtype(ms.float32)
p._data = p.to(device="CPU", dtype=ms.float32)
return self
def half(self):
for p in self.get_parameters():
p.set_dtype(ms.float16)
p._data = p.to(device="CPU", dtype=ms.float16)
return self
def float(self):
for p in self.get_parameters():
if p.dtype != ms.float32:
p._data = p.to(device="CPU", dtype=ms.float32)
return self
def half(self):
for p in self.get_parameters():
if p.dtype != ms.float16:
p._data = p.to(device="CPU", dtype=ms.float16)
return self


import mindspore as ms
from mindspore import nn, ops
from mindspore.ops import Cast
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The Cast operator is imported but not used in this file. Please remove the unused import to keep the code clean.

from mindspore import Parameter, Tensor, mint, nn, ops
from mindspore.nn import CrossEntropyLoss, Identity
from mindspore.nn.utils import no_init_parameters
from mindspore.ops import Cast
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The Cast operator is imported but not used in this file. Please remove the unused import.

@Cui-yshoho Cui-yshoho force-pushed the load_checkpoint_faster_v2 branch from 3718397 to 3b08431 Compare November 7, 2025 07:26
Comment on lines -922 to +931
p.set_dtype(ms.float32)
p._data = p.to(device="CPU", dtype=ms.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 这个是把参数移到了cpu吗 后面是移回来的操作在什么位置?
  2. 另外为啥对 p._data 做操作 而不是 p.data

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 这里主要目的是改变网络本身的dtype,一般在网络执行前操作,主要是因为使用了延迟初始化,这里如果不设置固定device会出现问题,二是zero3在cpu上操作,权重转换为网络的dtype那部分也统一使用了cpu,这里就相同操作了。三是网络在没有执行前,除了我们使用mint或者ops赋值的parameter,剩下的都还在CPU上。2.p._data是ms2.7.1新增的一个特性,只有使用这个方法和set_dtype才可以对parameter本身进行原地操作,但是p._data可以修改内存地址,速度非常快。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

p._data不是公共接口,建议写一个函数包一下,确保2.7.0/2.7.1或以上调用不同方法

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,我修改一下。

Comment on lines +116 to +117
ms.Parameter._data = ms.Tensor.data
ms.Parameter.data_ptr = ms.Tensor.data_ptr
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ms.Tensor.data 属性表示的是什么 好像没有找到定义的位置

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里就是ms2.7.1的新增特性,通过这样定义+后面的param._data修改,可以直接对param进行原地dtype转换操作。

Comment on lines -376 to +381
v.set_dtype(local_state[k].dtype)
v._data = v.to(device="CPU", dtype=local_state[k].dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么移动到cpu操作,后面移动回来的地方是哪里?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是因为zero3切分权重的时候,目前在CPU上进行,所以这里必须在CPU上操作,原来使用的是cpu_cast,但是速度没有to本身快,ms2.7.1 tensor.to新增device,所以这里用了to(device="cpu"),这部分得到的param_dict会载入网络,之后正常前向等

setattr(attr, "__init__", _new_init)


def restore_nn_default_dtype():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是否叫unpatch_xx会好一些?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,我修改一下~

@Cui-yshoho Cui-yshoho force-pushed the load_checkpoint_faster_v2 branch from 3b08431 to 9cadec7 Compare November 7, 2025 08:57
@Cui-yshoho Cui-yshoho force-pushed the load_checkpoint_faster_v2 branch from 9cadec7 to f5ea810 Compare November 7, 2025 09:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants