Skip to content

Commit 710f79e

Browse files
Basic test for segmentation fine-tuning tasks
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent 10f82fa commit 710f79e

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

tests/test_finetune.py

+13
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,16 @@ def test_finetune_bands_str(model_name):
6262
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_metrics_from_file.yaml"]
6363
_ = build_lightning_cli(command_list)
6464

65+
@pytest.mark.parametrize("model_name", ["prithvi_swin_B"])
66+
def test_finetune_segmentation(model_name):
67+
68+
model_instance = timm.create_model(model_name)
69+
70+
state_dict = model_instance.state_dict()
71+
72+
torch.save(state_dict, os.path.join("tests/", model_name + ".pt"))
73+
74+
# Running the terratorch CLI
75+
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_segmentation.yaml"]
76+
_ = build_lightning_cli(command_list)
77+

0 commit comments

Comments
 (0)