Skip to content

Commit 0c1fad8

Browse files
Alternative tests
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent eb815d1 commit 0c1fad8

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

.github/workflows/test.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,4 @@ jobs:
3232
run: pip list
3333
- name: Test with pytest
3434
run: |
35-
export PYTHONPATH=.
3635
pytest -s tests

tests/test_finetune.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import os
88

99
from terratorch.cli_tools import build_lightning_cli
10-
10+
"""
1111
@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"])
1212
def test_finetune_multiple_backbones(model_name):
1313
@@ -22,5 +22,22 @@ def test_finetune_multiple_backbones(model_name):
2222
# Running the terratorch CLI
2323
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}.yaml"]
2424
_ = build_lightning_cli(command_list)
25+
"""
26+
27+
@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"])
28+
def test_finetune_multiple_backbones(model_name):
29+
30+
model_instance = timm.create_model(model_name)
31+
pretrained_bands = [0, 1, 2, 3, 4, 5]
32+
model_bands = [0, 1, 2, 3, 4, 5]
33+
34+
state_dict = model_instance.state_dict()
2535

36+
torch.save(state_dict, os.path.join("tests/", model_name + ".pt"))
37+
38+
# Running the terratorch CLI
39+
command_str = f"python terratorch/__main__.py fit -c tests/manufactured-finetune_{model_name}.yaml"
40+
command_out = subprocess.run(command_str, shell=True)
41+
42+
assert not command_out.returncode
2643

0 commit comments

Comments
 (0)