Skip to content

Conversation

lanluo-nvidia
Copy link
Collaborator

Description

Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.

Fixes # (issue)

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: tests Issues re: Tests component: lowering Issues re: The lowering / preprocessing passes component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: build system Issues re: Build system component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels May 15, 2025
@github-actions github-actions bot requested a review from gs-olive May 15, 2025 17:28
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py	2025-05-15 17:28:16.606815+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py	2025-05-15 17:28:40.517973+00:00
@@ -140,12 +140,11 @@
    return dequantized_data


# TODO: to remove it this is to make sure our global scale and block scale calculation is correct during debugging
def _test_weights_scaling_factor(
-    weights_tensor: torch.Tensor, 
-    global_scale: torch.Tensor
+    weights_tensor: torch.Tensor, global_scale: torch.Tensor
) -> None:

    import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor
    import modelopt.onnx.quantization.quant_utils as quant_utils

@@ -192,11 +191,13 @@
    """

    import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor

    block_scale_fp8 = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor(
-        weights_tensor, 16, global_scale,
+        weights_tensor,
+        16,
+        global_scale,
    )[0]

    weights_tensor_scaled = nvfp4_tensor.NVFP4QTensor.quantize(
        weights_tensor,
        16,
@@ -205,11 +206,13 @@
    )[0]._quantized_data

    block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8")
    global_scale = to_torch(global_scale, None)
    global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale")
-    weights_fp4_represented_in_uint8 = get_trt_tensor(ctx, weights_tensor_scaled, name + "_weights_fp4_represented_in_uint8")
+    weights_fp4_represented_in_uint8 = get_trt_tensor(
+        ctx, weights_tensor_scaled, name + "_weights_fp4_represented_in_uint8"
+    )

    # dequantize block scale from fp8 to float32
    dequantize_block_scale_layer = ctx.net.add_dequantize(
        block_scale_fp8,
        global_scale,
@@ -248,6 +251,5 @@
    )  # amax is calculated from input_tensor.abs().amax().float()
    global_scale = torch.divide(amax, 6 * 448)
    if global_scale == 0:
        global_scale = 1.0
    return global_scale
-

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py	2025-05-15 21:33:37.025993+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py	2025-05-15 21:33:59.004002+00:00
@@ -140,12 +140,11 @@
    return dequantized_data


# TODO: to remove it this is to make sure our global scale and block scale calculation is correct during debugging
def _test_weights_scaling_factor(
-    weights_tensor: torch.Tensor, 
-    global_scale: torch.Tensor
+    weights_tensor: torch.Tensor, global_scale: torch.Tensor
) -> None:

    import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor
    import modelopt.onnx.quantization.quant_utils as quant_utils

@@ -192,11 +191,13 @@
    """

    import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor

    block_scale_fp8 = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor(
-        weights_tensor, 16, global_scale,
+        weights_tensor,
+        16,
+        global_scale,
    )[0]

    weights_tensor_scaled = nvfp4_tensor.NVFP4QTensor.quantize(
        weights_tensor,
        16,
@@ -205,11 +206,13 @@
    )[0]._quantized_data

    block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8")
    global_scale = to_torch(global_scale, None)
    global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale")
-    weights_fp4_represented_in_uint8 = get_trt_tensor(ctx, weights_tensor_scaled, name + "_weights_fp4_represented_in_uint8")
+    weights_fp4_represented_in_uint8 = get_trt_tensor(
+        ctx, weights_tensor_scaled, name + "_weights_fp4_represented_in_uint8"
+    )

    # dequantize block scale from fp8 to float32
    dequantize_block_scale_layer = ctx.net.add_dequantize(
        block_scale_fp8,
        global_scale,
@@ -248,6 +251,5 @@
    )  # amax is calculated from input_tensor.abs().amax().float()
    global_scale = torch.divide(amax, 6 * 448)
    if global_scale == 0:
        global_scale = 1.0
    return global_scale
-

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py	2025-05-16 17:17:53.756341+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py	2025-05-16 17:18:21.840287+00:00
@@ -107,11 +107,13 @@

    """
    global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale")

    if input_tensor.dtype not in [trt.DataType.HALF, trt.DataType.FLOAT]:
-        raise ValueError(f"Currently try float16, float32 only on input tensor for now. Unsupported dtype: {input_tensor.dtype}")
+        raise ValueError(
+            f"Currently try float16, float32 only on input tensor for now. Unsupported dtype: {input_tensor.dtype}"
+        )
    # dynamic quantize input tensor to fp4
    dynamic_quantize_layer = ctx.net.add_dynamic_quantize(
        input_tensor,
        axis,
        block_size,
@@ -194,17 +196,19 @@
    Returns:
        quantized data tensor in fp4
    """

    import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor
-    
+
    if weights_tensor.dtype == torch.float16:
        original_dtype = trt.DataType.HALF
    elif weights_tensor.dtype == torch.float32:
        original_dtype = trt.DataType.FLOAT
    else:
-        raise ValueError(f"Currently try float16, float32 only on weights tensor. Unsupported dtype: {weights_tensor.dtype}")
+        raise ValueError(
+            f"Currently try float16, float32 only on weights tensor. Unsupported dtype: {weights_tensor.dtype}"
+        )

    block_scale_fp8 = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor(
        weights_tensor,
        16,
        global_scale,
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-05-16 17:17:53.783341+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-05-16 17:18:27.298406+00:00
@@ -213,11 +213,13 @@
    from modelopt.torch.quantization.utils import export_torch_mode

    class SimpleNetwork(torch.nn.Module):
        def __init__(self):
            super(SimpleNetwork, self).__init__()
-            self.linear1 = torch.nn.Linear(in_features=64, out_features=32, bias=False, dtype=torch.float16)
+            self.linear1 = torch.nn.Linear(
+                in_features=64, out_features=32, bias=False, dtype=torch.float16
+            )

        def forward(self, x):
            x = self.linear1(x)
            return x

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-05-18 17:54:24.708675+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-05-18 17:54:58.520847+00:00
@@ -235,11 +235,11 @@
    print(f"lan added pytorch output_pyt: {output_pyt}")

    quant_cfg = mtq.NVFP4_DEFAULT_CFG
    mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
    # model has qdq nodes at this point
-    
+
    torch.onnx.export(model, input_tensor, "mtq_model.onnx")

    with torch.no_grad():
        with export_torch_mode():
            exp_program = torch.export.export(model, (input_tensor,), strict=False)

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py	2025-05-18 21:19:00.783067+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py	2025-05-18 21:19:23.297120+00:00
@@ -214,19 +214,25 @@
    block_scale_fp8 = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor(
        weights_tensor,
        16,
        global_scale,
    )[0]
-    print(f"lan added global_scale: {global_scale.shape=} {global_scale.dtype=} {global_scale=}")
-    print(f"lan added block_scale_fp8: {block_scale_fp8.shape=} {block_scale_fp8.dtype=} {block_scale_fp8=}")
+    print(
+        f"lan added global_scale: {global_scale.shape=} {global_scale.dtype=} {global_scale=}"
+    )
+    print(
+        f"lan added block_scale_fp8: {block_scale_fp8.shape=} {block_scale_fp8.dtype=} {block_scale_fp8=}"
+    )
    weights_tensor_fp4 = nvfp4_tensor.NVFP4QTensor.quantize(
        weights_tensor,
        16,
        block_scale_fp8,
        global_scale,
    )[0]._quantized_data
-    print(f"lan added weights_tensor_fp4: {weights_tensor_fp4.shape=} {weights_tensor_fp4.dtype=} {weights_tensor_fp4=}")
+    print(
+        f"lan added weights_tensor_fp4: {weights_tensor_fp4.shape=} {weights_tensor_fp4.dtype=} {weights_tensor_fp4=}"
+    )
    block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8")
    global_scale = to_torch(global_scale, None)
    global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale")
    weights_tensor_fp4 = get_trt_tensor(ctx, weights_tensor_fp4, name + "_weights_fp4")
    # dequantize block scale from fp8 to original dtype (default is float32)
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-05-18 21:19:00.810067+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-05-18 21:19:28.498117+00:00
@@ -229,22 +229,28 @@

    input_tensor = torch.ones(128, 64, dtype=torch.float16).cuda()

    print(f"lan added amax: {input_tensor.abs().amax()}")
    model = SimpleNetwork().eval().cuda()
-    model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=torch.float16).cuda())
-    model.linear1.bias = torch.nn.Parameter(torch.zeros(128, 32, dtype=torch.float16).cuda())
+    model.linear1.weight = torch.nn.Parameter(
+        torch.ones(32, 64, dtype=torch.float16).cuda()
+    )
+    model.linear1.bias = torch.nn.Parameter(
+        torch.zeros(128, 32, dtype=torch.float16).cuda()
+    )
    output_pyt = model(input_tensor)
-    print(f"lan added model input: {input_tensor=}")    
+    print(f"lan added model input: {input_tensor=}")
    print(f"lan added model weight: {model.linear1.weight=}")
    print(f"lan added model bias: {model.linear1.bias=}")
-    print(f"lan added pytorch output_pyt: {output_pyt} {output_pyt.dtype=} {output_pyt.shape=}")
+    print(
+        f"lan added pytorch output_pyt: {output_pyt} {output_pyt.dtype=} {output_pyt.shape=}"
+    )

    quant_cfg = mtq.NVFP4_DEFAULT_CFG
    mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
    # model has qdq nodes at this point
