Skip to content

Commit 9f75c42

Browse files
committed
Fixed the bug of memory overflow
1 parent caf1968 commit 9f75c42

File tree

3 files changed

+11
-15
lines changed

3 files changed

+11
-15
lines changed

examples/apps/flux_demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def forward_loop(mod):
8282
pipe.transformer = mod
8383
do_calibrate(
8484
pipe=pipe,
85-
prompt="test",
85+
prompt="a dog running in a park",
8686
)
8787

8888
if args.dtype != "fp16":

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,20 +98,21 @@ def replace_node_with_constant(
9898
class _TorchTensorRTConstantFolder(ConstantFolder): # type: ignore[misc]
9999
def __init__(self, *args: Any, **kwargs: Any) -> None:
100100
super().__init__(*args, **kwargs)
101-
102-
def is_impure(self, node: torch.fx.node.Node) -> bool:
103101
# Set of known quantization ops to be excluded from constant folding.
104102
# Currently, we exclude all quantization ops coming from modelopt library.
105-
quantization_ops: Set[torch._ops.OpOverload] = set()
103+
self.quantization_ops = set()
106104
try:
107105
# modelopt import ensures torch.ops.tensorrt.quantize_op.default is registered
108-
import modelopt.torch.quantization as mtq # noqa: F401
106+
import modelopt.torch.quantization as mtq
109107

110108
assert torch.ops.tensorrt.quantize_op.default
111-
quantization_ops.add(torch.ops.tensorrt.quantize_op.default)
112-
quantization_ops.add(torch.ops.tensorrt.dynamic_block_quantize_op.default)
109+
self.quantization_ops.add(torch.ops.tensorrt.quantize_op.default)
113110
except Exception as e:
114111
pass
115-
if quantization_ops and node.target in quantization_ops:
112+
113+
# TODO: Update this function when quantization is added
114+
def is_impure(self, node: torch.fx.node.Node) -> bool:
115+
116+
if node.target in self.quantization_ops:
116117
return True
117118
return False

tools/perf/Flux/flux_perf.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from time import time
55

66
sys.path.append(os.path.join(os.path.dirname(__file__), "../../../examples/apps"))
7-
from flux_demo import compile_model
7+
from flux_demo import compile_model, parse_args
88

99

1010
def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1):
@@ -56,16 +56,11 @@ def main(args):
5656
action="store_true",
5757
help="Use dynamic shapes",
5858
)
59-
parser.add_argument(
60-
"--max_batch_size",
61-
type=int,
62-
default=1,
63-
help="Maximum batch size to use",
64-
)
6559
parser.add_argument(
6660
"--debug",
6761
action="store_true",
6862
help="Use debug mode",
6963
)
64+
parser.add_argument("--max_batch_size", type=int, default=1)
7065
args = parser.parse_args()
7166
main(args)

0 commit comments

Comments
 (0)