Skip to content

Commit 6430c3c

Browse files
committed
add test for default in_channels
Signed-off-by: Carlos Gomes <[email protected]>
1 parent 1e6e1c8 commit 6430c3c

File tree

1 file changed

+63
-1
lines changed

1 file changed

+63
-1
lines changed

tests/test_prithvi_model_factory.py

+63-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
NUM_CLASSES = 2
1414
EXPECTED_SEGMENTATION_OUTPUT_SHAPE = (1, NUM_CLASSES, 224, 224)
1515
EXPECTED_REGRESSION_OUTPUT_SHAPE = (1, 224, 224)
16+
EXPECTED_CLASSIFICATION_OUTPUT_SHAPE = (1, NUM_CLASSES)
1617

1718

1819
@pytest.fixture(scope="session")
@@ -24,6 +25,37 @@ def model_factory() -> PrithviModelFactory:
2425
def model_input() -> torch.Tensor:
2526
return torch.ones((1, NUM_CHANNELS, 224, 224))
2627

28+
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
29+
def test_create_classification_model(backbone, model_factory: PrithviModelFactory, model_input):
30+
model = model_factory.build_model(
31+
"classification",
32+
backbone=backbone,
33+
decoder="IdentityDecoder",
34+
in_channels=NUM_CHANNELS,
35+
bands=PRETRAINED_BANDS,
36+
pretrained=False,
37+
num_classes=NUM_CLASSES,
38+
)
39+
model.eval()
40+
41+
with torch.no_grad():
42+
assert model(model_input).output.shape == EXPECTED_CLASSIFICATION_OUTPUT_SHAPE
43+
44+
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
45+
def test_create_classification_model_no_in_channels(backbone, model_factory: PrithviModelFactory, model_input):
46+
model = model_factory.build_model(
47+
"classification",
48+
backbone=backbone,
49+
decoder="IdentityDecoder",
50+
bands=PRETRAINED_BANDS,
51+
pretrained=False,
52+
num_classes=NUM_CLASSES,
53+
)
54+
model.eval()
55+
56+
with torch.no_grad():
57+
assert model(model_input).output.shape == EXPECTED_CLASSIFICATION_OUTPUT_SHAPE
58+
2759
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
2860
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
2961
def test_create_segmentation_model(backbone, decoder, model_factory: PrithviModelFactory, model_input):
@@ -41,6 +73,22 @@ def test_create_segmentation_model(backbone, decoder, model_factory: PrithviMode
4173
with torch.no_grad():
4274
assert model(model_input).output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE
4375

76+
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
77+
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
78+
def test_create_segmentation_model_no_in_channels(backbone, decoder, model_factory: PrithviModelFactory, model_input):
79+
model = model_factory.build_model(
80+
"segmentation",
81+
backbone=backbone,
82+
decoder=decoder,
83+
bands=PRETRAINED_BANDS,
84+
pretrained=False,
85+
num_classes=NUM_CLASSES,
86+
)
87+
model.eval()
88+
89+
with torch.no_grad():
90+
assert model(model_input).output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE
91+
4492

4593
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
4694
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
@@ -83,6 +131,20 @@ def test_create_regression_model(backbone, decoder, model_factory: PrithviModelF
83131
with torch.no_grad():
84132
assert model(model_input).output.shape == EXPECTED_REGRESSION_OUTPUT_SHAPE
85133

134+
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
135+
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
136+
def test_create_regression_model_no_in_channels(backbone, decoder, model_factory: PrithviModelFactory, model_input):
137+
model = model_factory.build_model(
138+
"regression",
139+
backbone=backbone,
140+
decoder=decoder,
141+
bands=PRETRAINED_BANDS,
142+
pretrained=False,
143+
)
144+
model.eval()
145+
146+
with torch.no_grad():
147+
assert model(model_input).output.shape == EXPECTED_REGRESSION_OUTPUT_SHAPE
86148

87149
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
88150
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
@@ -115,7 +177,7 @@ def test_create_model_with_extra_bands(backbone, decoder, model_factory: Prithvi
115177
"segmentation",
116178
backbone=backbone,
117179
decoder=decoder,
118-
in_channels=NUM_CHANNELS,
180+
in_channels=NUM_CHANNELS + 1,
119181
bands=[*PRETRAINED_BANDS, 7], # add an extra band
120182
pretrained=False,
121183
num_classes=NUM_CLASSES,

0 commit comments

Comments
 (0)