17
17
)
18
18
from terratorch .models .pixel_wise_model import PixelWiseModel
19
19
from terratorch .models .scalar_output_model import ScalarOutputModel
20
+ from terratorch .models .smp_model_factory import get_smp_decoder
20
21
21
22
PIXEL_WISE_TASKS = ["segmentation" , "regression" ]
22
23
SCALAR_TASKS = ["classification" ]
@@ -95,7 +96,13 @@ def build_model(
95
96
msg = f"Task { task } not supported. Please choose one of { SUPPORTED_TASKS } "
96
97
raise NotImplementedError (msg )
97
98
99
+ # These params are used in case we need a SMP decoder
100
+ # but should not be used for timm encoder
98
101
backbone_kwargs , kwargs = _extract_prefix_keys (kwargs , "backbone_" )
102
+ smp_kwargs , kwargs = _extract_prefix_keys (kwargs , "smp_" )
103
+ aux_kwargs , kwargs = _extract_prefix_keys (kwargs , "aux_" )
104
+ output_stride = backbone_kwargs .pop ('output_stride' , None )
105
+ out_channels = backbone_kwargs .pop ('out_channels' , None )
99
106
100
107
backbone : nn .Module = timm .create_model (
101
108
backbone ,
@@ -106,13 +113,16 @@ def build_model(
106
113
features_only = True ,
107
114
** backbone_kwargs ,
108
115
)
109
- # allow decoder to be a module passed directly
110
- decoder_cls = _get_decoder (decoder )
111
116
112
117
decoder_kwargs , kwargs = _extract_prefix_keys (kwargs , "decoder_" )
113
-
118
+ args = kwargs . copy ()
114
119
# TODO: remove this
115
- decoder : nn .Module = decoder_cls (backbone .feature_info .channels (), ** decoder_kwargs )
120
+ if decoder .startswith ("smp_" ):
121
+ decoder : nn .Module = get_smp_decoder (decoder , backbone_kwargs , smp_kwargs , aux_kwargs , args , out_channels , in_channels , num_classes , output_stride )
122
+ else :
123
+ # allow decoder to be a module passed directly
124
+ decoder_cls = _get_decoder (decoder )
125
+ decoder : nn .Module = decoder_cls (backbone .feature_info .channels (), ** decoder_kwargs )
116
126
# decoder: nn.Module = decoder_cls([128, 256, 512, 1024], **decoder_kwargs)
117
127
118
128
head_kwargs , kwargs = _extract_prefix_keys (kwargs , "head_" )
0 commit comments