-    
+
    torch.onnx.export(model, input_tensor, "mtq_model.onnx")

    with torch.no_grad():
        with export_torch_mode():
            exp_program = torch.export.export(model, (input_tensor,), strict=False)

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py	2025-05-20 22:04:08.054204+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py	2025-05-20 22:04:33.547147+00:00
@@ -214,19 +214,25 @@
    block_scale_fp8 = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor(
        weights_tensor,
        16,
        global_scale,
    )[0]
-    print(f"lan added global_scale: {global_scale.shape=} {global_scale.dtype=} {global_scale=}")
-    print(f"lan added block_scale_fp8: {block_scale_fp8.shape=} {block_scale_fp8.dtype=} {block_scale_fp8=}")
+    print(
+        f"lan added global_scale: {global_scale.shape=} {global_scale.dtype=} {global_scale=}"
+    )
+    print(
+        f"lan added block_scale_fp8: {block_scale_fp8.shape=} {block_scale_fp8.dtype=} {block_scale_fp8=}"
+    )
    weights_tensor_fp4 = nvfp4_tensor.NVFP4QTensor.quantize(
        weights_tensor,
        16,
        block_scale_fp8,
        global_scale,
    )[0]._quantized_data
-    print(f"lan added weights_tensor_fp4: {weights_tensor_fp4.shape=} {weights_tensor_fp4.dtype=} {weights_tensor_fp4=}")
+    print(
+        f"lan added weights_tensor_fp4: {weights_tensor_fp4.shape=} {weights_tensor_fp4.dtype=} {weights_tensor_fp4=}"
+    )
    block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8")
    global_scale = to_torch(global_scale, None)
    global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale")
    weights_tensor_fp4 = get_trt_tensor(ctx, weights_tensor_fp4, name + "_weights_fp4")
    # dequantize block scale from fp8 to original dtype (default is float32)
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-05-20 22:04:08.081205+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-05-20 22:04:39.052999+00:00
@@ -229,22 +229,28 @@

    input_tensor = torch.ones(128, 64, dtype=torch.float16).cuda()

    print(f"lan added amax: {input_tensor.abs().amax()}")
    model = SimpleNetwork().eval().cuda()
-    model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=torch.float16).cuda())
-    model.linear1.bias = torch.nn.Parameter(torch.zeros(128, 32, dtype=torch.float16).cuda())
+    model.linear1.weight = torch.nn.Parameter(
+        torch.ones(32, 64, dtype=torch.float16).cuda()
+    )
+    model.linear1.bias = torch.nn.Parameter(
+        torch.zeros(128, 32, dtype=torch.float16).cuda()
+    )
    output_pyt = model(input_tensor)
-    print(f"lan added model input: {input_tensor=}")    
+    print(f"lan added model input: {input_tensor=}")
    print(f"lan added model weight: {model.linear1.weight=}")
    print(f"lan added model bias: {model.linear1.bias=}")
-    print(f"lan added pytorch output_pyt: {output_pyt} {output_pyt.dtype=} {output_pyt.shape=}")
+    print(
+        f"lan added pytorch output_pyt: {output_pyt} {output_pyt.dtype=} {output_pyt.shape=}"
+    )

    quant_cfg = mtq.NVFP4_DEFAULT_CFG
    mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
    # model has qdq nodes at this point
-    
+
    torch.onnx.export(model, input_tensor, "mtq_model.onnx")

    with torch.no_grad():
        with export_torch_mode():
            exp_program = torch.export.export(model, (input_tensor,), strict=False)

@lanluo-nvidia lanluo-nvidia changed the title Test fp4: Lluo/fp4 try out Test Only fp4: Lluo/fp4 try out May 21, 2025
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2025-05-21 21:05:16.522261+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2025-05-21 21:05:39.067088+00:00
@@ -580,16 +580,16 @@
            f"Detected torch_executed_modules was non-empty: {torch_executed_modules}"
            "\nThis feature is unimplemented in Torch-TRT Dynamo currently."
        )

    # if use_explicit_typing:
-        # if len(enabled_precisions) != 1 or not any(
-        #     x in enabled_precisions for x in {torch.float32, dtype.f32}
-        # ):
-        #     raise AssertionError(
-        #         f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}"
-        #     )
+    # if len(enabled_precisions) != 1 or not any(
+    #     x in enabled_precisions for x in {torch.float32, dtype.f32}
+    # ):
+    #     raise AssertionError(
+    #         f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}"
+    #     )

    if use_fp32_acc:
        logger.debug(
            "FP32 accumulation for matmul layers is enabled. This option should only be enabled if the model already has FP16 weights and has no effect if it has FP32 weights. \
                     This flag inserts casts around matmul layers and ensures TensorRT executes the matmul layers in FP16 with FP32 accumulation."
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py	2025-05-21 21:05:16.525261+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py	2025-05-21 21:05:39.293128+00:00
@@ -12,10 +12,11 @@
    to_torch,
)
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor
import os
+

def nvfp4_quantize(
    ctx: ConversionContext,
    target: Target,
    source_ir: Optional[SourceIR],
@@ -219,19 +220,25 @@
    block_scale_fp8 = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor(
        weights_tensor,
        16,
        global_scale,
    )[0]
-    print(f"lan added global_scale: {global_scale.shape=} {global_scale.dtype=} {global_scale=}")
-    print(f"lan added block_scale_fp8: {block_scale_fp8.shape=} {block_scale_fp8.dtype=} {block_scale_fp8=}")
+    print(
+        f"lan added global_scale: {global_scale.shape=} {global_scale.dtype=} {global_scale=}"
+    )
+    print(
+        f"lan added block_scale_fp8: {block_scale_fp8.shape=} {block_scale_fp8.dtype=} {block_scale_fp8=}"
+    )
    weights_tensor_fp4 = nvfp4_tensor.NVFP4QTensor.quantize(
        weights_tensor,
        16,
        block_scale_fp8,
        global_scale,
    )[0]._quantized_data
-    print(f"lan added weights_tensor_fp4: {weights_tensor_fp4.shape=} {weights_tensor_fp4.dtype=} {weights_tensor_fp4=}")
+    print(
+        f"lan added weights_tensor_fp4: {weights_tensor_fp4.shape=} {weights_tensor_fp4.dtype=} {weights_tensor_fp4=}"
+    )
    block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8")
    global_scale = to_torch(global_scale, None)
    global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale")
    weights_tensor_fp4 = get_trt_tensor(ctx, weights_tensor_fp4, name + "_weights_fp4")
    # dequantize block scale from fp8 to original dtype (default is float32)
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-05-21 21:05:16.552261+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-05-21 21:05:44.770942+00:00
@@ -228,22 +228,28 @@

    input_tensor = torch.ones(128, 64, dtype=torch.float16).cuda()

    print(f"lan added amax: {input_tensor.abs().amax()}")
    model = SimpleNetwork().eval().cuda()
-    model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=torch.float16).cuda())
-    model.linear1.bias = torch.nn.Parameter(torch.ones(128, 32, dtype=torch.float16).cuda())
+    model.linear1.weight = torch.nn.Parameter(
+        torch.ones(32, 64, dtype=torch.float16).cuda()
+    )
+    model.linear1.bias = torch.nn.Parameter(
+        torch.ones(128, 32, dtype=torch.float16).cuda()
+    )
    output_pyt = model(input_tensor)
-    print(f"lan added model input: {input_tensor=}")    
+    print(f"lan added model input: {input_tensor=}")
    print(f"lan added model weight: {model.linear1.weight=}")
    print(f"lan added model bias: {model.linear1.bias=}")
-    print(f"lan added pytorch output_pyt: {output_pyt} {output_pyt.dtype=} {output_pyt.shape=}")
+    print(
+        f"lan added pytorch output_pyt: {output_pyt} {output_pyt.dtype=} {output_pyt.shape=}"
+    )

    quant_cfg = mtq.NVFP4_DEFAULT_CFG
    mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
    # model has qdq nodes at this point
-    
+
    torch.onnx.export(model, input_tensor, "mtq_model.onnx")

    with torch.no_grad():
        with export_torch_mode():
            exp_program = torch.export.export(model, (input_tensor,), strict=False)

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2025-05-23 16:51:47.625324+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2025-05-23 16:52:10.277549+00:00
@@ -580,16 +580,16 @@
            f"Detected torch_executed_modules was non-empty: {torch_executed_modules}"
            "\nThis feature is unimplemented in Torch-TRT Dynamo currently."
        )

    # if use_explicit_typing:
-        # if len(enabled_precisions) != 1 or not any(
-        #     x in enabled_precisions for x in {torch.float32, dtype.f32}
-        # ):
-        #     raise AssertionError(
-        #         f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}"
-        #     )
+    # if len(enabled_precisions) != 1 or not any(
+    #     x in enabled_precisions for x in {torch.float32, dtype.f32}
+    # ):
+    #     raise AssertionError(
+    #         f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}"
+    #     )

    if use_fp32_acc:
        logger.debug(
            "FP32 accumulation for matmul layers is enabled. This option should only be enabled if the model already has FP16 weights and has no effect if it has FP32 weights. \
                     This flag inserts casts around matmul layers and ensures TensorRT executes the matmul layers in FP16 with FP32 accumulation."
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py	2025-05-23 16:51:47.628324+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py	2025-05-23 16:52:10.511363+00:00
@@ -13,10 +13,11 @@
)
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor
import os
import torch_tensorrt.dynamo.conversion.impl as impl
+

