-
Notifications
You must be signed in to change notification settings - Fork 376
WIP: intial checkin for cutile kernel support #3963
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
lanluo-nvidia
wants to merge
3
commits into
main
Choose a base branch
from
lluo/cutile_13.1
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
+941
−19
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/automatic_plugin/cutile/attention.py 2025-12-12 22:37:00.909120+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/automatic_plugin/cutile/attention.py 2025-12-12 22:37:42.243426+00:00
@@ -16,20 +16,25 @@
ConstBool = ct.Constant[bool]
# --- FMHA Kernel Implementation ---
@ct.kernel(occupancy=2)
-def fmha_kernel(Q, K, V, Out,
- qk_scale: float,
- input_pos: int,
- TILE_D: ConstInt, # TILE_D = hidden_size
- H: ConstInt,
- TILE_M: ConstInt,
- TILE_N: ConstInt,
- QUERY_GROUP_SIZE: ConstInt,
- CAUSAL: ConstBool,
- EVEN_K: ConstBool):
+def fmha_kernel(
+ Q,
+ K,
+ V,
+ Out,
+ qk_scale: float,
+ input_pos: int,
+ TILE_D: ConstInt, # TILE_D = hidden_size
+ H: ConstInt,
+ TILE_M: ConstInt,
+ TILE_N: ConstInt,
+ QUERY_GROUP_SIZE: ConstInt,
+ CAUSAL: ConstBool,
+ EVEN_K: ConstBool,
+):
"""
cuTile kernel for Fused Multi-Head Attention (FMHA).
Computes attention output for a specific batch item and head, using tiling and online softmax.
"""
# Map block IDs to batch and head indices
@@ -57,11 +62,13 @@
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=np.float32)
# Load query tile for this batch, head, and M-chunk
q = ct.load(
Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D)
- ).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D]
+ ).reshape(
+ (TILE_M, TILE_D)
+ ) # [TILE_M, TILE_D]
# loop over k, v and update accumulator
m_end = input_pos + (bid_x + 1) * TILE_M
k_seqlen = K.shape[2]
if CAUSAL:
@@ -76,16 +83,18 @@
# Loop over K, V blocks (N-dimension chunks)
for j in range(0, Tc):
# --- Compute QK product ---
k = ct.load(
- K, index=(batch_idx, off_kv_h, 0, j), shape=(1, 1, TILE_D, TILE_N),
+ K,
+ index=(batch_idx, off_kv_h, 0, j),
+ shape=(1, 1, TILE_D, TILE_N),
order=(0, 1, 3, 2),
latency=2,
)
k = k.reshape((TILE_D, TILE_N)) # [TILE_D, TILE_N]
- qk = ct.full((TILE_M, TILE_N), 0., dtype=np.float32)
+ qk = ct.full((TILE_M, TILE_N), 0.0, dtype=np.float32)
qk = ct.mma(q, k, qk) # [TILE_M, TILE_N]
# --- Apply Causal Masking ---
if (CAUSAL or not EVEN_K) and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
@@ -113,16 +122,20 @@
# scale acc
acc = acc * alpha # [TILE_M, TILE_N]
# --- Compute PV product ---
v = ct.load(
- V, index=(batch_idx, off_kv_h, j, 0), shape=(1, 1, TILE_N, TILE_D),
+ V,
+ index=(batch_idx, off_kv_h, j, 0),
+ shape=(1, 1, TILE_N, TILE_D),
latency=4,
- ).reshape((TILE_N, TILE_D)) # [TILE_N, TILE_D]
+ ).reshape(
+ (TILE_N, TILE_D)
+ ) # [TILE_N, TILE_D]
p = p.astype(Q.dtype)
acc = ct.mma(p, v, acc) # [TILE_M, TILE_N]
m_i = m_ij # [TILE_M, 1]
# --- Final Normalization and Store ---
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
- ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
\ No newline at end of file
+ ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/automatic_plugin/cutile/matmul.py 2025-12-12 22:37:00.909120+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/automatic_plugin/cutile/matmul.py 2025-12-12 22:37:42.315077+00:00
@@ -23,14 +23,18 @@
bid_n = (bid % num_bid_in_group) // group_size_m
return bid_m, bid_n
@ct.kernel(num_ctas=ct.ByTarget(sm_100=2))
-def matmul_kernel(A, B, C,
- tm: ConstInt, # Tile size along M dimension (rows of C)
- tn: ConstInt, # Tile size along N dimension (columns of C)
- tk: ConstInt): # Tile size along K dimension (inner product dimension)
+def matmul_kernel(
+ A,
+ B,
+ C,
+ tm: ConstInt, # Tile size along M dimension (rows of C)
+ tn: ConstInt, # Tile size along N dimension (columns of C)
+ tk: ConstInt,
+): # Tile size along K dimension (inner product dimension)
"""
cuTile kernel for performing matrix multiplication C = A @ B.
This kernel uses a tiled approach, where each CUDA thread block (CTA)
computes a `tm` x `tn` tile of the output matrix C. The computation
@@ -72,16 +76,20 @@
# are loaded, multiplied, and accumulated.
for k in range(num_tiles_k):
# Load tile from matrix A.
# The `index=(bidx, k_tile_idx)` specifies which (M-tile, K-tile) to load
# from global memory A. `shape=(tm, tk)` defines the size of this tile.
- a = ct.load(A, index=(bidx, k), shape=(tm, tk), padding_mode=zero_pad).astype(dtype)
+ a = ct.load(A, index=(bidx, k), shape=(tm, tk), padding_mode=zero_pad).astype(
+ dtype
+ )
# Load tile from matrix B.
# The `index=(k_tile_idx, bidy)` specifies which (K-tile, N-tile) to load
# from global memory B. `shape=(tk, tn)` defines the size of this tile.
- b = ct.load(B, index=(k, bidy), shape=(tk, tn), padding_mode=zero_pad).astype(dtype)
+ b = ct.load(B, index=(k, bidy), shape=(tk, tn), padding_mode=zero_pad).astype(
+ dtype
+ )
# Perform Matrix Multiplication for the current tiles.
# `ct.mma` computes the product of the two loaded tiles and accumulates the result.
accumulator = ct.mma(a, b, accumulator)
@@ -93,13 +101,13 @@
# The `(bidx, bidy)` directly corresponds to the tile's position in the 2D output matrix.
ct.store(C, index=(bidx, bidy), tile=accumulator)
@ct.kernel
-def matmul_split_k_kernel(A, B, C, LOCKS, COUNTS,
- tm: ConstInt, tn: ConstInt, tk: ConstInt,
- SPLIT_K: ConstInt):
+def matmul_split_k_kernel(
+ A, B, C, LOCKS, COUNTS, tm: ConstInt, tn: ConstInt, tk: ConstInt, SPLIT_K: ConstInt
+):
GROUP_SIZE_M = 8
M = A.shape[0]
N = B.shape[1]
bidx, bidy = swizzle_2d(M, N, tm, tn, GROUP_SIZE_M)
bidz = ct.bid(1)
@@ -110,20 +118,25 @@
# Convert fp32 to tf32 to use tensorcore
dtype = ct.tfloat32 if A.dtype == ct.float32 else A.dtype
for k in range(bidz, num_tiles, SPLIT_K):
- a = ct.load(A, index=(bidx, k), shape=(tm, tk),
- padding_mode=zero_pad).astype(dtype)
- b = ct.load(B, index=(k, bidy), shape=(tk, tn),
- padding_mode=zero_pad).astype(dtype)
+ a = ct.load(A, index=(bidx, k), shape=(tm, tk), padding_mode=zero_pad).astype(
+ dtype
+ )
+ b = ct.load(B, index=(k, bidy), shape=(tk, tn), padding_mode=zero_pad).astype(
+ dtype
+ )
sum = ct.mma(a, b, sum)
sum = ct.astype(sum, C.dtype)
lock_offset = ct.bid(0)
count_offset = lock_offset
- while ct.atomic_cas(LOCKS, lock_offset, 0, 1, memory_order=ct.MemoryOrder.ACQUIRE) == 1:
+ while (
+ ct.atomic_cas(LOCKS, lock_offset, 0, 1, memory_order=ct.MemoryOrder.ACQUIRE)
+ == 1
+ ):
pass
count = ct.gather(COUNTS, count_offset)
if count == 0:
ct.store(C, index=(bidx, bidy), tile=sum)
else:
@@ -154,19 +167,23 @@
zero_pad = ct.PaddingMode.ZERO
# K-dimension loop
for k in range(num_k_tiles):
# Load tiles with 3D index and 3D shape
# A is (Batch, M, K), load (1, tm, tk) tile
- a = ct.load(A, index=(pid_batch, pidx, k), shape=(1, tm, tk), padding_mode=zero_pad)
+ a = ct.load(
+ A, index=(pid_batch, pidx, k), shape=(1, tm, tk), padding_mode=zero_pad
+ )
a = ct.reshape(a, (tm, tk)) # Reshape to 2D for ct.mma
# B is (Batch, K, N), load (1, tk, tn) tile
- b = ct.load(B, index=(pid_batch, k, pidy), shape=(1, tk, tn), padding_mode=zero_pad)
+ b = ct.load(
+ B, index=(pid_batch, k, pidy), shape=(1, tk, tn), padding_mode=zero_pad
+ )
b = ct.reshape(b, (tk, tn)) # Reshape to 2D for ct.mma
accumulator = ct.mma(a, b, acc=accumulator)
# Convert to output dtype and store
result = ct.astype(accumulator, C.dtype)
# Store with 3D index and 3D shape, C is (Batch, M, N)
result_3d = ct.reshape(result, (1, tm, tn))
- ct.store(C, index=(pid_batch, pidx, pidy), tile=result_3d)
\ No newline at end of file
+ ct.store(C, index=(pid_batch, pidx, pidy), tile=result_3d)
--- /home/runner/work/TensorRT/TensorRT/tools/llm/torchtrt_ext/attention.py 2025-12-12 22:37:00.923120+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/llm/torchtrt_ext/attention.py 2025-12-12 22:37:47.051575+00:00
@@ -16,20 +16,25 @@
ConstBool = ct.Constant[bool]
# --- FMHA Kernel Implementation ---
@ct.kernel(occupancy=2)
-def fmha_kernel(Q, K, V, Out,
- qk_scale: float,
- input_pos: int,
- TILE_D: ConstInt, # TILE_D = hidden_size
- H: ConstInt,
- TILE_M: ConstInt,
- TILE_N: ConstInt,
- QUERY_GROUP_SIZE: ConstInt,
- CAUSAL: ConstBool,
- EVEN_K: ConstBool):
+def fmha_kernel(
+ Q,
+ K,
+ V,
+ Out,
+ qk_scale: float,
+ input_pos: int,
+ TILE_D: ConstInt, # TILE_D = hidden_size
+ H: ConstInt,
+ TILE_M: ConstInt,
+ TILE_N: ConstInt,
+ QUERY_GROUP_SIZE: ConstInt,
+ CAUSAL: ConstBool,
+ EVEN_K: ConstBool,
+):
"""
cuTile kernel for Fused Multi-Head Attention (FMHA).
Computes attention output for a specific batch item and head, using tiling and online softmax.
"""
# Map block IDs to batch and head indices
@@ -57,11 +62,13 @@
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=np.float32)
# Load query tile for this batch, head, and M-chunk
q = ct.load(
Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D)
- ).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D]
+ ).reshape(
+ (TILE_M, TILE_D)
+ ) # [TILE_M, TILE_D]
# loop over k, v and update accumulator
m_end = input_pos + (bid_x + 1) * TILE_M
k_seqlen = K.shape[2]
if CAUSAL:
@@ -76,16 +83,18 @@
# Loop over K, V blocks (N-dimension chunks)
for j in range(0, Tc):
# --- Compute QK product ---
k = ct.load(
- K, index=(batch_idx, off_kv_h, 0, j), shape=(1, 1, TILE_D, TILE_N),
+ K,
+ index=(batch_idx, off_kv_h, 0, j),
+ shape=(1, 1, TILE_D, TILE_N),
order=(0, 1, 3, 2),
latency=2,
)
k = k.reshape((TILE_D, TILE_N)) # [TILE_D, TILE_N]
- qk = ct.full((TILE_M, TILE_N), 0., dtype=np.float32)
+ qk = ct.full((TILE_M, TILE_N), 0.0, dtype=np.float32)
qk = ct.mma(q, k, qk) # [TILE_M, TILE_N]
# --- Apply Causal Masking ---
if (CAUSAL or not EVEN_K) and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
@@ -113,16 +122,20 @@
# scale acc
acc = acc * alpha # [TILE_M, TILE_N]
# --- Compute PV product ---
v = ct.load(
- V, index=(batch_idx, off_kv_h, j, 0), shape=(1, 1, TILE_N, TILE_D),
+ V,
+ index=(batch_idx, off_kv_h, j, 0),
+ shape=(1, 1, TILE_N, TILE_D),
latency=4,
- ).reshape((TILE_N, TILE_D)) # [TILE_N, TILE_D]
+ ).reshape(
+ (TILE_N, TILE_D)
+ ) # [TILE_N, TILE_D]
p = p.astype(Q.dtype)
acc = ct.mma(p, v, acc) # [TILE_M, TILE_N]
m_i = m_ij # [TILE_M, 1]
# --- Final Normalization and Store ---
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
- ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
\ No newline at end of file
+ ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tools/llm/torchtrt_ext/attention.py 2025-12-12 22:57:38.287784+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/llm/torchtrt_ext/attention.py 2025-12-12 22:58:21.945158+00:00
@@ -16,20 +16,25 @@
ConstBool = ct.Constant[bool]
# --- FMHA Kernel Implementation ---
@ct.kernel(occupancy=2)
-def fmha_kernel(Q, K, V, Out,
- qk_scale: float,
- input_pos: int,
- TILE_D: ConstInt, # TILE_D = hidden_size
- H: ConstInt,
- TILE_M: ConstInt,
- TILE_N: ConstInt,
- QUERY_GROUP_SIZE: ConstInt,
- CAUSAL: ConstBool,
- EVEN_K: ConstBool):
+def fmha_kernel(
+ Q,
+ K,
+ V,
+ Out,
+ qk_scale: float,
+ input_pos: int,
+ TILE_D: ConstInt, # TILE_D = hidden_size
+ H: ConstInt,
+ TILE_M: ConstInt,
+ TILE_N: ConstInt,
+ QUERY_GROUP_SIZE: ConstInt,
+ CAUSAL: ConstBool,
+ EVEN_K: ConstBool,
+):
"""
cuTile kernel for Fused Multi-Head Attention (FMHA).
Computes attention output for a specific batch item and head, using tiling and online softmax.
"""
# Map block IDs to batch and head indices
@@ -57,11 +62,13 @@
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=np.float32)
# Load query tile for this batch, head, and M-chunk
q = ct.load(
Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D)
- ).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D]
+ ).reshape(
+ (TILE_M, TILE_D)
+ ) # [TILE_M, TILE_D]
# loop over k, v and update accumulator
m_end = input_pos + (bid_x + 1) * TILE_M
k_seqlen = K.shape[2]
if CAUSAL:
@@ -76,16 +83,18 @@
# Loop over K, V blocks (N-dimension chunks)
for j in range(0, Tc):
# --- Compute QK product ---
k = ct.load(
- K, index=(batch_idx, off_kv_h, 0, j), shape=(1, 1, TILE_D, TILE_N),
+ K,
+ index=(batch_idx, off_kv_h, 0, j),
+ shape=(1, 1, TILE_D, TILE_N),
order=(0, 1, 3, 2),
latency=2,
)
k = k.reshape((TILE_D, TILE_N)) # [TILE_D, TILE_N]
- qk = ct.full((TILE_M, TILE_N), 0., dtype=np.float32)
+ qk = ct.full((TILE_M, TILE_N), 0.0, dtype=np.float32)
qk = ct.mma(q, k, qk) # [TILE_M, TILE_N]
# --- Apply Causal Masking ---
if (CAUSAL or not EVEN_K) and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
@@ -113,16 +122,20 @@
# scale acc
acc = acc * alpha # [TILE_M, TILE_N]
# --- Compute PV product ---
v = ct.load(
- V, index=(batch_idx, off_kv_h, j, 0), shape=(1, 1, TILE_N, TILE_D),
+ V,
+ index=(batch_idx, off_kv_h, j, 0),
+ shape=(1, 1, TILE_N, TILE_D),
latency=4,
- ).reshape((TILE_N, TILE_D)) # [TILE_N, TILE_D]
+ ).reshape(
+ (TILE_N, TILE_D)
+ ) # [TILE_N, TILE_D]
p = p.astype(Q.dtype)
acc = ct.mma(p, v, acc) # [TILE_M, TILE_N]
m_i = m_ij # [TILE_M, 1]
# --- Final Normalization and Store ---
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
- ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
\ No newline at end of file
+ ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
cla signed
component: tests
Issues re: Tests
WIP
Work is in progress, pull request should not be merged yet
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: