Skip to content

Commit 3481674

Browse files
author
Carlos Gomes
committed
add tests for swin backbone
Signed-off-by: Carlos Gomes <[email protected]>
1 parent cefb8ef commit 3481674

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

tests/test_backbones.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,23 @@ def input_386():
2828
return torch.ones((1, NUM_CHANNELS, 386, 386))
2929

3030

31-
@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300"]) #["prithvi_swin_90_us", "prithvi_vit_100", "prithvi_vit_300"])
31+
@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"])
3232
@pytest.mark.parametrize("test_input", ["input_224", "input_512"])
3333
def test_can_create_backbones_from_timm(model_name, test_input, request):
3434
backbone = timm.create_model(model_name, pretrained=False)
3535
input_tensor = request.getfixturevalue(test_input)
3636
backbone(input_tensor)
3737

3838

39-
@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300"])
39+
@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"])
4040
@pytest.mark.parametrize("test_input", ["input_224", "input_512"])
4141
def test_can_create_backbones_from_timm_features_only(model_name, test_input, request):
4242
backbone = timm.create_model(model_name, pretrained=False, features_only=True)
4343
input_tensor = request.getfixturevalue(test_input)
4444
backbone(input_tensor)
4545

4646

47-
@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300"])
47+
@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"])
4848
def test_vit_models_accept_multitemporal(model_name, input_224_multitemporal):
4949
backbone = timm.create_model(model_name, pretrained=False, num_frames=NUM_FRAMES)
5050
backbone(input_224_multitemporal)

tests/test_prithvi_tasks.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def model_input() -> torch.Tensor:
2020
return torch.ones((1, NUM_CHANNELS, 224, 224))
2121

2222

23-
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
23+
@pytest.mark.parametrize("backbone",["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"])
2424
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
2525
@pytest.mark.parametrize("loss", ["ce", "jaccard", "focal", "dice"])
2626
def test_create_segmentation_task(backbone, decoder, loss, model_factory: PrithviModelFactory):
@@ -38,7 +38,7 @@ def test_create_segmentation_task(backbone, decoder, loss, model_factory: Prithv
3838
)
3939

4040

41-
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
41+
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"])
4242
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
4343
@pytest.mark.parametrize("loss", ["mae", "rmse", "huber"])
4444
def test_create_regression_task(backbone, decoder, loss, model_factory: PrithviModelFactory):
@@ -55,7 +55,7 @@ def test_create_regression_task(backbone, decoder, loss, model_factory: PrithviM
5555
)
5656

5757

58-
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
58+
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"])
5959
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
6060
@pytest.mark.parametrize("loss", ["ce", "bce", "jaccard", "focal"])
6161
def test_create_classification_task(backbone, decoder, loss, model_factory: PrithviModelFactory):

0 commit comments

Comments
 (0)