@@ -22,7 +22,8 @@ def model_input() -> torch.Tensor:
22
22
23
23
@pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" , "prithvi_vit_300" ])
24
24
@pytest .mark .parametrize ("decoder" , ["FCNDecoder" , "UperNetDecoder" , "IdentityDecoder" ])
25
- def test_create_segmentation_task (backbone , decoder , model_factory : PrithviModelFactory ):
25
+ @pytest .mark .parametrize ("loss" , ["ce" , "jaccard" , "focal" , "dice" ])
26
+ def test_create_segmentation_task (backbone , decoder , loss , model_factory : PrithviModelFactory ):
26
27
SemanticSegmentationTask (
27
28
{
28
29
"backbone" : backbone ,
@@ -33,12 +34,14 @@ def test_create_segmentation_task(backbone, decoder, model_factory: PrithviModel
33
34
"num_classes" : NUM_CLASSES ,
34
35
},
35
36
model_factory ,
37
+ loss = loss
36
38
)
37
39
38
40
39
41
@pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" , "prithvi_vit_300" ])
40
42
@pytest .mark .parametrize ("decoder" , ["FCNDecoder" , "UperNetDecoder" , "IdentityDecoder" ])
41
- def test_create_regression_task (backbone , decoder , model_factory : PrithviModelFactory ):
43
+ @pytest .mark .parametrize ("loss" , ["mae" , "rmse" , "huber" ])
44
+ def test_create_regression_task (backbone , decoder , loss , model_factory : PrithviModelFactory ):
42
45
PixelwiseRegressionTask (
43
46
{
44
47
"backbone" : backbone ,
@@ -48,12 +51,14 @@ def test_create_regression_task(backbone, decoder, model_factory: PrithviModelFa
48
51
"pretrained" : False ,
49
52
},
50
53
model_factory ,
54
+ loss = loss
51
55
)
52
56
53
57
54
58
@pytest .mark .parametrize ("backbone" , ["prithvi_vit_100" , "prithvi_vit_300" ])
55
59
@pytest .mark .parametrize ("decoder" , ["FCNDecoder" , "UperNetDecoder" , "IdentityDecoder" ])
56
- def test_create_classification_task (backbone , decoder , model_factory : PrithviModelFactory ):
60
+ @pytest .mark .parametrize ("loss" , ["ce" , "bce" , "jaccard" , "focal" ])
61
+ def test_create_classification_task (backbone , decoder , loss , model_factory : PrithviModelFactory ):
57
62
ClassificationTask (
58
63
{
59
64
"backbone" : backbone ,
@@ -64,4 +69,5 @@ def test_create_classification_task(backbone, decoder, model_factory: PrithviMod
64
69
"num_classes" : NUM_CLASSES ,
65
70
},
66
71
model_factory ,
72
+ loss = loss
67
73
)
0 commit comments