2828)
2929from ._features import feature_take_indices
3030from ._features_fx import register_notrace_module
31- from ._manipulate import checkpoint_seq , checkpoint
31+ from ._manipulate import checkpoint_seq
3232from ._registry import generate_default_cfgs , register_model
3333
3434__all__ = ['MobileNetV5' , 'MobileNetV5Encoder' ]
3535
36+ _GELU = partial (nn .GELU , approximate = 'tanh' )
37+
3638
3739@register_notrace_module
3840class MobileNetV5MultiScaleFusionAdapter (nn .Module ):
@@ -68,7 +70,7 @@ def __init__(
6870 self .layer_scale_init_value = layer_scale_init_value
6971 self .noskip = noskip
7072
71- act_layer = act_layer or nn . GELU
73+ act_layer = act_layer or _GELU
7274 norm_layer = norm_layer or RmsNorm2d
7375 self .ffn = UniversalInvertedResidual (
7476 in_chs = self .in_channels ,
@@ -167,7 +169,7 @@ def __init__(
167169 global_pool: Type of pooling to use for global pooling features of the FC head.
168170 """
169171 super ().__init__ ()
170- act_layer = act_layer or nn . GELU
172+ act_layer = act_layer or _GELU
171173 norm_layer = get_norm_layer (norm_layer ) or RmsNorm2d
172174 norm_act_layer = get_norm_act_layer (norm_layer , act_layer )
173175 se_layer = se_layer or SqueezeExcite
@@ -410,7 +412,7 @@ def __init__(
410412 block_args : BlockArgs ,
411413 in_chans : int = 3 ,
412414 stem_size : int = 64 ,
413- stem_bias : bool = False ,
415+ stem_bias : bool = True ,
414416 fix_stem : bool = False ,
415417 pad_type : str = '' ,
416418 msfa_indices : Sequence [int ] = (- 2 , - 1 ),
@@ -426,7 +428,7 @@ def __init__(
426428 layer_scale_init_value : Optional [float ] = None ,
427429 ):
428430 super ().__init__ ()
429- act_layer = act_layer or nn . GELU
431+ act_layer = act_layer or _GELU
430432 norm_layer = get_norm_layer (norm_layer ) or RmsNorm2d
431433 se_layer = se_layer or SqueezeExcite
432434 self .num_classes = 0 # Exists to satisfy ._hub module APIs.
@@ -526,6 +528,7 @@ def forward_intermediates(
526528 feat_idx = 0 # stem is index 0
527529 x = self .conv_stem (x )
528530 if feat_idx in take_indices :
531+ print ("conv_stem is captured" )
529532 intermediates .append (x )
530533 if feat_idx in self .msfa_indices :
531534 msfa_intermediates .append (x )
@@ -777,7 +780,7 @@ def _gen_mobilenet_v5(
777780 fix_stem = channel_multiplier < 1.0 ,
778781 round_chs_fn = partial (round_channels , multiplier = channel_multiplier ),
779782 norm_layer = RmsNorm2d ,
780- act_layer = nn . GELU ,
783+ act_layer = _GELU ,
781784 layer_scale_init_value = 1e-5 ,
782785 )
783786 model_kwargs = dict (model_kwargs , ** kwargs )
0 commit comments