Skip to content

Conversation

@nvchenghaoz
Copy link
Collaborator

@nvchenghaoz nvchenghaoz commented Nov 14, 2025

Summary by CodeRabbit

Release Notes

  • New Features
    • Added automatic CUDA hardware capability detection to enable device-specific acceleration on compatible GPUs
    • Implemented optimized fused CUDA kernel-based routing for improved inference performance

For nemotron MOE:

1k/1k/8:
Baseline: Without FP8 kv cache 0.5945
With FP8 kv cache 0.6075
With FP8 kv cache + cublas kernel: 0.6417
With FP8 kv cache + cuda(no cublas) kernel + patch the NemotronHTopkRouter: 0.6733
With FP8 kv cache + cublas kernel + patch the NemotronHTopkRouter: 0.7288

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 14, 2025

📝 Walkthrough

Walkthrough

Two files in the auto_deploy module are modified: quant.py adds runtime CUDA capability detection and conditional enable_cuda_core flag based on SM 8.9 or 12.0 support, narrowing the CUDA-core path to specific hardware; nemotron_h.py introduces a new optimized forward method for NemotronHTopkRouter using fused CUDA kernel-based top-k routing, registered via module patching.

Changes

Cohort / File(s) Summary
CUDA Hardware Detection & Optimization
tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py
Adds runtime CUDA device capability detection. Determines if device is SM 8.9 or 12.0 and sets enable_cuda_core flag accordingly. Updates CUDA-core path selection logic to require both input size ≤ 8 AND enable_cuda_core true; otherwise falls back to cuBLAS.
Model Routing Optimization
tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py
Introduces new _nemotron_h_topk_router_forward method implementing fused CUDA kernel-based top-k routing. Workflow: reshapes inputs, computes router logits via linear transformation, calls noaux_tc_op kernel for top-k weights and indices, returns results. Registers method via CUSTOM_MODULE_PATCHES for runtime patching of NemotronHTopkRouter.

Sequence Diagram

sequenceDiagram
    participant caller as Caller
    participant forward as _nemotron_h_topk_router_forward
    participant reshape as Reshape Input
    participant linear as Router Linear
    participant kernel as noaux_tc_op Kernel
    participant return as Return Results
    
    caller->>forward: hidden_states
    forward->>reshape: reshape(hidden_states)
    reshape->>linear: reshaped_input
    linear->>linear: compute router logits
    linear->>kernel: logits
    rect rgba(100, 150, 200, 0.2)
        note over kernel: CUDA kernel execution<br/>(top-k selection)
    end
    kernel->>kernel: extract top-k weights & indices
    kernel->>return: weights, indices
    return->>caller: (indices, weights)
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

  • quant.py: Review CUDA capability detection logic; verify SM 8.9 and 12.0 are correct target architectures and that input size threshold of ≤ 8 is appropriate
  • nemotron_h.py: Verify noaux_tc_op kernel call correctness; confirm module patching mechanism via CUSTOM_MODULE_PATCHES properly replaces the original forward method at runtime; check shape transformations and tensor operations

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description is incomplete and does not follow the template. It lacks a proper Description section explaining what changes were made and why, and provides only benchmark results without context for the code modifications. Add a comprehensive Description section explaining the rationale for changes to quant.py and nemotron_h.py patches. Include Test Coverage section listing relevant tests that validate the modifications.
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main change: performance improvement for small batch size in AutoDeploy, and follows the required format with [None][feat] prefix.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py (1)

91-117: Fused router forward looks consistent with existing MOE interface; consider reshape instead of view.

The new _nemotron_h_topk_router_forward keeps the contract that self.gate(hidden_states) returns (topk_indices, topk_weights) in the shape expected by torch_ops.auto_deploy.torch_moe, without extra reshaping, which is aligned with the NemotronH MOE usage pattern. Based on learnings.

One minor robustness tweak: hidden_states = hidden_states.view(-1, self.config.hidden_size) assumes that hidden_states is contiguous. To avoid surprises if a non‑contiguous tensor ever reaches this router, using reshape (or .contiguous().view(...)) would be safer:

-    hidden_states = hidden_states.view(-1, self.config.hidden_size)
+    hidden_states = hidden_states.reshape(-1, self.config.hidden_size)

