Skip to content

Commit 7fe96e7

Browse files
committed
More MobileNet-v4 fixes
* missed final norm after post pooling 1x1 PW head conv * improve repr of model by flipping a few modules to None when not used, nn.Sequential for MultiQueryAttention query/key/value/output * allow layer scaling to be enabled/disabled at model variant level, conv variants don't use it
1 parent 28d76a9 commit 7fe96e7

File tree

4 files changed

+102
-102
lines changed

4 files changed

+102
-102
lines changed

timm/layers/attention2d.py

Lines changed: 35 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def __init__(
107107
attn_drop: float = 0.,
108108
proj_drop: float = 0.,
109109
norm_layer: nn.Module = nn.BatchNorm2d,
110+
use_bias: bool = False,
110111
):
111112
"""Initializer.
112113
@@ -130,81 +131,74 @@ def __init__(
130131
self.fused_attn = use_fused_attn()
131132
self.drop = attn_drop
132133

134+
self.query = nn.Sequential()
133135
if self.has_query_strides:
134136
# FIXME dilation
135-
self.query_down_pool = create_pool2d(
136-
'avg',
137-
kernel_size=self.query_strides,
138-
padding=padding,
139-
)
140-
self.query_down_norm = norm_layer(dim)
141-
else:
142-
self.query_down_pool = nn.Identity()
143-
self.query_down_norm = nn.Identity()
144-
145-
self.query_proj = create_conv2d(
137+
self.query.add_module('down_pool', create_pool2d(
138+
'avg',
139+
kernel_size=self.query_strides,
140+
padding=padding,
141+
))
142+
self.query.add_module('norm', norm_layer(dim))
143+
self.query.add_module('proj', create_conv2d(
146144
dim,
147145
self.num_heads * self.key_dim,
148146
kernel_size=1,
149-
)
147+
bias=use_bias,
148+
))
150149

150+
self.key = nn.Sequential()
151151
if kv_stride > 1:
152-
self.key_down_conv = create_conv2d(
152+
self.key.add_module('down_conv', create_conv2d(
153153
dim,
154154
dim,
155155
kernel_size=dw_kernel_size,
156156
stride=kv_stride,
157157
dilation=dilation,
158158
padding=padding,
159159
depthwise=True,
160-
)
161-
self.key_down_norm = norm_layer(dim)
162-
else:
163-
self.key_down_conv = nn.Identity()
164-
self.key_down_norm = nn.Identity()
165-
166-
self.key_proj = create_conv2d(
160+
))
161+
self.key.add_module('norm', norm_layer(dim))
162+
self.key.add_module('proj', create_conv2d(
167163
dim,
168164
self.key_dim,
169165
kernel_size=1,
170166
padding=padding,
171-
)
167+
bias=use_bias,
168+
))
172169

170+
self.value = nn.Sequential()
173171
if kv_stride > 1:
174-
self.value_down_conv = create_conv2d(
172+
self.value.add_module('down_conv', create_conv2d(
175173
dim,
176174
dim,
177175
kernel_size=dw_kernel_size,
178176
stride=kv_stride,
179177
dilation=dilation,
180178
padding=padding,
181179
depthwise=True,
182-
)
183-
self.value_down_norm = norm_layer(dim)
184-
else:
185-
self.value_down_conv = nn.Identity()
186-
self.value_down_norm = nn.Identity()
187-
188-
self.value_proj = create_conv2d(
180+
))
181+
self.value.add_module('norm', norm_layer(dim))
182+
self.value.add_module('proj', create_conv2d(
189183
dim,
190184
self.value_dim,
191185
kernel_size=1,
192-
)
186+
bias=use_bias,
187+
))
193188

194189
self.attn_drop = nn.Dropout(attn_drop)
195190

191+
self.output = nn.Sequential()
196192
if self.has_query_strides:
197-
self.upsampling = nn.Upsample(self.query_strides, mode='bilinear', align_corners=False)
198-
else:
199-
self.upsampling = nn.Identity()
200-
201-
self.out_proj = create_conv2d(
193+
self.output.add_module('upsample', nn.Upsample(self.query_strides, mode='bilinear', align_corners=False))
194+
self.output.add_module('proj', create_conv2d(
202195
self.value_dim * self.num_heads,
203196
dim_out,
204197
kernel_size=1,
205-
)
198+
bias=use_bias,
199+
))
200+
self.output.add_module('drop', nn.Dropout(proj_drop))
206201

207-
self.proj_drop = nn.Dropout(proj_drop)
208202
self.einsum = False
209203

210204
def _reshape_input(self, t: torch.Tensor):
@@ -237,21 +231,15 @@ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
237231
"""Run layer computation."""
238232
B, C, H, W = s = x.shape
239233

240-
q = self.query_down_pool(x)
241-
q = self.query_down_norm(q)
242-
q = self.query_proj(q)
234+
q = self.query(x)
243235
# desired q shape: [b, h, k, n x n] - [b, l, h, k]
244236
q = self._reshape_projected_query(q, self.num_heads, self.key_dim)
245237

246-
k = self.key_down_conv(x)
247-
k = self.key_down_norm(k)
248-
k = self.key_proj(k)
238+
k = self.key(x)
249239
# output shape of k: [b, k, p], p = m x m
250240
k = self._reshape_input(k)
251241

252-
v = self.value_down_conv(x)
253-
v = self.value_down_norm(v)
254-
v = self.value_proj(v)
242+
v = self.value(x)
255243
# output shape of v: [ b, p, k], p = m x m
256244
v = self._reshape_input(v)
257245

@@ -285,10 +273,7 @@ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
285273

286274
# reshape o into [b, hk, n, n,]
287275
o = self._reshape_output(o, self.num_heads, H // self.query_strides[0], W // self.query_strides[1])
288-
o = self.upsampling(o)
289-
290-
x = self.out_proj(o)
291-
x = self.proj_drop(x)
276+
x = self.output(o)
292277
return x
293278

294279

timm/models/_efficientnet_blocks.py

Lines changed: 28 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,12 @@ def feature_info(self, location):
174174

175175
def forward(self, x):
176176
shortcut = x
177-
#print('ii', x.shape)
177+
#print('ii', x.shape) # FIXME debug s2d
178178
if self.conv_s2d is not None:
179179
x = self.conv_s2d(x)
180180
x = self.bn_s2d(x)
181-
#print('id', x.shape)
181+
#print('id', x.shape) # FIXME debug s2d
182182
x = self.conv_dw(x)
183-
#print('od', x.shape)
184183
x = self.bn1(x)
185184
x = self.se(x)
186185
x = self.conv_pw(x)
@@ -296,7 +295,8 @@ def forward(self, x):
296295
class UniversalInvertedResidual(nn.Module):
297296
""" Universal Inverted Residual Block
298297
299-
For MobileNetV4 - https://arxiv.org/abs/
298+
For MobileNetV4 - https://arxiv.org/abs/, referenced from
299+
https://github.com/tensorflow/models/blob/d93c7e932de27522b2fa3b115f58d06d6f640537/official/vision/modeling/layers/nn_blocks.py#L778
300300
"""
301301

302302
def __init__(
@@ -338,8 +338,9 @@ def __init__(
338338
)
339339
self.norm_dw_start = dw_norm_act_layer(in_chs, apply_act=False)
340340
else:
341-
self.conv_dw_start = nn.Identity()
342-
self.norm_dw_start = nn.Identity()
341+
# start is None when not used for cleaner repr
342+
self.conv_dw_start = None
343+
self.norm_dw_start = None
343344

344345
# Point-wise expansion
345346
mid_chs = make_divisible(in_chs * exp_ratio)
@@ -359,6 +360,7 @@ def __init__(
359360
)
360361
self.norm_dw_mid = dw_norm_act_layer(mid_chs, inplace=True)
361362
else:
363+
# keeping mid as identity so it can be hooked more easily for features
362364
self.conv_dw_mid = nn.Identity()
363365
self.norm_dw_mid = nn.Identity()
364366

@@ -379,7 +381,7 @@ def __init__(
379381
)
380382
self.norm_dw_end = dw_norm_act_layer(out_chs, apply_act=False)
381383
else:
382-
# dw_end rarely used so keeping it out of repr by not using None instead of nn.Identitty()
384+
# end is None when not in use for cleaner repr
383385
self.conv_dw_end = None
384386
self.norm_dw_end = None
385387

@@ -397,8 +399,9 @@ def feature_info(self, location):
397399

398400
def forward(self, x):
399401
shortcut = x
400-
x = self.conv_dw_start(x)
401-
x = self.norm_dw_start(x)
402+
if self.conv_dw_start is not None:
403+
x = self.conv_dw_start(x)
404+
x = self.norm_dw_start(x)
402405
x = self.conv_pw(x)
403406
x = self.norm_pw(x)
404407
x = self.conv_dw_mid(x)
@@ -418,7 +421,8 @@ def forward(self, x):
418421
class MobileAttention(nn.Module):
419422
""" Mobile Attention Block
420423
421-
For MobileNetV4 - https://arxiv.org/abs/
424+
For MobileNetV4 - https://arxiv.org/abs/, referenced from
425+
https://github.com/tensorflow/models/blob/d93c7e932de27522b2fa3b115f58d06d6f640537/official/vision/modeling/layers/nn_blocks.py#L1504
422426
"""
423427
def __init__(
424428
self,
@@ -476,34 +480,21 @@ def __init__(
476480
num_heads = in_chs // key_dim
477481

478482
if use_multi_query:
479-
#if self.has_query_stride or self.kv_stride > 1:
480-
self.attn = (
481-
MultiQueryAttention2d(
482-
in_chs,
483-
dim_out=out_chs,
484-
num_heads=num_heads,
485-
key_dim=key_dim,
486-
value_dim=value_dim,
487-
query_strides=query_strides,
488-
kv_stride=kv_stride,
489-
dilation=dilation,
490-
padding=pad_type,
491-
dw_kernel_size=dw_kernel_size,
492-
attn_drop=attn_drop,
493-
proj_drop=proj_drop,
494-
#bias=use_bias, # why not here if used w/ mhsa?
495-
)
483+
self.attn = MultiQueryAttention2d(
484+
in_chs,
485+
dim_out=out_chs,
486+
num_heads=num_heads,
487+
key_dim=key_dim,
488+
value_dim=value_dim,
489+
query_strides=query_strides,
490+
kv_stride=kv_stride,
491+
dilation=dilation,
492+
padding=pad_type,
493+
dw_kernel_size=dw_kernel_size,
494+
attn_drop=attn_drop,
495+
proj_drop=proj_drop,
496+
#bias=use_bias, # why not here if used w/ mhsa?
496497
)
497-
# else:
498-
# self.attn = MultiQueryAttentionV2(
499-
# in_chs,
500-
# dim_out=out_chs,
501-
# num_heads=num_heads,
502-
# key_dim=key_dim,
503-
# value_dim=value_dim,
504-
# attn_drop=attn_drop,
505-
# proj_drop=proj_drop,
506-
# )
507498
else:
508499
self.attn = Attention2d(
509500
in_chs,

timm/models/_efficientnet_builder.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
66
Hacked together by / Copyright 2019, Ross Wightman
77
"""
8+
from typing import Callable, Optional
89

