|
12 | 12 | @pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"])
|
13 | 13 | def test_finetune_multiple_backbones(model_name):
|
14 | 14 |
|
15 |
| - # Instantiating and creating a manufactured |
16 |
| - # checkpoint just to test the finetuning pipeline |
17 |
| - if "vit" in model_name : |
18 |
| - module_str = "terratorch.models.backbones.prithvi_vit" |
19 |
| - ckpt_filter = checkpoint_filter_fn_vit |
20 |
| - elif "swin" in model_name: |
21 |
| - module_str = "terratorch.models.backbones.prithvi_swin" |
22 |
| - ckpt_filter = checkpoint_filter_fn_swin |
23 |
| - |
24 |
| - module_instance = importlib.import_module(module_str) |
25 |
| - |
26 |
| - model_template = getattr(module_instance, model_name) |
27 |
| - |
28 |
| - model_instance = model_template() |
29 |
| - |
| 15 | + model_instance = timm.create_model(model_name) |
30 | 16 | pretrained_bands = [0, 1, 2, 3, 4, 5]
|
31 | 17 | model_bands = [0, 1, 2, 3, 4, 5]
|
32 | 18 |
|
33 |
| - filtered_state_dict = ckpt_filter(model_instance.state_dict(), model_instance, pretrained_bands, model_bands) |
| 19 | + state_dict = model_instance.state_dict() |
34 | 20 |
|
35 |
| - torch.save(filtered_state_dict, os.path.join("tests/", model_name + ".pt")) |
| 21 | + torch.save(state_dict, os.path.join("tests/", model_name + ".pt")) |
36 | 22 |
|
37 | 23 | # Running the terratorch CLI
|
38 | 24 | command_str = f"terratorch fit -c tests/manufactured-finetune_{model_name}.yaml"
|
|
0 commit comments