13
13
NUM_CLASSES = 2
14
14
EXPECTED_SEGMENTATION_OUTPUT_SHAPE = (1 , NUM_CLASSES , 224 , 224 )
15
15
EXPECTED_REGRESSION_OUTPUT_SHAPE = (1 , 224 , 224 )
16
+ EXPECTED_CLASSIFICATION_OUTPUT_SHAPE = (1 , NUM_CLASSES )
16
17
17
18
18
19
@pytest .fixture (scope = "session" )
@@ -25,6 +26,37 @@ def model_input() -> torch.Tensor:
25
26
return torch .ones ((1 , NUM_CHANNELS , 224 , 224 ))
26
27
27
28
@pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" , "prithvi_vit_300" ])
29
+ def test_create_classification_model (backbone , model_factory : PrithviModelFactory , model_input ):
30
+ model = model_factory .build_model (
31
+ "classification" ,
32
+ backbone = backbone ,
33
+ decoder = "IdentityDecoder" ,
34
+ in_channels = NUM_CHANNELS ,
35
+ bands = PRETRAINED_BANDS ,
36
+ pretrained = False ,
37
+ num_classes = NUM_CLASSES ,
38
+ )
39
+ model .eval ()
40
+
41
+ with torch .no_grad ():
42
+ assert model (model_input ).output .shape == EXPECTED_CLASSIFICATION_OUTPUT_SHAPE
43
+
44
+ @pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" , "prithvi_vit_300" ])
45
+ def test_create_classification_model_no_in_channels (backbone , model_factory : PrithviModelFactory , model_input ):
46
+ model = model_factory .build_model (
47
+ "classification" ,
48
+ backbone = backbone ,
49
+ decoder = "IdentityDecoder" ,
50
+ bands = PRETRAINED_BANDS ,
51
+ pretrained = False ,
52
+ num_classes = NUM_CLASSES ,
53
+ )
54
+ model .eval ()
55
+
56
+ with torch .no_grad ():
57
+ assert model (model_input ).output .shape == EXPECTED_CLASSIFICATION_OUTPUT_SHAPE
58
+
59
+ @pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" ])
28
60
@pytest .mark .parametrize ("decoder" , ["FCNDecoder" , "UperNetDecoder" , "IdentityDecoder" ])
29
61
def test_create_segmentation_model (backbone , decoder , model_factory : PrithviModelFactory , model_input ):
30
62
model = model_factory .build_model (
@@ -41,8 +73,24 @@ def test_create_segmentation_model(backbone, decoder, model_factory: PrithviMode
41
73
with torch .no_grad ():
42
74
assert model (model_input ).output .shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE
43
75
76
+ @pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" ])
77
+ @pytest .mark .parametrize ("decoder" , ["FCNDecoder" , "UperNetDecoder" , "IdentityDecoder" ])
78
+ def test_create_segmentation_model_no_in_channels (backbone , decoder , model_factory : PrithviModelFactory , model_input ):
79
+ model = model_factory .build_model (
80
+ "segmentation" ,
81
+ backbone = backbone ,
82
+ decoder = decoder ,
83
+ bands = PRETRAINED_BANDS ,
84
+ pretrained = False ,
85
+ num_classes = NUM_CLASSES ,
86
+ )
87
+ model .eval ()
44
88
45
- @pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" , "prithvi_vit_300" ])
89
+ with torch .no_grad ():
90
+ assert model (model_input ).output .shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE
91
+
92
+
93
+ @pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" ])
46
94
@pytest .mark .parametrize ("decoder" , ["FCNDecoder" , "UperNetDecoder" , "IdentityDecoder" ])
47
95
def test_create_segmentation_model_with_aux_heads (backbone , decoder , model_factory : PrithviModelFactory , model_input ):
48
96
aux_heads_name = ["first_aux" , "second_aux" ]
@@ -67,7 +115,7 @@ def test_create_segmentation_model_with_aux_heads(backbone, decoder, model_facto
67
115
assert output .shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE
68
116
69
117
70
- @pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" , "prithvi_vit_300" ])
118
+ @pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" ])
71
119
@pytest .mark .parametrize ("decoder" , ["FCNDecoder" , "UperNetDecoder" , "IdentityDecoder" ])
72
120
def test_create_regression_model (backbone , decoder , model_factory : PrithviModelFactory , model_input ):
73
121
model = model_factory .build_model (
@@ -83,8 +131,22 @@ def test_create_regression_model(backbone, decoder, model_factory: PrithviModelF
83
131
with torch .no_grad ():
84
132
assert model (model_input ).output .shape == EXPECTED_REGRESSION_OUTPUT_SHAPE
85
133
134
+ @pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" ])
135
+ @pytest .mark .parametrize ("decoder" , ["FCNDecoder" , "UperNetDecoder" , "IdentityDecoder" ])
136
+ def test_create_regression_model_no_in_channels (backbone , decoder , model_factory : PrithviModelFactory , model_input ):
137
+ model = model_factory .build_model (
138
+ "regression" ,
139
+ backbone = backbone ,
140
+ decoder = decoder ,
141
+ bands = PRETRAINED_BANDS ,
142
+ pretrained = False ,
143
+ )
144
+ model .eval ()
145
+
146
+ with torch .no_grad ():
147
+ assert model (model_input ).output .shape == EXPECTED_REGRESSION_OUTPUT_SHAPE
86
148
87
- @pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" , "prithvi_vit_300" ])
149
+ @pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" ])
88
150
@pytest .mark .parametrize ("decoder" , ["FCNDecoder" , "UperNetDecoder" , "IdentityDecoder" ])
89
151
def test_create_regression_model_with_aux_heads (backbone , decoder , model_factory : PrithviModelFactory , model_input ):
90
152
aux_heads_name = ["first_aux" , "second_aux" ]
@@ -108,14 +170,14 @@ def test_create_regression_model_with_aux_heads(backbone, decoder, model_factory
108
170
assert output .shape == EXPECTED_REGRESSION_OUTPUT_SHAPE
109
171
110
172
111
- @pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" , "prithvi_vit_300" ])
173
+ @pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" ])
112
174
@pytest .mark .parametrize ("decoder" , ["FCNDecoder" , "UperNetDecoder" , "IdentityDecoder" ])
113
175
def test_create_model_with_extra_bands (backbone , decoder , model_factory : PrithviModelFactory ):
114
176
model = model_factory .build_model (
115
177
"segmentation" ,
116
178
backbone = backbone ,
117
179
decoder = decoder ,
118
- in_channels = NUM_CHANNELS ,
180
+ in_channels = NUM_CHANNELS + 1 ,
119
181
bands = [* PRETRAINED_BANDS , 7 ], # add an extra band
120
182
pretrained = False ,
121
183
num_classes = NUM_CLASSES ,
0 commit comments