910
import logging
1011
import math
@@ -321,15 +322,16 @@ class EfficientNetBuilder:
321322
"""
322323
def __init__(
323324
self,
324-
output_stride=32,
325-
pad_type='',
326-
round_chs_fn=round_channels,
327-
se_from_exp=False,
328-
act_layer=None,
329-
norm_layer=None,
330-
se_layer=None,
331-
drop_path_rate=0.,
332-
feature_location='',
325+
output_stride: int = 32,
326+
pad_type: str = '',
327+
round_chs_fn: Callable = round_channels,
328+
se_from_exp: bool = False,
329+
act_layer: Optional[Callable] = None,
330+
norm_layer: Optional[Callable] = None,
331+
se_layer: Optional[Callable] = None,
332+
drop_path_rate: float = 0.,
333+
layer_scale_init_value: Optional[float] = None,
334+
feature_location: str = '',
333335
):
334336
self.output_stride = output_stride
335337
self.pad_type = pad_type
@@ -344,6 +346,7 @@ def __init__(
344346
except TypeError:
345347
self.se_has_ratio = False
346348
self.drop_path_rate = drop_path_rate
349+
self.layer_scale_init_value = layer_scale_init_value
347350
if feature_location == 'depthwise':
348351
# old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense
349352
_logger.warning("feature_location=='depthwise' is deprecated, using 'expansion'")
@@ -402,13 +405,13 @@ def _make_block(self, ba, block_idx, block_count):
402405
block = ConvBnAct(**ba)
403406
elif bt == 'uir':
404407
_log_info_if(' UniversalInvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
405-
block = UniversalInvertedResidual(**ba)
408+
block = UniversalInvertedResidual(**ba, layer_scale_init_value=self.layer_scale_init_value)
406409
elif bt == 'mqa':
407410
_log_info_if(' MobileMultiQueryAttention {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
408-
block = MobileAttention(**ba, use_multi_query=True)
411+
block = MobileAttention(**ba, use_multi_query=True, layer_scale_init_value=self.layer_scale_init_value)
409412
elif bt == 'mha':
410413
_log_info_if(' MobileMultiHeadAttention {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
411-
block = MobileAttention(**ba)
414+
block = MobileAttention(**ba, layer_scale_init_value=self.layer_scale_init_value)
412415
else:
413416
assert False, 'Unknown block type (%s) while building model.' % bt
414417

0 commit comments

Comments
 (0)