2
2
import torch
3
3
import torch .nn as nn
4
4
5
- from timm .layers import create_act_layer , set_layer_config , get_act_layer , get_act_fn , Attention2d
5
+ from timm .layers import create_act_layer , set_layer_config , get_act_layer , get_act_fn , Attention2d , MultiQueryAttentionV2
6
6
7
7
import importlib
8
8
import os
@@ -121,6 +121,23 @@ def test_get_act_fn_none():
121
121
assert get_act_fn ('' ) is None
122
122
123
123
124
+ @pytest .mark .parametrize ("dim" , [128 ])
125
+ @pytest .mark .parametrize ("dim_out" , [128 , 256 ])
126
+ @pytest .mark .parametrize ("use_m" , [True , False ])
127
+ def test_mqa_v2 (dim , dim_out , use_m ):
128
+ mqa = MultiQueryAttentionV2 (dim , dim_out )
129
+
130
+ x = torch .randn (1 , dim , 32 , 48 )
131
+ if use_m :
132
+ m = torch .randn (1 , dim , 16 , 24 )
133
+ else :
134
+ m = None
135
+
136
+ y = mqa (x , m = m )
137
+
138
+ assert (y .shape ) == (1 , dim_out , 32 , 48 )
139
+
140
+
124
141
@pytest .mark .parametrize ("bias" , [True , False ])
125
142
@pytest .mark .parametrize ("expand_first" , [True , False ])
126
143
@pytest .mark .parametrize ("head_first" , [True , False ])
@@ -141,6 +158,3 @@ def test_attn2d(bias, expand_first, head_first, attn_mask):
141
158
o2 = attn (x , mask )
142
159
143
160
assert torch .allclose (o1 , o2 , atol = 1e-5 ), f"{ torch .abs (o1 - o2 ).max ()} "
144
-
145
-
146
-
0 commit comments