7
7
import os
8
8
9
9
from terratorch .cli_tools import build_lightning_cli
10
-
10
+ """
11
11
@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"])
12
12
def test_finetune_multiple_backbones(model_name):
13
13
@@ -22,5 +22,22 @@ def test_finetune_multiple_backbones(model_name):
22
22
# Running the terratorch CLI
23
23
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}.yaml"]
24
24
_ = build_lightning_cli(command_list)
25
+ """
26
+
27
+ @pytest .mark .parametrize ("model_name" , ["prithvi_swin_B" , "prithvi_swin_L" , "prithvi_vit_100" , "prithvi_vit_300" ])
28
+ def test_finetune_multiple_backbones (model_name ):
29
+
30
+ model_instance = timm .create_model (model_name )
31
+ pretrained_bands = [0 , 1 , 2 , 3 , 4 , 5 ]
32
+ model_bands = [0 , 1 , 2 , 3 , 4 , 5 ]
33
+
34
+ state_dict = model_instance .state_dict ()
25
35
36
+ torch .save (state_dict , os .path .join ("tests/" , model_name + ".pt" ))
37
+
38
+ # Running the terratorch CLI
39
+ command_str = f"python terratorch/__main__.py fit -c tests/manufactured-finetune_{ model_name } .yaml"
40
+ command_out = subprocess .run (command_str , shell = True )
41
+
42
+ assert not command_out .returncode
26
43
0 commit comments