Behavior otherwise looks correct: logits are computed in fp32, and noaux_tc_op receives logits along with e_score_correction_bias, n_group, topk_group, top_k, and routed_scaling_factor in a sensible order.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between cc4c980 and 2761887.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py (2 hunks)
🧰 Additional context used
🧠 Learnings (4)
📓 Common learnings
Learnt from: djns99
Repo: NVIDIA/TensorRT-LLM PR: 6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4010-4012
Timestamp: 2025-08-14T23:23:27.449Z
Learning: For MOE (Mixture of Experts) code reviews in TensorRT-LLM, avoid repeatedly suggesting finalize fusion validation checks and safety assertions. The user djns99 has indicated these suggestions are repetitive and unwanted across multiple MOE-related changes.
Learnt from: jhaotingc
Repo: NVIDIA/TensorRT-LLM PR: 7856
File: cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp:159-166
Timestamp: 2025-09-19T21:28:13.751Z
Learning: In TensorRT-LLM blockScaleMoe routing (cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu), the DeepSeek routing method performs reinterpret_cast<float*>(routingLogits) at line 89, which could cause issues if routing_logits are BF16. However, Qwen3-FP8 models use RenormalizeNaive routing method and are not affected by this dtype casting issue.
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py:98-116
Timestamp: 2025-10-20T17:07:18.745Z
Learning: In NemotronH models (tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py), the gate (self.gate) returns topk_indices and topk_weights that are already in the correct shape to be passed directly to torch_ops.auto_deploy.torch_moe without needing to reshape them when hidden_states is flattened.
📚 Learning: 2025-09-23T15:13:48.819Z
Learnt from: nv-lschneider
Repo: NVIDIA/TensorRT-LLM PR: 7910
File: cpp/tensorrt_llm/kernels/nccl_device/multimem.h:20-30
Timestamp: 2025-09-23T15:13:48.819Z
Learning: TRT-LLM targets modern CUDA toolkits that support FP8 datatypes, so cuda_fp8.h can be included unconditionally without version guards in TRT-LLM code.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py
📚 Learning: 2025-09-23T15:12:38.312Z
Learnt from: nv-lschneider
Repo: NVIDIA/TensorRT-LLM PR: 7910
File: cpp/tensorrt_llm/thop/allreduceOp.cpp:352-446
Timestamp: 2025-09-23T15:12:38.312Z
Learning: In TensorRT-LLM NCCL device implementation, NCCL version 2.28+ requirements are handled at runtime in the nccl_device/config layer rather than with compile-time guards. This allows the allreduceOp to remain version-agnostic and delegates version compatibility validation to the appropriate lower-level components that can gracefully handle unsupported configurations.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py
📚 Learning: 2025-10-20T17:07:18.745Z
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py:98-116
Timestamp: 2025-10-20T17:07:18.745Z
Learning: In NemotronH models (tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py), the gate (self.gate) returns topk_indices and topk_weights that are already in the correct shape to be passed directly to torch_ops.auto_deploy.torch_moe without needing to reshape them when hidden_states is flattened.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py (1)
cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h (3)
  • n_group (241-241)
  • topk_group (243-243)
  • top_k (240-240)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (1)
tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py (1)

169-169: Patch registration for NemotronHTopkRouter forward looks correct.

Binding _nemotron_h_topk_router_forward via CUSTOM_MODULE_PATCHES["NemotronHTopkRouter"] is consistent with the other NemotronH patches and should seamlessly swap in the fused router implementation at load time.

Comment on lines +107 to +113
enable_cuda_core = False
if torch.cuda.is_available():
capability = torch.cuda.get_device_capability(torch.device("cuda:0"))
# enable cuda core for sm89 and sm120
enable_cuda_core = (capability[0] == 8 and capability[1] == 9) or (
capability[0] == 12 and capability[1] == 0
)
Copy link
Contributor

@coderabbitai coderabbitai bot Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

Use the actual tensor device for capability checks (avoid hard‑coding cuda:0).

The enable_cuda_core gating looks good for constraining the CUDA‑core kernel to SM 8.9 and 12.0, but using torch.device("cuda:0") can mis-detect capabilities when the op runs on another GPU (e.g., cuda:1) in a heterogeneous multi‑GPU setup. That could enable the CUDA‑core path on an unsupported device or unnecessarily disable it on a supported one.

Consider deriving the device from the inputs instead, for example:

-    enable_cuda_core = False
-    if torch.cuda.is_available():
-        capability = torch.cuda.get_device_capability(torch.device("cuda:0"))
+    enable_cuda_core = False
+    if input.is_cuda:
+        capability = torch.cuda.get_device_capability(input.device)
         # enable cuda core for sm89 and sm120
         enable_cuda_core = (capability[0] == 8 and capability[1] == 9) or (
             capability[0] == 12 and capability[1] == 0
         )

(Using weight_fp8.device would also be fine if that’s more canonical here.)

Also, the inline comment # NOTE: this kernel work with n % 2 == 0 as well?? reads like a TODO/question; either clarify it or remove it once behavior is confirmed to avoid confusion later.

Also applies to: 116-118


🏁 Script executed:

#!/bin/bash
# Check if file exists and view the full context around the flagged lines
if [ -f "tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py" ]; then
  echo "=== File exists, showing context (lines 100-125) ==="
  sed -n '100,125p' "tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py"
  echo ""
  echo "=== Full function signature ==="
  head -120 "tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py" | tail -30
else
  echo "File not found"
fi

Length of output: 2603


🏁 Script executed:

#!/bin/bash
# Check function definition and inputs to understand what tensors are available
echo "=== Searching for function containing line 107 ==="
rg -n "def " "tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py" | head -20

Length of output: 688


🏁 Script executed:

#!/bin/bash
# Look for how this function is called to understand multi-GPU context
echo "=== Searching for imports and usage patterns ==="
rg -n "torch.cuda.get_device_capability|enable_cuda_core" "tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py" -A2 -B2

Length of output: 814


🏁 Script executed:

#!/bin/bash
# Check if input tensors are available and their device info
echo "=== Broader context to understand tensor parameters ==="
sed -n '80,140p' "tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py"

Length of output: 2409


🏁 Script executed:

#!/bin/bash
# Get the full function signature of trtllm_quant_fp8_linear
echo "=== Full function signature (lines 56-75) ==="
sed -n '56,75p' "tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py"

Length of output: 793


🏁 Script executed:

#!/bin/bash
# Check for device-related assertions or documentation in the function
echo "=== Searching for device assertions or comments ==="
rg -n "device|cuda|\.to\(|\.cuda\(\)" "tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py" -A1 -B1 | head -50

Length of output: 1909


🏁 Script executed:

#!/bin/bash
# Search for callers of this function to understand multi-GPU usage patterns
echo "=== Searching for callers of trtllm_quant_fp8_linear ==="
rg -n "trtllm_quant_fp8_linear" --type py -g '!*.pyc' | grep -v "def trtllm_quant_fp8_linear" | head -20

Length of output: 562


🏁 Script executed:

#!/bin/bash
# Check if there's any device synchronization or multi-GPU handling
echo "=== Checking for distributed or multi-GPU patterns ==="
rg -n "DataParallel|DistributedDataParallel|distributed\|all_reduce\|cuda:" "tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py" -C2

Length of output: 54


🌐 Web query:

torch.cuda.get_device_capability API documentation and device parameter requirements

💡 Result:

Summary — torch.cuda.get_device_capability(device=None)

  • Signature: torch.cuda.get_device_capability(device=None).
  • Purpose: returns (major, minor) CUDA capability for the specified device.
  • Accepted device values:
    • None (default) — use current device (torch.cuda.current_device()).
    • int — device ordinal (e.g. 0, 1). A negative integer is a no-op.
    • torch.device (and in many docs also accept str like "cuda:0").
  • Return: tuple (major, minor) CUDA capability.

Sources: PyTorch docs. [1][2]

References
[1] Official torch.cuda index / API listing.
[2] torch.cuda.get_device_capability API docs (signature and parameter description).


Use the actual tensor device for capability checks (avoid hard‑coding cuda:0).

