|
| 1 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +# Licensed under the MIT License. See License.txt in the project root for |
| 3 | +# license information. |
| 4 | +# -------------------------------------------------------------------------- |
| 5 | + |
| 6 | +import os |
| 7 | +import unittest |
| 8 | + |
| 9 | +import onnx |
| 10 | +import torch |
| 11 | +from parameterized import parameterized |
| 12 | +from parity_utilities import find_transformers_source |
| 13 | + |
| 14 | +if find_transformers_source(): |
| 15 | + from fusion_options import FusionOptions |
| 16 | + from onnx_model import OnnxModel |
| 17 | + from optimizer import optimize_model |
| 18 | +else: |
| 19 | + from onnxruntime.transformers.fusion_options import FusionOptions |
| 20 | + from onnxruntime.transformers.onnx_model import OnnxModel |
| 21 | + from onnxruntime.transformers.optimizer import optimize_model |
| 22 | + |
| 23 | + |
| 24 | +# https://github.com/huggingface/transformers/blob/af9b2eaa54c150741f298d6db939af6328e1dc38/src/transformers/models/siglip/modeling_siglip.py#L363 |
| 25 | +class SiglipAttention(torch.nn.Module): |
| 26 | + """Multi-headed attention from 'Attention Is All You Need' paper""" |
| 27 | + |
| 28 | + # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ |
| 29 | + def __init__(self): |
| 30 | + super().__init__() |
| 31 | + self.embed_dim = 20 |
| 32 | + self.num_heads = 2 |
| 33 | + self.head_dim = self.embed_dim // self.num_heads |
| 34 | + if self.head_dim * self.num_heads != self.embed_dim: |
| 35 | + raise ValueError( |
| 36 | + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" |
| 37 | + f" {self.num_heads})." |
| 38 | + ) |
| 39 | + self.scale = self.head_dim**-0.5 |
| 40 | + |
| 41 | + self.k_proj = torch.nn.Linear(self.embed_dim, self.embed_dim) |
| 42 | + self.v_proj = torch.nn.Linear(self.embed_dim, self.embed_dim) |
| 43 | + self.q_proj = torch.nn.Linear(self.embed_dim, self.embed_dim) |
| 44 | + self.out_proj = torch.nn.Linear(self.embed_dim, self.embed_dim) |
| 45 | + |
| 46 | + self.k_proj.weight.data.fill_(1) |
| 47 | + self.v_proj.weight.data.fill_(1) |
| 48 | + self.q_proj.weight.data.fill_(1) |
| 49 | + self.out_proj.weight.data.fill_(1) |
| 50 | + self.k_proj.bias.data.fill_(1) |
| 51 | + self.v_proj.bias.data.fill_(1) |
| 52 | + self.q_proj.bias.data.fill_(1) |
| 53 | + self.out_proj.bias.data.fill_(1) |
| 54 | + |
| 55 | + def forward( |
| 56 | + self, |
| 57 | + hidden_states: torch.Tensor, |
| 58 | + attention_mask: torch.Tensor | None = None, |
| 59 | + output_attentions: bool | None = False, |
| 60 | + ) -> tuple[torch.Tensor, torch.Tensor | None]: |
| 61 | + """Input shape: Batch x Time x Channel""" |
| 62 | + |
| 63 | + batch_size, q_len, _ = hidden_states.size() |
| 64 | + |
| 65 | + query_states = self.q_proj(hidden_states) |
| 66 | + key_states = self.k_proj(hidden_states) |
| 67 | + value_states = self.v_proj(hidden_states) |
| 68 | + |
| 69 | + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| 70 | + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| 71 | + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| 72 | + |
| 73 | + k_v_seq_len = key_states.shape[-2] |
| 74 | + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale |
| 75 | + |
| 76 | + if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): |
| 77 | + raise ValueError( |
| 78 | + f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" |
| 79 | + f" {attn_weights.size()}" |
| 80 | + ) |
| 81 | + |
| 82 | + if attention_mask is not None: |
| 83 | + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): |
| 84 | + raise ValueError( |
| 85 | + f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" |
| 86 | + ) |
| 87 | + attn_weights = attn_weights + attention_mask |
| 88 | + |
| 89 | + # upcast attention to fp32 |
| 90 | + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
| 91 | + attn_output = torch.matmul(attn_weights, value_states) |
| 92 | + |
| 93 | + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): |
| 94 | + raise ValueError( |
| 95 | + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" |
| 96 | + f" {attn_output.size()}" |
| 97 | + ) |
| 98 | + |
| 99 | + attn_output = attn_output.transpose(1, 2).contiguous() |
| 100 | + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) |
| 101 | + |
| 102 | + attn_output = self.out_proj(attn_output) |
| 103 | + |
| 104 | + return attn_output, attn_weights |
| 105 | + |
| 106 | + |
| 107 | +class Gemma3VSIGLIPAttentionAndLayerNorm(torch.nn.Module): |
| 108 | + def __init__(self): |
| 109 | + super().__init__() |
| 110 | + self.attn = SiglipAttention() |
| 111 | + self.ln = torch.nn.LayerNorm(20, eps=1e-05) |
| 112 | + |
| 113 | + def forward(self, x): |
| 114 | + # SkipLayerNorm ------+ |
| 115 | + # | | |
| 116 | + # Attention | |
| 117 | + # | | |
| 118 | + # MatMul | |
| 119 | + # | | |
| 120 | + # SkipLayerNorm ------+ |
| 121 | + |
| 122 | + # SkipLayerNorm |
| 123 | + x = x + x |
| 124 | + x = self.ln(x) |
| 125 | + residual = x |
| 126 | + |
| 127 | + # Attention + MatMul |
| 128 | + x, _ = self.attn(x) |
| 129 | + |
| 130 | + # SkipLayerNorm |
| 131 | + x = residual + x |
| 132 | + x = self.ln(x) |
| 133 | + return x |
| 134 | + |
| 135 | + |
| 136 | +class TestFusion(unittest.TestCase): |
| 137 | + def verify_fusion(self, optimized_model, expected_model_filename): |
| 138 | + optimized_model.topological_sort(is_deterministic=True) |
| 139 | + |
| 140 | + expected_model_path = os.path.join(os.path.dirname(__file__), "test_data", "models", expected_model_filename) |
| 141 | + expected_model = OnnxModel(onnx.load(expected_model_path)) |
| 142 | + expected_model.topological_sort(is_deterministic=True) |
| 143 | + |
| 144 | + nodes = optimized_model.model.graph.node |
| 145 | + self.assertEqual(len(nodes), len(expected_model.model.graph.node)) |
| 146 | + |
| 147 | + for i in range(len(nodes)): |
| 148 | + self.assertEqual(nodes[i], expected_model.model.graph.node[i]) |
| 149 | + |
| 150 | + for expected_initializer in expected_model.model.graph.initializer: |
| 151 | + self.assertTrue( |
| 152 | + OnnxModel.has_same_value( |
| 153 | + optimized_model.get_initializer(expected_initializer.name), |
| 154 | + expected_initializer, |
| 155 | + ) |
| 156 | + ) |
| 157 | + |
| 158 | + def export(self, model, inputs) -> onnx.ModelProto: |
| 159 | + with torch.no_grad(): |
| 160 | + onnx_program = torch.onnx.export( |
| 161 | + model, |
| 162 | + args=inputs, |
| 163 | + f=os.path.join(os.path.dirname(__file__), "export.onnx"), |
| 164 | + dynamo=True, |
| 165 | + ) |
| 166 | + return onnx_program.model_proto |
| 167 | + |
| 168 | + def tearDown(self): |
| 169 | + path = os.path.join(os.path.dirname(__file__), "export.onnx") |
| 170 | + if os.path.exists(path): |
| 171 | + os.remove(path) |
| 172 | + os.remove(path + ".data") |
| 173 | + |
| 174 | + @parameterized.expand( |
| 175 | + [ |
| 176 | + (torch.float32, "gemma3-vision-attention_fp32.onnx"), |
| 177 | + # (torch.float16, "gemma3-vision-attention_fp16.onnx"), |
| 178 | + ] |
| 179 | + ) |
| 180 | + def test_gemma3_vision_attention(self, dtype, model_name): |
| 181 | + model = Gemma3VSIGLIPAttentionAndLayerNorm().eval().to(dtype) |
| 182 | + inputs = (torch.randn(1, 2, 20, dtype=dtype),) |
| 183 | + original_model = self.export(model, inputs) |
| 184 | + |
| 185 | + options = FusionOptions("clip") |
| 186 | + optimized_model = optimize_model( |
| 187 | + original_model, |
| 188 | + model_type="clip", |
| 189 | + num_heads=2, |
| 190 | + hidden_size=20, |
| 191 | + optimization_options=options, |
| 192 | + opt_level=0, |
| 193 | + ) |
| 194 | + # onnx.save(optimized_model.model, model_name) |
| 195 | + self.verify_fusion(optimized_model, model_name) |
| 196 | + |
| 197 | + |
| 198 | +if __name__ == "__main__": |
| 199 | + unittest.main() |
0 commit comments