Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/mac local lambda #71

Merged
merged 5 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions terratorch/datasets/generic_pixel_wise_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch import Tensor
from torchgeo.datasets import NonGeoDataset

from terratorch.datasets.utils import HLSBands, filter_valid_files, to_tensor
from terratorch.datasets.utils import HLSBands, default_transform, filter_valid_files


class GenericPixelWiseDataset(NonGeoDataset, ABC):
Expand Down Expand Up @@ -136,7 +136,7 @@ def __init__(
self.filter_indices = None

# If no transform is given, apply only to transform to torch tensor
self.transform = transform if transform else lambda **batch: to_tensor(batch)
self.transform = transform if transform else default_transform
# self.transform = transform if transform else ToTensorV2()

def __len__(self) -> int:
Expand Down Expand Up @@ -186,10 +186,6 @@ def _generate_bands_intervals(self, bands_intervals: list[int | str | HLSBands |
bands.extend(expanded_element)
else:
bands.append(element)
# check the expansion didnt result in duplicate elements
if len(set(bands)) != len(bands):
msg = "Duplicate indices detected. Indices must be unique."
raise Exception(msg)
return bands


Expand Down
4 changes: 2 additions & 2 deletions terratorch/datasets/generic_scalar_label_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from torchgeo.datasets.utils import rasterio_loader
from torchvision.datasets import ImageFolder

from terratorch.datasets.utils import HLSBands, filter_valid_files, to_tensor
from terratorch.datasets.utils import HLSBands, default_transform, filter_valid_files


class GenericScalarLabelDataset(NonGeoDataset, ImageFolder, ABC):
Expand Down Expand Up @@ -128,7 +128,7 @@ def is_valid_file(x):
else:
self.filter_indices = None
# If no transform is given, apply only to transform to torch tensor
self.transforms = transform if transform else lambda **batch: to_tensor(batch)
self.transforms = transform if transform else default_transform
# self.transform = transform if transform else ToTensorV2()

def __len__(self) -> int:
Expand Down
3 changes: 3 additions & 0 deletions terratorch/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def try_convert_to_hls_bands_enum(cls, x: Any):
except ValueError:
return x

def default_transform(**batch):
return to_tensor(batch)


def filter_valid_files(
files, valid_files: Iterator[str] | None = None, ignore_extensions: bool = False, allow_substring: bool = True
Expand Down
2 changes: 1 addition & 1 deletion tests/manufactured-finetune_prithvi_swin_B.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# lightning.pytorch==2.1.1
seed_everything: 42
trainer:
accelerator: auto
accelerator: cpu
strategy: auto
devices: auto
num_nodes: 1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# lightning.pytorch==2.1.1
seed_everything: 42
trainer:
accelerator: auto
accelerator: cpu
strategy: auto
devices: auto
num_nodes: 1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# lightning.pytorch==2.1.1
seed_everything: 42
trainer:
accelerator: auto
accelerator: cpu
strategy: auto
devices: auto
num_nodes: 1
Expand Down
2 changes: 1 addition & 1 deletion tests/manufactured-finetune_prithvi_swin_B_string.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# lightning.pytorch==2.1.1
seed_everything: 42
trainer:
accelerator: auto
accelerator: cpu
strategy: auto
devices: auto
num_nodes: 1
Expand Down
2 changes: 1 addition & 1 deletion tests/manufactured-finetune_prithvi_swin_L.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# lightning.pytorch==2.1.1
seed_everything: 42
trainer:
accelerator: auto
accelerator: cpu
strategy: auto
devices: auto
num_nodes: 1
Expand Down
3 changes: 1 addition & 2 deletions tests/manufactured-finetune_prithvi_vit_100.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# lightning.pytorch==2.1.1
seed_everything: 42
trainer:
accelerator: auto
accelerator: cpu
strategy: auto
devices: auto
num_nodes: 1
Expand Down Expand Up @@ -111,7 +111,6 @@ model:
- NIR_NARROW
- SWIR_1
- SWIR_2
num_frames: 1
head_dropout: 0.5708022831486758
head_final_act: torch.nn.ReLU
head_learned_upscale_layers: 2
Expand Down
3 changes: 1 addition & 2 deletions tests/manufactured-finetune_prithvi_vit_300.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# lightning.pytorch==2.1.1
seed_everything: 42
trainer:
accelerator: auto
accelerator: cpu
strategy: auto
devices: auto
num_nodes: 1
Expand Down Expand Up @@ -111,7 +111,6 @@ model:
- NIR_NARROW
- SWIR_1
- SWIR_2
num_frames: 1
head_dropout: 0.5708022831486758
head_final_act: torch.nn.ReLU
head_learned_upscale_layers: 2
Expand Down
55 changes: 15 additions & 40 deletions tests/test_finetune.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,39 @@
import os
import shutil

import pytest
import timm
import torch
import importlib
import terratorch
import subprocess
import os

from terratorch.cli_tools import build_lightning_cli

@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"])
def test_finetune_multiple_backbones(model_name):

@pytest.fixture(autouse=True)
def setup_and_cleanup(model_name):
model_instance = timm.create_model(model_name)
pretrained_bands = [0, 1, 2, 3, 4, 5]
model_bands = [0, 1, 2, 3, 4, 5]

state_dict = model_instance.state_dict()

torch.save(state_dict, os.path.join("tests/", model_name + ".pt"))
torch.save(state_dict, os.path.join("tests", model_name + ".pt"))

yield # everything after this runs after each test

# Running the terratorch CLI
os.remove(os.path.join("tests", model_name + ".pt"))
shutil.rmtree(os.path.join("tests", "all_ecos_random"))

@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"])
def test_finetune_multiple_backbones(model_name):
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}.yaml"]
_ = build_lightning_cli(command_list)


@pytest.mark.parametrize("model_name", ["prithvi_swin_B"])
def test_finetune_bands_intervals(model_name):

model_instance = timm.create_model(model_name)

state_dict = model_instance.state_dict()

torch.save(state_dict, os.path.join("tests/", model_name + ".pt"))

# Running the terratorch CLI
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_band_interval.yaml"]
_ = build_lightning_cli(command_list)

@pytest.mark.parametrize("model_name", ["prithvi_swin_B"])
def test_finetune_bands_str(model_name):

model_instance = timm.create_model(model_name)

state_dict = model_instance.state_dict()

torch.save(state_dict, os.path.join("tests/", model_name + ".pt"))

# Running the terratorch CLI
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_string.yaml"]
_ = build_lightning_cli(command_list)

@pytest.mark.parametrize("model_name", ["prithvi_swin_B"])
def test_finetune_bands_str(model_name):

model_instance = timm.create_model(model_name)

state_dict = model_instance.state_dict()

torch.save(state_dict, os.path.join("tests/", model_name + ".pt"))

# Running the terratorch CLI
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_metrics_from_file.yaml"]
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_string.yaml"]
_ = build_lightning_cli(command_list)

Loading