def nvfp4_quantize(
    ctx: ConversionContext,
    target: Target,
    source_ir: Optional[SourceIR],
@@ -235,22 +236,30 @@
        keep_high_precision=True,
    )
    if enable_transpose:
        block_scale = block_scale.transpose(0, 1)
        weights_tensor_scaled = weights_tensor_scaled.transpose(0, 1)
-    
+
    block_scale_fp8 = block_scale.to(torch.float8_e4m3fn)
    weights_tensor_uint4 = nvfp4_tensor.NVFP4QTensor._cast_fp4(weights_tensor_scaled)
-    weights_tensor_uint8 = (weights_tensor_uint4[..., 1::2] << 4) | weights_tensor_uint4[..., 0::2]
-    
-    print(f"lan added block_scale_fp8: {block_scale_fp8.shape=} {block_scale_fp8.dtype=} {block_scale_fp8=}")
-    print(f"lan added weights_tensor_uint8: {weights_tensor_uint8.shape=} {weights_tensor_uint8.dtype=}")
-    
+    weights_tensor_uint8 = (
+        weights_tensor_uint4[..., 1::2] << 4
+    ) | weights_tensor_uint4[..., 0::2]
+
+    print(
+        f"lan added block_scale_fp8: {block_scale_fp8.shape=} {block_scale_fp8.dtype=} {block_scale_fp8=}"
+    )
+    print(
+        f"lan added weights_tensor_uint8: {weights_tensor_uint8.shape=} {weights_tensor_uint8.dtype=}"
+    )
+
    block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8")
    global_scale = to_torch(global_scale, None)
    global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale")
