-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[https://nvbugs/5456493][feat] add 6KD fp8 dense #9174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: CarstyYou <[email protected]>
📝 WalkthroughWalkthroughIntroduces a comprehensive SM120 CUDA kernel implementation for FP8 block-scaled GEMM with TMA load/store operations, barrier-synchronized producer-consumer stages, integrating it into the existing FP8 GEMM dispatch pipeline and PyTorch linear layer for SM120-specific execution paths. Changes
Sequence DiagramsequenceDiagram
participant Host as Host (PyTorch)
participant Dispatch as FP8 GEMM Dispatcher
participant SM120 as SM120 Kernel
participant Device as GPU Device
Host->>Dispatch: fp8_gemm_run (arch=120)
Dispatch->>Dispatch: Check architecture
alt arch == 120
Dispatch->>Dispatch: gemm_dispatch_sm120()
Dispatch->>SM120: Setup kernel params & TMA descriptors
Dispatch->>Device: cudaLaunchKernelEx (grid, block, smem)
else arch == 89
Dispatch->>Dispatch: gemm_dispatch_sm89()
end
Device->>SM120: Threads prefetch TMA descriptors
SM120->>Device: Prefetch kernel configuration
rect rgba(100, 150, 200, 0.3)
note right of SM120: Producer-Consumer Stages
SM120->>Device: Load A, B, SFA, SFB via TMA
SM120->>Device: Barrier sync (producer ready)
SM120->>Device: Compute (consumer via TiledMMA)
SM120->>Device: Barrier sync (update phase)
end
rect rgba(150, 100, 200, 0.3)
note right of SM120: Epilogue & Writeback
SM120->>SM120: epilogue_with_smem()
SM120->>Device: TMA store D (results)
SM120->>Device: Barrier finalize
end
Device-->>Host: Kernel complete, results in global memory
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Areas requiring special attention:
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
tensorrt_llm/_torch/modules/linear.py (1)
25-27: SM120 FP8 blockscale path wiring looks consistent; minor SM version reuse nitThe new SM120 branch in
FP8BlockScalesLinearMethod.apply(usingper_token_quant_and_transform+fp8_block_scaling_gemm) and the extension ofpost_load_weightsto callresmooth_to_fp8_e8m0/transform_sf_into_required_layoutforget_sm_version() == 120look consistent with the SM120 C++ kernel’s expectations (FP8 activations + int32 SFA/SFB layouts).Two small nits you might consider:
is_sm_100f()already callsget_sm_version()internally, and you now also callget_sm_version()explicitly both inapplyandpost_load_weights. If this ends up on a hot path, caching the SM version once per process (or passing it down) would avoid redundant device‑property queries.- For SM120 you unconditionally use the blockscale GEMM path, ignoring
use_cute_dsl_blockscaling_mm/disable_deep_gemm. That’s fine if SM120 doesn’t support deep‑gemm yet, but worth a short comment if that’s intentional so it’s obvious why SM120 is treated differently from SM100f.Also applies to: 29-29, 631-649, 735-737
cpp/tensorrt_llm/thop/fp8BlockScalingGemm.cpp (1)
115-151: RTX 6000 (SM120) GEMM entry: int32 scales and input validationThe SM120 path here looks structurally correct: FP8 mat dtypes, int32 scale dtypes, M/N/K bounds, and K%128/N%16 constraints all match the new kernel’s requirements, and the SM switch now correctly routes SM 120 to this function.
Two things to double‑check:
Scales are validated as
Int32but then passed to the runner asfloat const* mat1ScalePtr = reinterpret_cast<float const*>(mat1Scale.data_ptr()); float const* mat2ScalePtr = reinterpret_cast<float const*>(mat2Scale.data_ptr());This relies on the downstream path (
CutlassFp8BlockScaleGemmRunner→fp8_gemm_run→gemm_dispatch_sm120) treating the pointer purely as an opaque 32‑bit stream and reinterpreting toint32_t*before use. Please confirm there is no code in the runner stack that assumes FP32 semantics for scales on SM120; otherwise we’d be feeding it bit‑patterns from int32 tensors.Unlike the Ada/Hopper paths, this function doesn’t use the
CHECK_INPUT/CHECK_TH_CUDAhelpers format1Scale/mat2Scale(onlyscalar_type()is checked). Given these scales are produced by your own quantization utilities and should already be CUDA‑resident and contiguous, this is probably fine, but if you want parity with the other paths you might consider adding the same device/contiguity checks.Also applies to: 240-251
cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_kernel.cuh (1)
30-30: SM120 dispatch integration is coherent; consider documenting ld assumptions and using num_device_sms*The new SM120 path is wired cleanly:
gemm_dispatch_sm120correctly wrapsSM120BlockScaledKernel<KT>, sets dynamic shared memory toGemmKernel::kSmemSize, andfp8_gemm_runnow chooses sm89 vs sm120 kernels explicitly based ongetSMVersion().A couple of small points to keep in mind:
For SM120,
gemm_dispatch_sm120doesn’t take or use leading dimensions; it assumes the K‑major, contiguous layouts thatfp8_block_scaling_gemm_rtx_6000already passes (lda=ldb=k, ldd=n). If you ever reuse this dispatch path from another call site, it would be good either to assertld_a == shape_k/ld_b == shape_k/ld_d == shape_nat the boundary or to extend the SM120 builder to honor arbitrary ld*.The
num_device_smsargument here is only used to lazily initializekNumDeviceSMsand doesn’t affect grid selection for SM120, whereas other “deep gemm” paths feed it into autotuning. That’s fine for now, but if you later want SM‑count–aware autotuning for SM120, this is the place to thread it through.Also applies to: 637-685, 1696-1705
cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/6kd_blockwise_gemm/sm120_utils.cuh (1)
37-373: SM120BlockScaledBuilder is cohesive; watch CopyAtomC element type if TMA store gets enabledThe builder nicely packages all SM120‑specific types (FP8 inputs, int32 scale load, UE8M0 scale compute, BF16 accum/output), layouts, and TMA descriptors for A/B/SFA/SFB/D, and the SFA/SFB partitioning helpers align with the kernel’s usage.
One small thing to keep in mind:
CopyAtomCis currently defined asusing CopyAtomC = Copy_Atom<SM90_U32x2_STSM_N, cutlass::half_t>;while
ElementDiscute::bfloat16_t. Because the SM120 kernel currently usesepilogue_with_smemand the TMA store path is commented out, this mismatch is benign. If you later switch to thetma_storeepilogue, it’s worth revisitingCopyAtomC(and any associated smem layout assumptions) so the copy atom’s value type matches the actualElementDtype.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/6kd_blockwise_gemm/sm120_fp8_gemm_1d2d.cuh(1 hunks)cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/6kd_blockwise_gemm/sm120_utils.cuh(1 hunks)cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_kernel.cuh(3 hunks)cpp/tensorrt_llm/thop/fp8BlockScalingGemm.cpp(2 hunks)tensorrt_llm/_torch/modules/linear.py(3 hunks)
🧰 Additional context used
🧠 Learnings (8)
📓 Common learnings
Learnt from: nv-lschneider
Repo: NVIDIA/TensorRT-LLM PR: 7910
File: cpp/tensorrt_llm/kernels/nccl_device/multimem.h:20-30
Timestamp: 2025-09-23T15:13:48.819Z
Learning: TRT-LLM targets modern CUDA toolkits that support FP8 datatypes, so cuda_fp8.h can be included unconditionally without version guards in TRT-LLM code.
📚 Learning: 2025-08-08T22:03:40.707Z
Learnt from: sklevtsov-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:1198-1209
Timestamp: 2025-08-08T22:03:40.707Z
Learning: In the CUTLASS MoE kernels (cpp/tensorrt_llm/cutlass_extensions), when `layout_info.fusion` is set to `TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE`, the `router_scales` parameter must be non-null by design. The fused finalize kernel epilogue does not perform nullptr checks and requires valid router scales to function correctly. This is an implicit contract that callers must satisfy when enabling the FINALIZE fusion mode.
Applied to files:
cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_kernel.cuhcpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/6kd_blockwise_gemm/sm120_fp8_gemm_1d2d.cuhcpp/tensorrt_llm/thop/fp8BlockScalingGemm.cppcpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/6kd_blockwise_gemm/sm120_utils.cuh
📚 Learning: 2025-09-23T15:13:48.819Z
Learnt from: nv-lschneider
Repo: NVIDIA/TensorRT-LLM PR: 7910
File: cpp/tensorrt_llm/kernels/nccl_device/multimem.h:20-30
Timestamp: 2025-09-23T15:13:48.819Z
Learning: TRT-LLM targets modern CUDA toolkits that support FP8 datatypes, so cuda_fp8.h can be included unconditionally without version guards in TRT-LLM code.
Applied to files:
cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_kernel.cuhtensorrt_llm/_torch/modules/linear.pycpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/6kd_blockwise_gemm/sm120_fp8_gemm_1d2d.cuhcpp/tensorrt_llm/thop/fp8BlockScalingGemm.cppcpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/6kd_blockwise_gemm/sm120_utils.cuh
📚 Learning: 2025-08-19T03:35:20.866Z
Learnt from: djns99
Repo: NVIDIA/TensorRT-LLM PR: 6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4616-4626
Timestamp: 2025-08-19T03:35:20.866Z
Learning: In the MOE profiler TMA workspace preparation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu), the overlapping of TMA WS regions for NONE and FINALIZE variants is deliberate design to save memory space, as confirmed by djns99. The comment "reuse the same pointers to save space" reflects this intentional behavior.
Applied to files:
cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/6kd_blockwise_gemm/sm120_fp8_gemm_1d2d.cuhcpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/6kd_blockwise_gemm/sm120_utils.cuh
📚 Learning: 2025-09-23T15:01:00.070Z
Learnt from: nv-lschneider
Repo: NVIDIA/TensorRT-LLM PR: 7910
File: cpp/tensorrt_llm/kernels/nccl_device/config.cu:15-17
Timestamp: 2025-09-23T15:01:00.070Z
Learning: In TensorRT-LLM NCCL device kernels, the <sstream> header is not needed as an explicit include in config.cu because it's provided transitively through other headers. Local compilation testing confirms this works without the explicit include.
Applied to files:
cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/6kd_blockwise_gemm/sm120_fp8_gemm_1d2d.cuhcpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/6kd_blockwise_gemm/sm120_utils.cuh
📚 Learning: 2025-09-19T21:28:13.751Z
Learnt from: jhaotingc
Repo: NVIDIA/TensorRT-LLM PR: 7856
File: cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp:159-166
Timestamp: 2025-09-19T21:28:13.751Z
Learning: In TensorRT-LLM blockScaleMoe routing (cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu), the DeepSeek routing method performs reinterpret_cast<float*>(routingLogits) at line 89, which could cause issues if routing_logits are BF16. However, Qwen3-FP8 models use RenormalizeNaive routing method and are not affected by this dtype casting issue.
Applied to files:
cpp/tensorrt_llm/thop/fp8BlockScalingGemm.cpp
📚 Learning: 2025-08-08T05:10:38.906Z
Learnt from: sklevtsov-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 3294
File: cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp:0-0
Timestamp: 2025-08-08T05:10:38.906Z
Learning: The ScaledAccPerRowBiasPerColScaleScatter fusion in CUTLASS extensions (cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp) is specifically designed for per-column scaling factors only, so it uses a fixed Stride<_0,_1,int64_t> rather than conditional stride logic.
Applied to files:
cpp/tensorrt_llm/thop/fp8BlockScalingGemm.cpp
📚 Learning: 2025-08-08T05:06:31.596Z
Learnt from: sklevtsov-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 3294
File: cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp:36-36
Timestamp: 2025-08-08T05:06:31.596Z
Learning: CUTLASS extension files (under cpp/tensorrt_llm/cutlass_extensions/) follow CUTLASS coding style conventions, including using #pragma once instead of TRTLLM_ prefixed header guards, even though they are .hpp files.
Applied to files:
cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/6kd_blockwise_gemm/sm120_utils.cuh
🧬 Code graph analysis (1)
tensorrt_llm/_torch/modules/linear.py (3)
tensorrt_llm/quantization/utils/fp8_utils.py (3)
per_token_quant_and_transform(447-520)resmooth_to_fp8_e8m0(82-92)transform_sf_into_required_layout(169-217)tensorrt_llm/_utils.py (2)
get_sm_version(740-742)is_sm_100f(746-749)cpp/tensorrt_llm/thop/fp8BlockScalingGemm.cpp (2)
fp8_block_scaling_gemm(240-253)fp8_block_scaling_gemm(240-241)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (1)
cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/6kd_blockwise_gemm/sm120_fp8_gemm_1d2d.cuh (1)
26-519: SM120BlockScaledKernel structure looks sound and matches the SM90-style pipelineThe SM120 kernel wrapper is internally consistent: Params/Arguments construction, TMA prefetch, AB/SF double‑buffering with dedicated producer/consumer barriers, the main K‑loop (including math tail), and the epilogue path via shared memory all line up with the SM120 builder’s layouts and staging parameters. I don’t see any obvious correctness issues in the block‑level orchestration.
Signed-off-by: CarstyYou <[email protected]>
Signed-off-by: CarstyYou <[email protected]>
Signed-off-by: CarstyYou <[email protected]>
|
./bot run |
|
@coderabbitai generate docstrings |
✅ Actions performedInitiated docstring generation; will generate only if new commits exist. |
|
Caution Docstrings generation - FAILED An unexpected error occurred while opening a pull request: Reference update failed - https://docs.github.com/rest/git/refs#create-a-reference |
|
/bot run |
|
PR_Github #24840 [ run ] triggered by Bot. Commit: |
|
PR_Github #24841 Bot args parsing error: usage: /bot [-h] |
|
/bot run |
|
PR_Github #24842 [ run ] triggered by Bot. Commit: |
|
PR_Github #24840 [ run ] completed with state |
Summary by CodeRabbit
Description
Test Coverage
has passed unit test

has passed gsm8k e2e precision check

PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.