Skip to content

Commit 93c523e

Browse files
committed
extends SMPModelFactory
Signed-off-by: Pedro Henrique Conrado <[email protected]>
1 parent 52699a9 commit 93c523e

File tree

2 files changed

+59
-56
lines changed

2 files changed

+59
-56
lines changed

terratorch/models/prithvi_model_factory.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
class DecoderNotFoundError(Exception):
2828
pass
2929

30+
3031
@register_factory
3132
class PrithviModelFactory(ModelFactory):
3233
def build_model(
@@ -35,7 +36,8 @@ def build_model(
3536
backbone: str | nn.Module,
3637
decoder: str | nn.Module,
3738
bands: list[HLSBands | int],
38-
in_channels: int | None = None, # this should be removed, can be derived from bands. But it is a breaking change
39+
in_channels: int
40+
| None = None, # this should be removed, can be derived from bands. But it is a breaking change
3941
num_classes: int | None = None,
4042
pretrained: bool = True, # noqa: FBT001, FBT002
4143
num_frames: int = 1,
@@ -101,8 +103,8 @@ def build_model(
101103
backbone_kwargs, kwargs = _extract_prefix_keys(kwargs, "backbone_")
102104
smp_kwargs, kwargs = _extract_prefix_keys(kwargs, "smp_")
103105
aux_kwargs, kwargs = _extract_prefix_keys(kwargs, "aux_")
104-
output_stride = backbone_kwargs.pop('output_stride', None)
105-
out_channels = backbone_kwargs.pop('out_channels', None)
106+
output_stride = backbone_kwargs.pop("output_stride", None)
107+
out_channels = backbone_kwargs.pop("out_channels", None)
106108

107109
backbone: nn.Module = timm.create_model(
108110
backbone,
@@ -118,7 +120,17 @@ def build_model(
118120
args = kwargs.copy()
119121
# TODO: remove this
120122
if decoder.startswith("smp_"):
121-
decoder: nn.Module = get_smp_decoder(decoder, backbone_kwargs, smp_kwargs, aux_kwargs, args, out_channels, in_channels, num_classes, output_stride)
123+
decoder: nn.Module = get_smp_decoder(
124+
decoder,
125+
backbone_kwargs,
126+
smp_kwargs,
127+
aux_kwargs,
128+
args,
129+
out_channels,
130+
in_channels,
131+
num_classes,
132+
output_stride,
133+
)
122134
else:
123135
# allow decoder to be a module passed directly
124136
decoder_cls = _get_decoder(decoder)

terratorch/models/smp_model_factory.py

+43-52
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import segmentation_models_pytorch as smp
77
import torch
88
import torch.nn.functional as F # noqa: N812
9-
from segmentation_models_pytorch.encoders import encoders as ENCODERS
9+
from segmentation_models_pytorch.encoders import encoders as smp_encoders
1010
from torch import nn
1111

1212
from terratorch.datasets import HLSBands
@@ -27,6 +27,7 @@ class SMPDecoderForPrithviWrapper(nn.Module):
2727
forward_single_embed(x) -> torch.Tensor:
2828
Forward pass for a single embedding.
2929
"""
30+
3031
def __init__(self, decoder, num_channels) -> None:
3132
"""
3233
Args:
@@ -48,7 +49,6 @@ def forward_single_embed(self, x):
4849
return self.decoder(x[-1])
4950

5051

51-
5252
class SMPModelWrapper(Model, nn.Module):
5353
"""
5454
Wrapper class for SMP models.
@@ -69,21 +69,17 @@ class SMPModelWrapper(Model, nn.Module):
6969
freeze_decoder() -> None:
7070
Freezes the parameters of the decoder part of the model.
7171
"""
72-
def __init__(
73-
self,
74-
smp_model,
75-
rescale = True,
76-
relu=False,
77-
squeeze_single_class=False
78-
) -> None:
7972

73+
def __init__(self, smp_model, rescale=True, relu=False, squeeze_single_class=False) -> None: # noqa: FBT002
8074
super().__init__()
8175
"""
8276
Args:
8377
smp_model (nn.Module): The base SMP model to be wrapped.
8478
rescale (bool, optional): Whether to rescale the output to match the input dimensions. Defaults to True.
85-
relu (bool, optional): Whether to apply ReLU activation on the output. If False, Identity activation is used. Defaults to False.
86-
squeeze_single_class (bool, optional): Whether to squeeze the output if there is a single output class. Defaults to False.
79+
relu (bool, optional): Whether to apply ReLU activation on the output.
80+
If False, Identity activation is used. Defaults to False.
81+
squeeze_single_class (bool, optional): Whether to squeeze the output if there is a single output class.
82+
Defaults to False.
8783
"""
8884
self.rescale = rescale
8985
self.smp_model = smp_model
@@ -95,7 +91,7 @@ def forward(self, x):
9591
smp_output = self.smp_model(x)
9692
smp_output = self.final_act(smp_output)
9793

98-
#TODO: support auxiliary head labels
94+
# TODO: support auxiliary head labels
9995
if isinstance(smp_output, tuple):
10096
smp_output, labels = smp_output
10197

@@ -123,9 +119,9 @@ def build_model(
123119
bands: list[HLSBands | int],
124120
in_channels: int | None = None,
125121
num_classes: int = 1,
126-
pretrained: str | bool | None = True,
122+
pretrained: str | bool | None = True, # noqa: FBT002
127123
prepare_features_for_image_model: Callable | None = None,
128-
regression_relu: bool = False,
124+
regression_relu: bool = False, # noqa: FBT001, FBT002
129125
**kwargs,
130126
) -> Model:
131127
"""
@@ -173,9 +169,9 @@ def build_model(
173169
msg = f"Decoder {model} is not supported in SMP."
174170
raise ValueError(msg)
175171

176-
backbone_kwargs = _extract_prefix_keys(kwargs, "backbone_") # Encoder params should be prefixed backbone_
177-
smp_kwargs = _extract_prefix_keys(backbone_kwargs, "smp_") # Smp model params should be prefixed smp_
178-
aux_params = _extract_prefix_keys(backbone_kwargs, "aux_") # Auxiliary head params should be prefixed aux_
172+
backbone_kwargs = _extract_prefix_keys(kwargs, "backbone_") # Encoder params should be prefixed backbone_
173+
smp_kwargs = _extract_prefix_keys(backbone_kwargs, "smp_") # Smp model params should be prefixed smp_
174+
aux_params = _extract_prefix_keys(backbone_kwargs, "aux_") # Auxiliary head params should be prefixed aux_
179175
aux_params = None if aux_params == {} else aux_params
180176

181177
if isinstance(pretrained, bool):
@@ -185,12 +181,12 @@ def build_model(
185181
pretrained = None
186182

187183
# If encoder not currently supported by SMP (custom encoder).
188-
if backbone not in ENCODERS:
184+
if backbone not in smp_encoders:
189185
# These params must be included in the config file with appropriate prefix.
190186
required_params = {
191187
"encoder_depth": smp_kwargs,
192188
"out_channels": backbone_kwargs,
193-
"output_stride": backbone_kwargs
189+
"output_stride": backbone_kwargs,
194190
}
195191

196192
for param, config_dict in required_params.items():
@@ -209,7 +205,7 @@ def build_model(
209205
"encoder_weights": pretrained,
210206
"in_channels": in_channels,
211207
"classes": num_classes,
212-
**smp_kwargs
208+
**smp_kwargs,
213209
}
214210
# Using SMP encoder.
215211
else:
@@ -218,15 +214,13 @@ def build_model(
218214
"encoder_weights": pretrained,
219215
"in_channels": in_channels,
220216
"classes": num_classes,
221-
**smp_kwargs
217+
**smp_kwargs,
222218
}
223219

224220
model = model_module(**model_args, aux_params=aux_params)
225221

226222
return SMPModelWrapper(
227-
model,
228-
relu=task == "regression" and regression_relu,
229-
squeeze_single_class=task == "regression"
223+
model, relu=task == "regression" and regression_relu, squeeze_single_class=task == "regression"
230224
)
231225

232226

@@ -240,7 +234,7 @@ def get_smp_decoder(
240234
in_channels: int,
241235
num_classes: int,
242236
output_stride: int,
243-
) :
237+
):
244238
"""
245239
Creates and configures a decoder from the Segmentation Models Pytorch (SMP) library.
246240
@@ -279,8 +273,8 @@ def get_smp_decoder(
279273
# Little hack to make SMP model accept our encoder.
280274
# passes a dummy encoder to be changed later.
281275
# this is needed to pass encoder params.
282-
backbone_kwargs['out_channels'] = out_channels
283-
backbone_kwargs['output_stride'] = output_stride
276+
backbone_kwargs["out_channels"] = out_channels
277+
backbone_kwargs["output_stride"] = output_stride
284278
aux_kwargs = None if aux_kwargs == {} else aux_kwargs
285279

286280
dummy_encoder = _make_smp_encoder()
@@ -298,34 +292,33 @@ def get_smp_decoder(
298292
"encoder_weights": None,
299293
"in_channels": in_channels,
300294
"classes": num_classes,
301-
**smp_kwargs
295+
**smp_kwargs,
302296
}
303-
297+
304298
# Creates model with dummy encoder and decoder.
305299
model = decoder_module(**model_args, aux_params=aux_kwargs)
306300

307301
# Wrapper for SMP Decoder.
308-
smp_decoder = SMPDecoderForPrithviWrapper(
309-
decoder=model.decoder,
310-
num_channels=out_channels[-1]
311-
)
302+
smp_decoder = SMPDecoderForPrithviWrapper(decoder=model.decoder, num_channels=out_channels[-1])
312303
if "multiple_embed" in head_kwargs:
313304
smp_decoder.forward = smp_decoder.forward_multiple_embeds
314305
else:
315306
smp_decoder.forward = smp_decoder.forward_single_embed
316307

317308
return smp_decoder
318309

310+
319311
# Registers a custom encoder into SMP.
320-
def _register_custom_encoder( encoder, params, pretrained):
321-
ENCODERS["SMPEncoderWrapperWithPFFIM"] = {
312+
def _register_custom_encoder(encoder, params, pretrained):
313+
smp_encoders["SMPEncoderWrapperWithPFFIM"] = {
322314
"encoder": encoder,
323315
"params": params,
324316
"pretrained_settings": pretrained
325317
}
326318

319+
327320
# Gets class either from string or from Module reference.
328-
def _make_smp_encoder(encoder = None):
321+
def _make_smp_encoder(encoder=None):
329322
if isinstance(encoder, str):
330323
base_class = _get_class_from_string(encoder)
331324
else:
@@ -334,14 +327,14 @@ def _make_smp_encoder(encoder = None):
334327
# Wrapper needed to include SMP params and PFFIM
335328
class SMPEncoderWrapperWithPFFIM(base_class):
336329
def __init__(
337-
self,
338-
depth: int,
339-
output_stride: int,
340-
out_channels: list[int],
341-
prepare_features_for_image_model: Callable | None = None,
342-
*args,
343-
**kwargs
344-
) -> None:
330+
self,
331+
depth: int,
332+
output_stride: int,
333+
out_channels: list[int],
334+
prepare_features_for_image_model: Callable | None = None,
335+
*args,
336+
**kwargs,
337+
) -> None:
345338
super().__init__(*args, **kwargs)
346339
self._depth = depth
347340
self._output_stride = output_stride
@@ -362,7 +355,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
362355
features = super().forward(x)
363356
return self.prepare_features_for_image_model(features)
364357

365-
366358
@property
367359
def out_channels(self):
368360
if hasattr(super(), "out_channels"):
@@ -409,24 +401,23 @@ def _extract_prefix_keys(d: dict, prefix: str) -> dict:
409401
def _get_class_from_string(class_path):
410402
try:
411403
module_path, name = class_path.rsplit(".", 1)
412-
except ValueError:
404+
except ValueError as vr:
413405
msg = "Path must contain a '.' separating module from the class name"
414-
raise ValueError(msg)
406+
raise ValueError(msg) from vr
415407

416408
try:
417409
module = importlib.import_module(module_path)
418-
except ImportError:
410+
except ImportError as ie:
419411
msg = f"Could not import module '{module_path}'."
420-
raise ImportError(msg)
412+
raise ImportError(msg) from ie
421413

422414
try:
423415
return getattr(module, name)
424-
except AttributeError:
416+
except AttributeError as ae:
425417
msg = f"The class '{name}' was not found in the module '{module_path}'."
426-
raise AttributeError(msg)
418+
raise AttributeError(msg) from ae
427419

428420

429421
def freeze_module(module: nn.Module):
430422
for param in module.parameters():
431423
param.requires_grad_(False)
432-

0 commit comments

Comments
 (0)