flux fp4 example(WIP)#3537
Conversation
There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/addmm.py 2025-05-28 16:10:39.268834+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/addmm.py 2025-05-28 16:11:01.327800+00:00
@@ -6,10 +6,11 @@
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.types import TRTTensor
import os
+
def addmm(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2025-05-28 16:10:39.267834+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2025-05-28 16:11:01.908540+00:00
@@ -272,17 +272,23 @@
builder_config.set_memory_pool_limit(
trt.MemoryPoolType.DLA_GLOBAL_DRAM,
self.compilation_settings.dla_global_dram_size,
)
- if not self.compilation_settings.use_explicit_typing and dtype.float16 in self.compilation_settings.enabled_precisions:
+ if (
+ not self.compilation_settings.use_explicit_typing
+ and dtype.float16 in self.compilation_settings.enabled_precisions
+ ):
builder_config.set_flag(trt.BuilderFlag.FP16)
if dtype.int8 in self.compilation_settings.enabled_precisions:
builder_config.set_flag(trt.BuilderFlag.INT8)
- if not self.compilation_settings.use_explicit_typing and dtype.fp8 in self.compilation_settings.enabled_precisions:
+ if (
+ not self.compilation_settings.use_explicit_typing
+ and dtype.fp8 in self.compilation_settings.enabled_precisions
+ ):
builder_config.set_flag(trt.BuilderFlag.FP8)
if dtype.bfloat16 in self.compilation_settings.enabled_precisions:
builder_config.set_flag(trt.BuilderFlag.BF16)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/permutation.py 2025-05-28 16:10:39.269834+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/permutation.py 2025-05-28 16:11:01.949586+00:00
@@ -13,10 +13,11 @@
)
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
from torch_tensorrt.fx.types import TRTTensor
import os
+
def permute(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/quantize.py 2025-06-08 16:14:53.013799+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/quantize.py 2025-06-08 16:15:15.490934+00:00
@@ -94,6 +94,6 @@
if axis is not None:
dequantize_layer.axis = axis
set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir)
dq_output = dequantize_layer.get_output(0)
- return dq_output
\ No newline at end of file
+ return dq_outputThere was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/quantize.py 2025-06-09 16:52:54.851163+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/quantize.py 2025-06-09 16:53:17.995276+00:00
@@ -94,6 +94,6 @@
if axis is not None:
dequantize_layer.axis = axis
set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir)
dq_output = dequantize_layer.get_output(0)
- return dq_output
\ No newline at end of file
+ return dq_outputThere was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/quantize.py 2025-06-09 23:54:44.134169+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/quantize.py 2025-06-09 23:55:10.573434+00:00
@@ -94,6 +94,6 @@
if axis is not None:
dequantize_layer.axis = axis
set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir)
dq_output = dequantize_layer.get_output(0)
- return dq_output
\ No newline at end of file
+ return dq_outputThere was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/quantize.py 2025-06-12 14:38:26.461575+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/quantize.py 2025-06-12 14:38:49.017416+00:00
@@ -94,6 +94,6 @@
if axis is not None:
dequantize_layer.axis = axis
set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir)
dq_output = dequantize_layer.get_output(0)
- return dq_output
\ No newline at end of file
+ return dq_output
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py 2025-06-12 14:38:26.463575+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py 2025-06-12 14:38:49.337824+00:00
@@ -98,16 +98,17 @@
class _TorchTensorRTConstantFolder(ConstantFolder): # type: ignore[misc]
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
def is_impure(self, node: torch.fx.node.Node) -> bool:
- # Set of known quantization ops to be excluded from constant folding.
+ # Set of known quantization ops to be excluded from constant folding.
# Currently, we exclude all quantization ops coming from modelopt library.
quantization_ops = {}
try:
- # modelopt import ensures torch.ops.tensorrt.quantize_op.default is registered
+ # modelopt import ensures torch.ops.tensorrt.quantize_op.default is registered
import modelopt.torch.quantization as mtq
+
assert torch.ops.tensorrt.quantize_op.default
quantization_ops.add(torch.ops.tensorrt.quantize_op.default)
quantization_ops.add(torch.ops.tensorrt.dynamic_block_quantize_op.default)
except Exception as e:
passThere was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/quantize.py 2025-06-12 14:44:47.754478+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/quantize.py 2025-06-12 14:45:12.051449+00:00
@@ -94,6 +94,6 @@
if axis is not None:
dequantize_layer.axis = axis
set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir)
dq_output = dequantize_layer.get_output(0)
- return dq_output
\ No newline at end of file
+ return dq_output
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py 2025-06-12 14:44:47.756479+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py 2025-06-12 14:45:12.430297+00:00
@@ -98,16 +98,17 @@
class _TorchTensorRTConstantFolder(ConstantFolder): # type: ignore[misc]
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
def is_impure(self, node: torch.fx.node.Node) -> bool:
- # Set of known quantization ops to be excluded from constant folding.
+ # Set of known quantization ops to be excluded from constant folding.
# Currently, we exclude all quantization ops coming from modelopt library.
quantization_ops = {}
try:
- # modelopt import ensures torch.ops.tensorrt.quantize_op.default is registered
+ # modelopt import ensures torch.ops.tensorrt.quantize_op.default is registered
import modelopt.torch.quantization as mtq
+
assert torch.ops.tensorrt.quantize_op.default
quantization_ops.add(torch.ops.tensorrt.quantize_op.default)
quantization_ops.add(torch.ops.tensorrt.dynamic_block_quantize_op.default)
except Exception as e:
passThere was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/quantize.py 2025-06-12 14:46:31.608373+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/quantize.py 2025-06-12 14:46:59.413246+00:00
@@ -94,6 +94,6 @@
if axis is not None:
dequantize_layer.axis = axis
set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir)
dq_output = dequantize_layer.get_output(0)
- return dq_output
\ No newline at end of file
+ return dq_output
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py 2025-06-12 14:46:31.610373+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py 2025-06-12 14:46:59.893008+00:00
@@ -98,16 +98,17 @@
class _TorchTensorRTConstantFolder(ConstantFolder): # type: ignore[misc]
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
def is_impure(self, node: torch.fx.node.Node) -> bool:
- # Set of known quantization ops to be excluded from constant folding.
+ # Set of known quantization ops to be excluded from constant folding.
# Currently, we exclude all quantization ops coming from modelopt library.
quantization_ops = {}
try:
- # modelopt import ensures torch.ops.tensorrt.quantize_op.default is registered
+ # modelopt import ensures torch.ops.tensorrt.quantize_op.default is registered
import modelopt.torch.quantization as mtq
+
assert torch.ops.tensorrt.quantize_op.default
quantization_ops.add(torch.ops.tensorrt.quantize_op.default)
quantization_ops.add(torch.ops.tensorrt.dynamic_block_quantize_op.default)
except Exception as e:
pass
Description
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: