Skip to content

Commit 131518c

Browse files
committed
Add comments to MLP layers re expected layouts
1 parent d23facd commit 131518c

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

timm/layers/mlp.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
class Mlp(nn.Module):
1414
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
15+
16+
NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected.
1517
"""
1618
def __init__(
1719
self,
@@ -51,6 +53,8 @@ def forward(self, x):
5153
class GluMlp(nn.Module):
5254
""" MLP w/ GLU style gating
5355
See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
56+
57+
NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected.
5458
"""
5559
def __init__(
5660
self,
@@ -192,7 +196,7 @@ def forward(self, x):
192196

193197

194198
class ConvMlp(nn.Module):
195-
""" MLP using 1x1 convs that keeps spatial dims
199+
""" MLP using 1x1 convs that keeps spatial dims (for 2D NCHW tensors)
196200
"""
197201
def __init__(
198202
self,
@@ -226,6 +230,8 @@ def forward(self, x):
226230

227231
class GlobalResponseNormMlp(nn.Module):
228232
""" MLP w/ Global Response Norm (see grn.py), nn.Linear or 1x1 Conv2d
233+
234+
NOTE: Intended for '2D' NCHW (use_conv=True) or NHWC (use_conv=False, channels-last) tensor layouts
229235
"""
230236
def __init__(
231237
self,

0 commit comments

Comments
 (0)