1
+ import torch
2
+
3
+ torch .backends .cuda .matmul .allow_tf32 = False
4
+ torch .backends .cudnn .allow_tf32 = False
5
+
6
+ import torch .nn as nn
7
+ from torch .testing ._internal .common_utils import run_tests
8
+ from torch .testing ._internal .common_utils import TestCase
9
+ from transformers .models .gemma3 .modeling_gemma3 import Gemma3Attention , Gemma3DecoderLayer
10
+ from transformers .models .gemma3 .configuration_gemma3 import Gemma3Config
11
+ from transformers import AutoModelForCausalLM
12
+ import torch_tensorrt
13
+ from contextlib import nullcontext
14
+ import argparse
15
+ import sys
16
+ import os
17
+
18
+ # Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
19
+ sys .path .append (os .path .join (os .path .dirname (__file__ ), '..' ))
20
+ from register_sdpa import *
21
+
22
+
23
+ ATOL = 1e-5
24
+ RTOL = 1e-5
25
+
26
+
27
+ gemma3_model_name = "google/gemma-3-1b-it"
28
+ gemma3_model = AutoModelForCausalLM .from_pretrained (
29
+ gemma3_model_name ,
30
+ use_cache = False ,
31
+ attn_implementation = "sdpa" ,
32
+ num_hidden_layers = 1 ,
33
+ ).eval ().cuda ()
34
+ GEMMA3_CONFIG = gemma3_model .config
35
+
36
+ def print_diff (tensor1 , tensor2 , prefix = "" ):
37
+ """
38
+ Print the diff between two tensors
39
+ """
40
+ print (f"[{ prefix } ] Diff between tensor1 and tensor2: { torch .mean (torch .abs (tensor1 - tensor2 ))} " )
41
+
42
+
43
+ def test_gemma3_attention (args ):
44
+
45
+ DTYPE = torch .float32
46
+ if args .precision == "FP16" :
47
+ DTYPE = torch .float16
48
+ elif args .precision == "BF16" :
49
+ DTYPE = torch .bfloat16
50
+
51
+ # Set precision specific flags
52
+ use_fp32_acc = False
53
+ use_explicit_typing = False
54
+ if args .precision == "FP16" :
55
+ enabled_precisions = {torch .float32 }
56
+ use_fp32_acc = True
57
+ use_explicit_typing = True
58
+ elif args .precision == "BF16" :
59
+ enabled_precisions = {torch .bfloat16 }
60
+ use_fp32_acc = False
61
+ else :
62
+ enabled_precisions = {torch .float32 }
63
+
64
+ model = gemma3_model .model .layers [0 ].self_attn .to (DTYPE )
65
+
66
+ # gemma3
67
+ hidden_states = torch .randn ((1 , 5 , 1152 ), dtype = DTYPE ).cuda ()
68
+ position_embeddings = (torch .randn ((1 , 5 , 256 ), dtype = DTYPE ).cuda (), torch .randn ((1 , 5 , 256 ), dtype = DTYPE ).cuda ())
69
+
70
+ pyt_output = model (hidden_states , position_embeddings , None )
71
+ seq_len = torch .export .Dim ("seq_len" , min = 2 , max = 2176 )
72
+ dynamic_shapes = ({1 : seq_len }, ({1 : seq_len }, {1 : seq_len }), None )
73
+ ep = torch .export .export (model , (hidden_states , position_embeddings , None ), dynamic_shapes = dynamic_shapes )
74
+
75
+ with (torch_tensorrt .logging .debug () if args .debug else nullcontext ()):
76
+ trt_model = torch_tensorrt .dynamo .compile (ep ,
77
+ inputs = [hidden_states , position_embeddings , None ],
78
+ enabled_precisions = enabled_precisions ,
79
+ disable_tf32 = True ,
80
+ use_fp32_acc = use_fp32_acc ,
81
+ use_explicit_typing = use_explicit_typing ,
82
+ debug = args .debug )
83
+ trt_output = trt_model (hidden_states , position_embeddings , None )
84
+
85
+ if isinstance (pyt_output , tuple ):
86
+ print_diff (pyt_output [0 ], trt_output [0 ], "Diff b/w pyt and trt" )
87
+ assert torch .allclose (pyt_output [0 ], trt_output [0 ], atol = ATOL , rtol = RTOL )
88
+ else :
89
+ print_diff (pyt_output , trt_output , "Diff b/w pyt and trt" )
90
+ assert torch .allclose (pyt_output , trt_output , atol = ATOL , rtol = RTOL )
91
+
92
+ def test_gemma3_attention_with_static_cache (args ):
93
+
94
+ import static_cache_v2
95
+ DTYPE = torch .float32
96
+ model = gemma3_model .model .layers [0 ].self_attn .to (DTYPE )
97
+
98
+ # Inputs
99
+ ISL = 2048
100
+ NUM_TOKENS = 128
101
+ OSL = ISL + NUM_TOKENS
102
+ hidden_states = torch .randn ((1 , ISL , 1152 ), dtype = DTYPE ).cuda ()
103
+ position_embeddings = (torch .randn ((1 , ISL , 256 ), dtype = DTYPE ).cuda (), torch .randn ((1 , ISL , 256 ), dtype = DTYPE ).cuda ())
104
+ key_cache = torch .zeros (1 , 4 , OSL , 64 ).cuda ().to (DTYPE )
105
+ value_cache = torch .zeros (1 , 4 , OSL , 64 ).cuda ().to (DTYPE )
106
+ start_idx = 0
107
+ end_idx = ISL
108
+ is_causal = True
109
+
110
+ pyt_output = model (hidden_states , position_embeddings , None )
111
+ seq_len = torch .export .Dim ("seq_len" , min = 2 , max = 2176 )
112
+ dynamic_shapes = ({1 : seq_len }, ({1 : seq_len }, {1 : seq_len }), None )
113
+ ep = torch .export .export (model , (hidden_states , position_embeddings , None ), dynamic_shapes = dynamic_shapes )
114
+ with (torch_tensorrt .logging .debug () if args .debug else nullcontext ()):
115
+ trt_model = torch_tensorrt .dynamo .compile (ep ,
116
+ inputs = [hidden_states , position_embeddings , None , key_cache , value_cache , start_idx , end_idx , is_causal ],
117
+ enabled_precisions = {torch .float32 },
118
+ disable_tf32 = True ,
119
+ debug = args .debug ,
120
+ # offload_module_to_cpu=True,
121
+ use_python_runtime = True )
122
+
123
+ # Test Prefill
124
+ trt_output , _ , key_cache , value_cache = trt_model (hidden_states , position_embeddings , None , key_cache , value_cache , start_idx , end_idx , is_causal )
125
+ print_diff (pyt_output [0 ], trt_output [0 ], "pyt_output[0] vs trt_output[0] [Prefill]" )
126
+
127
+ # Test Generate
128
+ for start_idx in range (2048 , 2176 ):
129
+ end_idx = start_idx + 1
130
+ hidden_states_curr = torch .randn ((1 , 1 , 1152 ), dtype = DTYPE ).cuda ()
131
+ position_embeddings_curr = (torch .randn ((1 , 1 , 256 ), dtype = DTYPE ).cuda (), torch .randn ((1 , 1 , 256 ), dtype = DTYPE ).cuda ())
132
+ # Concatenate the current hidden_states with the previous ones
133
+ hidden_states_full = torch .cat ((hidden_states , hidden_states_curr ), dim = 1 )
134
+ position_embeddings_full = (torch .cat ((position_embeddings [0 ], position_embeddings_curr [0 ]), dim = 1 ), torch .cat ((position_embeddings [1 ], position_embeddings_curr [1 ]), dim = 1 ))
135
+
136
+ is_causal = False
137
+ out_no_cache , _ = model (hidden_states_full , position_embeddings_full , None )
138
+ out_trt , _ , key_cache , value_cache = trt_model (hidden_states_curr , position_embeddings_curr , None , key_cache , value_cache , start_idx , end_idx , is_causal )
139
+ out_pyt = out_no_cache [:, - 1 :, :]
140
+ print_diff (out_pyt , out_trt , f"pyt_curr_output vs out_trt for idx { start_idx } " )
141
+
142
+ hidden_states = hidden_states_full
143
+ position_embeddings = position_embeddings_full
144
+
145
+ def test_gemma3_decoder (args ):
146
+
147
+ DTYPE = torch .float32
148
+ if args .precision == "FP16" :
149
+ DTYPE = torch .float16
150
+ elif args .precision == "BF16" :
151
+ DTYPE = torch .bfloat16
152
+ model = gemma3_model .model .layers [0 ].to (DTYPE )
153
+ # model.self_attn.is_sliding = False
154
+
155
+ # gemma3
156
+ hidden_states = torch .randn ((1 , 6 , 1152 ), dtype = DTYPE ).cuda ()
157
+ position_embeddings_global = (torch .randn ((1 , 6 , 256 ), dtype = DTYPE ).cuda (), torch .randn ((1 , 6 , 256 ), dtype = DTYPE ).cuda ())
158
+ position_embeddings_local = (torch .randn ((1 , 6 , 256 ), dtype = DTYPE ).cuda (), torch .randn ((1 , 6 , 256 ), dtype = DTYPE ).cuda ())
159
+
160
+ pyt_output = model (hidden_states , position_embeddings_global , position_embeddings_local )
161
+ seq_len = torch .export .Dim ("seq_len" , min = 2 , max = 2176 )
162
+ dynamic_shapes = ({1 : seq_len }, ({1 : seq_len }, {1 : seq_len }), ({1 : seq_len }, {1 : seq_len }))
163
+ ep = torch .export .export (model , (hidden_states , position_embeddings_global , position_embeddings_local ), dynamic_shapes = dynamic_shapes )
164
+
165
+ with (torch_tensorrt .logging .debug () if args .debug else nullcontext ()):
166
+ trt_model = torch_tensorrt .dynamo .compile (ep ,
167
+ inputs = [hidden_states , position_embeddings_global , position_embeddings_local ],
168
+ enabled_precisions = {torch .float32 },
169
+ debug = args .debug )
170
+ trt_output = trt_model (hidden_states , position_embeddings_global , position_embeddings_local )
171
+
172
+ print (f"Diff b/w pyt and trt: { torch .mean (torch .abs (pyt_output [0 ] - trt_output [0 ]))} " )
173
+ # breakpoint()
174
+ assert torch .allclose (pyt_output [0 ], trt_output [0 ], atol = ATOL , rtol = RTOL )
175
+
176
+ def test_gemma3_decoder_with_static_cache (args ):
177
+
178
+ class Gemma3DecoderLayerBlock (nn .Module ):
179
+ def __init__ (self , model ):
180
+ super ().__init__ ()
181
+ self .config = GEMMA3_CONFIG
182
+ self .decoder = Gemma3DecoderLayer (
183
+ config = self .config ,
184
+ layer_idx = 0 )
185
+ self .model = model
186
+ def forward (self , hidden_states , position_embeddings ):
187
+ return self .model (hidden_states , position_embeddings = position_embeddings )
188
+
189
+ DTYPE = torch .float32
190
+ model = Gemma3DecoderLayerBlock (gemma3_model .model .layers [0 ].to (DTYPE ))
191
+
192
+ import static_cache_v2
193
+ # Inputs
194
+ ISL = 2048
195
+ NUM_TOKENS = 128
196
+ OSL = ISL + NUM_TOKENS
197
+ hidden_states = torch .randn ((1 , ISL , 1152 ), dtype = DTYPE ).cuda ()
198
+ position_embeddings_global = (torch .randn ((1 , ISL , 256 ), dtype = DTYPE ).cuda (), torch .randn ((1 , ISL , 256 ), dtype = DTYPE ).cuda ())
199
+ position_embeddings_local = (torch .randn ((1 , NUM_TOKENS , 256 ), dtype = DTYPE ).cuda (), torch .randn ((1 , NUM_TOKENS , 256 ), dtype = DTYPE ).cuda ())
200
+ key_cache = torch .zeros (1 , 4 , OSL , 64 ).cuda ().to (DTYPE )
201
+ value_cache = torch .zeros (1 , 4 , OSL , 64 ).cuda ().to (DTYPE )
202
+ start_idx = 0
203
+ end_idx = ISL
204
+ is_causal = True
205
+
206
+ pyt_output = model (hidden_states , position_embeddings_global , position_embeddings_local )
207
+ seq_len = torch .export .Dim ("seq_len" , min = 2 , max = 2176 )
208
+ dynamic_shapes = ({1 : seq_len }, ({1 : seq_len }, {1 : seq_len }))
209
+ ep = torch .export .export (model , args = (hidden_states , position_embeddings ), dynamic_shapes = dynamic_shapes )
210
+
211
+ with (torch_tensorrt .logging .debug () if args .debug else nullcontext ()):
212
+ trt_model = torch_tensorrt .dynamo .compile (ep ,
213
+ arg_inputs = [hidden_states , position_embeddings , key_cache , value_cache , start_idx , end_idx , is_causal ],
214
+ enabled_precisions = {torch .float32 },
215
+ disable_tf32 = True ,
216
+ debug = args .debug ,
217
+ # offload_module_to_cpu=True,
218
+ use_python_runtime = True )
219
+
220
+ # Test Prefill
221
+ trt_output , key_cache , value_cache = trt_model (hidden_states , position_embeddings , key_cache , value_cache , start_idx , end_idx , is_causal )
222
+ print_diff (pyt_output [0 ], trt_output , "pyt_output vs trt_output [Prefill]" )
223
+
224
+ # Test Generate
225
+ for start_idx in range (2048 , 2176 ):
226
+ end_idx = start_idx + 1
227
+ hidden_states_curr = torch .randn ((1 , 1 , 1152 ), dtype = DTYPE ).cuda ()
228
+ position_embeddings_curr = (torch .randn ((1 , 1 , 256 ), dtype = DTYPE ).cuda (), torch .randn ((1 , 1 , 256 ), dtype = DTYPE ).cuda ())
229
+ # Concatenate the current hidden_states with the previous ones
230
+ hidden_states_full = torch .cat ((hidden_states , hidden_states_curr ), dim = 1 )
231
+ position_embeddings_full = (torch .cat ((position_embeddings [0 ], position_embeddings_curr [0 ]), dim = 1 ), torch .cat ((position_embeddings [1 ], position_embeddings_curr [1 ]), dim = 1 ))
232
+
233
+ is_causal = False
234
+ out_no_cache = model (hidden_states_full , position_embeddings_full )
235
+
236
+ out_trt , key_cache , value_cache = trt_model (hidden_states_curr , position_embeddings_curr , key_cache , value_cache , start_idx , end_idx , is_causal )
237
+ out_pyt = out_no_cache [0 ][:, - 1 :, :]
238
+ print_diff (out_pyt , out_trt , f"pyt_curr_output vs out_trt for idx { start_idx } " )
239
+ hidden_states = hidden_states_full
240
+ position_embeddings = position_embeddings_full
241
+
242
+
243
+ if __name__ == "__main__" :
244
+ arg_parser = argparse .ArgumentParser (
245
+ description = "Run test cases for llama attention and decoder"
246
+ )
247
+ arg_parser .add_argument (
248
+ "--debug" ,
249
+ action = "store_true" ,
250
+ help = "Enable debug (default: False)"
251
+ )
252
+ arg_parser .add_argument ("--precision" , type = str , default = "FP16" , help = "Precision to use in the model. Options: FP16, BF16, FP32" )
253
+ args = arg_parser .parse_args ()
254
+ with torch .inference_mode ():
255
+ # test_gemma3_attention(args)
256
+ # test_gemma3_attention_with_static_cache(args)
257
+ test_gemma3_decoder (args )
258
+ # test_gemma3_decoder_with_static_cache(args)
0 commit comments