6
6
import subprocess
7
7
import os
8
8
9
- from terratorch .models .backbones .prithvi_vit import checkpoint_filter_fn as checkpoint_filter_fn_vit
10
- from terratorch .models .backbones .prithvi_swin import checkpoint_filter_fn as checkpoint_filter_fn_swin
9
+ from terratorch .cli_tools import build_lightning_cli
11
10
12
11
@pytest .mark .parametrize ("model_name" , ["prithvi_swin_B" , "prithvi_swin_L" , "prithvi_vit_100" , "prithvi_vit_300" ])
13
12
def test_finetune_multiple_backbones (model_name ):
@@ -21,11 +20,24 @@ def test_finetune_multiple_backbones(model_name):
21
20
torch .save (state_dict , os .path .join ("tests/" , model_name + ".pt" ))
22
21
23
22
# Running the terratorch CLI
24
- command_str = f"terratorch fit -c tests/manufactured-finetune_{ model_name } .yaml"
23
+ command_list = ["fit" , "-c" , f"tests/manufactured-finetune_{ model_name } .yaml" ]
24
+ _ = build_lightning_cli (command_list )
25
25
26
- command_out = subprocess .run (command_str , shell = True )
26
+ """
27
+ @pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"])
28
+ def test_finetune_multiple_backbones(model_name):
27
29
28
- assert not command_out .returncode
30
+ model_instance = timm.create_model(model_name)
31
+ pretrained_bands = [0, 1, 2, 3, 4, 5]
32
+ model_bands = [0, 1, 2, 3, 4, 5]
29
33
34
+ state_dict = model_instance.state_dict()
35
+
36
+ torch.save(state_dict, os.path.join("tests/", model_name + ".pt"))
30
37
38
+ # Running the terratorch CLI
39
+ command_str = f"python terratorch/__main__.py fit -c tests/manufactured-finetune_{model_name}.yaml"
40
+ command_out = subprocess.run(command_str, shell=True)
31
41
42
+ assert not command_out.returncode
43
+ """
0 commit comments