Skip to content

Conversation

@1801ZDL
Copy link

@1801ZDL 1801ZDL commented Nov 15, 2025

_fwd_kernel_stage2_npu:
Optimize the internal logic of the operator: support dividing the grid into the number of NPU cores; add new input parameters "BATCH_GROUP_SIZE" and "num_heads" to the kernel; BATCH_GROUP_SIZE is used to control the number of batches processed per block,num_heads is used to control the number of iterations of the for-loop inside the operator, reducing the division of blocks.

_moe_sum_reduce_kernel_npu:
Optimize the data loading and accumulation logic of this Triton operator: remove the internal for-loop-based loading and implement batch loading and batch computation for 2D matrices.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @1801ZDL, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces significant performance enhancements for Triton kernels operating on NPU devices within the attention and Mixture-of-Experts (MoE) layers. It refines the execution strategy for attention kernels by better utilizing NPU core parallelism and streamlines MoE sum reduction through vectorized operations, aiming for more efficient deep learning model execution.

Highlights

  • NPU Kernel Optimization for Attention: The _fwd_kernel_stage2_npu Triton kernel is optimized to support grid division based on NPU cores and introduces BATCH_GROUP_SIZE and num_heads parameters for improved batch processing and reduced block division.
  • Vectorized MoE Sum Reduction: The _moe_sum_reduce_kernel_npu Triton kernel is enhanced with vectorized data loading and batch computation for 2D matrices, replacing inefficient for-loop-based loading.
  • New Test Cases and Profiling: Dedicated test files (test_fwd_kernel_stage2.py and test_moe_sum_reduce_kernel.py) have been added to validate the functionality and performance of the optimized NPU kernels, including profiling capabilities.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces test files for two new NPU-optimized Triton kernels. The optimizations focus on improving performance by adjusting kernel launch configurations and data access patterns. My review identified several critical issues in the test setups, including undefined variables that would cause runtime failures, and incorrect tensor shapes and strides that would lead to incorrect computations. Furthermore, the test files contain significant amounts of commented-out code and are not configured to run the new kernels by default, which reduces their immediate usefulness. I have provided specific suggestions to address these correctness and maintainability issues.