The hard‑coded torch.device("cuda:0") on line 109 will mis-detect device capabilities when the op runs on another GPU (e.g., cuda:1) in a heterogeneous multi‑GPU setup. Since input_fp8 may reside on any device, this can enable the CUDA‑core path on an unsupported device or unnecessarily disable it on a supported one.

Derive the device from the tensor instead:

    enable_cuda_core = False
    if torch.cuda.is_available():
-        capability = torch.cuda.get_device_capability(torch.device("cuda:0"))
+        capability = torch.cuda.get_device_capability(input_fp8.device)
        # enable cuda core for sm89 and sm120
        enable_cuda_core = (capability[0] == 8 and capability[1] == 9) or (
            capability[0] == 12 and capability[1] == 0
        )

Also clarify the inline comment on line 118—either document the kernel's n % 2 == 0 requirement or remove the trailing ?? once behavior is confirmed.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
enable_cuda_core = False
if torch.cuda.is_available():
capability = torch.cuda.get_device_capability(torch.device("cuda:0"))
# enable cuda core for sm89 and sm120
enable_cuda_core = (capability[0] == 8 and capability[1] == 9) or (
capability[0] == 12 and capability[1] == 0
)
enable_cuda_core = False
if torch.cuda.is_available():
capability = torch.cuda.get_device_capability(input_fp8.device)
# enable cuda core for sm89 and sm120
enable_cuda_core = (capability[0] == 8 and capability[1] == 9) or (
capability[0] == 12 and capability[1] == 0
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can simplify

cap = torch.cuda.get_device_capability(0)
enable_cuda_core = cap == (8, 9) or cap == (12, 0)

You are not really enabling cuda-core (SM) - you are choosing whether to run an operation using cuBlas or a cuda kernel, so using the name enable_cuda_core will confuse some people.

is_blackwell_geforce_or_ada is a name that describes the condition defined by cap == (8, 9) or cap == (12, 0)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nzmora-nvidia, thank you for the excellent suggestions! You're absolutely right on both points:

  1. Code simplification: The tuple comparison is much cleaner:

    cap = torch.cuda.get_device_capability(0)
    enable_cuda_core = cap == (8, 9) or cap == (12, 0)
  2. Naming clarity: is_blackwell_geforce_or_ada is indeed much more descriptive than enable_cuda_core. SM 8.9 corresponds to Ada (RTX 4090, etc.) and SM 12.0 to Blackwell consumer GPUs (RTX 5090, etc.), so the name accurately describes the hardware architectures being detected rather than implying that CUDA cores are being "enabled." The current name could mislead readers into thinking it's about SM activation rather than kernel selection (cuBLAS vs. CUDA kernel).

The clearer naming will help future maintainers immediately understand that this is an architecture-specific optimization path for small batch sizes on specific GPU families.


✏️ Learnings added
Learnt from: nzmora-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 9163
File: tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py:107-113
Timestamp: 2025-11-14T11:22:03.711Z
Learning: In TensorRT-LLM AutoDeploy custom ops, when adding hardware capability checks to select between kernel implementations (e.g., cuBLAS vs. CUDA kernel), use descriptive variable names that identify the specific GPU architectures or families being targeted (e.g., `is_blackwell_geforce_or_ada`) rather than generic names like `enable_cuda_core`. This makes it clear that the code is selecting an implementation path based on hardware capabilities, not enabling/disabling hardware features.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.


🧠 Learnings used
Learnt from: nv-lschneider
Repo: NVIDIA/TensorRT-LLM PR: 7910
File: cpp/tensorrt_llm/kernels/nccl_device/multimem.h:20-30
Timestamp: 2025-09-23T15:13:48.819Z
Learning: TRT-LLM targets modern CUDA toolkits that support FP8 datatypes, so cuda_fp8.h can be included unconditionally without version guards in TRT-LLM code.

Learnt from: nv-lschneider
Repo: NVIDIA/TensorRT-LLM PR: 7910
File: cpp/tensorrt_llm/thop/allreduceOp.cpp:352-446
Timestamp: 2025-09-23T15:12:38.312Z
Learning: In TensorRT-LLM NCCL device implementation, NCCL version 2.28+ requirements are handled at runtime in the nccl_device/config layer rather than with compile-time guards. This allows the allreduceOp to remain version-agnostic and delegates version compatibility validation to the appropriate lower-level components that can gracefully handle unsupported configurations.

Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py:180-182
Timestamp: 2025-10-20T17:09:21.560Z
Learning: In tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py, the _gated_rmsnorm_replacement function does not need to cast the output of torch.ops.auto_deploy.torch_rmsnorm_gated back to the input dtype, even though the custom op returns fp32. The dtype handling is managed elsewhere or the fp32 output is acceptable for downstream consumers.

Learnt from: yibinl-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 8923
File: tests/unittest/_torch/modeling/test_modeling_starcoder2.py:157-160
Timestamp: 2025-11-13T18:53:29.113Z
Learning: In PyTorch 2.0+, torch.device() can be used as a context manager (e.g., `with torch.device('cuda'):`) to set the default device for tensor creation operations within that context. This is a valid pattern and should not be flagged as an error.

Learnt from: ixlmar
Repo: NVIDIA/TensorRT-LLM PR: 7294
File: tensorrt_llm/_torch/pyexecutor/sampler.py:1191-1197
Timestamp: 2025-08-28T10:22:02.288Z
Learning: In tensorrt_llm/_torch/pyexecutor/sampler.py, the object identity comparison `softmax_req_indices is not group_req_indices_cuda` on line ~1191 is intentional and used as an optimization to determine whether to reuse an existing indexer or create a new one, based on which code path was taken during tensor assignment.

Learnt from: jhaotingc
Repo: NVIDIA/TensorRT-LLM PR: 7856
File: cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp:159-166
Timestamp: 2025-09-19T21:28:13.751Z
Learning: In TensorRT-LLM blockScaleMoe routing (cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu), the DeepSeek routing method performs reinterpret_cast<float*>(routingLogits) at line 89, which could cause issues if routing_logits are BF16. However, Qwen3-FP8 models use RenormalizeNaive routing method and are not affected by this dtype casting issue.

Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py:6-6
Timestamp: 2025-10-20T16:54:09.824Z
Learning: In tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py, the import `from ...modules.mamba.layernorm_gated import _layer_norm_fwd` is correct and should not be changed to modules.fla.layernorm_gated. The _layer_norm_fwd function exists in both modules/mamba/layernorm_gated.py and modules/fla/layernorm_gated.py, but the mamba version is the intended implementation for this use case.

Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 7033
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:0-0
Timestamp: 2025-08-19T12:45:11.997Z
Learning: In tensorrt_llm/_torch/pyexecutor/model_engine.py, DoRA (Delta Orthogonal Rank Adaptation) functionality was removed from the PyTorch flow to eliminate issues with inverted DoRA detection logic. The original is_dora condition was checking if scaling_vec_pointer == 0, which was potentially incorrect.

Learnt from: djns99
Repo: NVIDIA/TensorRT-LLM PR: 6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4010-4012
Timestamp: 2025-08-14T23:23:27.449Z
Learning: For MOE (Mixture of Experts) code reviews in TensorRT-LLM, avoid repeatedly suggesting finalize fusion validation checks and safety assertions. The user djns99 has indicated these suggestions are repetitive and unwanted across multiple MOE-related changes.

Comment on lines +107 to +113
enable_cuda_core = False
if torch.cuda.is_available():
capability = torch.cuda.get_device_capability(torch.device("cuda:0"))
# enable cuda core for sm89 and sm120
enable_cuda_core = (capability[0] == 8 and capability[1] == 9) or (
capability[0] == 12 and capability[1] == 0
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can simplify

cap = torch.cuda.get_device_capability(0)
enable_cuda_core = cap == (8, 9) or cap == (12, 0)

You are not really enabling cuda-core (SM) - you are choosing whether to run an operation using cuBlas or a cuda kernel, so using the name enable_cuda_core will confuse some people.

is_blackwell_geforce_or_ada is a name that describes the condition defined by cap == (8, 9) or cap == (12, 0)

@suyoggupta
Copy link
Collaborator

Could you please also post accuracy numbers for tp1, tp2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: Backlog

Development

Successfully merging this pull request may close these issues.

3 participants