|
2 | 2 |
|
3 | 3 | Hacked together by / Copyright 2020 Ross Wightman
|
4 | 4 | """
|
5 |
| -import functools |
| 5 | +from typing import Any, Dict, Optional, Type |
| 6 | + |
6 | 7 | from torch import nn as nn
|
7 | 8 |
|
| 9 | +from .typing import LayerType, PadType |
| 10 | +from .blur_pool import create_aa |
8 | 11 | from .create_conv2d import create_conv2d
|
9 | 12 | from .create_norm_act import get_norm_act_layer
|
10 | 13 |
|
11 | 14 |
|
12 | 15 | class ConvNormAct(nn.Module):
|
13 | 16 | def __init__(
|
14 | 17 | self,
|
15 |
| - in_channels, |
16 |
| - out_channels, |
17 |
| - kernel_size=1, |
18 |
| - stride=1, |
19 |
| - padding='', |
20 |
| - dilation=1, |
21 |
| - groups=1, |
22 |
| - bias=False, |
23 |
| - apply_act=True, |
24 |
| - norm_layer=nn.BatchNorm2d, |
25 |
| - norm_kwargs=None, |
26 |
| - act_layer=nn.ReLU, |
27 |
| - act_kwargs=None, |
28 |
| - drop_layer=None, |
| 18 | + in_channels: int, |
| 19 | + out_channels: int, |
| 20 | + kernel_size: int = 1, |
| 21 | + stride: int = 1, |
| 22 | + padding: PadType = '', |
| 23 | + dilation: int = 1, |
| 24 | + groups: int = 1, |
| 25 | + bias: bool = False, |
| 26 | + apply_act: bool = True, |
| 27 | + norm_layer: LayerType = nn.BatchNorm2d, |
| 28 | + act_layer: LayerType = nn.ReLU, |
| 29 | + drop_layer: Optional[Type[nn.Module]] = None, |
| 30 | + conv_kwargs: Optional[Dict[str, Any]] = None, |
| 31 | + norm_kwargs: Optional[Dict[str, Any]] = None, |
| 32 | + act_kwargs: Optional[Dict[str, Any]] = None, |
29 | 33 | ):
|
30 | 34 | super(ConvNormAct, self).__init__()
|
| 35 | + conv_kwargs = conv_kwargs or {} |
31 | 36 | norm_kwargs = norm_kwargs or {}
|
32 | 37 | act_kwargs = act_kwargs or {}
|
33 | 38 |
|
34 | 39 | self.conv = create_conv2d(
|
35 |
| - in_channels, out_channels, kernel_size, stride=stride, |
36 |
| - padding=padding, dilation=dilation, groups=groups, bias=bias) |
| 40 | + in_channels, |
| 41 | + out_channels, |
| 42 | + kernel_size, |
| 43 | + stride=stride, |
| 44 | + padding=padding, |
| 45 | + dilation=dilation, |
| 46 | + groups=groups, |
| 47 | + bias=bias, |
| 48 | + **conv_kwargs, |
| 49 | + ) |
37 | 50 |
|
38 | 51 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions
|
39 | 52 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
@@ -64,54 +77,53 @@ def forward(self, x):
|
64 | 77 | ConvBnAct = ConvNormAct
|
65 | 78 |
|
66 | 79 |
|
67 |
| -def create_aa(aa_layer, channels, stride=2, enable=True): |
68 |
| - if not aa_layer or not enable: |
69 |
| - return nn.Identity() |
70 |
| - if isinstance(aa_layer, functools.partial): |
71 |
| - if issubclass(aa_layer.func, nn.AvgPool2d): |
72 |
| - return aa_layer() |
73 |
| - else: |
74 |
| - return aa_layer(channels) |
75 |
| - elif issubclass(aa_layer, nn.AvgPool2d): |
76 |
| - return aa_layer(stride) |
77 |
| - else: |
78 |
| - return aa_layer(channels=channels, stride=stride) |
79 |
| - |
80 |
| - |
81 | 80 | class ConvNormActAa(nn.Module):
|
82 | 81 | def __init__(
|
83 | 82 | self,
|
84 |
| - in_channels, |
85 |
| - out_channels, |
86 |
| - kernel_size=1, |
87 |
| - stride=1, |
88 |
| - padding='', |
89 |
| - dilation=1, |
90 |
| - groups=1, |
91 |
| - bias=False, |
92 |
| - apply_act=True, |
93 |
| - norm_layer=nn.BatchNorm2d, |
94 |
| - norm_kwargs=None, |
95 |
| - act_layer=nn.ReLU, |
96 |
| - act_kwargs=None, |
97 |
| - aa_layer=None, |
98 |
| - drop_layer=None, |
| 83 | + in_channels: int, |
| 84 | + out_channels: int, |
| 85 | + kernel_size: int = 1, |
| 86 | + stride: int = 1, |
| 87 | + padding: PadType = '', |
| 88 | + dilation: int = 1, |
| 89 | + groups: int = 1, |
| 90 | + bias: bool = False, |
| 91 | + apply_act: bool = True, |
| 92 | + norm_layer: LayerType = nn.BatchNorm2d, |
| 93 | + act_layer: LayerType = nn.ReLU, |
| 94 | + aa_layer: Optional[LayerType] = None, |
| 95 | + drop_layer: Optional[Type[nn.Module]] = None, |
| 96 | + conv_kwargs: Optional[Dict[str, Any]] = None, |
| 97 | + norm_kwargs: Optional[Dict[str, Any]] = None, |
| 98 | + act_kwargs: Optional[Dict[str, Any]] = None, |
99 | 99 | ):
|
100 | 100 | super(ConvNormActAa, self).__init__()
|
101 | 101 | use_aa = aa_layer is not None and stride == 2
|
| 102 | + conv_kwargs = conv_kwargs or {} |
102 | 103 | norm_kwargs = norm_kwargs or {}
|
103 | 104 | act_kwargs = act_kwargs or {}
|
104 | 105 |
|
105 | 106 | self.conv = create_conv2d(
|
106 |
| - in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, |
107 |
| - padding=padding, dilation=dilation, groups=groups, bias=bias) |
| 107 | + in_channels, out_channels, kernel_size, |
| 108 | + stride=1 if use_aa else stride, |
| 109 | + padding=padding, |
| 110 | + dilation=dilation, |
| 111 | + groups=groups, |
| 112 | + bias=bias, |
| 113 | + **conv_kwargs, |
| 114 | + ) |
108 | 115 |
|
109 | 116 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions
|
110 | 117 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
111 | 118 | # NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
|
112 | 119 | if drop_layer:
|
113 | 120 | norm_kwargs['drop_layer'] = drop_layer
|
114 |
| - self.bn = norm_act_layer(out_channels, apply_act=apply_act, act_kwargs=act_kwargs, **norm_kwargs) |
| 121 | + self.bn = norm_act_layer( |
| 122 | + out_channels, |
| 123 | + apply_act=apply_act, |
| 124 | + act_kwargs=act_kwargs, |
| 125 | + **norm_kwargs, |
| 126 | + ) |
115 | 127 | self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa)
|
116 | 128 |
|
117 | 129 | @property
|
|
0 commit comments