Skip to content

Commit 7edbbf8

Browse files
committed
adds SMPModelFactory tests and SMPModelFactory to model.md
Signed-off-by: Pedro Henrique Conrado <[email protected]>
1 parent e544105 commit 7edbbf8

File tree

4 files changed

+91
-8
lines changed

4 files changed

+91
-8
lines changed

docs/models.md

+1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ By passing a list of bands being used to the constructor, we automatically filte
6161

6262
## Model Factory
6363
### :::terratorch.models.PrithviModelFactory
64+
### :::terratorch.models.SMPModelFactory
6465

6566
# Adding new model types
6667
Adding new model types is as simple as creating a new factory that produces models. See for instance the example below for a potential `SMPModelFactory`

terratorch/models/prithvi_model_factory.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,9 @@ def build_model(
9999
msg = f"Task {task} not supported. Please choose one of {SUPPORTED_TASKS}"
100100
raise NotImplementedError(msg)
101101

102+
backbone_kwargs, kwargs = _extract_prefix_keys(kwargs, "backbone_")
102103
# These params are used in case we need a SMP decoder
103104
# but should not be used for timm encoder
104-
backbone_kwargs, kwargs = _extract_prefix_keys(kwargs, "backbone_")
105105
output_stride = backbone_kwargs.pop("output_stride", None)
106106
out_channels = backbone_kwargs.pop("out_channels", None)
107107

@@ -179,6 +179,7 @@ class SMPDecoderForPrithviWrapper(nn.Module):
179179
forward(x: List[torch.Tensor]) -> torch.Tensor:
180180
Forward pass for embeddings with specified indices.
181181
"""
182+
182183
def __init__(self, decoder, num_channels, in_index=-1) -> None:
183184
"""
184185
Args:

terratorch/models/smp_model_factory.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,9 @@ def build_model(
9696
as auxiliary decoders or modified encoders.
9797
9898
Attributes:
99-
task (str): Specifies the task for which the model is being built. Supported tasks include
100-
"segmentation" and "regression".
101-
backbone (str, nn.Module): Specifies the backbone model to be used. If a string, it should be
102-
recognized by the model factory and be able to be parsed appropriately.
99+
task (str): Specifies the task for which the model is being built. Supported tasks are
100+
"segmentation".
101+
backbone (str): Specifies the backbone model to be used.
103102
decoder (str): Specifies the decoder to be used for constructing the
104103
segmentation model.
105104
bands (list[terratorch.datasets.HLSBands | int]): A list specifying the bands that the model
@@ -116,12 +115,15 @@ def build_model(
116115
117116
Raises:
118117
ValueError: If the specified decoder is not supported by SMP.
119-
Exception: If the specified task is not "segmentation" or "regression".
118+
Exception: If the specified task is not "segmentation"
120119
121120
Returns:
122121
nn.Module: A model instance wrapped in SMPModelWrapper configured according to the specified
123122
parameters and tasks.
124123
"""
124+
if task != "segmentation":
125+
msg = f"SMP models can only perform segmentatio, but got task {task}"
126+
raise Exception(msg)
125127

126128
bands = [HLSBands.try_convert_to_hls_bands_enum(b) for b in bands]
127129
if in_channels is None:
@@ -197,8 +199,7 @@ def register_custom_encoder(encoder, params, pretrained):
197199
}
198200

199201

200-
# Gets class either from string or from Module reference.
201-
def make_smp_encoder(encoder = None):
202+
def make_smp_encoder(encoder=None):
202203
if isinstance(encoder, str):
203204
base_class = _get_class_from_string(encoder)
204205
else:

tests/test_smp_model_factory.py

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright contributors to the Terratorch project
2+
3+
import os
4+
5+
import pytest
6+
import torch
7+
8+
from terratorch.models import SMPModelFactory
9+
from terratorch.models.backbones.prithvi_vit import PRETRAINED_BANDS
10+
11+
# from terratorch.models.backbones.prithvi_vit import default_cfgs as vit_default_cfgs
12+
13+
NUM_CHANNELS = 6
14+
NUM_CLASSES = 2
15+
EXPECTED_SEGMENTATION_OUTPUT_SHAPE = (1, NUM_CLASSES, 224, 224)
16+
EXPECTED_REGRESSION_OUTPUT_SHAPE = (1, 224, 224)
17+
EXPECTED_CLASSIFICATION_OUTPUT_SHAPE = (1, NUM_CLASSES)
18+
19+
20+
@pytest.fixture(scope="session")
21+
def model_factory() -> SMPModelFactory:
22+
return SMPModelFactory()
23+
24+
25+
@pytest.fixture(scope="session")
26+
def model_input() -> torch.Tensor:
27+
return torch.ones((1, NUM_CHANNELS, 224, 224))
28+
29+
30+
@pytest.mark.parametrize("backbone", ["timm-regnetx_002"])
31+
@pytest.mark.parametrize("model", ["Unet", "DeepLabV3"])
32+
def test_create_segmentation_model(backbone, model, model_factory: SMPModelFactory, model_input):
33+
model = model_factory.build_model(
34+
"segmentation",
35+
backbone=backbone,
36+
model=model,
37+
in_channels=NUM_CHANNELS,
38+
bands=PRETRAINED_BANDS,
39+
pretrained=False,
40+
num_classes=NUM_CLASSES,
41+
)
42+
model.eval()
43+
44+
with torch.no_grad():
45+
assert model(model_input).output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE
46+
47+
48+
@pytest.mark.parametrize("backbone", ["timm-regnetx_002"])
49+
@pytest.mark.parametrize("model", ["Unet", "DeepLabV3"])
50+
def test_create_segmentation_model_no_in_channels(backbone, model, model_factory: SMPModelFactory, model_input):
51+
model = model_factory.build_model(
52+
"segmentation",
53+
backbone=backbone,
54+
model=model,
55+
bands=PRETRAINED_BANDS,
56+
pretrained=False,
57+
num_classes=NUM_CLASSES,
58+
)
59+
model.eval()
60+
61+
with torch.no_grad():
62+
assert model(model_input).output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE
63+
64+
65+
@pytest.mark.parametrize("backbone", ["timm-regnetx_002"])
66+
@pytest.mark.parametrize("model", ["Unet", "DeepLabV3"])
67+
def test_create_model_with_extra_bands(backbone, model, model_factory: SMPModelFactory):
68+
model = model_factory.build_model(
69+
"segmentation",
70+
backbone=backbone,
71+
model=model,
72+
in_channels=NUM_CHANNELS + 1,
73+
bands=[*PRETRAINED_BANDS, 7], # add an extra band
74+
pretrained=False,
75+
num_classes=NUM_CLASSES,
76+
)
77+
model.eval()
78+
model_input = torch.ones((1, NUM_CHANNELS + 1, 224, 224))
79+
with torch.no_grad():
80+
assert model(model_input).output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE

0 commit comments

Comments
 (0)