-
Notifications
You must be signed in to change notification settings - Fork 563
[MOE]move weight transpose to wakeup for RL secnarios #4147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: lhp-deep <[email protected]>
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to move the weight transposition logic for MoE layers to the wake_up function, enhancing support for Reinforcement Learning (RL) scenarios. The changes span an example file, the MoE operator implementation, and the worker logic. While the overall direction seems correct, I've identified a critical bug in the new code added to vllm_ascend/worker/worker_v1.py. The logic for retrieving a parameter's parent module is flawed and will lead to a runtime error. This needs to be addressed.
vllm_ascend/worker/worker_v1.py
Outdated
| parent_module = model | ||
| parts = name.split('.') | ||
| param_name = parts[-1] | ||
| module_path = parts[-1] | ||
|
|
||
| for part in module_path: | ||
| parent_module = getattr(parent_module, part) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic to retrieve the parent module of the w2_weight parameter is incorrect. module_path is assigned the parameter name string (parts[-1]), and the subsequent loop iterates over the characters of this string. This will cause getattr to fail at runtime.
You should iterate over the module path components to correctly traverse the model hierarchy. A cleaner way to achieve this is by using model.get_submodule().
| parent_module = model | |
| parts = name.split('.') | |
| param_name = parts[-1] | |
| module_path = parts[-1] | |
| for part in module_path: | |
| parent_module = getattr(parent_module, part) | |
| parts = name.split('.') | |
| param_name = parts[-1] | |
| parent_module = model.get_submodule(".".join(parts[:-1])) |
vllm_ascend/worker/worker_v1.py
Outdated
| parent_module = model | ||
| parts = name.split('.') | ||
| param_name = parts[-1] | ||
| module_path = parts[-1] | ||
|
|
||
| for part in module_path: | ||
| parent_module = getattr(parent_module, part) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the w2_weight block, the logic to retrieve the parent module for w13_weight is incorrect. module_path is assigned the parameter name string, and the loop iterates over its characters, which will cause a runtime error. You should use model.get_submodule() to correctly get the parent module.
| parent_module = model | |
| parts = name.split('.') | |
| param_name = parts[-1] | |
| module_path = parts[-1] | |
| for part in module_path: | |
| parent_module = getattr(parent_module, part) | |
| parts = name.split('.') | |
| param_name = parts[-1] | |
| parent_module = model.get_submodule(".".join(parts[:-1])) |
Signed-off-by: lhp-deep <[email protected]>
Signed-off-by: lhp-deep <[email protected]>
Signed-off-by: lhp-deep <[email protected]>
Signed-off-by: lhp-deep <[email protected]>
Signed-off-by: lhp-deep <[email protected]>
Signed-off-by: lhp-deep <[email protected]>
What this PR does / why we need it?
In reinforcement learning scenarios, the current inference applies a transpose operation to the weights. For a cleaner architecture, the weight transpose module was moved to wakeup.
Does this PR introduce any user-facing change?
How was this patch tested?