Skip to content

Commit 646cd34

Browse files
committed
update optimize=True
1 parent 7e8afff commit 646cd34

File tree

1 file changed

+199
-0
lines changed

1 file changed

+199
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License. See License.txt in the project root for
3+
# license information.
4+
# --------------------------------------------------------------------------
5+
6+
import os
7+
import unittest
8+
9+
import onnx
10+
import torch
11+
from parameterized import parameterized
12+
from parity_utilities import find_transformers_source
13+
14+
if find_transformers_source():
15+
from fusion_options import FusionOptions
16+
from onnx_model import OnnxModel
17+
from optimizer import optimize_model
18+
else:
19+
from onnxruntime.transformers.fusion_options import FusionOptions
20+
from onnxruntime.transformers.onnx_model import OnnxModel
21+
from onnxruntime.transformers.optimizer import optimize_model
22+
23+
24+
# https://github.com/huggingface/transformers/blob/af9b2eaa54c150741f298d6db939af6328e1dc38/src/transformers/models/siglip/modeling_siglip.py#L363
25+
class SiglipAttention(torch.nn.Module):
26+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
27+
28+
# Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
29+
def __init__(self):
30+
super().__init__()
31+
self.embed_dim = 20
32+
self.num_heads = 2
33+
self.head_dim = self.embed_dim // self.num_heads
34+
if self.head_dim * self.num_heads != self.embed_dim:
35+
raise ValueError(
36+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
37+
f" {self.num_heads})."
38+
)
39+
self.scale = self.head_dim**-0.5
40+
41+
self.k_proj = torch.nn.Linear(self.embed_dim, self.embed_dim)
42+
self.v_proj = torch.nn.Linear(self.embed_dim, self.embed_dim)
43+
self.q_proj = torch.nn.Linear(self.embed_dim, self.embed_dim)
44+
self.out_proj = torch.nn.Linear(self.embed_dim, self.embed_dim)
45+
46+
self.k_proj.weight.data.fill_(1)
47+
self.v_proj.weight.data.fill_(1)
48+
self.q_proj.weight.data.fill_(1)
49+
self.out_proj.weight.data.fill_(1)
50+
self.k_proj.bias.data.fill_(1)
51+
self.v_proj.bias.data.fill_(1)
52+
self.q_proj.bias.data.fill_(1)
53+
self.out_proj.bias.data.fill_(1)
54+
55+
def forward(
56+
self,
57+
hidden_states: torch.Tensor,
58+
attention_mask: torch.Tensor | None = None,
59+
output_attentions: bool | None = False,
60+
) -> tuple[torch.Tensor, torch.Tensor | None]:
61+
"""Input shape: Batch x Time x Channel"""
62+
63+
batch_size, q_len, _ = hidden_states.size()
64+
65+
query_states = self.q_proj(hidden_states)
66+
key_states = self.k_proj(hidden_states)
67+
value_states = self.v_proj(hidden_states)
68+
69+
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
70+
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
71+
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
72+
73+
k_v_seq_len = key_states.shape[-2]
74+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
75+
76+
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
77+
raise ValueError(
78+
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
79+
f" {attn_weights.size()}"
80+
)
81+
82+
if attention_mask is not None:
83+
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
84+
raise ValueError(
85+
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
86+
)
87+
attn_weights = attn_weights + attention_mask
88+
89+
# upcast attention to fp32
90+
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
91+
attn_output = torch.matmul(attn_weights, value_states)
92+
93+
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
94+
raise ValueError(
95+
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
96+
f" {attn_output.size()}"
97+
)
98+
99+
attn_output = attn_output.transpose(1, 2).contiguous()
100+
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
101+
102+
attn_output = self.out_proj(attn_output)
103+
104+
return attn_output, attn_weights
105+
106+
107+
class Gemma3VSIGLIPAttentionAndLayerNorm(torch.nn.Module):
108+
def __init__(self):
109+
super().__init__()
110+
self.attn = SiglipAttention()
111+
self.ln = torch.nn.LayerNorm(20, eps=1e-05)
112+
113+
def forward(self, x):
114+
# SkipLayerNorm ------+
115+
# | |
116+
# Attention |
117+
# | |
118+
# MatMul |
119+
# | |
120+
# SkipLayerNorm ------+
121+
122+
# SkipLayerNorm
123+
x = x + x
124+
x = self.ln(x)
125+
residual = x
126+
127+
# Attention + MatMul
128+
x, _ = self.attn(x)
129+
130+
# SkipLayerNorm
131+
x = residual + x
132+
x = self.ln(x)
133+
return x
134+
135+
136+
class TestFusion(unittest.TestCase):
137+
def verify_fusion(self, optimized_model, expected_model_filename):
138+
optimized_model.topological_sort(is_deterministic=True)
139+
140+
expected_model_path = os.path.join(os.path.dirname(__file__), "test_data", "models", expected_model_filename)
141+
expected_model = OnnxModel(onnx.load(expected_model_path))
142+
expected_model.topological_sort(is_deterministic=True)
143+
144+
nodes = optimized_model.model.graph.node
145+
self.assertEqual(len(nodes), len(expected_model.model.graph.node))
146+
147+
for i in range(len(nodes)):
148+
self.assertEqual(nodes[i], expected_model.model.graph.node[i])
149+
150+
for expected_initializer in expected_model.model.graph.initializer:
151+
self.assertTrue(
152+
OnnxModel.has_same_value(
153+
optimized_model.get_initializer(expected_initializer.name),
154+
expected_initializer,
155+
)
156+
)
157+
158+
def export(self, model, inputs) -> onnx.ModelProto:
159+
with torch.no_grad():
160+
onnx_program = torch.onnx.export(
161+
model,
162+
args=inputs,
163+
f=os.path.join(os.path.dirname(__file__), "export.onnx"),
164+
dynamo=True,
165+
)
166+
return onnx_program.model_proto
167+
168+
def tearDown(self):
169+
path = os.path.join(os.path.dirname(__file__), "export.onnx")
170+
if os.path.exists(path):
171+
os.remove(path)
172+
os.remove(path + ".data")
173+
174+
@parameterized.expand(
175+
[
176+
(torch.float32, "gemma3-vision-attention_fp32.onnx"),
177+
# (torch.float16, "gemma3-vision-attention_fp16.onnx"),
178+
]
179+
)
180+
def test_gemma3_vision_attention(self, dtype, model_name):
181+
model = Gemma3VSIGLIPAttentionAndLayerNorm().eval().to(dtype)
182+
inputs = (torch.randn(1, 2, 20, dtype=dtype),)
183+
original_model = self.export(model, inputs)
184+
185+
options = FusionOptions("clip")
186+
optimized_model = optimize_model(
187+
original_model,
188+
model_type="clip",
189+
num_heads=2,
190+
hidden_size=20,
191+
optimization_options=options,
192+
opt_level=0,
193+
)
194+
# onnx.save(optimized_model.model, model_name)
195+
self.verify_fusion(optimized_model, model_name)
196+
197+
198+
if __name__ == "__main__":
199+
unittest.main()

0 commit comments

Comments
 (0)