diff --git a/docs/models.md b/docs/models.md index 33b27008..1f5627a6 100644 --- a/docs/models.md +++ b/docs/models.md @@ -61,6 +61,7 @@ By passing a list of bands being used to the constructor, we automatically filte ## Model Factory ### :::terratorch.models.PrithviModelFactory +### :::terratorch.models.SMPModelFactory # Adding new model types Adding new model types is as simple as creating a new factory that produces models. See for instance the example below for a potential `SMPModelFactory` diff --git a/examples/confs/smp_model_factory.yaml b/examples/confs/smp_model_factory.yaml new file mode 100644 index 00000000..ea8d9cb1 --- /dev/null +++ b/examples/confs/smp_model_factory.yaml @@ -0,0 +1,93 @@ +benchmark_suffix: smp_test +experiment_name: smp_test +backbone: + # backbone: resnet18 + # backbone_args: + # pretrained: False + # output_stride: 2 + # smp_decoder_channels: 512 + # smp_encoder_depth: 5 + + backbone: swin3d.swin3d_backbone.Swin3dBackbone + backbone_args: + pretrained: False + output_stride: 2 + out_channels: + - 192 + - 384 + - 768 + - 768 + smp_decoder_channels: 768 + smp_encoder_depth: 5 + + +tasks: + - name: cashew + type: segmentation + loss: ce + model_factory: SMPModelFactory + bands: + - RED + - GREEN + - BLUE + num_classes: 7 + max_epochs: 60 + direction: max + datamodule: + class_path: terratorch.datamodules.MBeninSmallHolderCashewsNonGeoDataModule + init_args: + batch_size: 16 + num_workers: 4 + train_transform: + - class_path: albumentations.Resize + init_args: + always_apply: True + height: 224 + width: 224 + - class_path: ToTensorV2 + test_transform: + - class_path: albumentations.Resize + init_args: + always_apply: True + height: 224 + width: 224 + - class_path: ToTensorV2 + val_transform: + - class_path: albumentations.Resize + init_args: + height: 224 + width: 224 + - class_path: ToTensorV2 + data_root: "/dccstor/geofm-finetuning/geobench/segmentation_v1.0" + bands: + - "RED" + - "GREEN" + - "BLUE" + decoder: Unet + decoder_args: + channels: 128 + metric: val/Multiclass Jaccard Index + +n_trials: 16 +save_models: False +storage_uri: /dccstor/geofm-finetuning/pedrohc/smp_test +optimization_space: + model: + - DeepLabV3 + lr: + min: 6e-5 + max: 1e-3 + type: real + log: true + batch_size: + - 8 + - 16 + - 32 + decoder_channels: + - 32 + - 64 + - 128 + head_dropout: + min: 0.2 + max: 0.8 + type: real diff --git a/pyproject.toml b/pyproject.toml index d18c6bef..d04e5901 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,9 @@ dependencies = [ "geobench>=1.0.0", "mlflow>=2.12.1", # broken due to https://github.com/Lightning-AI/pytorch-lightning/issues/19977 - "lightning>=2, <=2.2.5" + "lightning>=2, <=2.2.5", + # see issue #64 + "albumentations<=1.4.10" ] [project.optional-dependencies] diff --git a/terratorch/datamodules/generic_pixel_wise_data_module.py b/terratorch/datamodules/generic_pixel_wise_data_module.py index 16c4a31c..434f7488 100644 --- a/terratorch/datamodules/generic_pixel_wise_data_module.py +++ b/terratorch/datamodules/generic_pixel_wise_data_module.py @@ -3,11 +3,11 @@ """ This module contains generic data modules for instantiation at runtime. """ - +import os from collections.abc import Callable, Iterable from pathlib import Path from typing import Any - +import numpy as np import albumentations as A import kornia.augmentation as K import torch @@ -17,7 +17,7 @@ from torchgeo.transforms import AugmentationSequential from terratorch.datasets import GenericNonGeoPixelwiseRegressionDataset, GenericNonGeoSegmentationDataset, HLSBands - +from terratorch.io.file import load_from_file_or_attribute def wrap_in_compose_is_list(transform_list): # set check shapes to false because of the multitemporal case @@ -79,8 +79,8 @@ def __init__( test_data_root: Path, img_grep: str, label_grep: str, - means: list[float], - stds: list[float], + means: list[float] | str, + stds: list[float] | str, num_classes: int, predict_data_root: Path | None = None, train_label_data_root: Path | None = None, @@ -91,9 +91,9 @@ def __init__( test_split: Path | None = None, ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, - dataset_bands: list[HLSBands | int] | None = None, - predict_dataset_bands: list[HLSBands | int] | None = None, - output_bands: list[HLSBands | int] | None = None, + dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, + predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, + output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, constant_scale: float = 1, rgb_indices: list[int] | None = None, train_transform: A.Compose | None | list[A.BasicTransform] = None, @@ -198,6 +198,9 @@ def __init__( # K.Normalize(means, stds), # data_keys=["image"], # ) + means = load_from_file_or_attribute(means) + stds = load_from_file_or_attribute(stds) + self.aug = Normalize(means, stds) # self.aug = Normalize(means, stds) @@ -317,8 +320,8 @@ def __init__( train_data_root: Path, val_data_root: Path, test_data_root: Path, - means: list[float], - stds: list[float], + means: list[float] | str, + stds: list[float] | str, predict_data_root: Path | None = None, img_grep: str | None = "*", label_grep: str | None = "*", @@ -330,9 +333,9 @@ def __init__( test_split: Path | None = None, ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, - dataset_bands: list[HLSBands | int] | None = None, - predict_dataset_bands: list[HLSBands | int] | None = None, - output_bands: list[HLSBands | int] | None = None, + dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, + predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, + output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, constant_scale: float = 1, rgb_indices: list[int] | None = None, train_transform: A.Compose | None | list[A.BasicTransform] = None, @@ -430,6 +433,9 @@ def __init__( # K.Normalize(means, stds), # data_keys=["image"], # ) + means = load_from_file_or_attribute(means) + stds = load_from_file_or_attribute(stds) + self.aug = Normalize(means, stds) self.no_data_replace = no_data_replace self.no_label_replace = no_label_replace diff --git a/terratorch/datamodules/generic_scalar_label_data_module.py b/terratorch/datamodules/generic_scalar_label_data_module.py index 5cd7470a..71f75533 100644 --- a/terratorch/datamodules/generic_scalar_label_data_module.py +++ b/terratorch/datamodules/generic_scalar_label_data_module.py @@ -22,12 +22,12 @@ HLSBands, ) +from terratorch.io.file import load_from_file_or_attribute def wrap_in_compose_is_list(transform_list): # set check shapes to false because of the multitemporal case return A.Compose(transform_list, is_check_shapes=False) if isinstance(transform_list, Iterable) else transform_list - class Normalize(Callable): def __init__(self, means, stds): super().__init__() @@ -68,8 +68,8 @@ def __init__( train_data_root: Path, val_data_root: Path, test_data_root: Path, - means: list[float], - stds: list[float], + means: list[float] | str, + stds: list[float] | str, num_classes: int, predict_data_root: Path | None = None, train_split: Path | None = None, @@ -166,6 +166,10 @@ def __init__( # K.Normalize(means, stds), # data_keys=["image"], # ) + + means = load_from_file_or_attribute(means) + stds = load_from_file_or_attribute(stds) + self.aug = Normalize(means, stds) # self.aug = Normalize(means, stds) diff --git a/terratorch/datasets/generic_pixel_wise_dataset.py b/terratorch/datasets/generic_pixel_wise_dataset.py index 083a2241..2c9a66d0 100644 --- a/terratorch/datasets/generic_pixel_wise_dataset.py +++ b/terratorch/datasets/generic_pixel_wise_dataset.py @@ -1,11 +1,10 @@ # Copyright contributors to the Terratorch project -"""Module containing generic dataset classes -""" +"""Module containing generic dataset classes""" + import glob import os from abc import ABC -from functools import partial from pathlib import Path from typing import Any @@ -13,11 +12,8 @@ import matplotlib as mpl import numpy as np import rioxarray -import torch import xarray as xr -from albumentations.pytorch import ToTensorV2 from einops import rearrange -from matplotlib import cm from matplotlib import pyplot as plt from matplotlib.figure import Figure from matplotlib.patches import Rectangle @@ -43,8 +39,8 @@ def __init__( ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, rgb_indices: list[int] | None = None, - dataset_bands: list[HLSBands | int] | None = None, - output_bands: list[HLSBands | int] | None = None, + dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, + output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, constant_scale: float = 1, transform: A.Compose | None = None, no_data_replace: float | None = None, @@ -73,8 +69,8 @@ def __init__( that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True. rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2]. - dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. - output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset. + dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None. + output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the dataset as named by dataset_bands. constant_scale (float): Factor to multiply image values by. Defaults to 1. transform (Albumentations.Compose | None): Albumentations transform to be applied. Should end with ToTensorV2(). If used through the generic_data_module, @@ -88,6 +84,7 @@ def __init__( expected 0. Defaults to False. """ super().__init__() + self.split_file = split label_data_root = label_data_root if label_data_root is not None else data_root @@ -120,19 +117,24 @@ def __init__( ) self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices - self.dataset_bands = dataset_bands - self.output_bands = output_bands + self.dataset_bands = self._generate_bands_intervals(dataset_bands) + self.output_bands = self._generate_bands_intervals(output_bands) + if self.output_bands and not self.dataset_bands: msg = "If output bands provided, dataset_bands must also be provided" return Exception(msg) # noqa: PLE0101 + # There is a special condition if the bands are defined as simple strings. if self.output_bands: if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands): msg = "Output bands must be a subset of dataset bands" raise Exception(msg) + self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands] + else: self.filter_indices = None + # If no transform is given, apply only to transform to torch tensor self.transform = transform if transform else lambda **batch: to_tensor(batch) # self.transform = transform if transform else ToTensorV2() @@ -141,7 +143,7 @@ def __len__(self) -> int: return len(self.image_files) def __getitem__(self, index: int) -> dict[str, Any]: - image = self._load_file(self.image_files[index], nan_replace = self.no_data_replace).to_numpy() + image = self._load_file(self.image_files[index], nan_replace=self.no_data_replace).to_numpy() # to channels last if self.expand_temporal_dimension: image = rearrange(image, "(channels time) h w -> channels time h w", channels=len(self.output_bands)) @@ -151,13 +153,17 @@ def __getitem__(self, index: int) -> dict[str, Any]: image = image[..., self.filter_indices] output = { "image": image.astype(np.float32) * self.constant_scale, - "mask": self._load_file(self.segmentation_mask_files[index], nan_replace = self.no_label_replace).to_numpy()[0], - "filename": self.image_files[index], + "mask": self._load_file(self.segmentation_mask_files[index], nan_replace=self.no_label_replace).to_numpy()[ + 0 + ] } + if self.reduce_zero_label: output["mask"] -= 1 if self.transform: output = self.transform(**output) + output["filename"] = self.image_files[index] + return output def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArray: @@ -166,6 +172,26 @@ def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArr data = data.fillna(nan_replace) return data + def _generate_bands_intervals(self, bands_intervals: list[int | str | HLSBands | tuple[int]] | None = None): + if bands_intervals is None: + return None + bands = [] + for element in bands_intervals: + # if its an interval + if isinstance(element, tuple): + if len(element) != 2: # noqa: PLR2004 + msg = "When defining an interval, a tuple of two integers should be passed, defining start and end indices inclusive" + raise Exception(msg) + expanded_element = list(range(element[0], element[1] + 1)) + bands.extend(expanded_element) + else: + bands.append(element) + # check the expansion didnt result in duplicate elements + if len(set(bands)) != len(bands): + msg = "Duplicate indices detected. Indices must be unique." + raise Exception(msg) + return bands + class GenericNonGeoSegmentationDataset(GenericPixelWiseDataset): """GenericNonGeoSegmentationDataset""" @@ -181,8 +207,8 @@ def __init__( ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, rgb_indices: list[str] | None = None, - dataset_bands: list[HLSBands | int] | None = None, - output_bands: list[HLSBands | int] | None = None, + dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, + output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, class_names: list[str] | None = None, constant_scale: float = 1, transform: A.Compose | None = None, @@ -348,8 +374,8 @@ def __init__( ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, rgb_indices: list[int] | None = None, - dataset_bands: list[HLSBands | int] | None = None, - output_bands: list[HLSBands | int] | None = None, + dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, + output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, constant_scale: float = 1, transform: A.Compose | None = None, no_data_replace: float | None = None, diff --git a/terratorch/datasets/generic_scalar_label_dataset.py b/terratorch/datasets/generic_scalar_label_dataset.py index 85b16a75..bd82e3b0 100644 --- a/terratorch/datasets/generic_scalar_label_dataset.py +++ b/terratorch/datasets/generic_scalar_label_dataset.py @@ -42,8 +42,8 @@ def __init__( ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, rgb_indices: list[int] | None = None, - dataset_bands: list[HLSBands | int] | None = None, - output_bands: list[HLSBands | int] | None = None, + dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, + output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, constant_scale: float = 1, transform: A.Compose | None = None, no_data_replace: float = 0, @@ -64,8 +64,8 @@ def __init__( that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True. rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2]. - dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. - output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset. + dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None. + output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the dataset as named by dataset_bands. constant_scale (float): Factor to multiply image values by. Defaults to 1. transform (Albumentations.Compose | None): Albumentations transform to be applied. Should end with ToTensorV2(). If used through the generic_data_module, @@ -110,17 +110,21 @@ def is_valid_file(x): self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices - self.dataset_bands = dataset_bands - self.output_bands = output_bands + self.dataset_bands = self._generate_bands_intervals(dataset_bands) + self.output_bands = self._generate_bands_intervals(output_bands) + if self.output_bands and not self.dataset_bands: msg = "If output bands provided, dataset_bands must also be provided" return Exception(msg) # noqa: PLE0101 + # There is a special condition if the bands are defined as simple strings. if self.output_bands: if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands): msg = "Output bands must be a subset of dataset bands" raise Exception(msg) + self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands] + else: self.filter_indices = None # If no transform is given, apply only to transform to torch tensor @@ -139,13 +143,12 @@ def __getitem__(self, index: int) -> dict[str, Any]: output = { "image": image.astype(np.float32) * self.constant_scale, - "label": label, - "filename": self.samples[index][ - 0 - ], # samples is an attribute of ImageFolder. Contains a tuple of (Path, Target) + "label": label, # samples is an attribute of ImageFolder. Contains a tuple of (Path, Target) } if self.transforms: output = self.transforms(**output) + output["filename"] = self.image_files[index] + return output def _load_file(self, path) -> xr.DataArray: diff --git a/terratorch/io/file.py b/terratorch/io/file.py index bb8ef9fc..6cab0acd 100644 --- a/terratorch/io/file.py +++ b/terratorch/io/file.py @@ -1,6 +1,7 @@ import os import importlib from torch import nn +import numpy as np def open_generic_torch_model(model: type | str = None, model_kwargs: dict = None, @@ -51,3 +52,21 @@ def load_torch_weights(model:nn.Module=None, save_dir: str = None, name: str = N ) return model + +def load_from_file_or_attribute(value: list[float]|str): + + if isinstance(value, list): + return value + elif isinstance(value, str): # It can be the path for a file + if os.path.isfile(value): + try: + print(value) + content = np.genfromtxt(value).tolist() + except: + raise Exception(f"File must be txt, but received {value}") + else: + raise Exception(f"The input {value} does not exist or is not a file.") + + return content + + diff --git a/terratorch/models/decoders/upernet_decoder.py b/terratorch/models/decoders/upernet_decoder.py index 90d2973c..5db0f188 100644 --- a/terratorch/models/decoders/upernet_decoder.py +++ b/terratorch/models/decoders/upernet_decoder.py @@ -1,5 +1,3 @@ -# Copyright contributors to the Terratorch project - import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn @@ -7,8 +5,6 @@ """ Adapted from https://github.com/yassouali/pytorch-segmentation/blob/master/models/upernet.py """ - - class ConvModule(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, padding=0, inplace=False) -> None: # noqa: FBT002 super().__init__() @@ -19,103 +15,6 @@ def __init__(self, in_channels, out_channels, kernel_size, padding=0, inplace=Fa def forward(self, x): return self.act(self.norm(self.conv(x))) - -# class PSPModule(nn.Module): -# # In the original inmplementation they use precise RoI pooling -# # Instead of using adaptative average pooling -# def __init__(self, in_channels: int, bin_sizes: list[int] | None = None): -# super().__init__() -# if bin_sizes is None: -# bin_sizes = [1, 2, 3, 6] -# out_channels = in_channels // len(bin_sizes) -# self.stages = nn.ModuleList([self._make_stages(in_channels, out_channels, b_s) for b_s in bin_sizes]) -# self.bottleneck = nn.Sequential( -# nn.Conv2d( -# in_channels + (out_channels * len(bin_sizes)), -# in_channels, -# kernel_size=3, -# padding=1, -# bias=False, -# ), -# nn.BatchNorm2d(in_channels), -# nn.ReLU(inplace=True), -# nn.Dropout2d(0.1), -# ) - -# def _make_stages(self, in_channels, out_channels, bin_sz): -# prior = nn.AdaptiveAvgPool2d(output_size=bin_sz) -# conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) -# bn = nn.BatchNorm2d(out_channels) -# relu = nn.ReLU(inplace=True) -# return nn.Sequential(prior, conv, bn, relu) - -# def forward(self, features): -# h, w = features.size()[2], features.size()[3] -# pyramids = [features] -# pyramids.extend( -# [F.interpolate(stage(features), size=(h, w), mode="bilinear", align_corners=True) for stage in self.stages] -# ) -# output = self.bottleneck(torch.cat(pyramids, dim=1)) -# return output - - -# def up_and_add(x, y): -# return F.interpolate(x, size=(y.size(2), y.size(3)), mode="bilinear", align_corners=True) + y - - -# class FPNFuse(nn.Module): -# def __init__(self, feature_channels=None, fpn_out=256): -# super().__init__() -# if feature_channels is None: -# feature_channels = [256, 512, 1024, 2048] -# if not feature_channels[0] == fpn_out: -# msg = f"First index of feature channel ({feature_channels[0]}) did not match fpn_out ({fpn_out})" -# raise Exception(msg) -# self.conv1x1 = nn.ModuleList([nn.Conv2d(ft_size, fpn_out, kernel_size=1) for ft_size in feature_channels[1:]]) -# self.smooth_conv = nn.ModuleList( -# [nn.Conv2d(fpn_out, fpn_out, kernel_size=3, padding=1)] * (len(feature_channels) - 1) -# ) -# self.conv_fusion = nn.Sequential( -# nn.Conv2d( -# len(feature_channels) * fpn_out, -# fpn_out, -# kernel_size=3, -# padding=1, -# bias=False, -# ), -# nn.BatchNorm2d(fpn_out), -# nn.ReLU(inplace=True), -# ) - -# def forward(self, features): -# features[1:] = [conv1x1(feature) for feature, conv1x1 in zip(features[1:], self.conv1x1, strict=False)] -# p = [up_and_add(features[i], features[i - 1]) for i in reversed(range(1, len(features)))] -# p = [smooth_conv(x) for smooth_conv, x in zip(self.smooth_conv, p, strict=False)] -# p = list(reversed(p)) -# p.append(features[-1]) # P = [P1, P2, P3, P4] -# h, w = p[0].size(2), p[0].size(3) -# p[1:] = [F.interpolate(feature, size=(h, w), mode="bilinear", align_corners=True) for feature in p[1:]] - -# x = self.conv_fusion(torch.cat(p, dim=1)) -# return x - - -# class UperNetDecoder(nn.Module): -# def __init__(self, embed_dim: list[int]) -> None: -# super().__init__() -# self.embed_dim = embed_dim -# self.output_embed_dim = embed_dim[0] -# self.PPN = PSPModule(embed_dim[-1]) -# self.FPN = FPNFuse(embed_dim, fpn_out=self.output_embed_dim) - -# def forward(self, x: Tensor): -# x = [f.clone() for f in x] -# x[-1] = self.PPN(x[-1]) -# x = self.FPN(x) - -# return x - - # Adapted from MMSegmentation class UperNetDecoder(nn.Module): """UperNetDecoder. Adapted from MMSegmentation.""" @@ -126,6 +25,7 @@ def __init__( pool_scales: tuple[int] = (1, 2, 3, 6), channels: int = 256, align_corners: bool = True, # noqa: FBT001, FBT002 + scale_modules: bool = False ): """Constructor @@ -134,10 +34,29 @@ def __init__( pool_scales (tuple[int], optional): Pooling scales used in Pooling Pyramid Module applied on the last feature. Default: (1, 2, 3, 6). channels (int, optional): Channels used in the decoder. Defaults to 256. - align_corners (bool, optional): Whter to align corners in rescaling. Defaults to True. + align_corners (bool, optional): Wheter to align corners in rescaling. Defaults to True. + scale_modules (bool, optional): Whether to apply scale modules to the inputs. Needed for plain ViT. + Defaults to False. """ super().__init__() - self.embed_dim = embed_dim + self.scale_modules = scale_modules + if scale_modules: + self.fpn1 = nn.Sequential( + nn.ConvTranspose2d(embed_dim[0], + embed_dim[0] // 2, 2, 2), + nn.BatchNorm2d(embed_dim[0] // 2), + nn.GELU(), + nn.ConvTranspose2d(embed_dim[0] // 2, + embed_dim[0] // 4, 2, 2)) + self.fpn2 = nn.Sequential( + nn.ConvTranspose2d(embed_dim[1], + embed_dim[1] // 2, 2, 2)) + self.fpn3 = nn.Sequential(nn.Identity()) + self.fpn4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2)) + self.embed_dim = [embed_dim[0] // 4, embed_dim[1] // 2, embed_dim[2], embed_dim[3]] + else: + self.embed_dim = embed_dim + self.output_embed_dim = channels self.channels = channels self.align_corners = align_corners @@ -192,6 +111,14 @@ def forward(self, inputs): feats (Tensor): A tensor of shape (batch_size, self.channels, H, W) which is feature map for last layer of decoder head. """ + + if self.scale_modules: + scaled_inputs = [] + scaled_inputs.append(self.fpn1(inputs[0])) + scaled_inputs.append(self.fpn2(inputs[1])) + scaled_inputs.append(self.fpn3(inputs[2])) + scaled_inputs.append(self.fpn4(inputs[3])) + inputs = scaled_inputs # build laterals laterals = [lateral_conv(inputs[i]) for i, lateral_conv in enumerate(self.lateral_convs)] laterals.append(self.psp_forward(inputs)) diff --git a/terratorch/models/prithvi_model_factory.py b/terratorch/models/prithvi_model_factory.py index 5347bca5..eb07796a 100644 --- a/terratorch/models/prithvi_model_factory.py +++ b/terratorch/models/prithvi_model_factory.py @@ -2,6 +2,7 @@ from collections.abc import Callable +import segmentation_models_pytorch as smp import timm import torch from torch import nn @@ -17,6 +18,7 @@ ) from terratorch.models.pixel_wise_model import PixelWiseModel from terratorch.models.scalar_output_model import ScalarOutputModel +from terratorch.models.smp_model_factory import make_smp_encoder, register_custom_encoder PIXEL_WISE_TASKS = ["segmentation", "regression"] SCALAR_TASKS = ["classification"] @@ -26,6 +28,7 @@ class DecoderNotFoundError(Exception): pass + @register_factory class PrithviModelFactory(ModelFactory): def build_model( @@ -34,7 +37,8 @@ def build_model( backbone: str | nn.Module, decoder: str | nn.Module, bands: list[HLSBands | int], - in_channels: int | None = None, # this should be removed, can be derived from bands. But it is a breaking change + in_channels: int + | None = None, # this should be removed, can be derived from bands. But it is a breaking change num_classes: int | None = None, pretrained: bool = True, # noqa: FBT001, FBT002 num_frames: int = 1, @@ -96,6 +100,10 @@ def build_model( raise NotImplementedError(msg) backbone_kwargs, kwargs = _extract_prefix_keys(kwargs, "backbone_") + # These params are used in case we need a SMP decoder + # but should not be used for timm encoder + output_stride = backbone_kwargs.pop("output_stride", None) + out_channels = backbone_kwargs.pop("out_channels", None) backbone: nn.Module = timm.create_model( backbone, @@ -106,14 +114,24 @@ def build_model( features_only=True, **backbone_kwargs, ) - # allow decoder to be a module passed directly - decoder_cls = _get_decoder(decoder) decoder_kwargs, kwargs = _extract_prefix_keys(kwargs, "decoder_") - # TODO: remove this - decoder: nn.Module = decoder_cls(backbone.feature_info.channels(), **decoder_kwargs) - # decoder: nn.Module = decoder_cls([128, 256, 512, 1024], **decoder_kwargs) + if decoder.startswith("smp_"): + decoder: nn.Module = _get_smp_decoder( + decoder, + backbone_kwargs, + decoder_kwargs, + out_channels, + in_channels, + num_classes, + output_stride, + ) + else: + # allow decoder to be a module passed directly + decoder_cls = _get_decoder(decoder) + decoder: nn.Module = decoder_cls(backbone.feature_info.channels(), **decoder_kwargs) + # decoder: nn.Module = decoder_cls([128, 256, 512, 1024], **decoder_kwargs) head_kwargs, kwargs = _extract_prefix_keys(kwargs, "head_") if num_classes: @@ -148,6 +166,46 @@ def build_model( ) +class SMPDecoderForPrithviWrapper(nn.Module): + """ + A wrapper for SMP decoders designed to handle single or multiple embeddings with specified indices. + + Attributes: + decoder (nn.Module): The SMP decoder module being wrapped. + channels (int): The number of output channels of the decoder. + in_index (Union[int, List[int]]): Index or indices of the embeddings to pass to the decoder. + + Methods: + forward(x: List[torch.Tensor]) -> torch.Tensor: + Forward pass for embeddings with specified indices. + """ + + def __init__(self, decoder, num_channels, in_index=-1) -> None: + """ + Args: + decoder (nn.Module): The SMP decoder module to be wrapped. + num_channels (int): The number of output channels of the decoder. + in_index (Union[int, List[int]], optional): Index or indices of the input embeddings to pass to the decoder. + Defaults to -1. + """ + super().__init__() + self.decoder = decoder + self.channels = num_channels + self.in_index = in_index + + @property + def output_embed_dim(self): + return self.channels + + def forward(self, x): + if isinstance(self.in_index, int): + selected_inputs = [x[self.in_index]] + else: + selected_inputs = [x[i] for i in self.in_index] + + return self.decoder(*selected_inputs) + + def _build_appropriate_model( task: str, backbone: nn.Module, @@ -178,6 +236,82 @@ def _build_appropriate_model( ) +def _get_smp_decoder( + decoder: str, + backbone_kwargs: dict, + decoder_kwargs: dict, + out_channels: list[int] | int, + in_channels: int, + num_classes: int, + output_stride: int, +): + """ + Creates and configures a decoder from the Segmentation Models Pytorch (SMP) library. + + This function constructs a decoder module based on the specified parameters and wraps it in a + custom wrapper that allows handling single or multiple embeddings. It also ensures that the + appropriate encoder parameters are passed and registered correctly. + + Args: + decoder (str): The name of the SMP decoder to use. + backbone_kwargs (dict): Dictionary of parameters for configuring the backbone. + decoder_kwargs (dict): Dictionary of parameters specific to the decoder. + out_channels (Union[list[int], int]): The number of output channels for each layer of the decoder. + in_channels (int): The number of input channels. + num_classes (int): The number of output classes for the model. + output_stride (int): The output stride of the decoder. + + Returns: + SMPDecoderForPrithviWrapper: A wrapped decoder module configured based on the provided parameters. + + Raises: + ValueError: If the specified decoder is not supported by SMP. + """ + decoder = decoder.removeprefix("smp_") + decoder_module = getattr(smp, decoder, None) + if decoder_module is None: + msg = f"Decoder {decoder} is not supported in SMP." + raise ValueError(msg) + + # Little hack to make SMP model accept our encoder. + # passes a dummy encoder to be changed later. + # this is needed to pass encoder params. + aux_kwargs, decoder_kwargs = _extract_prefix_keys(decoder_kwargs, "aux_") + smp_kwargs, decoder_kwargs = _extract_prefix_keys(decoder_kwargs, "smp_") + backbone_kwargs["out_channels"] = out_channels + backbone_kwargs["output_stride"] = output_stride + aux_kwargs = None if aux_kwargs == {} else aux_kwargs + + dummy_encoder = make_smp_encoder() + + register_custom_encoder(dummy_encoder, backbone_kwargs, None) + + dummy_encoder = dummy_encoder( + depth=smp_kwargs["encoder_depth"], + output_stride=backbone_kwargs["output_stride"], + out_channels=backbone_kwargs["out_channels"], + ) + + model_args = { + "encoder_name": "SMPEncoderWrapperWithPFFIM", + "encoder_weights": None, + "in_channels": in_channels, + "classes": num_classes, + **smp_kwargs, + } + + # Creates model with dummy encoder and decoder. + model = decoder_module(**model_args, aux_params=aux_kwargs) + + smp_decoder = SMPDecoderForPrithviWrapper( + decoder=model.decoder, + num_channels=out_channels[-1], + in_index=decoder_kwargs["in_index"], + ) + + return smp_decoder + + def _get_decoder(decoder: str | nn.Module) -> nn.Module: if isinstance(decoder, nn.Module): return decoder @@ -197,7 +331,7 @@ def _extract_prefix_keys(d: dict, prefix: str) -> dict: remaining_dict = {} for k, v in d.items(): if k.startswith(prefix): - extracted_dict[k.split(prefix)[1]] = v + extracted_dict[k[len(prefix) :]] = v else: remaining_dict[k] = v diff --git a/terratorch/models/smp_model_factory.py b/terratorch/models/smp_model_factory.py index ee58c0d1..46d8637e 100644 --- a/terratorch/models/smp_model_factory.py +++ b/terratorch/models/smp_model_factory.py @@ -1,83 +1,268 @@ # Copyright contributors to the Terratorch project -""" -This is just an example of a possible structure to include SMP models -Right now it always returns a UNET, but could easily be extended to many of the models provided by SMP. -""" +import importlib +from collections.abc import Callable import segmentation_models_pytorch as smp +import torch +import torch.nn.functional as F # noqa: N812 +from segmentation_models_pytorch.encoders import encoders as smp_encoders from torch import nn +from terratorch.datasets import HLSBands from terratorch.models.model import Model, ModelFactory, ModelOutput, register_factory +class SMPModelWrapper(Model, nn.Module): + """ + Wrapper class for SMP models. + + This class provides additional functionalities on top of SMP models. + + Attributes: + rescale (bool): Whether to rescale the output to match the input dimensions. + smp_model (nn.Module): The base SMP model being wrapped. + final_act (nn.Module): The final activation function to be applied on the output. + squeeze_single_class (bool): Whether to squeeze the output if there is a single output class. + + Methods: + forward(x: torch.Tensor) -> ModelOutput: + Forward pass through the model, optionally rescaling the output. + freeze_encoder() -> None: + Freezes the parameters of the encoder part of the model. + freeze_decoder() -> None: + Freezes the parameters of the decoder part of the model. + """ + + def __init__(self, smp_model, rescale=True, relu=False, squeeze_single_class=False) -> None: # noqa: FBT002 + super().__init__() + """ + Args: + smp_model (nn.Module): The base SMP model to be wrapped. + rescale (bool, optional): Whether to rescale the output to match the input dimensions. Defaults to True. + relu (bool, optional): Whether to apply ReLU activation on the output. + If False, Identity activation is used. Defaults to False. + squeeze_single_class (bool, optional): Whether to squeeze the output if there is a single output class. + Defaults to False. + """ + self.rescale = rescale + self.smp_model = smp_model + self.final_act = nn.ReLU() if relu else nn.Identity() + self.squeeze_single_class = squeeze_single_class + + def forward(self, x): + input_size = x.shape[-2:] + smp_output = self.smp_model(x) + smp_output = self.final_act(smp_output) + + # TODO: support auxiliary head labels + if isinstance(smp_output, tuple): + smp_output, labels = smp_output + + if smp_output.shape[1] == 1 and self.squeeze_single_class: + smp_output = smp_output.squeeze(1) + + if self.rescale and smp_output.shape[-2:] != input_size: + smp_output = F.interpolate(smp_output, size=input_size, mode="bilinear") + return ModelOutput(smp_output) + + def freeze_encoder(self): + freeze_module(self.smp_model.encoder) + + def freeze_decoder(self): + freeze_module(self.smp_model.decoder) + + @register_factory class SMPModelFactory(ModelFactory): def build_model( self, task: str, backbone: str, - decoder: str, - in_channels: int, - pretrained: str | bool | None = True, + model: str, + bands: list[HLSBands | int], + in_channels: int | None = None, num_classes: int = 1, - regression_relu: bool = False, + pretrained: str | bool | None = True, # noqa: FBT002 + prepare_features_for_image_model: Callable | None = None, + regression_relu: bool = False, # noqa: FBT001, FBT002 **kwargs, ) -> Model: - """Factory to create model based on SMP. + """ + Factory class for creating SMP (Segmentation Models Pytorch) based models with optional customization. - Args: - task (str): Must be "segmentation". - backbone (str): Name of backbone. - decoder (str): Decoder architecture. Currently only supports "unet". - in_channels (int): Number of input channels. - pretrained(str | bool): Which weights to use for the backbone. If true, will use "imagenet". If false or None, random weights. Defaults to True. - num_classes (int): Number of classes. - regression_relu (bool). Whether to apply a ReLU if task is regression. Defaults to False. + This factory handles the instantiation of segmentation and regression models using specified + encoders and decoders from the SMP library, along with custom modifications and extensions such + as auxiliary decoders or modified encoders. + + Attributes: + task (str): Specifies the task for which the model is being built. Supported tasks are + "segmentation". + backbone (str): Specifies the backbone model to be used. + decoder (str): Specifies the decoder to be used for constructing the + segmentation model. + bands (list[terratorch.datasets.HLSBands | int]): A list specifying the bands that the model + will operate on. These are expected to be from terratorch.datasets.HLSBands. + in_channels (int, optional): Specifies the number of input channels. Defaults to None. + num_classes (int, optional): The number of output classes for the model. + pretrained (bool | Path, optional): Indicates whether to load pretrained weights for the + backbone. Can also specify a path to weights. Defaults to True. + num_frames (int, optional): Specifies the number of timesteps the model should handle. Useful + for temporal models. + regression_relu (bool): Whether to apply ReLU activation in the case of regression tasks. + **kwargs: Additional arguments that might be passed to further customize the backbone, decoder, + or any auxiliary heads. These should be prefixed appropriately + + Raises: + ValueError: If the specified decoder is not supported by SMP. + Exception: If the specified task is not "segmentation" Returns: - Model: SMP model wrapped in SMPModelWrapper. + nn.Module: A model instance wrapped in SMPModelWrapper configured according to the specified + parameters and tasks. """ - if task not in ["segmentation", "regression"]: - msg = f"SMP models can only perform pixel wise tasks, but got task {task}" + if task != "segmentation": + msg = f"SMP models can only perform segmentation, but got task {task}" raise Exception(msg) - # backbone_kwargs = _extract_prefix_keys(kwargs, "backbone_") + + bands = [HLSBands.try_convert_to_hls_bands_enum(b) for b in bands] + if in_channels is None: + in_channels = len(bands) + + # Gets decoder module. + model_module = getattr(smp, model, None) + if model_module is None: + msg = f"Decoder {model} is not supported in SMP." + raise ValueError(msg) + + backbone_kwargs = _extract_prefix_keys(kwargs, "backbone_") # Encoder params should be prefixed backbone_ + smp_kwargs = _extract_prefix_keys(backbone_kwargs, "smp_") # Smp model params should be prefixed smp_ + aux_params = _extract_prefix_keys(backbone_kwargs, "aux_") # Auxiliary head params should be prefixed aux_ + aux_params = None if aux_params == {} else aux_params + if isinstance(pretrained, bool): if pretrained: pretrained = "imagenet" else: pretrained = None - if decoder == "unet": - model = smp.Unet( - encoder_name=backbone, encoder_weights=pretrained, in_channels=in_channels, classes=num_classes - ) + + # If encoder not currently supported by SMP (custom encoder). + if backbone not in smp_encoders: + # These params must be included in the config file with appropriate prefix. + required_params = { + "encoder_depth": smp_kwargs, + "out_channels": backbone_kwargs, + "output_stride": backbone_kwargs, + } + + for param, config_dict in required_params.items(): + if param not in config_dict: + msg = f"Config must include the '{param}' parameter" + raise ValueError(msg) + + # Using new encoder. + backbone_class = make_smp_encoder(backbone) + backbone_kwargs["prepare_features_for_image_model"] = prepare_features_for_image_model + # Registering custom encoder into SMP. + register_custom_encoder(backbone_class, backbone_kwargs, pretrained) + + model_args = { + "encoder_name": "SMPEncoderWrapperWithPFFIM", + "encoder_weights": pretrained, + "in_channels": in_channels, + "classes": num_classes, + **smp_kwargs, + } + # Using SMP encoder. else: - msg = "Only unet decoder implemented" - raise NotImplementedError(msg) + model_args = { + "encoder_name": backbone, + "encoder_weights": pretrained, + "in_channels": in_channels, + "classes": num_classes, + **smp_kwargs, + } + + model = model_module(**model_args, aux_params=aux_params) + return SMPModelWrapper( model, relu=task == "regression" and regression_relu, squeeze_single_class=task == "regression" ) -class SMPModelWrapper(Model, nn.Module): - def __init__(self, smp_model, relu=False, squeeze_single_class=False) -> None: - super().__init__() - self.smp_model = smp_model - self.final_act = nn.ReLU() if relu else nn.Identity() - self.squeeze_single_class = squeeze_single_class +# Registers a custom encoder into SMP. +def register_custom_encoder(encoder, params, pretrained): + smp_encoders["SMPEncoderWrapperWithPFFIM"] = { + "encoder": encoder, + "params": params, + "pretrained_settings": pretrained, + } - def forward(self, *args, **kwargs): - smp_output = self.smp_model(*args, **kwargs) - smp_output = self.final_act(smp_output) - if smp_output.shape[1] == 1 and self.squeeze_single_class: - smp_output = smp_output.squeeze(1) - return ModelOutput(smp_output) - def freeze_encoder(self): - raise NotImplementedError() +def make_smp_encoder(encoder=None): + if isinstance(encoder, str): + base_class = _get_class_from_string(encoder) + else: + base_class = nn.Module - def freeze_decoder(self): - raise NotImplementedError() + # Wrapper needed to include SMP params and PFFIM + class SMPEncoderWrapperWithPFFIM(base_class): + def __init__( + self, + depth: int, + output_stride: int, + out_channels: list[int], + prepare_features_for_image_model: Callable | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self._depth = depth + self._output_stride = output_stride + self._out_channels = out_channels + self.model = None + + if prepare_features_for_image_model: + self.prepare_features_for_image_model = prepare_features_for_image_model + elif not hasattr(super(), "prepare_features_for_image_model"): + self.prepare_features_for_image_model = lambda x: x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.model: + features = self.model(x) + if hasattr(self.model, "prepare_features_for_image_model"): + return self.model.prepare_features_for_image_model(features) + + features = super().forward(x) + return self.prepare_features_for_image_model(features) + + @property + def out_channels(self): + if hasattr(super(), "out_channels"): + return super().out_channels() + + return self._out_channels + + @property + def output_stride(self): + if hasattr(super(), "output_stride"): + return super().output_stride() + + return min(self._output_stride, 2**self._depth) + + def set_in_channels(self, in_channels, pretrained): + if hasattr(super(), "set_in_channels"): + return super().set_in_channels(in_channels, pretrained) + else: + pass + + def make_dilated(self, output_stride): + if hasattr(super(), "make_dilated"): + return super().make_dilated(output_stride) + else: + pass + + return SMPEncoderWrapperWithPFFIM def _extract_prefix_keys(d: dict, prefix: str) -> dict: @@ -92,3 +277,28 @@ def _extract_prefix_keys(d: dict, prefix: str) -> dict: del d[k] return extracted_dict + + +def _get_class_from_string(class_path): + try: + module_path, name = class_path.rsplit(".", 1) + except ValueError as vr: + msg = "Path must contain a '.' separating module from the class name" + raise ValueError(msg) from vr + + try: + module = importlib.import_module(module_path) + except ImportError as ie: + msg = f"Could not import module '{module_path}'." + raise ImportError(msg) from ie + + try: + return getattr(module, name) + except AttributeError as ae: + msg = f"The class '{name}' was not found in the module '{module_path}'." + raise AttributeError(msg) from ae + + +def freeze_module(module: nn.Module): + for param in module.parameters(): + param.requires_grad_(False) diff --git a/terratorch/tasks/classification_tasks.py b/terratorch/tasks/classification_tasks.py index dbcfd7e9..5820e23d 100644 --- a/terratorch/tasks/classification_tasks.py +++ b/terratorch/tasks/classification_tasks.py @@ -154,17 +154,17 @@ def configure_metrics(self) -> None: class_names = self.hparams["class_names"] metrics = MetricCollection( { - "Overall Accuracy": MulticlassAccuracy( + "Overall_Accuracy": MulticlassAccuracy( num_classes=num_classes, ignore_index=ignore_index, average="micro", ), - "Average Accuracy": MulticlassAccuracy( + "Average_Accuracy": MulticlassAccuracy( num_classes=num_classes, ignore_index=ignore_index, average="macro", ), - "Multiclass Accuracy Class": ClasswiseWrapper( + "Multiclass_Accuracy_Class": ClasswiseWrapper( MulticlassAccuracy( num_classes=num_classes, ignore_index=ignore_index, @@ -172,13 +172,13 @@ def configure_metrics(self) -> None: ), labels=class_names, ), - "Multiclass Jaccard Index": MulticlassJaccardIndex(num_classes=num_classes, ignore_index=ignore_index), - "Multiclass Jaccard Index Class": ClasswiseWrapper( + "Multiclass_Jaccard_Index": MulticlassJaccardIndex(num_classes=num_classes, ignore_index=ignore_index), + "Multiclass_Jaccard_Index_Class": ClasswiseWrapper( MulticlassJaccardIndex(num_classes=num_classes, ignore_index=ignore_index, average=None), labels=class_names, ), # why FBetaScore - "Multiclass F1 Score": MulticlassFBetaScore( + "Multiclass_F1_Score": MulticlassFBetaScore( num_classes=num_classes, ignore_index=ignore_index, beta=1.0, diff --git a/terratorch/tasks/multilabel_classification_tasks.py b/terratorch/tasks/multilabel_classification_tasks.py index 68e68283..baf38523 100644 --- a/terratorch/tasks/multilabel_classification_tasks.py +++ b/terratorch/tasks/multilabel_classification_tasks.py @@ -42,13 +42,13 @@ def configure_losses(self) -> None: def configure_metrics(self) -> None: metrics = MetricCollection( { - "Overall Accuracy": MultilabelAccuracy( + "Overall_Accuracy": MultilabelAccuracy( num_labels=self.hparams["model_args"]["num_classes"], average="micro" ), - "Average Accuracy": MultilabelAccuracy( + "Average_Accuracy": MultilabelAccuracy( num_labels=self.hparams["model_args"]["num_classes"], average="macro" ), - "Multilabel F1 Score": MultilabelFBetaScore( + "Multilabel_F1_Score": MultilabelFBetaScore( num_labels=self.hparams["model_args"]["num_classes"], beta=1.0, average="micro" ), } diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index 7070b9c5..5f123351 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -168,13 +168,13 @@ def configure_metrics(self) -> None: class_names = self.hparams["class_names"] metrics = MetricCollection( { - "Multiclass Accuracy": MulticlassAccuracy( + "Multiclass_Accuracy": MulticlassAccuracy( num_classes=num_classes, ignore_index=ignore_index, multidim_average="global", average="micro", ), - "Multiclass Accuracy Class": ClasswiseWrapper( + "Multiclass_Accuracy_Class": ClasswiseWrapper( MulticlassAccuracy( num_classes=num_classes, ignore_index=ignore_index, @@ -183,18 +183,18 @@ def configure_metrics(self) -> None: ), labels=class_names, ), - "Multiclass Jaccard Index Micro": MulticlassJaccardIndex( + "Multiclass_Jaccard_Index_Micro": MulticlassJaccardIndex( num_classes=num_classes, ignore_index=ignore_index, average="micro" ), - "Multiclass Jaccard Index": MulticlassJaccardIndex( + "Multiclass_Jaccard_Index": MulticlassJaccardIndex( num_classes=num_classes, ignore_index=ignore_index, ), - "Multiclass Jaccard Index Class": ClasswiseWrapper( + "Multiclass_Jaccard_Index_Class": ClasswiseWrapper( MulticlassJaccardIndex(num_classes=num_classes, ignore_index=ignore_index, average=None), labels=class_names, ), - "Multiclass F1 Score": MulticlassF1Score( + "Multiclass_F1_Score": MulticlassF1Score( num_classes=num_classes, ignore_index=ignore_index, multidim_average="global", diff --git a/tests/manufactured-finetune_prithvi_swin_B_band_interval.yaml b/tests/manufactured-finetune_prithvi_swin_B_band_interval.yaml new file mode 100644 index 00000000..8697cd63 --- /dev/null +++ b/tests/manufactured-finetune_prithvi_swin_B_band_interval.yaml @@ -0,0 +1,136 @@ +# lightning.pytorch==2.1.1 +seed_everything: 42 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + # precision: 16-mixed + logger: + class_path: TensorBoardLogger + init_args: + save_dir: tests/ + name: all_ecos_random + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 100 + max_epochs: 5 + check_val_every_n_epoch: 1 + log_every_n_steps: 20 + enable_checkpointing: true + default_root_dir: tests/ +data: + class_path: GenericNonGeoPixelwiseRegressionDataModule + init_args: + batch_size: 2 + num_workers: 4 + train_transform: + - class_path: albumentations.HorizontalFlip + init_args: + p: 0.5 + - class_path: albumentations.Rotate + init_args: + limit: 30 + border_mode: 0 # cv2.BORDER_CONSTANT + value: 0 + # mask_value: 1 + p: 0.5 + - class_path: ToTensorV2 + dataset_bands: + - [0, 11] + output_bands: + - [1, 3] + - [4, 6] + rgb_indices: + - 2 + - 1 + - 0 + train_data_root: tests/ + train_label_data_root: tests/ + val_data_root: tests/ + val_label_data_root: tests/ + test_data_root: tests/ + test_label_data_root: tests/ + img_grep: "regression*input*.tif" + label_grep: "regression*label*.tif" + means: + - 547.36707 + - 898.5121 + - 1020.9082 + - 2665.5352 + - 2340.584 + - 1610.1407 + stds: + - 411.4701 + - 558.54065 + - 815.94025 + - 812.4403 + - 1113.7145 + - 1067.641 + no_label_replace: -1 + no_data_replace: 0 + +model: + class_path: terratorch.tasks.PixelwiseRegressionTask + init_args: + model_args: + decoder: UperNetDecoder + pretrained: true + backbone: prithvi_swin_B + backbone_pretrained_cfg_overlay: + file: tests/prithvi_swin_B.pt + backbone_drop_path_rate: 0.3 + # backbone_window_size: 8 + decoder_channels: 256 + in_channels: 6 + bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + num_frames: 1 + head_dropout: 0.5708022831486758 + head_final_act: torch.nn.ReLU + head_learned_upscale_layers: 2 + loss: rmse + #aux_heads: + # - name: aux_head + # decoder: IdentityDecoder + # decoder_args: + # decoder_out_index: 2 + # head_dropout: 0,5 + # head_channel_list: + # - 64 + # head_final_act: torch.nn.ReLU + #aux_loss: + # aux_head: 0.4 + ignore_index: -1 + freeze_backbone: true + freeze_decoder: false + model_factory: PrithviModelFactory + + # uncomment this block for tiled inference + # tiled_inference_parameters: + # h_crop: 224 + # h_stride: 192 + # w_crop: 224 + # w_stride: 192 + # average_patches: true +optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 0.00013524680528283027 + weight_decay: 0.047782217873995426 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss + diff --git a/tests/manufactured-finetune_prithvi_swin_B_metrics_from_file.yaml b/tests/manufactured-finetune_prithvi_swin_B_metrics_from_file.yaml new file mode 100644 index 00000000..91a72a3c --- /dev/null +++ b/tests/manufactured-finetune_prithvi_swin_B_metrics_from_file.yaml @@ -0,0 +1,124 @@ +# lightning.pytorch==2.1.1 +seed_everything: 42 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + # precision: 16-mixed + logger: + class_path: TensorBoardLogger + init_args: + save_dir: tests/ + name: all_ecos_random + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 100 + max_epochs: 5 + check_val_every_n_epoch: 1 + log_every_n_steps: 20 + enable_checkpointing: true + default_root_dir: tests/ +data: + class_path: GenericNonGeoPixelwiseRegressionDataModule + init_args: + batch_size: 2 + num_workers: 4 + train_transform: + - class_path: albumentations.HorizontalFlip + init_args: + p: 0.5 + - class_path: albumentations.Rotate + init_args: + limit: 30 + border_mode: 0 # cv2.BORDER_CONSTANT + value: 0 + # mask_value: 1 + p: 0.5 + - class_path: ToTensorV2 + dataset_bands: + - [0, 11] + output_bands: + - [1, 3] + - [4, 6] + rgb_indices: + - 2 + - 1 + - 0 + train_data_root: tests/ + train_label_data_root: tests/ + val_data_root: tests/ + val_label_data_root: tests/ + test_data_root: tests/ + test_label_data_root: tests/ + img_grep: "regression*input*.tif" + label_grep: "regression*label*.tif" + means: tests/means.txt + stds: tests/stds.txt + no_label_replace: -1 + no_data_replace: 0 + +model: + class_path: terratorch.tasks.PixelwiseRegressionTask + init_args: + model_args: + decoder: UperNetDecoder + pretrained: true + backbone: prithvi_swin_B + backbone_pretrained_cfg_overlay: + file: tests/prithvi_swin_B.pt + backbone_drop_path_rate: 0.3 + # backbone_window_size: 8 + decoder_channels: 256 + in_channels: 6 + bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + num_frames: 1 + head_dropout: 0.5708022831486758 + head_final_act: torch.nn.ReLU + head_learned_upscale_layers: 2 + loss: rmse + #aux_heads: + # - name: aux_head + # decoder: IdentityDecoder + # decoder_args: + # decoder_out_index: 2 + # head_dropout: 0,5 + # head_channel_list: + # - 64 + # head_final_act: torch.nn.ReLU + #aux_loss: + # aux_head: 0.4 + ignore_index: -1 + freeze_backbone: true + freeze_decoder: false + model_factory: PrithviModelFactory + + # uncomment this block for tiled inference + # tiled_inference_parameters: + # h_crop: 224 + # h_stride: 192 + # w_crop: 224 + # w_stride: 192 + # average_patches: true +optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 0.00013524680528283027 + weight_decay: 0.047782217873995426 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss + diff --git a/tests/manufactured-finetune_prithvi_swin_B_string.yaml b/tests/manufactured-finetune_prithvi_swin_B_string.yaml new file mode 100644 index 00000000..a7aa84c2 --- /dev/null +++ b/tests/manufactured-finetune_prithvi_swin_B_string.yaml @@ -0,0 +1,149 @@ +# lightning.pytorch==2.1.1 +seed_everything: 42 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + # precision: 16-mixed + logger: + class_path: TensorBoardLogger + init_args: + save_dir: tests/ + name: all_ecos_random + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 100 + max_epochs: 5 + check_val_every_n_epoch: 1 + log_every_n_steps: 20 + enable_checkpointing: true + default_root_dir: tests/ +data: + class_path: GenericNonGeoPixelwiseRegressionDataModule + init_args: + batch_size: 2 + num_workers: 4 + train_transform: + - class_path: albumentations.HorizontalFlip + init_args: + p: 0.5 + - class_path: albumentations.Rotate + init_args: + limit: 30 + border_mode: 0 # cv2.BORDER_CONSTANT + value: 0 + # mask_value: 1 + p: 0.5 + - class_path: ToTensorV2 + dataset_bands: + - "band_1" + - "band_2" + - "band_3" + - "band_4" + - "band_5" + - "band_6" + - "band_7" + - "band_8" + - "band_9" + - "band_10" + output_bands: + - "band_2" + - "band_3" + - "band_4" + - "band_5" + - "band_6" + - "band_7" + rgb_indices: + - 2 + - 1 + - 0 + train_data_root: tests/ + train_label_data_root: tests/ + val_data_root: tests/ + val_label_data_root: tests/ + test_data_root: tests/ + test_label_data_root: tests/ + img_grep: "regression*input*.tif" + label_grep: "regression*label*.tif" + means: + - 547.36707 + - 898.5121 + - 1020.9082 + - 2665.5352 + - 2340.584 + - 1610.1407 + stds: + - 411.4701 + - 558.54065 + - 815.94025 + - 812.4403 + - 1113.7145 + - 1067.641 + no_label_replace: -1 + no_data_replace: 0 + +model: + class_path: terratorch.tasks.PixelwiseRegressionTask + init_args: + model_args: + decoder: UperNetDecoder + pretrained: true + backbone: prithvi_swin_B + backbone_pretrained_cfg_overlay: + file: tests/prithvi_swin_B.pt + backbone_drop_path_rate: 0.3 + # backbone_window_size: 8 + decoder_channels: 256 + in_channels: 6 + bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + num_frames: 1 + head_dropout: 0.5708022831486758 + head_final_act: torch.nn.ReLU + head_learned_upscale_layers: 2 + loss: rmse + #aux_heads: + # - name: aux_head + # decoder: IdentityDecoder + # decoder_args: + # decoder_out_index: 2 + # head_dropout: 0,5 + # head_channel_list: + # - 64 + # head_final_act: torch.nn.ReLU + #aux_loss: + # aux_head: 0.4 + ignore_index: -1 + freeze_backbone: true + freeze_decoder: false + model_factory: PrithviModelFactory + + # uncomment this block for tiled inference + # tiled_inference_parameters: + # h_crop: 224 + # h_stride: 192 + # w_crop: 224 + # w_stride: 192 + # average_patches: true +optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 0.00013524680528283027 + weight_decay: 0.047782217873995426 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss + diff --git a/tests/means.txt b/tests/means.txt new file mode 100644 index 00000000..56900a7f --- /dev/null +++ b/tests/means.txt @@ -0,0 +1,7 @@ +411.4701 +558.54065 +815.94025 +812.4403 +1113.7145 +1067.641 + diff --git a/tests/stds.txt b/tests/stds.txt new file mode 100644 index 00000000..ad602006 --- /dev/null +++ b/tests/stds.txt @@ -0,0 +1,7 @@ +547.36707 +898.5121 +1020.9082 +2665.5352 +2340.584 +1610.1407 + diff --git a/tests/test_finetune.py b/tests/test_finetune.py index bb48b94e..8592e639 100644 --- a/tests/test_finetune.py +++ b/tests/test_finetune.py @@ -23,21 +23,42 @@ def test_finetune_multiple_backbones(model_name): command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}.yaml"] _ = build_lightning_cli(command_list) -""" -@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"]) -def test_finetune_multiple_backbones(model_name): +@pytest.mark.parametrize("model_name", ["prithvi_swin_B"]) +def test_finetune_bands_intervals(model_name): model_instance = timm.create_model(model_name) - pretrained_bands = [0, 1, 2, 3, 4, 5] - model_bands = [0, 1, 2, 3, 4, 5] state_dict = model_instance.state_dict() torch.save(state_dict, os.path.join("tests/", model_name + ".pt")) # Running the terratorch CLI - command_str = f"python terratorch/__main__.py fit -c tests/manufactured-finetune_{model_name}.yaml" - command_out = subprocess.run(command_str, shell=True) + command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_band_interval.yaml"] + _ = build_lightning_cli(command_list) + +@pytest.mark.parametrize("model_name", ["prithvi_swin_B"]) +def test_finetune_bands_str(model_name): + + model_instance = timm.create_model(model_name) + + state_dict = model_instance.state_dict() + + torch.save(state_dict, os.path.join("tests/", model_name + ".pt")) + + # Running the terratorch CLI + command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_string.yaml"] + _ = build_lightning_cli(command_list) - assert not command_out.returncode - """ +@pytest.mark.parametrize("model_name", ["prithvi_swin_B"]) +def test_finetune_bands_str(model_name): + + model_instance = timm.create_model(model_name) + + state_dict = model_instance.state_dict() + + torch.save(state_dict, os.path.join("tests/", model_name + ".pt")) + + # Running the terratorch CLI + command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_metrics_from_file.yaml"] + _ = build_lightning_cli(command_list) + diff --git a/tests/test_generic_dataset.py b/tests/test_generic_dataset.py index 7ec70933..9ac8bd2a 100644 --- a/tests/test_generic_dataset.py +++ b/tests/test_generic_dataset.py @@ -6,7 +6,7 @@ import torch from _pytest.tmpdir import TempPathFactory -from terratorch.datasets import GenericNonGeoPixelwiseRegressionDataset, GenericNonGeoSegmentationDataset +from terratorch.datasets import GenericNonGeoPixelwiseRegressionDataset, GenericNonGeoSegmentationDataset, HLSBands REGRESSION_IMAGE_PATH = "tests/regression_test_input.tif" REGRESSION_LABEL_PATH = "tests/regression_test_label.tif" @@ -14,6 +14,56 @@ SEGMENTATION_LABEL_PATH = "tests/segmentation_test_label.tif" NUM_CLASSES_SEGMENTATION = 2 +# Testing bands +# HLS_bands +HLS_dataset_bands = [ + "COASTAL_AEROSOL", + "BLUE", + "GREEN", + "RED", + "NIR_NARROW", + "SWIR_1", + "SWIR_2", + "CIRRUS", + "THEMRAL_INFRARED_1", + "THEMRAL_INFRARED_2", +] + +HLS_output_bands = [ + "BLUE", + "GREEN", + "RED", + "NIR_NARROW", + "SWIR_1", + "SWIR_2", +] + +HLS_expected_filter_bands = list(range(1, 7)) +# Integer Intervals bands +int_dataset_bands = [(0, 20)] +int_output_bands = [(1, 6), (10, 12)] +# Simple string bands +str_dataset_bands = [f"band_{j}" for j in range(20)] +str_output_bands = [f"band_{j}" for j in range(1, 7)] + [f"band_{j}" for j in range(10, 13)] + +expected_filter_indices = list(range(1, 7)) + list(range(10, 13)) + + +# Mixed case +mixed_dataset_bands = [ + (0, 10), + HLSBands.RED, + HLSBands.BLUE, + HLSBands.GREEN, + "extra_band_1", + "extra_band_2", + 200, + 201, + 202, +] +mixed_output_bands = [1, 2, HLSBands.BLUE, "extra_band_1", 201, 202] +expected_mixed_filter_indices = [1, 2, 12, 14, 17, 18] + @pytest.fixture(scope="session") def split_file_path(tmp_path_factory): @@ -59,6 +109,64 @@ def test_data_type_regression_float_float(self, regression_dataset): assert torch.is_floating_point(regression_dataset[0]["image"]) assert torch.is_floating_point(regression_dataset[0]["mask"]) + @pytest.fixture(scope="class") + def regression_dataset_with_HLS_bands(self, data_root_regression, split_file_path): + return GenericNonGeoPixelwiseRegressionDataset( + data_root_regression, + dataset_bands=HLS_dataset_bands, + output_bands=HLS_output_bands, + image_grep="input_data/*_img.tif", + label_grep="label_data/*_label.tif", + split=split_file_path, + ), HLS_expected_filter_bands + + @pytest.fixture(scope="class") + def regression_dataset_with_interval_bands(self, data_root_regression, split_file_path): + return GenericNonGeoPixelwiseRegressionDataset( + data_root_regression, + dataset_bands=int_dataset_bands, + output_bands=int_output_bands, + image_grep="input_data/*_img.tif", + label_grep="label_data/*_label.tif", + split=split_file_path, + ), expected_filter_indices + + @pytest.fixture(scope="class") + def regression_dataset_with_str_bands(self, data_root_regression, split_file_path): + return GenericNonGeoPixelwiseRegressionDataset( + data_root_regression, + dataset_bands=str_dataset_bands, + output_bands=str_output_bands, + image_grep="input_data/*_img.tif", + label_grep="label_data/*_label.tif", + split=split_file_path, + ), expected_filter_indices + + @pytest.fixture(scope="class") + def regression_dataset_with_mixed_bands(self, data_root_regression, split_file_path): + return GenericNonGeoPixelwiseRegressionDataset( + data_root_regression, + dataset_bands=mixed_dataset_bands, + output_bands=mixed_output_bands, + image_grep="input_data/*_img.tif", + label_grep="label_data/*_label.tif", + split=split_file_path, + ), expected_mixed_filter_indices + + @pytest.mark.parametrize( + "dataset", + [ + "regression_dataset_with_HLS_bands", + "regression_dataset_with_str_bands", + "regression_dataset_with_interval_bands", + "regression_dataset_with_str_bands", + "regression_dataset_with_mixed_bands", + ], + ) + def test_correct_filter(self, dataset, request): + fixture, expected = request.getfixturevalue(dataset) + assert fixture.filter_indices == expected + class TestGenericSegmentationDataset: @pytest.fixture(scope="class") @@ -94,3 +202,65 @@ def test_file_discovery_generic_segmentation_dataset(self, segmentation_dataset) def test_data_type_regression_float_long(self, segmentation_dataset): assert torch.is_floating_point(segmentation_dataset[0]["image"]) assert not torch.is_floating_point(segmentation_dataset[0]["mask"]) + + @pytest.fixture(scope="class") + def segmentation_dataset_with_HLS_bands(self, data_root_segmentation, split_file_path): + return GenericNonGeoSegmentationDataset( + data_root_segmentation, + NUM_CLASSES_SEGMENTATION, + dataset_bands=HLS_dataset_bands, + output_bands=HLS_output_bands, + image_grep="input_data/*_img.tif", + label_grep="label_data/*_label.tif", + split=split_file_path, + ), HLS_expected_filter_bands + + @pytest.fixture(scope="class") + def segmentation_dataset_with_interval_bands(self, data_root_segmentation, split_file_path): + return GenericNonGeoSegmentationDataset( + data_root_segmentation, + NUM_CLASSES_SEGMENTATION, + dataset_bands=int_dataset_bands, + output_bands=int_output_bands, + image_grep="input_data/*_img.tif", + label_grep="label_data/*_label.tif", + split=split_file_path, + ), expected_filter_indices + + @pytest.fixture(scope="class") + def segmentation_dataset_with_str_bands(self, data_root_segmentation, split_file_path): + return GenericNonGeoSegmentationDataset( + data_root_segmentation, + NUM_CLASSES_SEGMENTATION, + dataset_bands=str_dataset_bands, + output_bands=str_output_bands, + image_grep="input_data/*_img.tif", + label_grep="label_data/*_label.tif", + split=split_file_path, + ), expected_filter_indices + + @pytest.fixture(scope="class") + def segmentation_dataset_with_mixed_bands(self, data_root_segmentation, split_file_path): + return GenericNonGeoSegmentationDataset( + data_root_segmentation, + NUM_CLASSES_SEGMENTATION, + dataset_bands=mixed_dataset_bands, + output_bands=mixed_output_bands, + image_grep="input_data/*_img.tif", + label_grep="label_data/*_label.tif", + split=split_file_path, + ), expected_mixed_filter_indices + + @pytest.mark.parametrize( + "dataset", + [ + "segmentation_dataset_with_HLS_bands", + "segmentation_dataset_with_str_bands", + "segmentation_dataset_with_interval_bands", + "segmentation_dataset_with_str_bands", + "segmentation_dataset_with_mixed_bands", + ], + ) + def test_correct_filter(self, dataset, request): + fixture, expected = request.getfixturevalue(dataset) + assert fixture.filter_indices == expected diff --git a/tests/test_prithvi_model_factory.py b/tests/test_prithvi_model_factory.py index 6d3a4b82..3239671d 100644 --- a/tests/test_prithvi_model_factory.py +++ b/tests/test_prithvi_model_factory.py @@ -8,7 +8,7 @@ from terratorch.models import PrithviModelFactory from terratorch.models.backbones.prithvi_vit import PRETRAINED_BANDS -#from terratorch.models.backbones.prithvi_vit import default_cfgs as vit_default_cfgs +# from terratorch.models.backbones.prithvi_vit import default_cfgs as vit_default_cfgs from terratorch.models.model import AuxiliaryHead NUM_CHANNELS = 6 @@ -17,6 +17,11 @@ EXPECTED_REGRESSION_OUTPUT_SHAPE = (1, 224, 224) EXPECTED_CLASSIFICATION_OUTPUT_SHAPE = (1, NUM_CLASSES) +PIXELWISE_TASK_EXPECTED_OUTPUT = [ + ("regression", EXPECTED_REGRESSION_OUTPUT_SHAPE), + ("segmentation", EXPECTED_SEGMENTATION_OUTPUT_SHAPE), +] + @pytest.fixture(scope="session") def model_factory() -> PrithviModelFactory: @@ -27,6 +32,7 @@ def model_factory() -> PrithviModelFactory: def model_input() -> torch.Tensor: return torch.ones((1, NUM_CHANNELS, 224, 224)) + @pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"]) def test_create_classification_model(backbone, model_factory: PrithviModelFactory, model_input): model = model_factory.build_model( @@ -43,6 +49,7 @@ def test_create_classification_model(backbone, model_factory: PrithviModelFactor with torch.no_grad(): assert model(model_input).output.shape == EXPECTED_CLASSIFICATION_OUTPUT_SHAPE + @pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"]) def test_create_classification_model_no_in_channels(backbone, model_factory: PrithviModelFactory, model_input): model = model_factory.build_model( @@ -58,133 +65,115 @@ def test_create_classification_model_no_in_channels(backbone, model_factory: Pri with torch.no_grad(): assert model(model_input).output.shape == EXPECTED_CLASSIFICATION_OUTPUT_SHAPE -@pytest.mark.parametrize("backbone", ["prithvi_vit_100"]) -@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) -def test_create_segmentation_model(backbone, decoder, model_factory: PrithviModelFactory, model_input): - model = model_factory.build_model( - "segmentation", - backbone=backbone, - decoder=decoder, - in_channels=NUM_CHANNELS, - bands=PRETRAINED_BANDS, - pretrained=False, - num_classes=NUM_CLASSES, - ) - model.eval() - - with torch.no_grad(): - assert model(model_input).output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE @pytest.mark.parametrize("backbone", ["prithvi_vit_100"]) +@pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT) @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) -def test_create_segmentation_model_no_in_channels(backbone, decoder, model_factory: PrithviModelFactory, model_input): - model = model_factory.build_model( - "segmentation", - backbone=backbone, - decoder=decoder, - bands=PRETRAINED_BANDS, - pretrained=False, - num_classes=NUM_CLASSES, - ) +def test_create_pixelwise_model(backbone, task, expected, decoder, model_factory: PrithviModelFactory, model_input): + model_args = { + "task": task, + "backbone": backbone, + "decoder": decoder, + "in_channels": NUM_CHANNELS, + "bands": PRETRAINED_BANDS, + "pretrained": False, + } + + if task == "segmentation": + model_args["num_classes"] = NUM_CLASSES + if decoder == "UperNetDecoder": + model_args["out_indices"] = [1, 2, 3, 4] + model_args["scale_modules"] = True + + model = model_factory.build_model(**model_args) model.eval() with torch.no_grad(): - assert model(model_input).output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE + assert model(model_input).output.shape == expected @pytest.mark.parametrize("backbone", ["prithvi_vit_100"]) +@pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT) @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) -def test_create_segmentation_model_with_aux_heads(backbone, decoder, model_factory: PrithviModelFactory, model_input): - aux_heads_name = ["first_aux", "second_aux"] - model = model_factory.build_model( - "segmentation", - backbone=backbone, - decoder=decoder, - in_channels=NUM_CHANNELS, - bands=PRETRAINED_BANDS, - pretrained=False, - num_classes=NUM_CLASSES, - aux_decoders=[AuxiliaryHead(name, "FCNDecoder", None) for name in aux_heads_name], - ) +def test_create_pixelwise_model_no_in_channels( + backbone, task, expected, decoder, model_factory: PrithviModelFactory, model_input +): + model_args = { + "task": task, + "backbone": backbone, + "decoder": decoder, + "bands": PRETRAINED_BANDS, + "pretrained": False, + } + + if task == "segmentation": + model_args["num_classes"] = NUM_CLASSES + if decoder == "UperNetDecoder": + model_args["out_indices"] = [1, 2, 3, 4] + model_args["scale_modules"] = True + + model = model_factory.build_model(**model_args) model.eval() with torch.no_grad(): - model_output = model(model_input) - assert model_output.output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE - - assert len(model_output.auxiliary_heads.keys() & aux_heads_name) == len(aux_heads_name) - for _, output in model_output.auxiliary_heads.items(): - assert output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE + assert model(model_input).output.shape == expected @pytest.mark.parametrize("backbone", ["prithvi_vit_100"]) +@pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT) @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) -def test_create_regression_model(backbone, decoder, model_factory: PrithviModelFactory, model_input): - model = model_factory.build_model( - "regression", - backbone=backbone, - decoder=decoder, - in_channels=NUM_CHANNELS, - bands=PRETRAINED_BANDS, - pretrained=False, - ) - model.eval() - - with torch.no_grad(): - assert model(model_input).output.shape == EXPECTED_REGRESSION_OUTPUT_SHAPE - -@pytest.mark.parametrize("backbone", ["prithvi_vit_100"]) -@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) -def test_create_regression_model_no_in_channels(backbone, decoder, model_factory: PrithviModelFactory, model_input): - model = model_factory.build_model( - "regression", - backbone=backbone, - decoder=decoder, - bands=PRETRAINED_BANDS, - pretrained=False, - ) - model.eval() - - with torch.no_grad(): - assert model(model_input).output.shape == EXPECTED_REGRESSION_OUTPUT_SHAPE - -@pytest.mark.parametrize("backbone", ["prithvi_vit_100"]) -@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) -def test_create_regression_model_with_aux_heads(backbone, decoder, model_factory: PrithviModelFactory, model_input): +def test_create_pixelwise_model_with_aux_heads( + backbone, task, expected, decoder, model_factory: PrithviModelFactory, model_input +): aux_heads_name = ["first_aux", "second_aux"] - model = model_factory.build_model( - "regression", - backbone=backbone, - decoder=decoder, - in_channels=NUM_CHANNELS, - bands=PRETRAINED_BANDS, - pretrained=False, - aux_decoders=[AuxiliaryHead(name, "FCNDecoder", None) for name in aux_heads_name], - ) + model_args = { + "task": task, + "backbone": backbone, + "decoder": decoder, + "in_channels": NUM_CHANNELS, + "bands": PRETRAINED_BANDS, + "pretrained": False, + "aux_decoders": [AuxiliaryHead(name, "FCNDecoder", None) for name in aux_heads_name], + } + if task == "segmentation": + model_args["num_classes"] = NUM_CLASSES + + if decoder == "UperNetDecoder": + model_args["out_indices"] = [1, 2, 3, 4] + model_args["scale_modules"] = True + + model = model_factory.build_model(**model_args) model.eval() with torch.no_grad(): model_output = model(model_input) - assert model_output.output.shape == EXPECTED_REGRESSION_OUTPUT_SHAPE + assert model_output.output.shape == expected assert len(model_output.auxiliary_heads.keys() & aux_heads_name) == len(aux_heads_name) for _, output in model_output.auxiliary_heads.items(): - assert output.shape == EXPECTED_REGRESSION_OUTPUT_SHAPE + assert output.shape == expected @pytest.mark.parametrize("backbone", ["prithvi_vit_100"]) +@pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT) @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) -def test_create_model_with_extra_bands(backbone, decoder, model_factory: PrithviModelFactory): - model = model_factory.build_model( - "segmentation", - backbone=backbone, - decoder=decoder, - in_channels=NUM_CHANNELS + 1, - bands=[*PRETRAINED_BANDS, 7], # add an extra band - pretrained=False, - num_classes=NUM_CLASSES, - ) +def test_create_pixelwise_model_with_extra_bands(backbone, task, expected, decoder, model_factory: PrithviModelFactory): + model_args = { + "task": task, + "backbone": backbone, + "decoder": decoder, + "in_channels": NUM_CHANNELS + 1, + "bands": [*PRETRAINED_BANDS, 7], + "pretrained": False, + } + if task == "segmentation": + model_args["num_classes"] = NUM_CLASSES + + if decoder == "UperNetDecoder": + model_args["out_indices"] = [1, 2, 3, 4] + model_args["scale_modules"] = True + model = model_factory.build_model(**model_args) model.eval() model_input = torch.ones((1, NUM_CHANNELS + 1, 224, 224)) with torch.no_grad(): - assert model(model_input).output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE + assert model(model_input).output.shape == expected diff --git a/tests/test_prithvi_tasks.py b/tests/test_prithvi_tasks.py index cccb1193..44348de6 100644 --- a/tests/test_prithvi_tasks.py +++ b/tests/test_prithvi_tasks.py @@ -22,21 +22,26 @@ def model_input() -> torch.Tensor: return torch.ones((1, NUM_CHANNELS, 224, 224)) -@pytest.mark.parametrize("backbone",["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"]) +@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"]) @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) @pytest.mark.parametrize("loss", ["ce", "jaccard", "focal", "dice"]) def test_create_segmentation_task(backbone, decoder, loss, model_factory: PrithviModelFactory): + model_args = { + "backbone": backbone, + "decoder": decoder, + "in_channels": NUM_CHANNELS, + "bands": PRETRAINED_BANDS, + "pretrained": False, + "num_classes": NUM_CLASSES, + } + + if decoder == "UperNetDecoder" and backbone.startswith("prithvi_vit"): + model_args["out_indices"] = [1, 2, 3, 4] + model_args["scale_modules"] = True SemanticSegmentationTask( - { - "backbone": backbone, - "decoder": decoder, - "in_channels": NUM_CHANNELS, - "bands": PRETRAINED_BANDS, - "pretrained": False, - "num_classes": NUM_CLASSES, - }, + model_args, model_factory, - loss=loss + loss=loss, ) @@ -44,16 +49,22 @@ def test_create_segmentation_task(backbone, decoder, loss, model_factory: Prithv @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) @pytest.mark.parametrize("loss", ["mae", "rmse", "huber"]) def test_create_regression_task(backbone, decoder, loss, model_factory: PrithviModelFactory): + model_args = { + "backbone": backbone, + "decoder": decoder, + "in_channels": NUM_CHANNELS, + "bands": PRETRAINED_BANDS, + "pretrained": False, + } + + if decoder == "UperNetDecoder" and backbone.startswith("prithvi_vit"): + model_args["out_indices"] = [1, 2, 3, 4] + model_args["scale_modules"] = True + PixelwiseRegressionTask( - { - "backbone": backbone, - "decoder": decoder, - "in_channels": NUM_CHANNELS, - "bands": PRETRAINED_BANDS, - "pretrained": False, - }, + model_args, model_factory, - loss=loss + loss=loss, ) @@ -61,15 +72,21 @@ def test_create_regression_task(backbone, decoder, loss, model_factory: PrithviM @pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"]) @pytest.mark.parametrize("loss", ["ce", "bce", "jaccard", "focal"]) def test_create_classification_task(backbone, decoder, loss, model_factory: PrithviModelFactory): + model_args = { + "backbone": backbone, + "decoder": decoder, + "in_channels": NUM_CHANNELS, + "bands": PRETRAINED_BANDS, + "pretrained": False, + "num_classes": NUM_CLASSES, + } + + if decoder == "UperNetDecoder" and backbone.startswith("prithvi_vit"): + model_args["out_indices"] = [1, 2, 3, 4] + model_args["scale_modules"] = True + ClassificationTask( - { - "backbone": backbone, - "decoder": decoder, - "in_channels": NUM_CHANNELS, - "bands": PRETRAINED_BANDS, - "pretrained": False, - "num_classes": NUM_CLASSES, - }, + model_args, model_factory, - loss=loss + loss=loss, ) diff --git a/tests/test_smp_model_factory.py b/tests/test_smp_model_factory.py new file mode 100644 index 00000000..11b0a67b --- /dev/null +++ b/tests/test_smp_model_factory.py @@ -0,0 +1,80 @@ +# Copyright contributors to the Terratorch project + +import os + +import pytest +import torch + +from terratorch.models import SMPModelFactory +from terratorch.models.backbones.prithvi_vit import PRETRAINED_BANDS + +# from terratorch.models.backbones.prithvi_vit import default_cfgs as vit_default_cfgs + +NUM_CHANNELS = 6 +NUM_CLASSES = 2 +EXPECTED_SEGMENTATION_OUTPUT_SHAPE = (1, NUM_CLASSES, 224, 224) +EXPECTED_REGRESSION_OUTPUT_SHAPE = (1, 224, 224) +EXPECTED_CLASSIFICATION_OUTPUT_SHAPE = (1, NUM_CLASSES) + + +@pytest.fixture(scope="session") +def model_factory() -> SMPModelFactory: + return SMPModelFactory() + + +@pytest.fixture(scope="session") +def model_input() -> torch.Tensor: + return torch.ones((1, NUM_CHANNELS, 224, 224)) + + +@pytest.mark.parametrize("backbone", ["timm-regnetx_002"]) +@pytest.mark.parametrize("model", ["Unet", "DeepLabV3"]) +def test_create_segmentation_model(backbone, model, model_factory: SMPModelFactory, model_input): + model = model_factory.build_model( + "segmentation", + backbone=backbone, + model=model, + in_channels=NUM_CHANNELS, + bands=PRETRAINED_BANDS, + pretrained=False, + num_classes=NUM_CLASSES, + ) + model.eval() + + with torch.no_grad(): + assert model(model_input).output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE + + +@pytest.mark.parametrize("backbone", ["timm-regnetx_002"]) +@pytest.mark.parametrize("model", ["Unet", "DeepLabV3"]) +def test_create_segmentation_model_no_in_channels(backbone, model, model_factory: SMPModelFactory, model_input): + model = model_factory.build_model( + "segmentation", + backbone=backbone, + model=model, + bands=PRETRAINED_BANDS, + pretrained=False, + num_classes=NUM_CLASSES, + ) + model.eval() + + with torch.no_grad(): + assert model(model_input).output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE + + +@pytest.mark.parametrize("backbone", ["timm-regnetx_002"]) +@pytest.mark.parametrize("model", ["Unet", "DeepLabV3"]) +def test_create_model_with_extra_bands(backbone, model, model_factory: SMPModelFactory): + model = model_factory.build_model( + "segmentation", + backbone=backbone, + model=model, + in_channels=NUM_CHANNELS + 1, + bands=[*PRETRAINED_BANDS, 7], # add an extra band + pretrained=False, + num_classes=NUM_CLASSES, + ) + model.eval() + model_input = torch.ones((1, NUM_CHANNELS + 1, 224, 224)) + with torch.no_grad(): + assert model(model_input).output.shape == EXPECTED_SEGMENTATION_OUTPUT_SHAPE