Comment on lines +274 to +275
stride_obs = O.stride(1)
stride_oh = O.stride(2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The strides for the output tensor O are calculated incorrectly. For an output tensor O of shape (batch, num_heads, Lv), stride_obs should correspond to the batch dimension and stride_oh to the head dimension. You are using O.stride(1) and O.stride(2), but it should be O.stride(0) and O.stride(1). This will lead to incorrect memory access and wrong results.

Suggested change
stride_obs = O.stride(1)
stride_oh = O.stride(2)
stride_obs = O.stride(0)
stride_oh = O.stride(1)

input_stride_1 = input_ptr.stride(1)
input_stride_2 = input_ptr.stride(2)

output_ptr = torch.zeros(size=(token_num, token_num), dtype=torch.float32, device='cpu')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The output tensor output_ptr is initialized with an incorrect shape of (token_num, token_num). According to the logic and the assertions in the reference moe_sum_reduce_triton function, the output shape should be (token_num, hidden_dim). This will cause incorrect results and potentially memory corruption.

Suggested change
output_ptr = torch.zeros(size=(token_num, token_num), dtype=torch.float32, device='cpu')
output_ptr = torch.zeros(size=(token_num, hidden_dim), dtype=torch.float32, device='cpu')

BLOCK_DIM = hidden_dim // 2
NUM_STAGE = 1

grid = (npu_num_core, 1, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The grid for launching the kernel is configured as (npu_num_core, 1, 1). However, the kernel uses tl.program_id(1) to index into the hidden_dim. With the second dimension of the grid being 1, dim_block_id will always be 0. This means the kernel will only process the first BLOCK_DIM elements of the hidden_dim, leading to incomplete and incorrect results. The grid should be 2D to cover the entire hidden_dim.

Suggested change
grid = (npu_num_core, 1, 1)
grid = (npu_num_core, triton.cdiv(hidden_dim, BLOCK_DIM))

Comment on lines +172 to +221
# Calling scenario of this Triton operator (sglang v0.4.8 python/sglang/srt/layers/attention/triton_ops/decode_attention.py)
def _decode_softmax_reducev_fwd(
logits,
lse,
q,
o,
v_buffer,
kv_indptr,
num_kv_splits,
max_kv_splits,
):
batch, head_num = q.shape[0], q.shape[1]
Lv = v_buffer.shape[-1]
BLOCK_DV = triton.next_power_of_2(Lv)

MAX_KV_SPLITS = max_kv_splits

extra_kargs = {}
if _is_hip:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}

# grid = (batch, head_num)

# Add a new input parameter BATCH_GROUP_SIZE to control the amount of data processed per block; meanwhile, remove the division of head_num and integrate it into the operator, processing it within the same block using a for-loop.
small_batch = batch // 5
grid = (small_batch, )
BATCH_GROUP_SIZE = (batch + small_batch - 1)// small_batch
_fwd_kernel_stage2[grid](
logits,
lse,
o,
kv_indptr,
num_kv_splits,
logits.stride(0),
logits.stride(1),
logits.stride(2),
o.stride(0),
o.stride(1),
MAX_KV_SPLITS=MAX_KV_SPLITS,
MIN_BLOCK_KV=_MIN_BLOCK_KV,
BLOCK_DV=BLOCK_DV,
Lv=Lv,
num_warps=4,
num_stages=2,
BATCH_GROUP_SIZE=BATCH_GROUP_SIZE, # new input parameter
num_heads=head_num, # new input parameter
**extra_kargs,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This function _decode_softmax_reducev_fwd is not called within the file. It also contains references to undefined variables which will cause runtime errors:

  • _is_hip on line 190.
  • _MIN_BLOCK_KV on line 213.

If this function is intended for future use or as a reference, these errors should be fixed. If it's not needed, it should be removed to improve code clarity.

Comment on lines +250 to +328
def fn_triton(grid, input_data):
# _fwd_kernel_stage2_npu[grid](**input_data)

# The following code is used for comparison with the original operator.
_fwd_kernel_stage2[grid](**input_data)
return

def test_fwd_kernel_stage2():
batch = 160
num_heads = 8
MAX_KV_SPLITS = 16
MIN_BLOCK_KV = 32
BLOCK_DV = 128
Lv = 128

Mid_O = torch.randn(size=(batch, num_heads, MAX_KV_SPLITS, Lv), dtype=torch.float32, device='cpu')
Mid_O_1 = torch.randn(size=(batch, num_heads, MAX_KV_SPLITS), dtype=torch.float32, device='cpu')
O = torch.randn(size=(batch, num_heads, Lv), dtype=torch.bfloat16, device='cpu')
kv_indptr = torch.arange(0, batch+1, dtype=torch.int32, device='cpu')
num_kv_splits = torch.full(size=(batch,), fill_value=8, dtype=torch.int32, device='cpu')
sink_ptr = torch.zeros(size=(8,), dtype=torch.float32, device='cpu')
stride_mid_ob = Mid_O.stride(0)
stride_mid_oh = Mid_O.stride(1)
stride_mid_os = Mid_O.stride(2)
stride_obs = O.stride(1)
stride_oh = O.stride(2)
HAS_SINK = 0

device = torch.npu.current_device()
device_properties = driver.active.utils.get_device_properties(device)
npu_num_core = device_properties["num_vectorcore"]
BATCH_GROUP_SIZE = (batch + npu_num_core - 1) // npu_num_core

grid = (npu_num_core, 1,)
# input_data = {
# 'Mid_O': Mid_O,
# 'Mid_O_1': Mid_O_1,
# 'O': O,
# 'kv_indptr': kv_indptr,
# 'num_kv_splits': num_kv_splits,
# 'stride_mid_ob': stride_mid_ob,
# 'stride_mid_oh': stride_mid_oh,
# 'stride_mid_os': stride_mid_os,
# 'stride_obs': stride_obs,
# 'stride_oh': stride_oh,
# 'MAX_KV_SPLITS': MAX_KV_SPLITS,
# 'MIN_BLOCK_KV': MIN_BLOCK_KV,
# 'BLOCK_DV': BLOCK_DV,
# 'Lv': Lv,
# 'sink_ptr': sink_ptr,
# 'HAS_SINK': HAS_SINK,
# 'BATCH_GROUP_SIZE': BATCH_GROUP_SIZE,
# 'num_heads': num_heads
# }

# The following code is used to compare the performance of the original operator. Please modify the operator call in fn_triton accordingly.
grid = (batch, num_heads,)
input_data = {
'Mid_O': Mid_O,
'Mid_O_1': Mid_O_1,
'O': O,
'kv_indptr': kv_indptr,
'num_kv_splits': num_kv_splits,
'stride_mid_ob': stride_mid_ob,
'stride_mid_oh': stride_mid_oh,
'stride_mid_os': stride_mid_os,
'stride_obs': stride_obs,
'stride_oh': stride_oh,
'MAX_KV_SPLITS': MAX_KV_SPLITS,
'MIN_BLOCK_KV': MIN_BLOCK_KV,
'BLOCK_DV': BLOCK_DV,
'Lv': Lv,
'sink_ptr': sink_ptr,
'HAS_SINK': HAS_SINK,
}

input_data = convert_tensor_with_device_type(input_data, device_type='npu')

profiling_test(fn_triton, (grid, input_data))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The test setup in test_fwd_kernel_stage2 and fn_triton is confusing and doesn't test the new _fwd_kernel_stage2_npu kernel by default.

  • fn_triton calls the original _fwd_kernel_stage2 kernel, while the call to the new NPU kernel is commented out.
  • test_fwd_kernel_stage2 prepares input_data for the original kernel, while the data preparation for the new kernel is commented out.

This makes it difficult to run the test for the new kernel. Please refactor the test to make it easy to test and profile the new kernel. For example, you could have separate test functions for the original and the new kernel, or use a parameter to switch between them. Also, a correctness check comparing the outputs of both kernels would be very valuable.

Comment on lines 6 to 11
import sys
import pytest
import triton
import torch
import triton.language as tl
import os
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The modules sys (line 6), pytest (line 7), and os (line 11) are imported but never used in this file. Please remove these unused imports to improve code clarity.

Comment on lines 6 to 7
import sys
import pytest
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The modules sys and pytest are imported but never used in this file. Please remove them to keep the code clean.

Comment on lines 29 to 44
def _moe_sum_reduce_kernel_vec(
input_ptr,
input_stride_0,
input_stride_1,
input_stride_2,
output_ptr,
output_stride_0,
output_stride_1,
token_num: int,
topk_num: tl.constexpr,
hidden_dim: int,
routed_scaling_factor: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DIM: tl.constexpr,
NUM_STAGE: tl.constexpr,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The kernel _moe_sum_reduce_kernel_vec accepts several parameters that are unused: input_stride_2 (line 33), output_stride_1 (line 36), and NUM_STAGE (line 43). If these are not necessary for the computation, they should be removed from the function signature and the call site to simplify the code. If they are required for API compatibility, please add a comment explaining why they are present.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant