-
Notifications
You must be signed in to change notification settings - Fork 516
[Main][Perf] Add fused matmul/reduce-scatter kernel for performance optimization. #3669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a fused matmul/reduce-scatter kernel to optimize performance, which is a great initiative. The implementation correctly identifies the specific CANN version and quantization methods to apply the optimization. However, I've found a critical correctness issue where the bias term is ignored in the unquantized path, which could lead to incorrect model outputs. Additionally, there are several opportunities to improve code quality by removing dead code, placing imports correctly, and refactoring duplicated logic. Addressing these points will make the new optimized code paths more robust and maintainable.
vllm_ascend/ops/linear_op.py
Outdated
| output_parallel = torch.empty(x.shape[0] // self.layer.tp_size, | ||
| self.layer.weight.shape[0], | ||
| dtype=self.layer.params_dtype, | ||
| device=x.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
fe6ecca to
f4cdbbc
Compare
Co-authored-by: ZhaoJiangJiang <[email protected]> Co-authored-by: rjg-lyh <[email protected]> Co-authored-by: raintBN-91 <[email protected]> Signed-off-by: ZYang6263 <[email protected]> Changes to be committed: modified: vllm_ascend/ops/linear_op.py modified: vllm_ascend/ops/register_custom_ops.py Signed-off-by: ZYang6263 <[email protected]>
What this PR does / why we need it?
This PR boosts performance by introducing a fused kernel for the matrix matmul and reduce scatter operations. It supports both unquantized (e.g., BFloat16) and W8A8 quantized models.
Does this PR introduce any user-facing change?
How was this patch tested?