Skip to content

Commit 2866275

Browse files
Merge pull request #13 from IBM/fix/model_factory
Fix/model factory
2 parents c7d9fec + 46e7b9a commit 2866275

File tree

2 files changed

+68
-6
lines changed

2 files changed

+68
-6
lines changed

src/terratorch/models/prithvi_model_factory.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def build_model(
3232
backbone: str | nn.Module,
3333
decoder: str | nn.Module,
3434
bands: list[HLSBands | int],
35-
in_channels: int = int | None, # this should be removed, can be derived from bands. But it is a breaking change
35+
in_channels: int | None = None, # this should be removed, can be derived from bands. But it is a breaking change
3636
num_classes: int | None = None,
3737
pretrained: bool = True, # noqa: FBT001, FBT002
3838
num_frames: int = 1,

tests/test_prithvi_model_factory.py

+67-5
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
NUM_CLASSES = 2
1414
EXPECTED_SEGMENTATION_OUTPUT_SHAPE = (1, NUM_CLASSES, 224, 224)
1515
EXPECTED_REGRESSION_OUTPUT_SHAPE = (1, 224, 224)
16+
EXPECTED_CLASSIFICATION_OUTPUT_SHAPE = (1, NUM_CLASSES)
1617

1718

1819
@pytest.fixture(scope="session")
@@ -25,6 +26,37 @@ def model_input() -> torch.Tensor:
2526
return torch.ones((1, NUM_CHANNELS, 224, 224))
2627

2728
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
29+
def test_create_classification_model(backbone, model_factory: PrithviModelFactory, model_input):
30+
model = model_factory.build_model(
31+
"classification",
32+
backbone=backbone,
33+
decoder="IdentityDecoder",
34+
in_channels=NUM_CHANNELS,
35+
bands=PRETRAINED_BANDS,
36+
pretrained=False,
37+
num_classes=NUM_CLASSES,
38+
)
39+
model.eval()
40+
41+
with torch.no_grad():
42+
assert model(model_input).output.shape == EXPECTED_CLASSIFICATION_OUTPUT_SHAPE
43+
44+
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
45+
def test_create_classification_model_no_in_channels(backbone, model_factory: PrithviModelFactory, model_input):
46+
model = model_factory.build_model(
47+
"classification",
48+
backbone=backbone,
49+
decoder="IdentityDecoder",
50+
bands=PRETRAINED_BANDS,
51+
pretrained=False,
52+
num_classes=NUM_CLASSES,
53+
)
54+
model.eval()
55+
56+
with torch.no_grad():
57+
assert model(model_input).output.shape == EXPECTED_CLASSIFICATION_OUTPUT_SHAPE
58+
59+
@pytest.mark.parametrize("backbone", ["prithvi_vit_100"])
2860
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
2961
def test_create_segmentation_model(backbone, decoder, model_factory: PrithviModelFactory, model_input):
3062
model = model_factory.build_model(
@@ -41,8 +73,24 @@ def test_create_segmentation_model(backbone, decoder, model_factory: PrithviMode
4173
with torch.no_grad():
4274
assert model(model_input).output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE
4375

76+
@pytest.mark.parametrize("backbone", ["prithvi_vit_100"])
77+
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
78+
def test_create_segmentation_model_no_in_channels(backbone, decoder, model_factory: PrithviModelFactory, model_input):
79+
model = model_factory.build_model(
80+
"segmentation",
81+
backbone=backbone,
82+
decoder=decoder,
83+
bands=PRETRAINED_BANDS,
84+
pretrained=False,
85+
num_classes=NUM_CLASSES,
86+
)
87+
model.eval()
4488

45-
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
89+
with torch.no_grad():
90+
assert model(model_input).output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE
91+
92+
93+
@pytest.mark.parametrize("backbone", ["prithvi_vit_100"])
4694
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
4795
def test_create_segmentation_model_with_aux_heads(backbone, decoder, model_factory: PrithviModelFactory, model_input):
4896
aux_heads_name = ["first_aux", "second_aux"]
@@ -67,7 +115,7 @@ def test_create_segmentation_model_with_aux_heads(backbone, decoder, model_facto
67115
assert output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE
68116

69117

70-
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
118+
@pytest.mark.parametrize("backbone", ["prithvi_vit_100"])
71119
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
72120
def test_create_regression_model(backbone, decoder, model_factory: PrithviModelFactory, model_input):
73121
model = model_factory.build_model(
@@ -83,8 +131,22 @@ def test_create_regression_model(backbone, decoder, model_factory: PrithviModelF
83131
with torch.no_grad():
84132
assert model(model_input).output.shape == EXPECTED_REGRESSION_OUTPUT_SHAPE
85133

134+
@pytest.mark.parametrize("backbone", ["prithvi_vit_100"])
135+
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
136+
def test_create_regression_model_no_in_channels(backbone, decoder, model_factory: PrithviModelFactory, model_input):
137+
model = model_factory.build_model(
138+
"regression",
139+
backbone=backbone,
140+
decoder=decoder,
141+
bands=PRETRAINED_BANDS,
142+
pretrained=False,
143+
)
144+
model.eval()
145+
146+
with torch.no_grad():
147+
assert model(model_input).output.shape == EXPECTED_REGRESSION_OUTPUT_SHAPE
86148

87-
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
149+
@pytest.mark.parametrize("backbone", ["prithvi_vit_100"])
88150
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
89151
def test_create_regression_model_with_aux_heads(backbone, decoder, model_factory: PrithviModelFactory, model_input):
90152
aux_heads_name = ["first_aux", "second_aux"]
@@ -108,14 +170,14 @@ def test_create_regression_model_with_aux_heads(backbone, decoder, model_factory
108170
assert output.shape == EXPECTED_REGRESSION_OUTPUT_SHAPE
109171

110172

111-
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
173+
@pytest.mark.parametrize("backbone", ["prithvi_vit_100"])
112174
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
113175
def test_create_model_with_extra_bands(backbone, decoder, model_factory: PrithviModelFactory):
114176
model = model_factory.build_model(
115177
"segmentation",
116178
backbone=backbone,
117179
decoder=decoder,
118-
in_channels=NUM_CHANNELS,
180+
in_channels=NUM_CHANNELS + 1,
119181
bands=[*PRETRAINED_BANDS, 7], # add an extra band
120182
pretrained=False,
121183
num_classes=NUM_CLASSES,

0 commit comments

Comments
 (0)