-    weights_tensor_fp4 = get_trt_tensor(ctx, weights_tensor_uint8, name + "_weights_fp4")
+    weights_tensor_fp4 = get_trt_tensor(
+        ctx, weights_tensor_uint8, name + "_weights_fp4"
+    )
    # dequantize block scale from fp8 to original dtype (default is float32)
    dequantize_block_scale_layer = ctx.net.add_dequantize(
        block_scale_fp8,
        global_scale,
        original_dtype,
@@ -282,11 +291,18 @@
    print(
        f"lan added dequantize_data_layer: {dequantize_data_layer.to_type=} {dequantize_data_layer.axis=} {dequantize_data_layer.precision=} {dequantize_data_layer.get_output_type(0)=}"
    )
    dequantized_data = dequantize_data_layer.get_output(0)
    if enable_transpose:
-        dequantized_data = impl.permutation.permute(ctx, target, source_ir, name + "_dequantized_data_transposed", dequantized_data, (-1, -2))
+        dequantized_data = impl.permutation.permute(
+            ctx,
+            target,
+            source_ir,
+            name + "_dequantized_data_transposed",
+            dequantized_data,
+            (-1, -2),
+        )
    return dequantized_data


def _calculate_global_scale(
    ctx: ConversionContext,
@@ -302,39 +318,47 @@
    global_scale = torch.divide(amax, 6 * 448)
    if global_scale == 0:
        global_scale = 1.0
    return global_scale

+
def _get_weights_scaling_factor_transposed(
    weights_tensor: torch.Tensor,
    global_scale: torch.Tensor,
    keep_high_precision: bool = False,
) -> torch.Tensor:
    [k, n] = weights_tensor.shape[-2:]
-    assert k % 16 == 0, "Weight shape is not divisible for block size for block quantiation."
-    weights_tensor = weights_tensor.reshape(tuple(weights_tensor.shape[:-2]) + (k // 16, n, 16))
+    assert (
+        k % 16 == 0
+    ), "Weight shape is not divisible for block size for block quantiation."
+    weights_tensor = weights_tensor.reshape(
+        tuple(weights_tensor.shape[:-2]) + (k // 16, n, 16)
+    )
    per_block_amax = weights_tensor.abs().amax(dim=-1).float()
    per_block_scale = per_block_amax / 6.0
    q_per_block_scale = per_block_scale / global_scale
    q_per_block_scale[per_block_scale == 0] = 1.0
    if not keep_high_precision:
        q_per_block_scale = q_per_block_scale.to(torch.float8_e4m3fn)
    return q_per_block_scale

+
def _quantized_weights_transposed(
    input: torch.Tensor,
    weights_scaling_factor: torch.Tensor,
    weights_scaling_factor_2: torch.Tensor,
    keep_high_precision: bool = False,
) -> torch.Tensor:
-    
+
    # Reshape the weight and scale factors
    input = input.view((*tuple(input.shape[:-1]), -1, block_size))

    # Scale weights
    scaled_weight = input / (
-        (weights_scaling_factor.to(torch.float32) * weights_scaling_factor_2).unsqueeze(-1)
+        (weights_scaling_factor.to(torch.float32) * weights_scaling_factor_2).unsqueeze(
+            -1
+        )
    )

    # Reshape weights to original
    scaled_weight = scaled_weight.view((*tuple(scaled_weight.shape[:-2]), -1))

@@ -347,7 +371,5 @@
    return (
        cls(input_shape, input_dtype, packed_weight),
        weights_scaling_factor,
        weights_scaling_factor_2,
    )
-
-    
\ No newline at end of file
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-05-23 16:51:47.655324+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-05-23 16:52:16.159101+00:00
@@ -228,17 +228,19 @@

    input_tensor = torch.ones(128, 64, dtype=torch.float16).cuda()

    print(f"lan added amax: {input_tensor.abs().amax()}")
    model = SimpleNetwork().eval().cuda()
-    #model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=torch.float16).cuda())
-    #model.linear1.bias = torch.nn.Parameter(torch.ones(128, 32, dtype=torch.float16).cuda())
+    # model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=torch.float16).cuda())
+    # model.linear1.bias = torch.nn.Parameter(torch.ones(128, 32, dtype=torch.float16).cuda())
    output_pyt = model(input_tensor)
-    print(f"lan added model input: {input_tensor=}")    
+    print(f"lan added model input: {input_tensor=}")
    print(f"lan added model weight: {model.linear1.weight=}")
    print(f"lan added model bias: {model.linear1.bias=}")
-    print(f"lan added pytorch output_pyt: {output_pyt} {output_pyt.dtype=} {output_pyt.shape=}")
+    print(
+        f"lan added pytorch output_pyt: {output_pyt} {output_pyt.dtype=} {output_pyt.shape=}"
+    )

    quant_cfg = mtq.NVFP4_DEFAULT_CFG
    mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
    # model has qdq nodes at this point
    with torch.no_grad():

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2025-05-23 16:59:48.524241+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2025-05-23 17:00:12.827589+00:00
@@ -580,16 +580,16 @@
            f"Detected torch_executed_modules was non-empty: {torch_executed_modules}"
            "\nThis feature is unimplemented in Torch-TRT Dynamo currently."
        )

    # if use_explicit_typing:
-        # if len(enabled_precisions) != 1 or not any(
-        #     x in enabled_precisions for x in {torch.float32, dtype.f32}
-        # ):
-        #     raise AssertionError(
-        #         f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}"
-        #     )
+    # if len(enabled_precisions) != 1 or not any(
+    #     x in enabled_precisions for x in {torch.float32, dtype.f32}
+    # ):
+    #     raise AssertionError(
+    #         f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}"
+    #     )

    if use_fp32_acc:
        logger.debug(
            "FP32 accumulation for matmul layers is enabled. This option should only be enabled if the model already has FP16 weights and has no effect if it has FP32 weights. \
                     This flag inserts casts around matmul layers and ensures TensorRT executes the matmul layers in FP16 with FP32 accumulation."
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py	2025-05-23 16:59:48.526242+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py	2025-05-23 17:00:13.111447+00:00
@@ -13,10 +13,11 @@
)
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor
import os
import torch_tensorrt.dynamo.conversion.impl as impl
+

def nvfp4_quantize(
    ctx: ConversionContext,
    target: Target,
    source_ir: Optional[SourceIR],
@@ -236,22 +237,30 @@
        keep_high_precision=True,
    )
    if enable_transpose:
        block_scale = block_scale.transpose(0, 1)
        weights_tensor_scaled = weights_tensor_scaled.transpose(0, 1)
-    
+
    block_scale_fp8 = block_scale.to(torch.float8_e4m3fn)
    weights_tensor_uint4 = nvfp4_tensor.NVFP4QTensor._cast_fp4(weights_tensor_scaled)
-    weights_tensor_uint8 = (weights_tensor_uint4[..., 1::2] << 4) | weights_tensor_uint4[..., 0::2]
-    
-    print(f"lan added block_scale_fp8: {block_scale_fp8.shape=} {block_scale_fp8.dtype=} {block_scale_fp8=}")
-    print(f"lan added weights_tensor_uint8: {weights_tensor_uint8.shape=} {weights_tensor_uint8.dtype=}")
-    
+    weights_tensor_uint8 = (
+        weights_tensor_uint4[..., 1::2] << 4
+    ) | weights_tensor_uint4[..., 0::2]
+
+    print(
+        f"lan added block_scale_fp8: {block_scale_fp8.shape=} {block_scale_fp8.dtype=} {block_scale_fp8=}"
+    )
+    print(
+        f"lan added weights_tensor_uint8: {weights_tensor_uint8.shape=} {weights_tensor_uint8.dtype=}"
+    )
+
    block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8")
    global_scale = to_torch(global_scale, None)
    global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale")
-    weights_tensor_fp4 = get_trt_tensor(ctx, weights_tensor_uint8, name + "_weights_fp4")
+    weights_tensor_fp4 = get_trt_tensor(
+        ctx, weights_tensor_uint8, name + "_weights_fp4"
+    )
    # dequantize block scale from fp8 to original dtype (default is float32)
    dequantize_block_scale_layer = ctx.net.add_dequantize(
        block_scale_fp8,
        global_scale,
        original_dtype,
@@ -281,11 +290,18 @@
    print(
        f"lan added dequantize_data_layer: {dequantize_data_layer.to_type=} {dequantize_data_layer.axis=} {dequantize_data_layer.precision=} {dequantize_data_layer.get_output_type(0)=}"
    )
    dequantized_data = dequantize_data_layer.get_output(0)
    if enable_transpose:
-        dequantized_data = impl.permutation.permute(ctx, target, source_ir, name + "_dequantized_data_transposed", dequantized_data, (-1, -2))
+        dequantized_data = impl.permutation.permute(
+            ctx,
+            target,
+            source_ir,
+            name + "_dequantized_data_transposed",
+            dequantized_data,
+            (-1, -2),
+        )
    return dequantized_data


def _calculate_global_scale(
    ctx: ConversionContext,
@@ -301,39 +317,47 @@
    global_scale = torch.divide(amax, 6 * 448)
    if global_scale == 0:
        global_scale = 1.0
    return global_scale

+
def _get_weights_scaling_factor_transposed(
    weights_tensor: torch.Tensor,
    global_scale: torch.Tensor,
    keep_high_precision: bool = False,
) -> torch.Tensor:
    [k, n] = weights_tensor.shape[-2:]
-    assert k % 16 == 0, "Weight shape is not divisible for block size for block quantiation."
-    weights_tensor = weights_tensor.reshape(tuple(weights_tensor.shape[:-2]) + (k // 16, n, 16))
+    assert (
+        k % 16 == 0
+    ), "Weight shape is not divisible for block size for block quantiation."
+    weights_tensor = weights_tensor.reshape(
+        tuple(weights_tensor.shape[:-2]) + (k // 16, n, 16)
+    )
    per_block_amax = weights_tensor.abs().amax(dim=-1).float()
    per_block_scale = per_block_amax / 6.0
    q_per_block_scale = per_block_scale / global_scale
    q_per_block_scale[per_block_scale == 0] = 1.0
    if not keep_high_precision:
        q_per_block_scale = q_per_block_scale.to(torch.float8_e4m3fn)
    return q_per_block_scale

+
def _quantized_weights_transposed(
    input: torch.Tensor,
    weights_scaling_factor: torch.Tensor,
    weights_scaling_factor_2: torch.Tensor,
    keep_high_precision: bool = False,
) -> torch.Tensor:
-    
+
    # Reshape the weight and scale factors
    input = input.view((*tuple(input.shape[:-1]), -1, block_size))

    # Scale weights
    scaled_weight = input / (
-        (weights_scaling_factor.to(torch.float32) * weights_scaling_factor_2).unsqueeze(-1)
+        (weights_scaling_factor.to(torch.float32) * weights_scaling_factor_2).unsqueeze(
+            -1
+        )
    )

    # Reshape weights to original
    scaled_weight = scaled_weight.view((*tuple(scaled_weight.shape[:-2]), -1))

@@ -346,7 +370,5 @@
    return (
        cls(input_shape, input_dtype, packed_weight),
        weights_scaling_factor,
        weights_scaling_factor_2,
    )
-
-    
\ No newline at end of file
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-05-23 16:59:48.553243+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-05-23 17:00:18.800102+00:00
@@ -228,17 +228,19 @@

    input_tensor = torch.ones(128, 64, dtype=torch.float16).cuda()

    print(f"lan added amax: {input_tensor.abs().amax()}")
    model = SimpleNetwork().eval().cuda()
-    #model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=torch.float16).cuda())
-    #model.linear1.bias = torch.nn.Parameter(torch.ones(128, 32, dtype=torch.float16).cuda())
+    # model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=torch.float16).cuda())
+    # model.linear1.bias = torch.nn.Parameter(torch.ones(128, 32, dtype=torch.float16).cuda())
    output_pyt = model(input_tensor)
-    print(f"lan added model input: {input_tensor=}")    
+    print(f"lan added model input: {input_tensor=}")
    print(f"lan added model weight: {model.linear1.weight=}")
    print(f"lan added model bias: {model.linear1.bias=}")
-    print(f"lan added pytorch output_pyt: {output_pyt} {output_pyt.dtype=} {output_pyt.shape=}")
+    print(
+        f"lan added pytorch output_pyt: {output_pyt} {output_pyt.dtype=} {output_pyt.shape=}"
+    )

    quant_cfg = mtq.NVFP4_DEFAULT_CFG
    mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
    # model has qdq nodes at this point
    with torch.no_grad():

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/addmm.py	2025-05-23 21:29:17.101078+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/addmm.py	2025-05-23 21:29:40.534529+00:00
@@ -6,10 +6,11 @@
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.types import TRTTensor
import os
+

def addmm(
    ctx: ConversionContext,
    target: Target,
    source_ir: Optional[SourceIR],
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2025-05-23 21:29:17.099077+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2025-05-23 21:29:40.828598+00:00
@@ -580,16 +580,16 @@
            f"Detected torch_executed_modules was non-empty: {torch_executed_modules}"
            "\nThis feature is unimplemented in Torch-TRT Dynamo currently."
        )

    # if use_explicit_typing:
-        # if len(enabled_precisions) != 1 or not any(
-        #     x in enabled_precisions for x in {torch.float32, dtype.f32}
-        # ):
-        #     raise AssertionError(
-        #         f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}"
-        #     )
+    # if len(enabled_precisions) != 1 or not any(
+    #     x in enabled_precisions for x in {torch.float32, dtype.f32}
+    # ):
+    #     raise AssertionError(
+    #         f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}"
+    #     )

    if use_fp32_acc:
        logger.debug(
            "FP32 accumulation for matmul layers is enabled. This option should only be enabled if the model already has FP16 weights and has no effect if it has FP32 weights. \
                     This flag inserts casts around matmul layers and ensures TensorRT executes the matmul layers in FP16 with FP32 accumulation."
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/permutation.py	2025-05-23 21:29:17.102077+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/permutation.py	2025-05-23 21:29:41.043949+00:00
@@ -13,10 +13,11 @@
)
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
from torch_tensorrt.fx.types import TRTTensor
import os

+
def permute(
    ctx: ConversionContext,
    target: Target,
    source_ir: Optional[SourceIR],
    name: str,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py	2025-05-23 21:29:17.102077+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py	2025-05-23 21:29:41.079546+00:00
@@ -13,10 +13,11 @@
)
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor
import os
import torch_tensorrt.dynamo.conversion.impl as impl
+

def nvfp4_quantize(
    ctx: ConversionContext,
    target: Target,
    source_ir: Optional[SourceIR],
@@ -206,11 +207,13 @@
        quantized data tensor in fp4
    """
    if os.getenv("DISABLE_STATIC_QUANTIZE", "false").lower() == "true":
        print("lan added disable_static_quantize is set, skipping static quantize")
        return get_trt_tensor(ctx, weights_tensor, name + "_weights")
-    print("lan added static disable_static_quantize is not set, do disable_static_quantize ")
+    print(
+        "lan added static disable_static_quantize is not set, do disable_static_quantize "
+    )
    if os.getenv("ENABLE_TRANSPOSE", "false").lower() == "true":
        print("lan added enable_transpose is set, transposing weights tensor")
        enable_transpose = True
        axis = -2
    else:
@@ -239,22 +242,30 @@
        keep_high_precision=True,
    )
    if enable_transpose:
        block_scale = block_scale.transpose(0, 1)
        weights_tensor_scaled = weights_tensor_scaled.transpose(0, 1)
-    
+
    block_scale_fp8 = block_scale.to(torch.float8_e4m3fn)
    weights_tensor_uint4 = nvfp4_tensor.NVFP4QTensor._cast_fp4(weights_tensor_scaled)
-    weights_tensor_uint8 = (weights_tensor_uint4[..., 1::2] << 4) | weights_tensor_uint4[..., 0::2]
-    
-    print(f"lan added block_scale_fp8: {block_scale_fp8.shape=} {block_scale_fp8.dtype=} {block_scale_fp8=}")
-    print(f"lan added weights_tensor_uint8: {weights_tensor_uint8.shape=} {weights_tensor_uint8.dtype=}")
-    
+    weights_tensor_uint8 = (
+        weights_tensor_uint4[..., 1::2] << 4
+    ) | weights_tensor_uint4[..., 0::2]
+
+    print(
+        f"lan added block_scale_fp8: {block_scale_fp8.shape=} {block_scale_fp8.dtype=} {block_scale_fp8=}"
+    )
+    print(
+        f"lan added weights_tensor_uint8: {weights_tensor_uint8.shape=} {weights_tensor_uint8.dtype=}"
+    )
+
    block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8")
    global_scale = to_torch(global_scale, None)
    global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale")
-    weights_tensor_fp4 = get_trt_tensor(ctx, weights_tensor_uint8, name + "_weights_fp4")
+    weights_tensor_fp4 = get_trt_tensor(
+        ctx, weights_tensor_uint8, name + "_weights_fp4"
+    )
    # dequantize block scale from fp8 to original dtype (default is float32)
    dequantize_block_scale_layer = ctx.net.add_dequantize(
        block_scale_fp8,
        global_scale,
        original_dtype,
@@ -284,11 +295,18 @@
    print(
        f"lan added dequantize_data_layer: {dequantize_data_layer.to_type=} {dequantize_data_layer.axis=} {dequantize_data_layer.get_input(0).shape=} {dequantize_data_layer.get_input(1).shape=}"
    )
    dequantized_data = dequantize_data_layer.get_output(0)
    if enable_transpose:
-        dequantized_data = impl.permutation.permute(ctx, target, source_ir, name + "_dequantized_data_transposed", dequantized_data, (-1, -2))
+        dequantized_data = impl.permutation.permute(
+            ctx,
+            target,
+            source_ir,
+            name + "_dequantized_data_transposed",
+            dequantized_data,
+            (-1, -2),
+        )
    return dequantized_data


def _calculate_global_scale(
    ctx: ConversionContext,
@@ -304,18 +322,23 @@
    global_scale = torch.divide(amax, 6 * 448)
    if global_scale == 0:
        global_scale = 1.0
    return global_scale

+
def _get_weights_scaling_factor_transposed(
    weights_tensor: torch.Tensor,
    global_scale: torch.Tensor,
    keep_high_precision: bool = False,
) -> torch.Tensor:
    [k, n] = weights_tensor.shape[-2:]
-    assert k % 16 == 0, "Weight shape is not divisible for block size for block quantiation."
-    weights_tensor = weights_tensor.reshape(tuple(weights_tensor.shape[:-2]) + (k // 16, n, 16))
+    assert (
+        k % 16 == 0
+    ), "Weight shape is not divisible for block size for block quantiation."
+    weights_tensor = weights_tensor.reshape(
+        tuple(weights_tensor.shape[:-2]) + (k // 16, n, 16)
+    )
    per_block_amax = weights_tensor.abs().amax(dim=-1).float()
    per_block_scale = per_block_amax / 6.0
    q_per_block_scale = per_block_scale / global_scale
    q_per_block_scale[per_block_scale == 0] = 1.0
    if not keep_high_precision:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-05-23 21:29:17.128078+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-05-23 21:29:46.726986+00:00
@@ -13,10 +13,11 @@
from packaging.version import Version

assertions = unittest.TestCase()
import os

+
@pytest.mark.unit
def test_resnet18(ir):
    model = models.resnet18(pretrained=True).eval().to("cuda")
    input = torch.randn((1, 3, 224, 224)).to("cuda")

@@ -226,21 +227,22 @@
        """Simple calibration function for testing."""
        model(input_tensor)

    input_tensor = torch.ones(128, 64, dtype=torch.float16).cuda()

-    
    model = SimpleNetwork().eval().cuda()
-    model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=torch.float16).cuda())
-    #model.linear1.bias = torch.nn.Parameter(torch.ones(128, 32, dtype=torch.float16).cuda())
+    model.linear1.weight = torch.nn.Parameter(
+        torch.ones(32, 64, dtype=torch.float16).cuda()
+    )
+    # model.linear1.bias = torch.nn.Parameter(torch.ones(128, 32, dtype=torch.float16).cuda())
    print(f"lan added amax: {input_tensor.abs().amax()=}")
    print(f"lan added amax: {model.linear1.weight.abs().amax()=}")
    expected_output = model(input_tensor)
-    print(f"lan added model input: {input_tensor=}")    
+    print(f"lan added model input: {input_tensor=}")
    print(f"lan added model weight: {model.linear1.weight=}")
    print(f"lan added model bias: {model.linear1.bias=}")
-    
+
    quant_cfg = mtq.NVFP4_DEFAULT_CFG
    mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
    # model has qdq nodes at this point
    with torch.no_grad():
        with export_torch_mode():
@@ -268,15 +270,21 @@
                print("lan added disable_gemm is set, compring result with weights")
                expected_output = model.linear1.weight
            else:
                print("lan added disable_gemm is not set, compring result with pytorch")

-            print(f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=}")
-            print(f"lan added pytorch output_pyt: {expected_output=} {outexpected_outputput_pyt.dtype=} {expected_output.shape=}")
+            print(
+                f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=}"
+            )
+            print(
+                f"lan added pytorch output_pyt: {expected_output=} {outexpected_outputput_pyt.dtype=} {expected_output.shape=}"
+            )

            abs_diff = torch.abs(expected_output - outputs_trt)
-            print(f"lan added max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}")
+            print(
+                f"lan added max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}"
+            )
            print(f"lan added abs_diff: {abs_diff=}")
            assert torch.allclose(expected_output, outputs_trt, rtol=0.8, atol=0.8)


@unittest.skipIf(

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/addmm.py	2025-05-23 21:40:54.274947+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/addmm.py	2025-05-23 21:41:18.027955+00:00
@@ -6,10 +6,11 @@
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.types import TRTTensor
import os
+

def addmm(
    ctx: ConversionContext,
    target: Target,
    source_ir: Optional[SourceIR],
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2025-05-23 21:40:54.272947+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2025-05-23 21:41:18.270807+00:00
@@ -580,16 +580,16 @@
            f"Detected torch_executed_modules was non-empty: {torch_executed_modules}"
            "\nThis feature is unimplemented in Torch-TRT Dynamo currently."
        )

    # if use_explicit_typing:
-        # if len(enabled_precisions) != 1 or not any(
-        #     x in enabled_precisions for x in {torch.float32, dtype.f32}
-        # ):
-        #     raise AssertionError(
-        #         f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}"
-        #     )
+    # if len(enabled_precisions) != 1 or not any(
+    #     x in enabled_precisions for x in {torch.float32, dtype.f32}
+    # ):
+    #     raise AssertionError(
+    #         f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}"
+    #     )

    if use_fp32_acc:
        logger.debug(
            "FP32 accumulation for matmul layers is enabled. This option should only be enabled if the model already has FP16 weights and has no effect if it has FP32 weights. \
                     This flag inserts casts around matmul layers and ensures TensorRT executes the matmul layers in FP16 with FP32 accumulation."
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py	2025-05-23 21:40:54.275948+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py	2025-05-23 21:41:18.545512+00:00
@@ -13,10 +13,11 @@
)
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor
import os
import torch_tensorrt.dynamo.conversion.impl as impl
+

def nvfp4_quantize(
    ctx: ConversionContext,
    target: Target,
    source_ir: Optional[SourceIR],
@@ -206,11 +207,13 @@
        quantized data tensor in fp4
    """
    if os.getenv("DISABLE_STATIC_QUANTIZE", "false").lower() == "true":
        print("lan added disable_static_quantize is set, skipping static quantize")
        return get_trt_tensor(ctx, weights_tensor, name + "_weights")
-    print("lan added static disable_static_quantize is not set, do disable_static_quantize ")
+    print(
+        "lan added static disable_static_quantize is not set, do disable_static_quantize "
+    )
    if os.getenv("ENABLE_TRANSPOSE", "false").lower() == "true":
        print("lan added enable_transpose is set, transposing weights tensor")
        enable_transpose = True
        axis = -2
    else:
@@ -239,22 +242,30 @@
        keep_high_precision=True,
    )
    if enable_transpose:
        block_scale = block_scale.transpose(0, 1)
        weights_tensor_scaled = weights_tensor_scaled.transpose(0, 1)
-    
+
    block_scale_fp8 = block_scale.to(torch.float8_e4m3fn)
    weights_tensor_uint4 = nvfp4_tensor.NVFP4QTensor._cast_fp4(weights_tensor_scaled)
-    weights_tensor_uint8 = (weights_tensor_uint4[..., 1::2] << 4) | weights_tensor_uint4[..., 0::2]
-    
-    print(f"lan added block_scale_fp8: {block_scale_fp8.shape=} {block_scale_fp8.dtype=} {block_scale_fp8=}")
-    print(f"lan added weights_tensor_uint8: {weights_tensor_uint8.shape=} {weights_tensor_uint8.dtype=}")
-    
+    weights_tensor_uint8 = (
+        weights_tensor_uint4[..., 1::2] << 4
+    ) | weights_tensor_uint4[..., 0::2]
+
+    print(
+        f"lan added block_scale_fp8: {block_scale_fp8.shape=} {block_scale_fp8.dtype=} {block_scale_fp8=}"
+    )
+    print(
+        f"lan added weights_tensor_uint8: {weights_tensor_uint8.shape=} {weights_tensor_uint8.dtype=}"
+    )
+
    block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8")
    global_scale = to_torch(global_scale, None)
    global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale")
-    weights_tensor_fp4 = get_trt_tensor(ctx, weights_tensor_uint8, name + "_weights_fp4")
+    weights_tensor_fp4 = get_trt_tensor(
+        ctx, weights_tensor_uint8, name + "_weights_fp4"
+    )
    # dequantize block scale from fp8 to original dtype (default is float32)
    dequantize_block_scale_layer = ctx.net.add_dequantize(
        block_scale_fp8,
        global_scale,
        original_dtype,
@@ -284,11 +295,18 @@
    print(
        f"lan added dequantize_data_layer: {dequantize_data_layer.to_type=} {dequantize_data_layer.axis=} {dequantize_data_layer.get_input(0).shape=} {dequantize_data_layer.get_input(1).shape=}"
    )
    dequantized_data = dequantize_data_layer.get_output(0)
    if enable_transpose:
-        dequantized_data = impl.permutation.permute(ctx, target, source_ir, name + "_dequantized_data_transposed", dequantized_data, (-1, -2))
+        dequantized_data = impl.permutation.permute(
+            ctx,
+            target,
+            source_ir,
+            name + "_dequantized_data_transposed",
+            dequantized_data,
+            (-1, -2),
+        )
    return dequantized_data


def _calculate_global_scale(
    ctx: ConversionContext,
@@ -304,18 +322,23 @@
    global_scale = torch.divide(amax, 6 * 448)
    if global_scale == 0:
        global_scale = 1.0
    return global_scale

+
def _get_weights_scaling_factor_transposed(
    weights_tensor: torch.Tensor,
    global_scale: torch.Tensor,
    keep_high_precision: bool = False,
) -> torch.Tensor:
    [k, n] = weights_tensor.shape[-2:]
-    assert k % 16 == 0, "Weight shape is not divisible for block size for block quantiation."
-    weights_tensor = weights_tensor.reshape(tuple(weights_tensor.shape[:-2]) + (k // 16, n, 16))
+    assert (
+        k % 16 == 0
+    ), "Weight shape is not divisible for block size for block quantiation."
+    weights_tensor = weights_tensor.reshape(
+        tuple(weights_tensor.shape[:-2]) + (k // 16, n, 16)
+    )
    per_block_amax = weights_tensor.abs().amax(dim=-1).float()
    per_block_scale = per_block_amax / 6.0
    q_per_block_scale = per_block_scale / global_scale
    q_per_block_scale[per_block_scale == 0] = 1.0
    if not keep_high_precision:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/permutation.py	2025-05-23 21:40:54.275948+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/permutation.py	2025-05-23 21:41:18.581377+00:00
@@ -13,10 +13,11 @@
)
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
from torch_tensorrt.fx.types import TRTTensor
import os

+
def permute(
    ctx: ConversionContext,
    target: Target,
    source_ir: Optional[SourceIR],
    name: str,
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-05-23 21:40:54.302949+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-05-23 21:41:23.779400+00:00
@@ -13,10 +13,11 @@
from packaging.version import Version

assertions = unittest.TestCase()
import os

+
@pytest.mark.unit
def test_resnet18(ir):
    model = models.resnet18(pretrained=True).eval().to("cuda")
    input = torch.randn((1, 3, 224, 224)).to("cuda")

@@ -226,21 +227,22 @@
        """Simple calibration function for testing."""
        model(input_tensor)

    input_tensor = torch.ones(128, 64, dtype=torch.float16).cuda()

-    
    model = SimpleNetwork().eval().cuda()
-    model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=torch.float16).cuda())
-    #model.linear1.bias = torch.nn.Parameter(torch.ones(128, 32, dtype=torch.float16).cuda())
+    model.linear1.weight = torch.nn.Parameter(
+        torch.ones(32, 64, dtype=torch.float16).cuda()
+    )
+    # model.linear1.bias = torch.nn.Parameter(torch.ones(128, 32, dtype=torch.float16).cuda())
    print(f"lan added amax: {input_tensor.abs().amax()=}")
    print(f"lan added amax: {model.linear1.weight.abs().amax()=}")
    expected_output = model(input_tensor)
-    print(f"lan added model input: {input_tensor=}")    
+    print(f"lan added model input: {input_tensor=}")
    print(f"lan added model weight: {model.linear1.weight=}")
    print(f"lan added model bias: {model.linear1.bias=}")
-    
+
    quant_cfg = mtq.NVFP4_DEFAULT_CFG
    mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
    # model has qdq nodes at this point
    with torch.no_grad():
        with export_torch_mode():
@@ -268,15 +270,21 @@
                print("lan added disable_gemm is set, compring result with weights")
                expected_output = model.linear1.weight
            else:
                print("lan added disable_gemm is not set, compring result with pytorch")

-            print(f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=}")
-            print(f"lan added expected output_pyt: {expected_output=} {expected_output.dtype=} {expected_output.shape=}")
+            print(
+                f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=}"
+            )
+            print(
+                f"lan added expected output_pyt: {expected_output=} {expected_output.dtype=} {expected_output.shape=}"
+            )

            abs_diff = torch.abs(expected_output - outputs_trt)
-            print(f"lan added max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}")
+            print(
+                f"lan added max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}"
+            )
            print(f"lan added abs_diff: {abs_diff=}")
            assert torch.allclose(expected_output, outputs_trt, rtol=0.8, atol=0.8)


@unittest.skipIf(

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/addmm.py	2025-05-23 23:24:04.125539+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/addmm.py	2025-05-23 23:24:25.584127+00:00
@@ -6,10 +6,11 @@
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.types import TRTTensor
import os
+

def addmm(
    ctx: ConversionContext,
    target: Target,
    source_ir: Optional[SourceIR],
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2025-05-23 23:24:04.124539+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2025-05-23 23:24:26.119372+00:00
@@ -272,17 +272,23 @@
            builder_config.set_memory_pool_limit(
                trt.MemoryPoolType.DLA_GLOBAL_DRAM,
                self.compilation_settings.dla_global_dram_size,
            )

-        if not self.compilation_settings.use_explicit_typing and dtype.float16 in self.compilation_settings.enabled_precisions:
+        if (
+            not self.compilation_settings.use_explicit_typing
+            and dtype.float16 in self.compilation_settings.enabled_precisions
+        ):
            builder_config.set_flag(trt.BuilderFlag.FP16)

        if dtype.int8 in self.compilation_settings.enabled_precisions:
            builder_config.set_flag(trt.BuilderFlag.INT8)

-        if not self.compilation_settings.use_explicit_typing and dtype.fp8 in self.compilation_settings.enabled_precisions:
+        if (
+            not self.compilation_settings.use_explicit_typing
+            and dtype.fp8 in self.compilation_settings.enabled_precisions
+        ):
            builder_config.set_flag(trt.BuilderFlag.FP8)

        if dtype.bfloat16 in self.compilation_settings.enabled_precisions:
            builder_config.set_flag(trt.BuilderFlag.BF16)

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py	2025-05-23 23:24:04.126539+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/nvfp4_quantize.py	2025-05-23 23:24:26.134310+00:00
@@ -13,10 +13,11 @@
)
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor
import os
import torch_tensorrt.dynamo.conversion.impl as impl
+

def nvfp4_quantize(
    ctx: ConversionContext,
    target: Target,
    source_ir: Optional[SourceIR],
@@ -206,11 +207,13 @@
        quantized data tensor in fp4
    """
    if os.getenv("DISABLE_STATIC_QUANTIZE", "false").lower() == "true":
        print("lan added disable_static_quantize is set, skipping static quantize")
        return get_trt_tensor(ctx, weights_tensor, name + "_weights")
-    print("lan added static disable_static_quantize is not set, do disable_static_quantize ")
+    print(
+        "lan added static disable_static_quantize is not set, do disable_static_quantize "
+    )
    if os.getenv("ENABLE_TRANSPOSE", "false").lower() == "true":
        print("lan added enable_transpose is set, transposing weights tensor")
        enable_transpose = True
        axis = -2
    else:
@@ -239,22 +242,30 @@
        keep_high_precision=True,
    )
    if enable_transpose:
        block_scale = block_scale.transpose(0, 1)
        weights_tensor_scaled = weights_tensor_scaled.transpose(0, 1)
-    
+
    block_scale_fp8 = block_scale.to(torch.float8_e4m3fn)
    weights_tensor_uint4 = nvfp4_tensor.NVFP4QTensor._cast_fp4(weights_tensor_scaled)
-    weights_tensor_uint8 = (weights_tensor_uint4[..., 1::2] << 4) | weights_tensor_uint4[..., 0::2]
-    
-    print(f"lan added block_scale_fp8: {block_scale_fp8.shape=} {block_scale_fp8.dtype=} {block_scale_fp8=}")
-    print(f"lan added weights_tensor_uint8: {weights_tensor_uint8.shape=} {weights_tensor_uint8.dtype=}")
-    
+    weights_tensor_uint8 = (
+        weights_tensor_uint4[..., 1::2] << 4
+    ) | weights_tensor_uint4[..., 0::2]
+
+    print(
+        f"lan added block_scale_fp8: {block_scale_fp8.shape=} {block_scale_fp8.dtype=} {block_scale_fp8=}"
+    )
+    print(
+        f"lan added weights_tensor_uint8: {weights_tensor_uint8.shape=} {weights_tensor_uint8.dtype=}"
+    )
+
    block_scale_fp8 = get_trt_tensor(ctx, block_scale_fp8, name + "_block_scale_fp8")
    global_scale = to_torch(global_scale, None)
    global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale")
-    weights_tensor_fp4 = get_trt_tensor(ctx, weights_tensor_uint8, name + "_weights_fp4")
+    weights_tensor_fp4 = get_trt_tensor(
+        ctx, weights_tensor_uint8, name + "_weights_fp4"
+    )
    # dequantize block scale from fp8 to original dtype (default is float32)
    dequantize_block_scale_layer = ctx.net.add_dequantize(
        block_scale_fp8,
        global_scale,
        original_dtype,
@@ -284,11 +295,18 @@
    print(
        f"lan added dequantize_data_layer: {dequantize_data_layer.to_type=} {dequantize_data_layer.axis=} {dequantize_data_layer.get_input(0).shape=} {dequantize_data_layer.get_input(1).shape=}"
    )
    dequantized_data = dequantize_data_layer.get_output(0)
    if enable_transpose:
-        dequantized_data = impl.permutation.permute(ctx, target, source_ir, name + "_dequantized_data_transposed", dequantized_data, (-1, -2))
+        dequantized_data = impl.permutation.permute(
+            ctx,
+            target,
+            source_ir,
+            name + "_dequantized_data_transposed",
+            dequantized_data,
+            (-1, -2),
+        )
    return dequantized_data


def _calculate_global_scale(
    ctx: ConversionContext,
@@ -304,18 +322,23 @@
    global_scale = torch.divide(amax, 6 * 448)
    if global_scale == 0:
        global_scale = 1.0
    return global_scale

+
def _get_weights_scaling_factor_transposed(
    weights_tensor: torch.Tensor,
    global_scale: torch.Tensor,
    keep_high_precision: bool = False,
) -> torch.Tensor:
    [k, n] = weights_tensor.shape[-2:]
-    assert k % 16 == 0, "Weight shape is not divisible for block size for block quantiation."
-    weights_tensor = weights_tensor.reshape(tuple(weights_tensor.shape[:-2]) + (k // 16, n, 16))
+    assert (
+        k % 16 == 0
+    ), "Weight shape is not divisible for block size for block quantiation."
+    weights_tensor = weights_tensor.reshape(
+        tuple(weights_tensor.shape[:-2]) + (k // 16, n, 16)
+    )
    per_block_amax = weights_tensor.abs().amax(dim=-1).float()
    per_block_scale = per_block_amax / 6.0
    q_per_block_scale = per_block_scale / global_scale
    q_per_block_scale[per_block_scale == 0] = 1.0
    if not keep_high_precision:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/permutation.py	2025-05-23 23:24:04.126539+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/permutation.py	2025-05-23 23:24:26.280124+00:00
@@ -13,10 +13,11 @@
)
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
from torch_tensorrt.fx.types import TRTTensor
import os

+
def permute(
    ctx: ConversionContext,
    target: Target,
    source_ir: Optional[SourceIR],
    name: str,
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-05-23 23:24:04.153539+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-05-23 23:24:31.794505+00:00
@@ -13,10 +13,11 @@
from packaging.version import Version

assertions = unittest.TestCase()
import os

+
@pytest.mark.unit
def test_resnet18(ir):
    model = models.resnet18(pretrained=True).eval().to("cuda")
    input = torch.randn((1, 3, 224, 224)).to("cuda")

@@ -208,10 +209,11 @@
)
@pytest.mark.unit
def test_base_fp4(ir):
    import modelopt.torch.quantization as mtq
    from modelopt.torch.quantization.utils import export_torch_mode
+
    dtype = torch.float16

    class SimpleNetwork(torch.nn.Module):
        def __init__(self):
            super(SimpleNetwork, self).__init__()
@@ -227,21 +229,20 @@
        """Simple calibration function for testing."""
        model(input_tensor)

    input_tensor = torch.ones(128, 64, dtype=dtype).cuda()

-    
    model = SimpleNetwork().eval().cuda()
    model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=dtype).cuda())
    model.linear1.bias = torch.nn.Parameter(torch.zeros(128, 32, dtype=dtype).cuda())
    print(f"lan added amax: {input_tensor.abs().amax()=}")
    print(f"lan added amax: {model.linear1.weight.abs().amax()=}")
    expected_output = model(input_tensor)
-    print(f"lan added model input: {input_tensor=}")    
+    print(f"lan added model input: {input_tensor=}")
    print(f"lan added model weight: {model.linear1.weight=}")
    print(f"lan added model bias: {model.linear1.bias=}")
-    
+
    quant_cfg = mtq.NVFP4_DEFAULT_CFG
    mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
    # model has qdq nodes at this point
    with torch.no_grad():
        with export_torch_mode():
@@ -269,15 +270,21 @@
                print("lan added disable_gemm is set, compring result with weights")
                expected_output = model.linear1.weight
            else:
                print("lan added disable_gemm is not set, compring result with pytorch")

-            print(f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=} {outputs_trt.abs().amax()=}")
-            print(f"lan added expected output_pyt: {expected_output=} {expected_output.dtype=} {expected_output.shape=} {expected_output.abs().amax()=}")
+            print(
+                f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=} {outputs_trt.abs().amax()=}"
+            )
+            print(
+                f"lan added expected output_pyt: {expected_output=} {expected_output.dtype=} {expected_output.shape=} {expected_output.abs().amax()=}"
+            )

            abs_diff = torch.abs(expected_output - outputs_trt)
-            print(f"lan added max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}")
+            print(
+                f"lan added max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}"
+            )
            print(f"lan added abs_diff: {abs_diff=}")
            assert torch.allclose(expected_output, outputs_trt, rtol=0.8, atol=0.8)


@unittest.skipIf(

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/addmm.py	2025-05-25 16:30:16.756896+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/addmm.py	2025-05-25 16:30:38.965265+00:00
@@ -6,10 +6,11 @@
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.types import TRTTensor
import os
+

def addmm(
    ctx: ConversionContext,
    target: Target,
    source_ir: Optional[SourceIR],
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2025-05-25 16:30:16.755896+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2025-05-25 16:30:39.537693+00:00
@@ -272,17 +272,23 @@
            builder_config.set_memory_pool_limit(
                trt.MemoryPoolType.DLA_GLOBAL_DRAM,
                self.compilation_settings.dla_global_dram_size,
            )

-        if not self.compilation_settings.use_explicit_typing and dtype.float16 in self.compilation_settings.enabled_precisions:
+        if (
+            not self.compilation_settings.use_explicit_typing
+            and dtype.float16 in self.compilation_settings.enabled_precisions
+        ):
            builder_config.set_flag(trt.BuilderFlag.FP16)

        if dtype.int8 in self.compilation_settings.enabled_precisions:
            builder_config.set_flag(trt.BuilderFlag.INT8)

-        if not self.compilation_settings.use_explicit_typing and dtype.fp8 in self.compilation_settings.enabled_precisions:
+        if (
+            not self.compilation_settings.use_explicit_typing
+            and dtype.fp8 in self.compilation_settings.enabled_precisions
+        ):
            builder_config.set_flag(trt.BuilderFlag.FP8)

        if dtype.bfloat16 in self.compilation_settings.enabled_precisions:
            builder_config.set_flag(trt.BuilderFlag.BF16)

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/permutation.py	2025-05-25 16:30:16.757896+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/permutation.py	2025-05-25 16:30:39.599915+00:00
@@ -13,10 +13,11 @@
)
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
from torch_tensorrt.fx.types import TRTTensor
import os

+
def permute(
    ctx: ConversionContext,
    target: Target,
    source_ir: Optional[SourceIR],
    name: str,
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-05-25 16:30:16.783896+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-05-25 16:30:44.850706+00:00
@@ -13,10 +13,11 @@
from packaging.version import Version

assertions = unittest.TestCase()
import os

+
@pytest.mark.unit
def test_resnet18(ir):
    model = models.resnet18(pretrained=True).eval().to("cuda")
    input = torch.randn((1, 3, 224, 224)).to("cuda")

@@ -208,10 +209,11 @@
)
@pytest.mark.unit
def test_base_fp4(ir):
    import modelopt.torch.quantization as mtq
    from modelopt.torch.quantization.utils import export_torch_mode
+
    dtype = torch.float16

    class SimpleNetwork(torch.nn.Module):
        def __init__(self):
            super(SimpleNetwork, self).__init__()
@@ -227,21 +229,20 @@
        """Simple calibration function for testing."""
        model(input_tensor)

    input_tensor = torch.ones(128, 64, dtype=dtype).cuda()

-    
    model = SimpleNetwork().eval().cuda()
    model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=dtype).cuda())
    model.linear1.bias = torch.nn.Parameter(torch.zeros(128, 32, dtype=dtype).cuda())
    print(f"lan added amax: {input_tensor.abs().amax()=}")
    print(f"lan added amax: {model.linear1.weight.abs().amax()=}")
    expected_output = model(input_tensor)
-    print(f"lan added model input: {input_tensor=}")    
+    print(f"lan added model input: {input_tensor=}")
    print(f"lan added model weight: {model.linear1.weight=}")
    print(f"lan added model bias: {model.linear1.bias=}")
-    
+
    quant_cfg = mtq.NVFP4_DEFAULT_CFG
    mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
    # model has qdq nodes at this point
    with torch.no_grad():
        with export_torch_mode():
@@ -269,15 +270,21 @@
                print("lan added disable_gemm is set, compring result with weights")
                expected_output = model.linear1.weight
            else:
                print("lan added disable_gemm is not set, compring result with pytorch")

-            print(f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=} {outputs_trt.abs().amax()=}")
-            print(f"lan added expected output_pyt: {expected_output=} {expected_output.dtype=} {expected_output.shape=} {expected_output.abs().amax()=}")
+            print(
+                f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=} {outputs_trt.abs().amax()=}"
+            )
+            print(
+                f"lan added expected output_pyt: {expected_output=} {expected_output.dtype=} {expected_output.shape=} {expected_output.abs().amax()=}"
+            )

            abs_diff = torch.abs(expected_output - outputs_trt)
-            print(f"lan added max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}")
+            print(
+                f"lan added max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}"
+            )
            print(f"lan added abs_diff: {abs_diff=}")
            assert torch.allclose(expected_output, outputs_trt, rtol=0.8, atol=0.8)


@unittest.skipIf(

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/addmm.py	2025-05-25 16:56:16.246144+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/addmm.py	2025-05-25 16:56:38.065465+00:00
@@ -6,10 +6,11 @@
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.types import TRTTensor
import os
+

def addmm(
    ctx: ConversionContext,
    target: Target,
    source_ir: Optional[SourceIR],
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2025-05-25 16:56:16.246144+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2025-05-25 16:56:38.589217+00:00
@@ -272,17 +272,23 @@
            builder_config.set_memory_pool_limit(
                trt.MemoryPoolType.DLA_GLOBAL_DRAM,
                self.compilation_settings.dla_global_dram_size,
            )

-        if not self.compilation_settings.use_explicit_typing and dtype.float16 in self.compilation_settings.enabled_precisions:
+        if (
+            not self.compilation_settings.use_explicit_typing
+            and dtype.float16 in self.compilation_settings.enabled_precisions
+        ):
            builder_config.set_flag(trt.BuilderFlag.FP16)

        if dtype.int8 in self.compilation_settings.enabled_precisions:
            builder_config.set_flag(trt.BuilderFlag.INT8)

-        if not self.compilation_settings.use_explicit_typing and dtype.fp8 in self.compilation_settings.enabled_precisions:
+        if (
+            not self.compilation_settings.use_explicit_typing
+            and dtype.fp8 in self.compilation_settings.enabled_precisions
+        ):
            builder_config.set_flag(trt.BuilderFlag.FP8)

        if dtype.bfloat16 in self.compilation_settings.enabled_precisions:
            builder_config.set_flag(trt.BuilderFlag.BF16)

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/permutation.py	2025-05-25 16:56:16.247144+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/permutation.py	2025-05-25 16:56:38.658356+00:00
@@ -13,10 +13,11 @@
)
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
from torch_tensorrt.fx.types import TRTTensor
import os

+
def permute(
    ctx: ConversionContext,
    target: Target,
    source_ir: Optional[SourceIR],
    name: str,
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-05-25 16:56:16.274145+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-05-25 16:56:44.051114+00:00
@@ -13,10 +13,11 @@
from packaging.version import Version

assertions = unittest.TestCase()
import os

+
@pytest.mark.unit
def test_resnet18(ir):
    model = models.resnet18(pretrained=True).eval().to("cuda")
    input = torch.randn((1, 3, 224, 224)).to("cuda")

@@ -208,10 +209,11 @@
)
@pytest.mark.unit
def test_base_fp4(ir):
    import modelopt.torch.quantization as mtq
    from modelopt.torch.quantization.utils import export_torch_mode
+
    dtype = torch.float16

    class SimpleNetwork(torch.nn.Module):
        def __init__(self):
            super(SimpleNetwork, self).__init__()
@@ -227,21 +229,20 @@
        """Simple calibration function for testing."""
        model(input_tensor)

    input_tensor = torch.ones(128, 64, dtype=dtype).cuda()

-    
    model = SimpleNetwork().eval().cuda()
    model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=dtype).cuda())
    model.linear1.bias = torch.nn.Parameter(torch.zeros(128, 32, dtype=dtype).cuda())
    print(f"lan added amax: {input_tensor.abs().amax()=}")
    print(f"lan added amax: {model.linear1.weight.abs().amax()=}")
    expected_output = model(input_tensor)
-    print(f"lan added model input: {input_tensor=}")    
+    print(f"lan added model input: {input_tensor=}")
    print(f"lan added model weight: {model.linear1.weight=}")
    print(f"lan added model bias: {model.linear1.bias=}")
-    
+
    quant_cfg = mtq.NVFP4_DEFAULT_CFG
    mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
    # model has qdq nodes at this point
    with torch.no_grad():
        with export_torch_mode():
@@ -269,15 +270,21 @@
                print("lan added disable_gemm is set, compring result with weights")
                expected_output = model.linear1.weight
            else:
                print("lan added disable_gemm is not set, compring result with pytorch")

-            print(f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=} {outputs_trt.abs().amax()=}")
-            print(f"lan added expected output_pyt: {expected_output=} {expected_output.dtype=} {expected_output.shape=} {expected_output.abs().amax()=}")
+            print(
+                f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=} {outputs_trt.abs().amax()=}"
+            )
+            print(
+                f"lan added expected output_pyt: {expected_output=} {expected_output.dtype=} {expected_output.shape=} {expected_output.abs().amax()=}"
+            )

            abs_diff = torch.abs(expected_output - outputs_trt)
-            print(f"lan added max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}")
+            print(
+                f"lan added max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}"
+            )
            print(f"lan added abs_diff: {abs_diff=}")
            assert torch.allclose(expected_output, outputs_trt, rtol=0.8, atol=0.8)


@unittest.skipIf(

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/addmm.py	2025-05-25 17:49:25.859096+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/addmm.py	2025-05-25 17:49:49.619103+00:00
@@ -6,10 +6,11 @@
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.types import TRTTensor
import os
+

def addmm(
    ctx: ConversionContext,
    target: Target,
    source_ir: Optional[SourceIR],
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2025-05-25 17:49:25.858096+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2025-05-25 17:49:50.177756+00:00
@@ -272,17 +272,23 @@
            builder_config.set_memory_pool_limit(
                trt.MemoryPoolType.DLA_GLOBAL_DRAM,
                self.compilation_settings.dla_global_dram_size,
            )

-        if not self.compilation_settings.use_explicit_typing and dtype.float16 in self.compilation_settings.enabled_precisions:
+        if (
+            not self.compilation_settings.use_explicit_typing
+            and dtype.float16 in self.compilation_settings.enabled_precisions
+        ):
            builder_config.set_flag(trt.BuilderFlag.FP16)

        if dtype.int8 in self.compilation_settings.enabled_precisions:
            builder_config.set_flag(trt.BuilderFlag.INT8)

-        if not self.compilation_settings.use_explicit_typing and dtype.fp8 in self.compilation_settings.enabled_precisions:
+        if (
+            not self.compilation_settings.use_explicit_typing
+            and dtype.fp8 in self.compilation_settings.enabled_precisions
+        ):
            builder_config.set_flag(trt.BuilderFlag.FP8)

        if dtype.bfloat16 in self.compilation_settings.enabled_precisions:
            builder_config.set_flag(trt.BuilderFlag.BF16)

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/permutation.py	2025-05-25 17:49:25.860096+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/permutation.py	2025-05-25 17:49:50.227263+00:00
@@ -13,10 +13,11 @@
)
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
from torch_tensorrt.fx.types import TRTTensor
import os

+
def permute(
    ctx: ConversionContext,
    target: Target,
    source_ir: Optional[SourceIR],
    name: str,
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-05-25 17:49:25.887096+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py	2025-05-25 17:49:55.813236+00:00
@@ -13,10 +13,11 @@
from packaging.version import Version

assertions = unittest.TestCase()
import os

+
@pytest.mark.unit
def test_resnet18(ir):
    model = models.resnet18(pretrained=True).eval().to("cuda")
    input = torch.randn((1, 3, 224, 224)).to("cuda")

@@ -208,10 +209,11 @@
)
@pytest.mark.unit
def test_base_fp4(ir):
    import modelopt.torch.quantization as mtq
    from modelopt.torch.quantization.utils import export_torch_mode
+
    dtype = torch.float16

    class SimpleNetwork(torch.nn.Module):
        def __init__(self):
            super(SimpleNetwork, self).__init__()
@@ -227,21 +229,20 @@
        """Simple calibration function for testing."""
        model(input_tensor)

    input_tensor = torch.ones(128, 64, dtype=dtype).cuda()

-    
    model = SimpleNetwork().eval().cuda()
    model.linear1.weight = torch.nn.Parameter(torch.ones(32, 64, dtype=dtype).cuda())
    model.linear1.bias = torch.nn.Parameter(torch.zeros(128, 32, dtype=dtype).cuda())
    print(f"lan added amax: {input_tensor.abs().amax()=}")
    print(f"lan added amax: {model.linear1.weight.abs().amax()=}")
    expected_output = model(input_tensor)
-    print(f"lan added model input: {input_tensor=}")    
+    print(f"lan added model input: {input_tensor=}")
    print(f"lan added model weight: {model.linear1.weight=}")
    print(f"lan added model bias: {model.linear1.bias=}")
-    
+
    quant_cfg = mtq.NVFP4_DEFAULT_CFG
    mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
    # model has qdq nodes at this point
    with torch.no_grad():
        with export_torch_mode():
@@ -269,15 +270,21 @@
                print("lan added disable_gemm is set, compring result with weights")
                expected_output = model.linear1.weight
            else:
                print("lan added disable_gemm is not set, compring result with pytorch")

-            print(f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=} {outputs_trt.abs().amax()=}")
-            print(f"lan added expected output_pyt: {expected_output=} {expected_output.dtype=} {expected_output.shape=} {expected_output.abs().amax()=}")
+            print(
+                f"lan added torch_tensorrt outputs_trt: {outputs_trt=} {outputs_trt.dtype=} {outputs_trt.shape=} {outputs_trt.abs().amax()=}"
+            )
+            print(
+                f"lan added expected output_pyt: {expected_output=} {expected_output.dtype=} {expected_output.shape=} {expected_output.abs().amax()=}"
+            )

            abs_diff = torch.abs(expected_output - outputs_trt)
-            print(f"lan added max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}")
+            print(
+                f"lan added max /mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}"
+            )
            print(f"lan added abs_diff: {abs_diff=}")
            assert torch.allclose(expected_output, outputs_trt, rtol=0.8, atol=0.8)


@unittest.skipIf(

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: build system Issues re: Build system component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants