diff --git a/terratorch/datasets/generic_pixel_wise_dataset.py b/terratorch/datasets/generic_pixel_wise_dataset.py index 2c9a66d0..4bf67b80 100644 --- a/terratorch/datasets/generic_pixel_wise_dataset.py +++ b/terratorch/datasets/generic_pixel_wise_dataset.py @@ -20,7 +20,7 @@ from torch import Tensor from torchgeo.datasets import NonGeoDataset -from terratorch.datasets.utils import HLSBands, filter_valid_files, to_tensor +from terratorch.datasets.utils import HLSBands, default_transform, filter_valid_files class GenericPixelWiseDataset(NonGeoDataset, ABC): @@ -136,7 +136,7 @@ def __init__( self.filter_indices = None # If no transform is given, apply only to transform to torch tensor - self.transform = transform if transform else lambda **batch: to_tensor(batch) + self.transform = transform if transform else default_transform # self.transform = transform if transform else ToTensorV2() def __len__(self) -> int: @@ -186,10 +186,6 @@ def _generate_bands_intervals(self, bands_intervals: list[int | str | HLSBands | bands.extend(expanded_element) else: bands.append(element) - # check the expansion didnt result in duplicate elements - if len(set(bands)) != len(bands): - msg = "Duplicate indices detected. Indices must be unique." - raise Exception(msg) return bands diff --git a/terratorch/datasets/generic_scalar_label_dataset.py b/terratorch/datasets/generic_scalar_label_dataset.py index bd82e3b0..f3255fc2 100644 --- a/terratorch/datasets/generic_scalar_label_dataset.py +++ b/terratorch/datasets/generic_scalar_label_dataset.py @@ -26,7 +26,7 @@ from torchgeo.datasets.utils import rasterio_loader from torchvision.datasets import ImageFolder -from terratorch.datasets.utils import HLSBands, filter_valid_files, to_tensor +from terratorch.datasets.utils import HLSBands, default_transform, filter_valid_files class GenericScalarLabelDataset(NonGeoDataset, ImageFolder, ABC): @@ -128,7 +128,7 @@ def is_valid_file(x): else: self.filter_indices = None # If no transform is given, apply only to transform to torch tensor - self.transforms = transform if transform else lambda **batch: to_tensor(batch) + self.transforms = transform if transform else default_transform # self.transform = transform if transform else ToTensorV2() def __len__(self) -> int: diff --git a/terratorch/datasets/utils.py b/terratorch/datasets/utils.py index 0dee447e..0d4065a8 100644 --- a/terratorch/datasets/utils.py +++ b/terratorch/datasets/utils.py @@ -34,6 +34,9 @@ def try_convert_to_hls_bands_enum(cls, x: Any): except ValueError: return x +def default_transform(**batch): + return to_tensor(batch) + def filter_valid_files( files, valid_files: Iterator[str] | None = None, ignore_extensions: bool = False, allow_substring: bool = True diff --git a/tests/manufactured-finetune_prithvi_swin_B.yaml b/tests/manufactured-finetune_prithvi_swin_B.yaml index b7498d1b..03cd7ea7 100644 --- a/tests/manufactured-finetune_prithvi_swin_B.yaml +++ b/tests/manufactured-finetune_prithvi_swin_B.yaml @@ -1,7 +1,7 @@ # lightning.pytorch==2.1.1 seed_everything: 42 trainer: - accelerator: auto + accelerator: cpu strategy: auto devices: auto num_nodes: 1 diff --git a/tests/manufactured-finetune_prithvi_swin_B_band_interval.yaml b/tests/manufactured-finetune_prithvi_swin_B_band_interval.yaml index 8697cd63..61685471 100644 --- a/tests/manufactured-finetune_prithvi_swin_B_band_interval.yaml +++ b/tests/manufactured-finetune_prithvi_swin_B_band_interval.yaml @@ -1,7 +1,7 @@ # lightning.pytorch==2.1.1 seed_everything: 42 trainer: - accelerator: auto + accelerator: cpu strategy: auto devices: auto num_nodes: 1 diff --git a/tests/manufactured-finetune_prithvi_swin_B_metrics_from_file.yaml b/tests/manufactured-finetune_prithvi_swin_B_metrics_from_file.yaml index 91a72a3c..edce7f91 100644 --- a/tests/manufactured-finetune_prithvi_swin_B_metrics_from_file.yaml +++ b/tests/manufactured-finetune_prithvi_swin_B_metrics_from_file.yaml @@ -1,7 +1,7 @@ # lightning.pytorch==2.1.1 seed_everything: 42 trainer: - accelerator: auto + accelerator: cpu strategy: auto devices: auto num_nodes: 1 diff --git a/tests/manufactured-finetune_prithvi_swin_B_string.yaml b/tests/manufactured-finetune_prithvi_swin_B_string.yaml index a7aa84c2..cac4d4f7 100644 --- a/tests/manufactured-finetune_prithvi_swin_B_string.yaml +++ b/tests/manufactured-finetune_prithvi_swin_B_string.yaml @@ -1,7 +1,7 @@ # lightning.pytorch==2.1.1 seed_everything: 42 trainer: - accelerator: auto + accelerator: cpu strategy: auto devices: auto num_nodes: 1 diff --git a/tests/manufactured-finetune_prithvi_swin_L.yaml b/tests/manufactured-finetune_prithvi_swin_L.yaml index 8619ffbf..3908d55e 100644 --- a/tests/manufactured-finetune_prithvi_swin_L.yaml +++ b/tests/manufactured-finetune_prithvi_swin_L.yaml @@ -1,7 +1,7 @@ # lightning.pytorch==2.1.1 seed_everything: 42 trainer: - accelerator: auto + accelerator: cpu strategy: auto devices: auto num_nodes: 1 diff --git a/tests/manufactured-finetune_prithvi_vit_100.yaml b/tests/manufactured-finetune_prithvi_vit_100.yaml index 8ee70a9c..7ebf0559 100644 --- a/tests/manufactured-finetune_prithvi_vit_100.yaml +++ b/tests/manufactured-finetune_prithvi_vit_100.yaml @@ -1,7 +1,7 @@ # lightning.pytorch==2.1.1 seed_everything: 42 trainer: - accelerator: auto + accelerator: cpu strategy: auto devices: auto num_nodes: 1 @@ -111,7 +111,6 @@ model: - NIR_NARROW - SWIR_1 - SWIR_2 - num_frames: 1 head_dropout: 0.5708022831486758 head_final_act: torch.nn.ReLU head_learned_upscale_layers: 2 diff --git a/tests/manufactured-finetune_prithvi_vit_300.yaml b/tests/manufactured-finetune_prithvi_vit_300.yaml index 1994f0d1..cac7291d 100644 --- a/tests/manufactured-finetune_prithvi_vit_300.yaml +++ b/tests/manufactured-finetune_prithvi_vit_300.yaml @@ -1,7 +1,7 @@ # lightning.pytorch==2.1.1 seed_everything: 42 trainer: - accelerator: auto + accelerator: cpu strategy: auto devices: auto num_nodes: 1 @@ -111,7 +111,6 @@ model: - NIR_NARROW - SWIR_1 - SWIR_2 - num_frames: 1 head_dropout: 0.5708022831486758 head_final_act: torch.nn.ReLU head_learned_upscale_layers: 2 diff --git a/tests/test_finetune.py b/tests/test_finetune.py index 8592e639..ff279a4e 100644 --- a/tests/test_finetune.py +++ b/tests/test_finetune.py @@ -1,64 +1,39 @@ +import os +import shutil + import pytest import timm import torch -import importlib -import terratorch -import subprocess -import os from terratorch.cli_tools import build_lightning_cli -@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"]) -def test_finetune_multiple_backbones(model_name): +@pytest.fixture(autouse=True) +def setup_and_cleanup(model_name): model_instance = timm.create_model(model_name) - pretrained_bands = [0, 1, 2, 3, 4, 5] - model_bands = [0, 1, 2, 3, 4, 5] state_dict = model_instance.state_dict() - torch.save(state_dict, os.path.join("tests/", model_name + ".pt")) + torch.save(state_dict, os.path.join("tests", model_name + ".pt")) + + yield # everything after this runs after each test - # Running the terratorch CLI + os.remove(os.path.join("tests", model_name + ".pt")) + shutil.rmtree(os.path.join("tests", "all_ecos_random")) + +@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"]) +def test_finetune_multiple_backbones(model_name): command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}.yaml"] _ = build_lightning_cli(command_list) + @pytest.mark.parametrize("model_name", ["prithvi_swin_B"]) def test_finetune_bands_intervals(model_name): - - model_instance = timm.create_model(model_name) - - state_dict = model_instance.state_dict() - - torch.save(state_dict, os.path.join("tests/", model_name + ".pt")) - - # Running the terratorch CLI command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_band_interval.yaml"] _ = build_lightning_cli(command_list) -@pytest.mark.parametrize("model_name", ["prithvi_swin_B"]) -def test_finetune_bands_str(model_name): - model_instance = timm.create_model(model_name) - - state_dict = model_instance.state_dict() - - torch.save(state_dict, os.path.join("tests/", model_name + ".pt")) - - # Running the terratorch CLI - command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_string.yaml"] - _ = build_lightning_cli(command_list) - @pytest.mark.parametrize("model_name", ["prithvi_swin_B"]) def test_finetune_bands_str(model_name): - - model_instance = timm.create_model(model_name) - - state_dict = model_instance.state_dict() - - torch.save(state_dict, os.path.join("tests/", model_name + ".pt")) - - # Running the terratorch CLI - command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_metrics_from_file.yaml"] + command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_string.yaml"] _ = build_lightning_cli(command_list) -