-
Notifications
You must be signed in to change notification settings - Fork 53
add test_moe_sum_reduce_kernel_npu and test_fwd_kernel_stage2_npu #186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @1801ZDL, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces significant performance enhancements for Triton kernels operating on NPU devices within the attention and Mixture-of-Experts (MoE) layers. It refines the execution strategy for attention kernels by better utilizing NPU core parallelism and streamlines MoE sum reduction through vectorized operations, aiming for more efficient deep learning model execution. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces test files for two new NPU-optimized Triton kernels. The optimizations focus on improving performance by adjusting kernel launch configurations and data access patterns. My review identified several critical issues in the test setups, including undefined variables that would cause runtime failures, and incorrect tensor shapes and strides that would lead to incorrect computations. Furthermore, the test files contain significant amounts of commented-out code and are not configured to run the new kernels by default, which reduces their immediate usefulness. I have provided specific suggestions to address these correctness and maintainability issues.
| stride_obs = O.stride(1) | ||
| stride_oh = O.stride(2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The strides for the output tensor O are calculated incorrectly. For an output tensor O of shape (batch, num_heads, Lv), stride_obs should correspond to the batch dimension and stride_oh to the head dimension. You are using O.stride(1) and O.stride(2), but it should be O.stride(0) and O.stride(1). This will lead to incorrect memory access and wrong results.
| stride_obs = O.stride(1) | |
| stride_oh = O.stride(2) | |
| stride_obs = O.stride(0) | |
| stride_oh = O.stride(1) |
| input_stride_1 = input_ptr.stride(1) | ||
| input_stride_2 = input_ptr.stride(2) | ||
|
|
||
| output_ptr = torch.zeros(size=(token_num, token_num), dtype=torch.float32, device='cpu') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The output tensor output_ptr is initialized with an incorrect shape of (token_num, token_num). According to the logic and the assertions in the reference moe_sum_reduce_triton function, the output shape should be (token_num, hidden_dim). This will cause incorrect results and potentially memory corruption.
| output_ptr = torch.zeros(size=(token_num, token_num), dtype=torch.float32, device='cpu') | |
| output_ptr = torch.zeros(size=(token_num, hidden_dim), dtype=torch.float32, device='cpu') |
| BLOCK_DIM = hidden_dim // 2 | ||
| NUM_STAGE = 1 | ||
|
|
||
| grid = (npu_num_core, 1, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The grid for launching the kernel is configured as (npu_num_core, 1, 1). However, the kernel uses tl.program_id(1) to index into the hidden_dim. With the second dimension of the grid being 1, dim_block_id will always be 0. This means the kernel will only process the first BLOCK_DIM elements of the hidden_dim, leading to incomplete and incorrect results. The grid should be 2D to cover the entire hidden_dim.
| grid = (npu_num_core, 1, 1) | |
| grid = (npu_num_core, triton.cdiv(hidden_dim, BLOCK_DIM)) |
| # Calling scenario of this Triton operator (sglang v0.4.8 python/sglang/srt/layers/attention/triton_ops/decode_attention.py) | ||
| def _decode_softmax_reducev_fwd( | ||
| logits, | ||
| lse, | ||
| q, | ||
| o, | ||
| v_buffer, | ||
| kv_indptr, | ||
| num_kv_splits, | ||
| max_kv_splits, | ||
| ): | ||
| batch, head_num = q.shape[0], q.shape[1] | ||
| Lv = v_buffer.shape[-1] | ||
| BLOCK_DV = triton.next_power_of_2(Lv) | ||
|
|
||
| MAX_KV_SPLITS = max_kv_splits | ||
|
|
||
| extra_kargs = {} | ||
| if _is_hip: | ||
| # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html | ||
| # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py | ||
| extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} | ||
|
|
||
| # grid = (batch, head_num) | ||
|
|
||
| # Add a new input parameter BATCH_GROUP_SIZE to control the amount of data processed per block; meanwhile, remove the division of head_num and integrate it into the operator, processing it within the same block using a for-loop. | ||
| small_batch = batch // 5 | ||
| grid = (small_batch, ) | ||
| BATCH_GROUP_SIZE = (batch + small_batch - 1)// small_batch | ||
| _fwd_kernel_stage2[grid]( | ||
| logits, | ||
| lse, | ||
| o, | ||
| kv_indptr, | ||
| num_kv_splits, | ||
| logits.stride(0), | ||
| logits.stride(1), | ||
| logits.stride(2), | ||
| o.stride(0), | ||
| o.stride(1), | ||
| MAX_KV_SPLITS=MAX_KV_SPLITS, | ||
| MIN_BLOCK_KV=_MIN_BLOCK_KV, | ||
| BLOCK_DV=BLOCK_DV, | ||
| Lv=Lv, | ||
| num_warps=4, | ||
| num_stages=2, | ||
| BATCH_GROUP_SIZE=BATCH_GROUP_SIZE, # new input parameter | ||
| num_heads=head_num, # new input parameter | ||
| **extra_kargs, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function _decode_softmax_reducev_fwd is not called within the file. It also contains references to undefined variables which will cause runtime errors:
_is_hipon line 190._MIN_BLOCK_KVon line 213.
If this function is intended for future use or as a reference, these errors should be fixed. If it's not needed, it should be removed to improve code clarity.
| def fn_triton(grid, input_data): | ||
| # _fwd_kernel_stage2_npu[grid](**input_data) | ||
|
|
||
| # The following code is used for comparison with the original operator. | ||
| _fwd_kernel_stage2[grid](**input_data) | ||
| return | ||
|
|
||
| def test_fwd_kernel_stage2(): | ||
| batch = 160 | ||
| num_heads = 8 | ||
| MAX_KV_SPLITS = 16 | ||
| MIN_BLOCK_KV = 32 | ||
| BLOCK_DV = 128 | ||
| Lv = 128 | ||
|
|
||
| Mid_O = torch.randn(size=(batch, num_heads, MAX_KV_SPLITS, Lv), dtype=torch.float32, device='cpu') | ||
| Mid_O_1 = torch.randn(size=(batch, num_heads, MAX_KV_SPLITS), dtype=torch.float32, device='cpu') | ||
| O = torch.randn(size=(batch, num_heads, Lv), dtype=torch.bfloat16, device='cpu') | ||
| kv_indptr = torch.arange(0, batch+1, dtype=torch.int32, device='cpu') | ||
| num_kv_splits = torch.full(size=(batch,), fill_value=8, dtype=torch.int32, device='cpu') | ||
| sink_ptr = torch.zeros(size=(8,), dtype=torch.float32, device='cpu') | ||
| stride_mid_ob = Mid_O.stride(0) | ||
| stride_mid_oh = Mid_O.stride(1) | ||
| stride_mid_os = Mid_O.stride(2) | ||
| stride_obs = O.stride(1) | ||
| stride_oh = O.stride(2) | ||
| HAS_SINK = 0 | ||
|
|
||
| device = torch.npu.current_device() | ||
| device_properties = driver.active.utils.get_device_properties(device) | ||
| npu_num_core = device_properties["num_vectorcore"] | ||
| BATCH_GROUP_SIZE = (batch + npu_num_core - 1) // npu_num_core | ||
|
|
||
| grid = (npu_num_core, 1,) | ||
| # input_data = { | ||
| # 'Mid_O': Mid_O, | ||
| # 'Mid_O_1': Mid_O_1, | ||
| # 'O': O, | ||
| # 'kv_indptr': kv_indptr, | ||
| # 'num_kv_splits': num_kv_splits, | ||
| # 'stride_mid_ob': stride_mid_ob, | ||
| # 'stride_mid_oh': stride_mid_oh, | ||
| # 'stride_mid_os': stride_mid_os, | ||
| # 'stride_obs': stride_obs, | ||
| # 'stride_oh': stride_oh, | ||
| # 'MAX_KV_SPLITS': MAX_KV_SPLITS, | ||
| # 'MIN_BLOCK_KV': MIN_BLOCK_KV, | ||
| # 'BLOCK_DV': BLOCK_DV, | ||
| # 'Lv': Lv, | ||
| # 'sink_ptr': sink_ptr, | ||
| # 'HAS_SINK': HAS_SINK, | ||
| # 'BATCH_GROUP_SIZE': BATCH_GROUP_SIZE, | ||
| # 'num_heads': num_heads | ||
| # } | ||
|
|
||
| # The following code is used to compare the performance of the original operator. Please modify the operator call in fn_triton accordingly. | ||
| grid = (batch, num_heads,) | ||
| input_data = { | ||
| 'Mid_O': Mid_O, | ||
| 'Mid_O_1': Mid_O_1, | ||
| 'O': O, | ||
| 'kv_indptr': kv_indptr, | ||
| 'num_kv_splits': num_kv_splits, | ||
| 'stride_mid_ob': stride_mid_ob, | ||
| 'stride_mid_oh': stride_mid_oh, | ||
| 'stride_mid_os': stride_mid_os, | ||
| 'stride_obs': stride_obs, | ||
| 'stride_oh': stride_oh, | ||
| 'MAX_KV_SPLITS': MAX_KV_SPLITS, | ||
| 'MIN_BLOCK_KV': MIN_BLOCK_KV, | ||
| 'BLOCK_DV': BLOCK_DV, | ||
| 'Lv': Lv, | ||
| 'sink_ptr': sink_ptr, | ||
| 'HAS_SINK': HAS_SINK, | ||
| } | ||
|
|
||
| input_data = convert_tensor_with_device_type(input_data, device_type='npu') | ||
|
|
||
| profiling_test(fn_triton, (grid, input_data)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test setup in test_fwd_kernel_stage2 and fn_triton is confusing and doesn't test the new _fwd_kernel_stage2_npu kernel by default.
fn_tritoncalls the original_fwd_kernel_stage2kernel, while the call to the new NPU kernel is commented out.test_fwd_kernel_stage2preparesinput_datafor the original kernel, while the data preparation for the new kernel is commented out.
This makes it difficult to run the test for the new kernel. Please refactor the test to make it easy to test and profile the new kernel. For example, you could have separate test functions for the original and the new kernel, or use a parameter to switch between them. Also, a correctness check comparing the outputs of both kernels would be very valuable.
| import sys | ||
| import pytest | ||
| import triton | ||
| import torch | ||
| import triton.language as tl | ||
| import os |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| import sys | ||
| import pytest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def _moe_sum_reduce_kernel_vec( | ||
| input_ptr, | ||
| input_stride_0, | ||
| input_stride_1, | ||
| input_stride_2, | ||
| output_ptr, | ||
| output_stride_0, | ||
| output_stride_1, | ||
| token_num: int, | ||
| topk_num: tl.constexpr, | ||
| hidden_dim: int, | ||
| routed_scaling_factor: tl.constexpr, | ||
| BLOCK_M: tl.constexpr, | ||
| BLOCK_DIM: tl.constexpr, | ||
| NUM_STAGE: tl.constexpr, | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The kernel _moe_sum_reduce_kernel_vec accepts several parameters that are unused: input_stride_2 (line 33), output_stride_1 (line 36), and NUM_STAGE (line 43). If these are not necessary for the computation, they should be removed from the function signature and the call site to simplify the code. If they are required for API compatibility, please add a comment explaining why they are present.
_fwd_kernel_stage2_npu:
Optimize the internal logic of the operator: support dividing the grid into the number of NPU cores; add new input parameters "BATCH_GROUP_SIZE" and "num_heads" to the kernel; BATCH_GROUP_SIZE is used to control the number of batches processed per block,num_heads is used to control the number of iterations of the for-loop inside the operator, reducing the division of blocks.
_moe_sum_reduce_kernel_npu:
Optimize the data loading and accumulation logic of this Triton operator: remove the internal for-loop-based loading and implement batch loading and batch computation for 2D matrices.