Skip to content

Commit 8b9111c

Browse files
The backbone instantiation can be simplified with timm
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent 52e15f8 commit 8b9111c

File tree

1 file changed

+3
-17
lines changed

1 file changed

+3
-17
lines changed

tests/test_finetune.py

+3-17
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,13 @@
1212
@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"])
1313
def test_finetune_multiple_backbones(model_name):
1414

15-
# Instantiating and creating a manufactured
16-
# checkpoint just to test the finetuning pipeline
17-
if "vit" in model_name :
18-
module_str = "terratorch.models.backbones.prithvi_vit"
19-
ckpt_filter = checkpoint_filter_fn_vit
20-
elif "swin" in model_name:
21-
module_str = "terratorch.models.backbones.prithvi_swin"
22-
ckpt_filter = checkpoint_filter_fn_swin
23-
24-
module_instance = importlib.import_module(module_str)
25-
26-
model_template = getattr(module_instance, model_name)
27-
28-
model_instance = model_template()
29-
15+
model_instance = timm.create_model(model_name)
3016
pretrained_bands = [0, 1, 2, 3, 4, 5]
3117
model_bands = [0, 1, 2, 3, 4, 5]
3218

33-
filtered_state_dict = ckpt_filter(model_instance.state_dict(), model_instance, pretrained_bands, model_bands)
19+
state_dict = model_instance.state_dict()
3420

35-
torch.save(filtered_state_dict, os.path.join("tests/", model_name + ".pt"))
21+
torch.save(state_dict, os.path.join("tests/", model_name + ".pt"))
3622

3723
# Running the terratorch CLI
3824
command_str = f"terratorch fit -c tests/manufactured-finetune_{model_name}.yaml"

0 commit comments

Comments
 (0)