Skip to content

Commit 0401bf8

Browse files
Testing Swin backbones
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent 0342fa1 commit 0401bf8

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

tests/test_backbones.py

+10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
import timm
33
import torch
4+
import importlib
45

56
import terratorch # noqa: F401
67

@@ -49,7 +50,16 @@ def test_vit_models_accept_multitemporal(model_name, input_224_multitemporal):
4950
backbone = timm.create_model(model_name, pretrained=False, num_frames=NUM_FRAMES)
5051
backbone(input_224_multitemporal)
5152

53+
# Swin IS NOT on HuggingFace
54+
@pytest.mark.parametrize("model_name", ["prithvi_swin_L", "prithvi_swin_B"])
55+
def test_swin_instantiation(model_name):
56+
base_module = "terratorch.models.backbones.prithvi_swin"
57+
module = importlib.import_module(base_module)
58+
model_class = getattr(module, model_name)
5259

60+
model = model_class(pretrained=False, pretrained_bands=[0,1,2,3,4,5,6,7,8,9],
61+
bands=[1,2,3,4,5,6])
62+
5363
#def test_swin_models_accept_non_divisible_by_patch_size(input_386):
5464
# backbone = timm.create_model("prithvi_swin_90_us", pretrained=False, num_frames=NUM_FRAMES)
5565
# backbone(input_386)

0 commit comments

Comments
 (0)