Skip to content

Commit 5fa6efa

Browse files
committed
Add anti-aliasing support to mobilenetv3 and efficientnet family models. Update MobileNetV4 model defs, resolutions. Fix #599
* create_aa helper function centralized for all timm uses (resnet, convbnact helper) * allow BlurPool w/ pre-defined channels (expand) * mobilenetv4 UIB block using ConvNormAct layers for improved clarity, esp with AA added * improve more mobilenetv3 and efficientnet related type annotations
1 parent 4ff7c25 commit 5fa6efa

8 files changed

+475
-325
lines changed

timm/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
55
from .attention_pool import AttentionPoolLatent
66
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
7-
from .blur_pool import BlurPool2d
7+
from .blur_pool import BlurPool2d, create_aa
88
from .classifier import ClassifierHead, create_classifier, NormMlpClassifierHead
99
from .cond_conv2d import CondConv2d, get_condconv_initializer
1010
from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \

timm/layers/blur_pool.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,16 @@
55
66
Hacked together by Chris Ha and Ross Wightman
77
"""
8+
from functools import partial
9+
from typing import Optional, Type
810

911
import torch
1012
import torch.nn as nn
1113
import torch.nn.functional as F
1214
import numpy as np
15+
1316
from .padding import get_padding
17+
from .typing import LayerType
1418

1519

1620
class BlurPool2d(nn.Module):
@@ -26,17 +30,62 @@ class BlurPool2d(nn.Module):
2630
Returns:
2731
torch.Tensor: the transformed tensor.
2832
"""
29-
def __init__(self, channels, filt_size=3, stride=2) -> None:
33+
def __init__(
34+
self,
35+
channels: Optional[int] = None,
36+
filt_size: int = 3,
37+
stride: int = 2,
38+
pad_mode: str = 'reflect',
39+
) -> None:
3040
super(BlurPool2d, self).__init__()
3141
assert filt_size > 1
3242
self.channels = channels
3343
self.filt_size = filt_size
3444
self.stride = stride
45+
self.pad_mode = pad_mode
3546
self.padding = [get_padding(filt_size, stride, dilation=1)] * 4
47+
3648
coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32))
37-
blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1)
49+
blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :]
50+
if channels is not None:
51+
blur_filter = blur_filter.repeat(self.channels, 1, 1, 1)
3852
self.register_buffer('filt', blur_filter, persistent=False)
3953

4054
def forward(self, x: torch.Tensor) -> torch.Tensor:
41-
x = F.pad(x, self.padding, 'reflect')
42-
return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels)
55+
x = F.pad(x, self.padding, mode=self.pad_mode)
56+
if self.channels is None:
57+
channels = x.shape[1]
58+
weight = self.filt.expand(channels, 1, self.filt_size, self.filt_size)
59+
else:
60+
channels = self.channels
61+
weight = self.filt
62+
return F.conv2d(x, weight, stride=self.stride, groups=channels)
63+
64+
65+
def create_aa(
66+
aa_layer: LayerType,
67+
channels: Optional[int] = None,
68+
stride: int = 2,
69+
enable: bool = True,
70+
noop: Optional[Type[nn.Module]] = nn.Identity
71+
) -> nn.Module:
72+
""" Anti-aliasing """
73+
if not aa_layer or not enable:
74+
return noop() if noop is not None else None
75+
76+
if isinstance(aa_layer, str):
77+
aa_layer = aa_layer.lower().replace('_', '').replace('-', '')
78+
if aa_layer == 'avg' or aa_layer == 'avgpool':
79+
aa_layer = nn.AvgPool2d
80+
elif aa_layer == 'blur' or aa_layer == 'blurpool':
81+
aa_layer = BlurPool2d
82+
elif aa_layer == 'blurpc':
83+
aa_layer = partial(BlurPool2d, pad_mode='constant')
84+
85+
else:
86+
assert False, f"Unknown anti-aliasing layer ({aa_layer})."
87+
88+
try:
89+
return aa_layer(channels=channels, stride=stride)
90+
except TypeError as e:
91+
return aa_layer(stride)

timm/layers/conv_bn_act.py

