Skip to content

Commit 308d540

Browse files
Merge pull request #33 from IBM/finetune_tests
Using the Lightning CLI interface to run tests
2 parents 1ebeeb5 + 11974fe commit 308d540

File tree

4 files changed

+44
-6
lines changed

4 files changed

+44
-6
lines changed

.github/dependabot.yaml

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file
2+
# mostly from https://github.com/microsoft/torchgeo/blob/main/.github/dependabot.yml
3+
version: 2
4+
updates:
5+
- package-ecosystem: "github-actions"
6+
directory: "/"
7+
schedule:
8+
interval: "weekly"
9+
- package-ecosystem: "pip"
10+
directory: "/"
11+
schedule:
12+
interval: "daily"
13+
groups:
14+
# torchvision pins torch, must update in unison
15+
torch:
16+
patterns:
17+
- "torch"
18+
- "torchvision"
19+
ignore:
20+
# setuptools releases new versions almost daily
21+
- dependency-name: "setuptools"
22+
update-types: ["version-update:semver-patch"]
23+
# segmentation-models-pytorch pins timm, must update in unison
24+
- dependency-name: "timm"

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ dependencies = [
3838
"h5py>=3.10.0",
3939
"geobench>=1.0.0",
4040
"mlflow>=2.12.1",
41-
"lightning<=2.2.5"
41+
# broken due to https://github.com/Lightning-AI/pytorch-lightning/issues/19977
42+
"lightning>=2, <=2.2.5"
4243
]
4344

4445
[project.optional-dependencies]

requirements/required.txt

+1
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ lightly==1.4.25
1010
h5py==3.10.0
1111
geobench==1.0.0
1212
mlflow==2.12.1
13+
lightning==2.2.5

tests/test_finetune.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
import subprocess
77
import os
88

9-
from terratorch.models.backbones.prithvi_vit import checkpoint_filter_fn as checkpoint_filter_fn_vit
10-
from terratorch.models.backbones.prithvi_swin import checkpoint_filter_fn as checkpoint_filter_fn_swin
9+
from terratorch.cli_tools import build_lightning_cli
1110

1211
@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"])
1312
def test_finetune_multiple_backbones(model_name):
@@ -21,11 +20,24 @@ def test_finetune_multiple_backbones(model_name):
2120
torch.save(state_dict, os.path.join("tests/", model_name + ".pt"))
2221

2322
# Running the terratorch CLI
24-
command_str = f"terratorch fit -c tests/manufactured-finetune_{model_name}.yaml"
23+
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}.yaml"]
24+
_ = build_lightning_cli(command_list)
2525

26-
command_out = subprocess.run(command_str, shell=True)
26+
"""
27+
@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"])
28+
def test_finetune_multiple_backbones(model_name):
2729
28-
assert not command_out.returncode
30+
model_instance = timm.create_model(model_name)
31+
pretrained_bands = [0, 1, 2, 3, 4, 5]
32+
model_bands = [0, 1, 2, 3, 4, 5]
2933
34+
state_dict = model_instance.state_dict()
35+
36+
torch.save(state_dict, os.path.join("tests/", model_name + ".pt"))
3037
38+
# Running the terratorch CLI
39+
command_str = f"python terratorch/__main__.py fit -c tests/manufactured-finetune_{model_name}.yaml"
40+
command_out = subprocess.run(command_str, shell=True)
3141
42+
assert not command_out.returncode
43+
"""

0 commit comments

Comments
 (0)