Skip to content

Commit 52699a9

Browse files
committed
extends smp model factory and adds functionalities in prithvi model factory
Signed-off-by: Pedro Henrique Conrado <[email protected]>
1 parent fb68d27 commit 52699a9

File tree

3 files changed

+290
-212
lines changed

3 files changed

+290
-212
lines changed

terratorch/models/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from terratorch.models.smp_model_factory import SMPModelFactory
77
from terratorch.models.timm_model_factory import TimmModelFactory
88

9+
from terratorch.models.smp_model_factory import get_smp_decoder
10+
911
__all__ = (
1012
"PrithviModelFactory",
1113
"ClayModelFactory",
@@ -15,4 +17,5 @@
1517
"TimmModelFactory",
1618
"AuxiliaryHead",
1719
"AuxiliaryHeadWithDecoderWithoutInstantiatedHead",
20+
"get_smp_decoder",
1821
)

terratorch/models/prithvi_model_factory.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818
from terratorch.models.pixel_wise_model import PixelWiseModel
1919
from terratorch.models.scalar_output_model import ScalarOutputModel
20+
from terratorch.models.smp_model_factory import get_smp_decoder
2021

2122
PIXEL_WISE_TASKS = ["segmentation", "regression"]
2223
SCALAR_TASKS = ["classification"]
@@ -95,7 +96,13 @@ def build_model(
9596
msg = f"Task {task} not supported. Please choose one of {SUPPORTED_TASKS}"
9697
raise NotImplementedError(msg)
9798

99+
# These params are used in case we need a SMP decoder
100+
# but should not be used for timm encoder
98101
backbone_kwargs, kwargs = _extract_prefix_keys(kwargs, "backbone_")
102+
smp_kwargs, kwargs = _extract_prefix_keys(kwargs, "smp_")
103+
aux_kwargs, kwargs = _extract_prefix_keys(kwargs, "aux_")
104+
output_stride = backbone_kwargs.pop('output_stride', None)
105+
out_channels = backbone_kwargs.pop('out_channels', None)
99106

100107
backbone: nn.Module = timm.create_model(
101108
backbone,
@@ -106,13 +113,16 @@ def build_model(
106113
features_only=True,
107114
**backbone_kwargs,
108115
)
109-
# allow decoder to be a module passed directly
110-
decoder_cls = _get_decoder(decoder)
111116

112117
decoder_kwargs, kwargs = _extract_prefix_keys(kwargs, "decoder_")
113-
118+
args = kwargs.copy()
114119
# TODO: remove this
115-
decoder: nn.Module = decoder_cls(backbone.feature_info.channels(), **decoder_kwargs)
120+
if decoder.startswith("smp_"):
121+
decoder: nn.Module = get_smp_decoder(decoder, backbone_kwargs, smp_kwargs, aux_kwargs, args, out_channels, in_channels, num_classes, output_stride)
122+
else:
123+
# allow decoder to be a module passed directly
124+
decoder_cls = _get_decoder(decoder)
125+
decoder: nn.Module = decoder_cls(backbone.feature_info.channels(), **decoder_kwargs)
116126
# decoder: nn.Module = decoder_cls([128, 256, 512, 1024], **decoder_kwargs)
117127

118128
head_kwargs, kwargs = _extract_prefix_keys(kwargs, "head_")

0 commit comments

Comments
 (0)