Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve/bands definition #54

Merged
merged 22 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
63d4733
Bands could be define by intervals
Joao-L-S-Almeida Jul 18, 2024
8731019
Constructing the bands using the definition by interval
Joao-L-S-Almeida Jul 19, 2024
1dd6650
Extending the supported formats for bands to include list[int]
Joao-L-S-Almeida Jul 19, 2024
a0ce8aa
Extending the supported formats for bands to include list[int]
Joao-L-S-Almeida Jul 19, 2024
295128f
Testing the definition by interval using a dedicated yaml file
Joao-L-S-Almeida Jul 19, 2024
64dcf5d
Special case for bands_list=:None
Joao-L-S-Almeida Jul 19, 2024
e48965f
Basic support to use simple strings to name the bands
Joao-L-S-Almeida Jul 22, 2024
5dba481
Strings are allowed to define bands
Joao-L-S-Almeida Jul 22, 2024
174f2f1
Testing to use strings to define a model
Joao-L-S-Almeida Jul 22, 2024
989bf80
Exception for None inputs
Joao-L-S-Almeida Jul 22, 2024
831c662
Support for str
Joao-L-S-Almeida Jul 22, 2024
de533dd
YAML file for testing string as bands
Joao-L-S-Almeida Jul 22, 2024
7a94a82
This test is no longer required
Joao-L-S-Almeida Jul 22, 2024
0a84c42
Band intervals should be tuples with two entries
Joao-L-S-Almeida Jul 23, 2024
2888252
More compact way to check if the bands are defined by interval
Joao-L-S-Almeida Jul 23, 2024
303bfe8
This warning is not necessary
Joao-L-S-Almeida Jul 23, 2024
25b10c6
Reformatting using black
Joao-L-S-Almeida Jul 23, 2024
10551a0
Minor improvements
Joao-L-S-Almeida Jul 23, 2024
aedca32
Missing imports
Joao-L-S-Almeida Jul 23, 2024
5708e0a
More tests to check if the bands ar properly returned
Joao-L-S-Almeida Jul 24, 2024
3faf7d3
accept mixed band specifications
CarlosGomes98 Jul 26, 2024
1a254e4
improve docstring comments
CarlosGomes98 Jul 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions terratorch/datamodules/generic_pixel_wise_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def __init__(
test_split: Path | None = None,
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
dataset_bands: list[HLSBands | int] | None = None,
predict_dataset_bands: list[HLSBands | int] | None = None,
output_bands: list[HLSBands | int] | None = None,
dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
constant_scale: float = 1,
rgb_indices: list[int] | None = None,
train_transform: A.Compose | None | list[A.BasicTransform] = None,
Expand Down Expand Up @@ -330,9 +330,9 @@ def __init__(
test_split: Path | None = None,
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
dataset_bands: list[HLSBands | int] | None = None,
predict_dataset_bands: list[HLSBands | int] | None = None,
output_bands: list[HLSBands | int] | None = None,
dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
constant_scale: float = 1,
rgb_indices: list[int] | None = None,
train_transform: A.Compose | None | list[A.BasicTransform] = None,
Expand Down
61 changes: 43 additions & 18 deletions terratorch/datasets/generic_pixel_wise_dataset.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
# Copyright contributors to the Terratorch project

"""Module containing generic dataset classes
"""
"""Module containing generic dataset classes"""

import glob
import os
from abc import ABC
from functools import partial
from pathlib import Path
from typing import Any

import albumentations as A
import matplotlib as mpl
import numpy as np
import rioxarray
import torch
import xarray as xr
from albumentations.pytorch import ToTensorV2
from einops import rearrange
from matplotlib import cm
from matplotlib import pyplot as plt
from matplotlib.figure import Figure
from matplotlib.patches import Rectangle
Expand All @@ -43,8 +39,8 @@ def __init__(
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
rgb_indices: list[int] | None = None,
dataset_bands: list[HLSBands | int] | None = None,
output_bands: list[HLSBands | int] | None = None,
dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
constant_scale: float = 1,
transform: A.Compose | None = None,
no_data_replace: float | None = None,
Expand Down Expand Up @@ -73,8 +69,8 @@ def __init__(
that must be present in file names to be included (as in mmsegmentation), or exact
matches (e.g. eurosat). Defaults to True.
rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None.
output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the dataset as named by dataset_bands.
constant_scale (float): Factor to multiply image values by. Defaults to 1.
transform (Albumentations.Compose | None): Albumentations transform to be applied.
Should end with ToTensorV2(). If used through the generic_data_module,
Expand All @@ -88,6 +84,7 @@ def __init__(
expected 0. Defaults to False.
"""
super().__init__()

self.split_file = split

label_data_root = label_data_root if label_data_root is not None else data_root
Expand Down Expand Up @@ -120,19 +117,24 @@ def __init__(
)
self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices

self.dataset_bands = dataset_bands
self.output_bands = output_bands
self.dataset_bands = self._generate_bands_intervals(dataset_bands)
self.output_bands = self._generate_bands_intervals(output_bands)

if self.output_bands and not self.dataset_bands:
msg = "If output bands provided, dataset_bands must also be provided"
return Exception(msg) # noqa: PLE0101

# There is a special condition if the bands are defined as simple strings.
if self.output_bands:
if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
msg = "Output bands must be a subset of dataset bands"
raise Exception(msg)

self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]

else:
self.filter_indices = None

# If no transform is given, apply only to transform to torch tensor
self.transform = transform if transform else lambda **batch: to_tensor(batch)
# self.transform = transform if transform else ToTensorV2()
Expand All @@ -141,7 +143,7 @@ def __len__(self) -> int:
return len(self.image_files)

def __getitem__(self, index: int) -> dict[str, Any]:
image = self._load_file(self.image_files[index], nan_replace = self.no_data_replace).to_numpy()
image = self._load_file(self.image_files[index], nan_replace=self.no_data_replace).to_numpy()
# to channels last
if self.expand_temporal_dimension:
image = rearrange(image, "(channels time) h w -> channels time h w", channels=len(self.output_bands))
Expand All @@ -151,9 +153,12 @@ def __getitem__(self, index: int) -> dict[str, Any]:
image = image[..., self.filter_indices]
output = {
"image": image.astype(np.float32) * self.constant_scale,
"mask": self._load_file(self.segmentation_mask_files[index], nan_replace = self.no_label_replace).to_numpy()[0],
"mask": self._load_file(self.segmentation_mask_files[index], nan_replace=self.no_label_replace).to_numpy()[
0
],
"filename": self.image_files[index],
}

if self.reduce_zero_label:
output["mask"] -= 1
if self.transform:
Expand All @@ -166,6 +171,26 @@ def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArr
data = data.fillna(nan_replace)
return data

def _generate_bands_intervals(self, bands_intervals: list[int | str | HLSBands | tuple[int]] | None = None):
if bands_intervals is None:
return None
bands = []
for element in bands_intervals:
# if its an interval
if isinstance(element, tuple):
if len(element) != 2: # noqa: PLR2004
msg = "When defining an interval, a tuple of two integers should be passed, defining start and end indices inclusive"
raise Exception(msg)
expanded_element = list(range(element[0], element[1] + 1))
bands.extend(expanded_element)
else:
bands.append(element)
# check the expansion didnt result in duplicate elements
if len(set(bands)) != len(bands):
msg = "Duplicate indices detected. Indices must be unique."
raise Exception(msg)
return bands


class GenericNonGeoSegmentationDataset(GenericPixelWiseDataset):
"""GenericNonGeoSegmentationDataset"""
Expand All @@ -181,8 +206,8 @@ def __init__(
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
rgb_indices: list[str] | None = None,
dataset_bands: list[HLSBands | int] | None = None,
output_bands: list[HLSBands | int] | None = None,
dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
class_names: list[str] | None = None,
constant_scale: float = 1,
transform: A.Compose | None = None,
Expand Down Expand Up @@ -348,8 +373,8 @@ def __init__(
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
rgb_indices: list[int] | None = None,
dataset_bands: list[HLSBands | int] | None = None,
output_bands: list[HLSBands | int] | None = None,
dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
constant_scale: float = 1,
transform: A.Compose | None = None,
no_data_replace: float | None = None,
Expand Down
16 changes: 10 additions & 6 deletions terratorch/datasets/generic_scalar_label_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def __init__(
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
rgb_indices: list[int] | None = None,
dataset_bands: list[HLSBands | int] | None = None,
output_bands: list[HLSBands | int] | None = None,
dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
constant_scale: float = 1,
transform: A.Compose | None = None,
no_data_replace: float = 0,
Expand All @@ -64,8 +64,8 @@ def __init__(
that must be present in file names to be included (as in mmsegmentation), or exact
matches (e.g. eurosat). Defaults to True.
rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None.
output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the dataset as named by dataset_bands.
constant_scale (float): Factor to multiply image values by. Defaults to 1.
transform (Albumentations.Compose | None): Albumentations transform to be applied.
Should end with ToTensorV2(). If used through the generic_data_module,
Expand Down Expand Up @@ -110,17 +110,21 @@ def is_valid_file(x):

self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices

self.dataset_bands = dataset_bands
self.output_bands = output_bands
self.dataset_bands = self._generate_bands_intervals(dataset_bands)
self.output_bands = self._generate_bands_intervals(output_bands)

if self.output_bands and not self.dataset_bands:
msg = "If output bands provided, dataset_bands must also be provided"
return Exception(msg) # noqa: PLE0101

# There is a special condition if the bands are defined as simple strings.
if self.output_bands:
if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
msg = "Output bands must be a subset of dataset bands"
raise Exception(msg)

self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]

else:
self.filter_indices = None
# If no transform is given, apply only to transform to torch tensor
Expand Down
136 changes: 136 additions & 0 deletions tests/manufactured-finetune_prithvi_swin_B_band_interval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# lightning.pytorch==2.1.1
seed_everything: 42
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
# precision: 16-mixed
logger:
class_path: TensorBoardLogger
init_args:
save_dir: tests/
name: all_ecos_random
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: EarlyStopping
init_args:
monitor: val/loss
patience: 100
max_epochs: 5
check_val_every_n_epoch: 1
log_every_n_steps: 20
enable_checkpointing: true
default_root_dir: tests/
data:
class_path: GenericNonGeoPixelwiseRegressionDataModule
init_args:
batch_size: 2
num_workers: 4
train_transform:
- class_path: albumentations.HorizontalFlip
init_args:
p: 0.5
- class_path: albumentations.Rotate
init_args:
limit: 30
border_mode: 0 # cv2.BORDER_CONSTANT
value: 0
# mask_value: 1
p: 0.5
- class_path: ToTensorV2
dataset_bands:
- [0, 11]
output_bands:
- [1, 3]
- [4, 6]
rgb_indices:
- 2
- 1
- 0
train_data_root: tests/
train_label_data_root: tests/
val_data_root: tests/
val_label_data_root: tests/
test_data_root: tests/
test_label_data_root: tests/
img_grep: "regression*input*.tif"
label_grep: "regression*label*.tif"
means:
- 547.36707
- 898.5121
- 1020.9082
- 2665.5352
- 2340.584
- 1610.1407
stds:
- 411.4701
- 558.54065
- 815.94025
- 812.4403
- 1113.7145
- 1067.641
no_label_replace: -1
no_data_replace: 0

model:
class_path: terratorch.tasks.PixelwiseRegressionTask
init_args:
model_args:
decoder: UperNetDecoder
pretrained: true
backbone: prithvi_swin_B
backbone_pretrained_cfg_overlay:
file: tests/prithvi_swin_B.pt
backbone_drop_path_rate: 0.3
# backbone_window_size: 8
decoder_channels: 256
in_channels: 6
bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
num_frames: 1
head_dropout: 0.5708022831486758
head_final_act: torch.nn.ReLU
head_learned_upscale_layers: 2
loss: rmse
#aux_heads:
# - name: aux_head
# decoder: IdentityDecoder
# decoder_args:
# decoder_out_index: 2
# head_dropout: 0,5
# head_channel_list:
# - 64
# head_final_act: torch.nn.ReLU
#aux_loss:
# aux_head: 0.4
ignore_index: -1
freeze_backbone: true
freeze_decoder: false
model_factory: PrithviModelFactory

# uncomment this block for tiled inference
# tiled_inference_parameters:
# h_crop: 224
# h_stride: 192
# w_crop: 224
# w_stride: 192
# average_patches: true
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.00013524680528283027
weight_decay: 0.047782217873995426
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss

Loading
Loading