13
13
NUM_CLASSES = 2
14
14
EXPECTED_SEGMENTATION_OUTPUT_SHAPE = (1 , NUM_CLASSES , 224 , 224 )
15
15
EXPECTED_REGRESSION_OUTPUT_SHAPE = (1 , 224 , 224 )
16
+ EXPECTED_CLASSIFICATION_OUTPUT_SHAPE = (1 , NUM_CLASSES )
16
17
17
18
18
19
@pytest .fixture (scope = "session" )
@@ -24,6 +25,37 @@ def model_factory() -> PrithviModelFactory:
24
25
def model_input () -> torch .Tensor :
25
26
return torch .ones ((1 , NUM_CHANNELS , 224 , 224 ))
26
27
28
+ @pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" , "prithvi_vit_300" ])
29
+ def test_create_classification_model (backbone , model_factory : PrithviModelFactory , model_input ):
30
+ model = model_factory .build_model (
31
+ "classification" ,
32
+ backbone = backbone ,
33
+ decoder = "IdentityDecoder" ,
34
+ in_channels = NUM_CHANNELS ,
35
+ bands = PRETRAINED_BANDS ,
36
+ pretrained = False ,
37
+ num_classes = NUM_CLASSES ,
38
+ )
39
+ model .eval ()
40
+
41
+ with torch .no_grad ():
42
+ assert model (model_input ).output .shape == EXPECTED_CLASSIFICATION_OUTPUT_SHAPE
43
+
44
+ @pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" , "prithvi_vit_300" ])
45
+ def test_create_classification_model_no_in_channels (backbone , model_factory : PrithviModelFactory , model_input ):
46
+ model = model_factory .build_model (
47
+ "classification" ,
48
+ backbone = backbone ,
49
+ decoder = "IdentityDecoder" ,
50
+ bands = PRETRAINED_BANDS ,
51
+ pretrained = False ,
52
+ num_classes = NUM_CLASSES ,
53
+ )
54
+ model .eval ()
55
+
56
+ with torch .no_grad ():
57
+ assert model (model_input ).output .shape == EXPECTED_CLASSIFICATION_OUTPUT_SHAPE
58
+
27
59
@pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" , "prithvi_vit_300" ])
28
60
@pytest .mark .parametrize ("decoder" , ["FCNDecoder" , "UperNetDecoder" , "IdentityDecoder" ])
29
61
def test_create_segmentation_model (backbone , decoder , model_factory : PrithviModelFactory , model_input ):
@@ -41,6 +73,22 @@ def test_create_segmentation_model(backbone, decoder, model_factory: PrithviMode
41
73
with torch .no_grad ():
42
74
assert model (model_input ).output .shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE
43
75
76
+ @pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" , "prithvi_vit_300" ])
77
+ @pytest .mark .parametrize ("decoder" , ["FCNDecoder" , "UperNetDecoder" , "IdentityDecoder" ])
78
+ def test_create_segmentation_model_no_in_channels (backbone , decoder , model_factory : PrithviModelFactory , model_input ):
79
+ model = model_factory .build_model (
80
+ "segmentation" ,
81
+ backbone = backbone ,
82
+ decoder = decoder ,
83
+ bands = PRETRAINED_BANDS ,
84
+ pretrained = False ,
85
+ num_classes = NUM_CLASSES ,
86
+ )
87
+ model .eval ()
88
+
89
+ with torch .no_grad ():
90
+ assert model (model_input ).output .shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE
91
+
44
92
45
93
@pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" , "prithvi_vit_300" ])
46
94
@pytest .mark .parametrize ("decoder" , ["FCNDecoder" , "UperNetDecoder" , "IdentityDecoder" ])
@@ -83,6 +131,20 @@ def test_create_regression_model(backbone, decoder, model_factory: PrithviModelF
83
131
with torch .no_grad ():
84
132
assert model (model_input ).output .shape == EXPECTED_REGRESSION_OUTPUT_SHAPE
85
133
134
+ @pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" , "prithvi_vit_300" ])
135
+ @pytest .mark .parametrize ("decoder" , ["FCNDecoder" , "UperNetDecoder" , "IdentityDecoder" ])
136
+ def test_create_regression_model_no_in_channels (backbone , decoder , model_factory : PrithviModelFactory , model_input ):
137
+ model = model_factory .build_model (
138
+ "regression" ,
139
+ backbone = backbone ,
140
+ decoder = decoder ,
141
+ bands = PRETRAINED_BANDS ,
142
+ pretrained = False ,
143
+ )
144
+ model .eval ()
145
+
146
+ with torch .no_grad ():
147
+ assert model (model_input ).output .shape == EXPECTED_REGRESSION_OUTPUT_SHAPE
86
148
87
149
@pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" , "prithvi_vit_300" ])
88
150
@pytest .mark .parametrize ("decoder" , ["FCNDecoder" , "UperNetDecoder" , "IdentityDecoder" ])
@@ -115,7 +177,7 @@ def test_create_model_with_extra_bands(backbone, decoder, model_factory: Prithvi
115
177
"segmentation" ,
116
178
backbone = backbone ,
117
179
decoder = decoder ,
118
- in_channels = NUM_CHANNELS ,
180
+ in_channels = NUM_CHANNELS + 1 ,
119
181
bands = [* PRETRAINED_BANDS , 7 ], # add an extra band
120
182
pretrained = False ,
121
183
num_classes = NUM_CLASSES ,
0 commit comments