6
6
import segmentation_models_pytorch as smp
7
7
import torch
8
8
import torch .nn .functional as F # noqa: N812
9
- from segmentation_models_pytorch .encoders import encoders as ENCODERS
9
+ from segmentation_models_pytorch .encoders import encoders as smp_encoders
10
10
from torch import nn
11
11
12
12
from terratorch .datasets import HLSBands
@@ -27,6 +27,7 @@ class SMPDecoderForPrithviWrapper(nn.Module):
27
27
forward_single_embed(x) -> torch.Tensor:
28
28
Forward pass for a single embedding.
29
29
"""
30
+
30
31
def __init__ (self , decoder , num_channels ) -> None :
31
32
"""
32
33
Args:
@@ -48,7 +49,6 @@ def forward_single_embed(self, x):
48
49
return self .decoder (x [- 1 ])
49
50
50
51
51
-
52
52
class SMPModelWrapper (Model , nn .Module ):
53
53
"""
54
54
Wrapper class for SMP models.
@@ -69,21 +69,17 @@ class SMPModelWrapper(Model, nn.Module):
69
69
freeze_decoder() -> None:
70
70
Freezes the parameters of the decoder part of the model.
71
71
"""
72
- def __init__ (
73
- self ,
74
- smp_model ,
75
- rescale = True ,
76
- relu = False ,
77
- squeeze_single_class = False
78
- ) -> None :
79
72
73
+ def __init__ (self , smp_model , rescale = True , relu = False , squeeze_single_class = False ) -> None : # noqa: FBT002
80
74
super ().__init__ ()
81
75
"""
82
76
Args:
83
77
smp_model (nn.Module): The base SMP model to be wrapped.
84
78
rescale (bool, optional): Whether to rescale the output to match the input dimensions. Defaults to True.
85
- relu (bool, optional): Whether to apply ReLU activation on the output. If False, Identity activation is used. Defaults to False.
86
- squeeze_single_class (bool, optional): Whether to squeeze the output if there is a single output class. Defaults to False.
79
+ relu (bool, optional): Whether to apply ReLU activation on the output.
80
+ If False, Identity activation is used. Defaults to False.
81
+ squeeze_single_class (bool, optional): Whether to squeeze the output if there is a single output class.
82
+ Defaults to False.
87
83
"""
88
84
self .rescale = rescale
89
85
self .smp_model = smp_model
@@ -95,7 +91,7 @@ def forward(self, x):
95
91
smp_output = self .smp_model (x )
96
92
smp_output = self .final_act (smp_output )
97
93
98
- #TODO: support auxiliary head labels
94
+ # TODO: support auxiliary head labels
99
95
if isinstance (smp_output , tuple ):
100
96
smp_output , labels = smp_output
101
97
@@ -123,9 +119,9 @@ def build_model(
123
119
bands : list [HLSBands | int ],
124
120
in_channels : int | None = None ,
125
121
num_classes : int = 1 ,
126
- pretrained : str | bool | None = True ,
122
+ pretrained : str | bool | None = True , # noqa: FBT002
127
123
prepare_features_for_image_model : Callable | None = None ,
128
- regression_relu : bool = False ,
124
+ regression_relu : bool = False , # noqa: FBT001, FBT002
129
125
** kwargs ,
130
126
) -> Model :
131
127
"""
@@ -173,9 +169,9 @@ def build_model(
173
169
msg = f"Decoder { model } is not supported in SMP."
174
170
raise ValueError (msg )
175
171
176
- backbone_kwargs = _extract_prefix_keys (kwargs , "backbone_" ) # Encoder params should be prefixed backbone_
177
- smp_kwargs = _extract_prefix_keys (backbone_kwargs , "smp_" ) # Smp model params should be prefixed smp_
178
- aux_params = _extract_prefix_keys (backbone_kwargs , "aux_" ) # Auxiliary head params should be prefixed aux_
172
+ backbone_kwargs = _extract_prefix_keys (kwargs , "backbone_" ) # Encoder params should be prefixed backbone_
173
+ smp_kwargs = _extract_prefix_keys (backbone_kwargs , "smp_" ) # Smp model params should be prefixed smp_
174
+ aux_params = _extract_prefix_keys (backbone_kwargs , "aux_" ) # Auxiliary head params should be prefixed aux_
179
175
aux_params = None if aux_params == {} else aux_params
180
176
181
177
if isinstance (pretrained , bool ):
@@ -185,12 +181,12 @@ def build_model(
185
181
pretrained = None
186
182
187
183
# If encoder not currently supported by SMP (custom encoder).
188
- if backbone not in ENCODERS :
184
+ if backbone not in smp_encoders :
189
185
# These params must be included in the config file with appropriate prefix.
190
186
required_params = {
191
187
"encoder_depth" : smp_kwargs ,
192
188
"out_channels" : backbone_kwargs ,
193
- "output_stride" : backbone_kwargs
189
+ "output_stride" : backbone_kwargs ,
194
190
}
195
191
196
192
for param , config_dict in required_params .items ():
@@ -209,7 +205,7 @@ def build_model(
209
205
"encoder_weights" : pretrained ,
210
206
"in_channels" : in_channels ,
211
207
"classes" : num_classes ,
212
- ** smp_kwargs
208
+ ** smp_kwargs ,
213
209
}
214
210
# Using SMP encoder.
215
211
else :
@@ -218,15 +214,13 @@ def build_model(
218
214
"encoder_weights" : pretrained ,
219
215
"in_channels" : in_channels ,
220
216
"classes" : num_classes ,
221
- ** smp_kwargs
217
+ ** smp_kwargs ,
222
218
}
223
219
224
220
model = model_module (** model_args , aux_params = aux_params )
225
221
226
222
return SMPModelWrapper (
227
- model ,
228
- relu = task == "regression" and regression_relu ,
229
- squeeze_single_class = task == "regression"
223
+ model , relu = task == "regression" and regression_relu , squeeze_single_class = task == "regression"
230
224
)
231
225
232
226
@@ -240,7 +234,7 @@ def get_smp_decoder(
240
234
in_channels : int ,
241
235
num_classes : int ,
242
236
output_stride : int ,
243
- ) :
237
+ ):
244
238
"""
245
239
Creates and configures a decoder from the Segmentation Models Pytorch (SMP) library.
246
240
@@ -279,8 +273,8 @@ def get_smp_decoder(
279
273
# Little hack to make SMP model accept our encoder.
280
274
# passes a dummy encoder to be changed later.
281
275
# this is needed to pass encoder params.
282
- backbone_kwargs [' out_channels' ] = out_channels
283
- backbone_kwargs [' output_stride' ] = output_stride
276
+ backbone_kwargs [" out_channels" ] = out_channels
277
+ backbone_kwargs [" output_stride" ] = output_stride
284
278
aux_kwargs = None if aux_kwargs == {} else aux_kwargs
285
279
286
280
dummy_encoder = _make_smp_encoder ()
@@ -298,34 +292,33 @@ def get_smp_decoder(
298
292
"encoder_weights" : None ,
299
293
"in_channels" : in_channels ,
300
294
"classes" : num_classes ,
301
- ** smp_kwargs
295
+ ** smp_kwargs ,
302
296
}
303
-
297
+
304
298
# Creates model with dummy encoder and decoder.
305
299
model = decoder_module (** model_args , aux_params = aux_kwargs )
306
300
307
301
# Wrapper for SMP Decoder.
308
- smp_decoder = SMPDecoderForPrithviWrapper (
309
- decoder = model .decoder ,
310
- num_channels = out_channels [- 1 ]
311
- )
302
+ smp_decoder = SMPDecoderForPrithviWrapper (decoder = model .decoder , num_channels = out_channels [- 1 ])
312
303
if "multiple_embed" in head_kwargs :
313
304
smp_decoder .forward = smp_decoder .forward_multiple_embeds
314
305
else :
315
306
smp_decoder .forward = smp_decoder .forward_single_embed
316
307
317
308
return smp_decoder
318
309
310
+
319
311
# Registers a custom encoder into SMP.
320
- def _register_custom_encoder ( encoder , params , pretrained ):
321
- ENCODERS ["SMPEncoderWrapperWithPFFIM" ] = {
312
+ def _register_custom_encoder (encoder , params , pretrained ):
313
+ smp_encoders ["SMPEncoderWrapperWithPFFIM" ] = {
322
314
"encoder" : encoder ,
323
315
"params" : params ,
324
316
"pretrained_settings" : pretrained
325
317
}
326
318
319
+
327
320
# Gets class either from string or from Module reference.
328
- def _make_smp_encoder (encoder = None ):
321
+ def _make_smp_encoder (encoder = None ):
329
322
if isinstance (encoder , str ):
330
323
base_class = _get_class_from_string (encoder )
331
324
else :
@@ -334,14 +327,14 @@ def _make_smp_encoder(encoder = None):
334
327
# Wrapper needed to include SMP params and PFFIM
335
328
class SMPEncoderWrapperWithPFFIM (base_class ):
336
329
def __init__ (
337
- self ,
338
- depth : int ,
339
- output_stride : int ,
340
- out_channels : list [int ],
341
- prepare_features_for_image_model : Callable | None = None ,
342
- * args ,
343
- ** kwargs
344
- ) -> None :
330
+ self ,
331
+ depth : int ,
332
+ output_stride : int ,
333
+ out_channels : list [int ],
334
+ prepare_features_for_image_model : Callable | None = None ,
335
+ * args ,
336
+ ** kwargs ,
337
+ ) -> None :
345
338
super ().__init__ (* args , ** kwargs )
346
339
self ._depth = depth
347
340
self ._output_stride = output_stride
@@ -362,7 +355,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
362
355
features = super ().forward (x )
363
356
return self .prepare_features_for_image_model (features )
364
357
365
-
366
358
@property
367
359
def out_channels (self ):
368
360
if hasattr (super (), "out_channels" ):
@@ -409,24 +401,23 @@ def _extract_prefix_keys(d: dict, prefix: str) -> dict:
409
401
def _get_class_from_string (class_path ):
410
402
try :
411
403
module_path , name = class_path .rsplit ("." , 1 )
412
- except ValueError :
404
+ except ValueError as vr :
413
405
msg = "Path must contain a '.' separating module from the class name"
414
- raise ValueError (msg )
406
+ raise ValueError (msg ) from vr
415
407
416
408
try :
417
409
module = importlib .import_module (module_path )
418
- except ImportError :
410
+ except ImportError as ie :
419
411
msg = f"Could not import module '{ module_path } '."
420
- raise ImportError (msg )
412
+ raise ImportError (msg ) from ie
421
413
422
414
try :
423
415
return getattr (module , name )
424
- except AttributeError :
416
+ except AttributeError as ae :
425
417
msg = f"The class '{ name } ' was not found in the module '{ module_path } '."
426
- raise AttributeError (msg )
418
+ raise AttributeError (msg ) from ae
427
419
428
420
429
421
def freeze_module (module : nn .Module ):
430
422
for param in module .parameters ():
431
423
param .requires_grad_ (False )
432
-
0 commit comments