Lines changed: 61 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,38 +2,51 @@
22
33
Hacked together by / Copyright 2020 Ross Wightman
44
"""
5-
import functools
5+
from typing import Any, Dict, Optional, Type
6+
67
from torch import nn as nn
78

9+
from .typing import LayerType, PadType
10+
from .blur_pool import create_aa
811
from .create_conv2d import create_conv2d
912
from .create_norm_act import get_norm_act_layer
1013

1114

1215
class ConvNormAct(nn.Module):
1316
def __init__(
1417
self,
15-
in_channels,
16-
out_channels,
17-
kernel_size=1,
18-
stride=1,
19-
padding='',
20-
dilation=1,
21-
groups=1,
22-
bias=False,
23-
apply_act=True,
24-
norm_layer=nn.BatchNorm2d,
25-
norm_kwargs=None,
26-
act_layer=nn.ReLU,
27-
act_kwargs=None,
28-
drop_layer=None,
18+
in_channels: int,
19+
out_channels: int,
20+
kernel_size: int = 1,
21+
stride: int = 1,
22+
padding: PadType = '',
23+
dilation: int = 1,
24+
groups: int = 1,
25+
bias: bool = False,
26+
apply_act: bool = True,
27+
norm_layer: LayerType = nn.BatchNorm2d,
28+
act_layer: LayerType = nn.ReLU,
29+
drop_layer: Optional[Type[nn.Module]] = None,
30+
conv_kwargs: Optional[Dict[str, Any]] = None,
31+
norm_kwargs: Optional[Dict[str, Any]] = None,
32+
act_kwargs: Optional[Dict[str, Any]] = None,
2933
):
3034
super(ConvNormAct, self).__init__()
35+
conv_kwargs = conv_kwargs or {}
3136
norm_kwargs = norm_kwargs or {}
3237
act_kwargs = act_kwargs or {}
3338

3439
self.conv = create_conv2d(
35-
in_channels, out_channels, kernel_size, stride=stride,
36-
padding=padding, dilation=dilation, groups=groups, bias=bias)
40+
in_channels,
41+
out_channels,
42+
kernel_size,
43+
stride=stride,
44+
padding=padding,
45+
dilation=dilation,
46+
groups=groups,
47+
bias=bias,
48+
**conv_kwargs,
49+
)
3750

3851
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
3952
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
@@ -64,54 +77,53 @@ def forward(self, x):
6477
ConvBnAct = ConvNormAct
6578

6679

67-
def create_aa(aa_layer, channels, stride=2, enable=True):
68-
if not aa_layer or not enable:
69-
return nn.Identity()
70-
if isinstance(aa_layer, functools.partial):
71-
if issubclass(aa_layer.func, nn.AvgPool2d):
72-
return aa_layer()
73-
else:
74-
return aa_layer(channels)
75-
elif issubclass(aa_layer, nn.AvgPool2d):
76-
return aa_layer(stride)
77-
else:
78-
return aa_layer(channels=channels, stride=stride)
79-
80-
8180
class ConvNormActAa(nn.Module):
8281
def __init__(
8382
self,
84-
in_channels,
85-
out_channels,
86-
kernel_size=1,
87-
stride=1,
88-
padding='',
89-
dilation=1,
90-
groups=1,
91-
bias=False,
92-
apply_act=True,
93-
norm_layer=nn.BatchNorm2d,
94-
norm_kwargs=None,
95-
act_layer=nn.ReLU,
96-
act_kwargs=None,
97-
aa_layer=None,
98-
drop_layer=None,
83+
in_channels: int,
84+
out_channels: int,
85+
kernel_size: int = 1,
86+
stride: int = 1,
87+
padding: PadType = '',
88+
dilation: int = 1,
89+
groups: int = 1,
90+
bias: bool = False,
91+
apply_act: bool = True,
92+
norm_layer: LayerType = nn.BatchNorm2d,
93+
act_layer: LayerType = nn.ReLU,
94+
aa_layer: Optional[LayerType] = None,
95+
drop_layer: Optional[Type[nn.Module]] = None,
96+
conv_kwargs: Optional[Dict[str, Any]] = None,
97+
norm_kwargs: Optional[Dict[str, Any]] = None,
98+
act_kwargs: Optional[Dict[str, Any]] = None,
9999
):
100100
super(ConvNormActAa, self).__init__()
101101
use_aa = aa_layer is not None and stride == 2
102+
conv_kwargs = conv_kwargs or {}
102103
norm_kwargs = norm_kwargs or {}
103104
act_kwargs = act_kwargs or {}
104105

105106
self.conv = create_conv2d(
106-
in_channels, out_channels, kernel_size, stride=1 if use_aa else stride,
107-
padding=padding, dilation=dilation, groups=groups, bias=bias)
107+
in_channels, out_channels, kernel_size,
108+
stride=1 if use_aa else stride,
109+
padding=padding,
110+
dilation=dilation,
111+
groups=groups,
112+
bias=bias,
113+
**conv_kwargs,
114+
)
108115

109116
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
110117
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
111118
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
112119
if drop_layer:
113120
norm_kwargs['drop_layer'] = drop_layer
114-
self.bn = norm_act_layer(out_channels, apply_act=apply_act, act_kwargs=act_kwargs, **norm_kwargs)
121+
self.bn = norm_act_layer(
122+
out_channels,
123+
apply_act=apply_act,
124+
act_kwargs=act_kwargs,
125+
**norm_kwargs,
126+
)
115127
self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa)
116128

117129
@property

0 commit comments

Comments
 (0)