@@ -20,7 +20,7 @@ def model_input() -> torch.Tensor:
20
20
return torch .ones ((1 , NUM_CHANNELS , 224 , 224 ))
21
21
22
22
23
- @pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" , "prithvi_vit_300" ])
23
+ @pytest .mark .parametrize ("backbone" ,["prithvi_vit_100" , "prithvi_vit_300" , "prithvi_swin_B " ])
24
24
@pytest .mark .parametrize ("decoder" , ["FCNDecoder" , "UperNetDecoder" , "IdentityDecoder" ])
25
25
@pytest .mark .parametrize ("loss" , ["ce" , "jaccard" , "focal" , "dice" ])
26
26
def test_create_segmentation_task (backbone , decoder , loss , model_factory : PrithviModelFactory ):
@@ -38,7 +38,7 @@ def test_create_segmentation_task(backbone, decoder, loss, model_factory: Prithv
38
38
)
39
39
40
40
41
- @pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" , "prithvi_vit_300" ])
41
+ @pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" , "prithvi_vit_300" , "prithvi_swin_B" ])
42
42
@pytest .mark .parametrize ("decoder" , ["FCNDecoder" , "UperNetDecoder" , "IdentityDecoder" ])
43
43
@pytest .mark .parametrize ("loss" , ["mae" , "rmse" , "huber" ])
44
44
def test_create_regression_task (backbone , decoder , loss , model_factory : PrithviModelFactory ):
@@ -55,7 +55,7 @@ def test_create_regression_task(backbone, decoder, loss, model_factory: PrithviM
55
55
)
56
56
57
57
58
- @pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" , "prithvi_vit_300" ])
58
+ @pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" , "prithvi_vit_300" , "prithvi_swin_B" ])
59
59
@pytest .mark .parametrize ("decoder" , ["FCNDecoder" , "UperNetDecoder" , "IdentityDecoder" ])
60
60
@pytest .mark .parametrize ("loss" , ["ce" , "bce" , "jaccard" , "focal" ])
61
61
def test_create_classification_task (backbone , decoder , loss , model_factory : PrithviModelFactory ):
0 commit comments