File tree 1 file changed +7
-1
lines changed
1 file changed +7
-1
lines changed Original file line number Diff line number Diff line change 12
12
13
13
class Mlp (nn .Module ):
14
14
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
15
+
16
+ NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected.
15
17
"""
16
18
def __init__ (
17
19
self ,
@@ -51,6 +53,8 @@ def forward(self, x):
51
53
class GluMlp (nn .Module ):
52
54
""" MLP w/ GLU style gating
53
55
See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
56
+
57
+ NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected.
54
58
"""
55
59
def __init__ (
56
60
self ,
@@ -192,7 +196,7 @@ def forward(self, x):
192
196
193
197
194
198
class ConvMlp (nn .Module ):
195
- """ MLP using 1x1 convs that keeps spatial dims
199
+ """ MLP using 1x1 convs that keeps spatial dims (for 2D NCHW tensors)
196
200
"""
197
201
def __init__ (
198
202
self ,
@@ -226,6 +230,8 @@ def forward(self, x):
226
230
227
231
class GlobalResponseNormMlp (nn .Module ):
228
232
""" MLP w/ Global Response Norm (see grn.py), nn.Linear or 1x1 Conv2d
233
+
234
+ NOTE: Intended for '2D' NCHW (use_conv=True) or NHWC (use_conv=False, channels-last) tensor layouts
229
235
"""
230
236
def __init__ (
231
237
self ,
You can’t perform that action at this time.
0 commit comments