Skip to content

Commit 9309725

Browse files
committed
chore: updates
1 parent 600e363 commit 9309725

File tree

6 files changed

+436
-85
lines changed

6 files changed

+436
-85
lines changed

examples/dynamo/llama_benchmark.py

Lines changed: 0 additions & 77 deletions
This file was deleted.

examples/dynamo/llm/run_llm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def get_model(args):
4141
args.model,
4242
use_cache=False,
4343
attn_implementation="sdpa",
44-
# num_hidden_layers=1
44+
num_hidden_layers=2
4545
)
4646
.eval()
4747
.cuda()
@@ -59,7 +59,7 @@ def get_model(args):
5959
def compile_torchtrt(model, input_ids, args):
6060
max_seq_len = input_ids.shape[1] + args.num_tokens
6161
ep = export_llm(model, input_ids, max_seq_len=max_seq_len)
62-
62+
position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE)
6363
# Set precision specific flags
6464
use_fp32_acc = False
6565
use_explicit_typing = False
@@ -76,7 +76,7 @@ def compile_torchtrt(model, input_ids, args):
7676
with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
7777
trt_model = torch_tensorrt.dynamo.compile(
7878
ep,
79-
inputs=[input_ids],
79+
inputs=[input_ids, position_ids],
8080
enabled_precisions=enabled_precisions,
8181
# truncate_double=True,
8282
use_explicit_typing=use_explicit_typing,

examples/dynamo/llm/test_gemma.py

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
import torch
2+
3+
torch.backends.cuda.matmul.allow_tf32 = False
4+
torch.backends.cudnn.allow_tf32 = False
5+
6+
import torch.nn as nn
7+
from torch.testing._internal.common_utils import run_tests
8+
from torch.testing._internal.common_utils import TestCase
9+
from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention, Gemma3DecoderLayer
10+
from transformers.models.gemma3.configuration_gemma3 import Gemma3Config
11+
from transformers import AutoModelForCausalLM
12+
import torch_tensorrt
13+
from contextlib import nullcontext
14+
import argparse
15+
import sys
16+
import os
17+
18+
# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
19+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
20+
from register_sdpa import *
21+
22+
23+
ATOL = 1e-5
24+
RTOL = 1e-5
25+
26+
27+
gemma3_model_name = "google/gemma-3-1b-it"
28+
gemma3_model = AutoModelForCausalLM.from_pretrained(
29+
gemma3_model_name,
30+
use_cache=False,
31+
attn_implementation="sdpa",
32+
num_hidden_layers=1,
33+
).eval().cuda()
34+
GEMMA3_CONFIG = gemma3_model.config
35+
36+
def print_diff(tensor1, tensor2, prefix=""):
37+
"""
38+
Print the diff between two tensors
39+
"""
40+
print(f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}")
41+
42+
43+
def test_gemma3_attention(args):
44+
45+
DTYPE = torch.float32
46+
if args.precision == "FP16":
47+
DTYPE = torch.float16
48+
elif args.precision == "BF16":
49+
DTYPE = torch.bfloat16
50+
51+
# Set precision specific flags
52+
use_fp32_acc = False
53+
use_explicit_typing = False
54+
if args.precision == "FP16":
55+
enabled_precisions = {torch.float32}
56+
use_fp32_acc = True
57+
use_explicit_typing = True
58+
elif args.precision == "BF16":
59+
enabled_precisions = {torch.bfloat16}
60+
use_fp32_acc = False
61+
else:
62+
enabled_precisions = {torch.float32}
63+
64+
model = gemma3_model.model.layers[0].self_attn.to(DTYPE)
65+
66+
# gemma3
67+
hidden_states = torch.randn((1, 5, 1152), dtype=DTYPE).cuda()
68+
position_embeddings = (torch.randn((1, 5, 256), dtype=DTYPE).cuda(), torch.randn((1, 5, 256), dtype=DTYPE).cuda())
69+
70+
pyt_output = model(hidden_states, position_embeddings, None)
71+
seq_len = torch.export.Dim("seq_len", min=2, max=2176)
72+
dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None)
73+
ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes)
74+
75+
with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
76+
trt_model = torch_tensorrt.dynamo.compile(ep,
77+
inputs=[hidden_states, position_embeddings, None],
78+
enabled_precisions=enabled_precisions,
79+
disable_tf32=True,
80+
use_fp32_acc=use_fp32_acc,
81+
use_explicit_typing=use_explicit_typing,
82+
debug=args.debug)
83+
trt_output = trt_model(hidden_states, position_embeddings, None)
84+
85+
if isinstance(pyt_output, tuple):
86+
print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt")
87+
assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL)
88+
else:
89+
print_diff(pyt_output, trt_output, "Diff b/w pyt and trt")
90+
assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL)
91+
92+
def test_gemma3_attention_with_static_cache(args):
93+
94+
import static_cache_v2
95+
DTYPE = torch.float32
96+
model = gemma3_model.model.layers[0].self_attn.to(DTYPE)
97+
98+
# Inputs
99+
ISL = 2048
100+
NUM_TOKENS = 128
101+
OSL = ISL + NUM_TOKENS
102+
hidden_states = torch.randn((1, ISL, 1152), dtype=DTYPE).cuda()
103+
position_embeddings = (torch.randn((1, ISL, 256), dtype=DTYPE).cuda(), torch.randn((1, ISL, 256), dtype=DTYPE).cuda())
104+
key_cache = torch.zeros(1, 4, OSL, 64).cuda().to(DTYPE)
105+
value_cache = torch.zeros(1, 4, OSL, 64).cuda().to(DTYPE)
106+
start_idx = 0
107+
end_idx = ISL
108+
is_causal = True
109+
110+
pyt_output = model(hidden_states, position_embeddings, None)
111+
seq_len = torch.export.Dim("seq_len", min=2, max=2176)
112+
dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None)
113+
ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes)
114+
with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
115+
trt_model = torch_tensorrt.dynamo.compile(ep,
116+
inputs=[hidden_states, position_embeddings, None, key_cache, value_cache, start_idx, end_idx, is_causal],
117+
enabled_precisions={torch.float32},
118+
disable_tf32=True,
119+
debug=args.debug,
120+
# offload_module_to_cpu=True,
121+
use_python_runtime=True)
122+
123+
# Test Prefill
124+
trt_output, _, key_cache, value_cache = trt_model(hidden_states, position_embeddings, None, key_cache, value_cache, start_idx, end_idx, is_causal)
125+
print_diff(pyt_output[0], trt_output[0], "pyt_output[0] vs trt_output[0] [Prefill]")
126+
127+
# Test Generate
128+
for start_idx in range(2048, 2176):
129+
end_idx = start_idx + 1
130+
hidden_states_curr = torch.randn((1, 1, 1152), dtype=DTYPE).cuda()
131+
position_embeddings_curr = (torch.randn((1, 1, 256), dtype=DTYPE).cuda(), torch.randn((1, 1, 256), dtype=DTYPE).cuda())
132+
# Concatenate the current hidden_states with the previous ones
133+
hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1)
134+
position_embeddings_full = (torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1), torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1))
135+
136+
is_causal = False
137+
out_no_cache, _ = model(hidden_states_full, position_embeddings_full, None)
138+
out_trt, _, key_cache, value_cache = trt_model(hidden_states_curr, position_embeddings_curr, None, key_cache, value_cache, start_idx, end_idx, is_causal)
139+
out_pyt = out_no_cache[:, -1:, :]
140+
print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}")
141+
142+
hidden_states = hidden_states_full
143+
position_embeddings = position_embeddings_full
144+
145+
def test_gemma3_decoder(args):
146+
147+
DTYPE = torch.float32
148+
if args.precision == "FP16":
149+
DTYPE = torch.float16
150+
elif args.precision == "BF16":
151+
DTYPE = torch.bfloat16
152+
model = gemma3_model.model.layers[0].to(DTYPE)
153+
# model.self_attn.is_sliding = False
154+
155+
# gemma3
156+
hidden_states = torch.randn((1, 6, 1152), dtype=DTYPE).cuda()
157+
position_embeddings_global = (torch.randn((1, 6, 256), dtype=DTYPE).cuda(), torch.randn((1, 6, 256), dtype=DTYPE).cuda())
158+
position_embeddings_local = (torch.randn((1, 6, 256), dtype=DTYPE).cuda(), torch.randn((1, 6, 256), dtype=DTYPE).cuda())
159+
160+
pyt_output = model(hidden_states, position_embeddings_global, position_embeddings_local)
161+
seq_len = torch.export.Dim("seq_len", min=2, max=2176)
162+
dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), ({1: seq_len}, {1: seq_len}))
163+
ep = torch.export.export(model, (hidden_states, position_embeddings_global, position_embeddings_local), dynamic_shapes=dynamic_shapes)
164+
165+
with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
166+
trt_model = torch_tensorrt.dynamo.compile(ep,
167+
inputs=[hidden_states, position_embeddings_global, position_embeddings_local],
168+
enabled_precisions={torch.float32},
169+
debug=args.debug)
170+
trt_output = trt_model(hidden_states, position_embeddings_global, position_embeddings_local)
171+
172+
print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}")
173+
# breakpoint()
174+
assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL)
175+
176+
def test_gemma3_decoder_with_static_cache(args):
177+
178+
class Gemma3DecoderLayerBlock(nn.Module):
179+
def __init__(self, model):
180+
super().__init__()
181+
self.config = GEMMA3_CONFIG
182+
self.decoder = Gemma3DecoderLayer(
183+
config=self.config,
184+
layer_idx=0)
185+
self.model = model
186+
def forward(self, hidden_states, position_embeddings):
187+
return self.model(hidden_states, position_embeddings=position_embeddings)
188+
189+
DTYPE = torch.float32
190+
model = Gemma3DecoderLayerBlock(gemma3_model.model.layers[0].to(DTYPE))
191+
192+
import static_cache_v2
193+
# Inputs
194+
ISL = 2048
195+
NUM_TOKENS = 128
196+
OSL = ISL + NUM_TOKENS
197+
hidden_states = torch.randn((1, ISL, 1152), dtype=DTYPE).cuda()
198+
position_embeddings_global = (torch.randn((1, ISL, 256), dtype=DTYPE).cuda(), torch.randn((1, ISL, 256), dtype=DTYPE).cuda())
199+
position_embeddings_local = (torch.randn((1, NUM_TOKENS, 256), dtype=DTYPE).cuda(), torch.randn((1, NUM_TOKENS, 256), dtype=DTYPE).cuda())
200+
key_cache = torch.zeros(1, 4, OSL, 64).cuda().to(DTYPE)
201+
value_cache = torch.zeros(1, 4, OSL, 64).cuda().to(DTYPE)
202+
start_idx = 0
203+
end_idx = ISL
204+
is_causal = True
205+
206+
pyt_output = model(hidden_states, position_embeddings_global, position_embeddings_local)
207+
seq_len = torch.export.Dim("seq_len", min=2, max=2176)
208+
dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}))
209+
ep = torch.export.export(model, args=(hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes)
210+
211+
with (torch_tensorrt.logging.debug() if args.debug else nullcontext()):
212+
trt_model = torch_tensorrt.dynamo.compile(ep,
213+
arg_inputs=[hidden_states, position_embeddings, key_cache, value_cache, start_idx, end_idx, is_causal],
214+
enabled_precisions={torch.float32},
215+
disable_tf32=True,
216+
debug=args.debug,
217+
# offload_module_to_cpu=True,
218+
use_python_runtime=True)
219+
220+
# Test Prefill
221+
trt_output, key_cache, value_cache = trt_model(hidden_states, position_embeddings, key_cache, value_cache, start_idx, end_idx, is_causal)
222+
print_diff(pyt_output[0], trt_output, "pyt_output vs trt_output [Prefill]")
223+
224+
# Test Generate
225+
for start_idx in range(2048, 2176):
226+
end_idx = start_idx + 1
227+
hidden_states_curr = torch.randn((1, 1, 1152), dtype=DTYPE).cuda()
228+
position_embeddings_curr = (torch.randn((1, 1, 256), dtype=DTYPE).cuda(), torch.randn((1, 1, 256), dtype=DTYPE).cuda())
229+
# Concatenate the current hidden_states with the previous ones
230+
hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1)
231+
position_embeddings_full = (torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1), torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1))
232+
233+
is_causal = False
234+
out_no_cache = model(hidden_states_full, position_embeddings_full)
235+
236+
out_trt, key_cache, value_cache = trt_model(hidden_states_curr, position_embeddings_curr, key_cache, value_cache, start_idx, end_idx, is_causal)
237+
out_pyt = out_no_cache[0][:, -1:, :]
238+
print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}")
239+
hidden_states = hidden_states_full
240+
position_embeddings = position_embeddings_full
241+
242+
243+
if __name__ == "__main__":
244+
arg_parser = argparse.ArgumentParser(
245+
description="Run test cases for llama attention and decoder"
246+
)
247+
arg_parser.add_argument(
248+
"--debug",
249+
action="store_true",
250+
help="Enable debug (default: False)"
251+
)
252+
arg_parser.add_argument("--precision", type=str, default="FP16", help="Precision to use in the model. Options: FP16, BF16, FP32")
253+
args = arg_parser.parse_args()
254+
with torch.inference_mode():
255+
# test_gemma3_attention(args)
256+
# test_gemma3_attention_with_static_cache(args)
257+
test_gemma3_decoder(args)
258+
# test_gemma3_decoder_with_static_cache(args)

examples/dynamo/llm/test_qwen2.5_components.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import torch.nn as nn
77
from torch.testing._internal.common_utils import run_tests
88
from torch.testing._internal.common_utils import TestCase
9-
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer
109
from transformers.models.llama.configuration_llama import LlamaConfig
1110
from transformers import AutoModelForCausalLM
1211
import torch_tensorrt

0 commit comments

Comments
 (0)