Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extends smp_model_factory class #56

Merged
merged 6 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ By passing a list of bands being used to the constructor, we automatically filte

## Model Factory
### :::terratorch.models.PrithviModelFactory
### :::terratorch.models.SMPModelFactory

# Adding new model types
Adding new model types is as simple as creating a new factory that produces models. See for instance the example below for a potential `SMPModelFactory`
Expand Down
148 changes: 141 additions & 7 deletions terratorch/models/prithvi_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections.abc import Callable

import segmentation_models_pytorch as smp
import timm
import torch
from torch import nn
Expand All @@ -17,6 +18,7 @@
)
from terratorch.models.pixel_wise_model import PixelWiseModel
from terratorch.models.scalar_output_model import ScalarOutputModel
from terratorch.models.smp_model_factory import make_smp_encoder, register_custom_encoder

PIXEL_WISE_TASKS = ["segmentation", "regression"]
SCALAR_TASKS = ["classification"]
Expand All @@ -26,6 +28,7 @@
class DecoderNotFoundError(Exception):
pass


@register_factory
class PrithviModelFactory(ModelFactory):
def build_model(
Expand All @@ -34,7 +37,8 @@ def build_model(
backbone: str | nn.Module,
decoder: str | nn.Module,
bands: list[HLSBands | int],
in_channels: int | None = None, # this should be removed, can be derived from bands. But it is a breaking change
in_channels: int
| None = None, # this should be removed, can be derived from bands. But it is a breaking change
num_classes: int | None = None,
pretrained: bool = True, # noqa: FBT001, FBT002
num_frames: int = 1,
Expand Down Expand Up @@ -96,6 +100,10 @@ def build_model(
raise NotImplementedError(msg)

backbone_kwargs, kwargs = _extract_prefix_keys(kwargs, "backbone_")
# These params are used in case we need a SMP decoder
# but should not be used for timm encoder
output_stride = backbone_kwargs.pop("output_stride", None)
out_channels = backbone_kwargs.pop("out_channels", None)

backbone: nn.Module = timm.create_model(
backbone,
Expand All @@ -106,14 +114,24 @@ def build_model(
features_only=True,
**backbone_kwargs,
)
# allow decoder to be a module passed directly
decoder_cls = _get_decoder(decoder)

decoder_kwargs, kwargs = _extract_prefix_keys(kwargs, "decoder_")

# TODO: remove this
decoder: nn.Module = decoder_cls(backbone.feature_info.channels(), **decoder_kwargs)
# decoder: nn.Module = decoder_cls([128, 256, 512, 1024], **decoder_kwargs)
if decoder.startswith("smp_"):
decoder: nn.Module = _get_smp_decoder(
decoder,
backbone_kwargs,
decoder_kwargs,
out_channels,
in_channels,
num_classes,
output_stride,
)
else:
# allow decoder to be a module passed directly
decoder_cls = _get_decoder(decoder)
decoder: nn.Module = decoder_cls(backbone.feature_info.channels(), **decoder_kwargs)
# decoder: nn.Module = decoder_cls([128, 256, 512, 1024], **decoder_kwargs)

head_kwargs, kwargs = _extract_prefix_keys(kwargs, "head_")
if num_classes:
Expand Down Expand Up @@ -148,6 +166,46 @@ def build_model(
)


class SMPDecoderForPrithviWrapper(nn.Module):
"""
A wrapper for SMP decoders designed to handle single or multiple embeddings with specified indices.

Attributes:
decoder (nn.Module): The SMP decoder module being wrapped.
channels (int): The number of output channels of the decoder.
in_index (Union[int, List[int]]): Index or indices of the embeddings to pass to the decoder.

Methods:
forward(x: List[torch.Tensor]) -> torch.Tensor:
Forward pass for embeddings with specified indices.
"""

def __init__(self, decoder, num_channels, in_index=-1) -> None:
"""
Args:
decoder (nn.Module): The SMP decoder module to be wrapped.
num_channels (int): The number of output channels of the decoder.
in_index (Union[int, List[int]], optional): Index or indices of the input embeddings to pass to the decoder.
Defaults to -1.
"""
super().__init__()
self.decoder = decoder
self.channels = num_channels
self.in_index = in_index

@property
def output_embed_dim(self):
return self.channels

def forward(self, x):
if isinstance(self.in_index, int):
selected_inputs = [x[self.in_index]]
else:
selected_inputs = [x[i] for i in self.in_index]

return self.decoder(*selected_inputs)


def _build_appropriate_model(
task: str,
backbone: nn.Module,
Expand Down Expand Up @@ -178,6 +236,82 @@ def _build_appropriate_model(
)


def _get_smp_decoder(
decoder: str,
backbone_kwargs: dict,
decoder_kwargs: dict,
out_channels: list[int] | int,
in_channels: int,
num_classes: int,
output_stride: int,
):
"""
Creates and configures a decoder from the Segmentation Models Pytorch (SMP) library.

This function constructs a decoder module based on the specified parameters and wraps it in a
custom wrapper that allows handling single or multiple embeddings. It also ensures that the
appropriate encoder parameters are passed and registered correctly.

Args:
decoder (str): The name of the SMP decoder to use.
backbone_kwargs (dict): Dictionary of parameters for configuring the backbone.
decoder_kwargs (dict): Dictionary of parameters specific to the decoder.
out_channels (Union[list[int], int]): The number of output channels for each layer of the decoder.
in_channels (int): The number of input channels.
num_classes (int): The number of output classes for the model.
output_stride (int): The output stride of the decoder.

Returns:
SMPDecoderForPrithviWrapper: A wrapped decoder module configured based on the provided parameters.

Raises:
ValueError: If the specified decoder is not supported by SMP.
"""
decoder = decoder.removeprefix("smp_")
decoder_module = getattr(smp, decoder, None)
if decoder_module is None:
msg = f"Decoder {decoder} is not supported in SMP."
raise ValueError(msg)

# Little hack to make SMP model accept our encoder.
# passes a dummy encoder to be changed later.
# this is needed to pass encoder params.
aux_kwargs, decoder_kwargs = _extract_prefix_keys(decoder_kwargs, "aux_")
smp_kwargs, decoder_kwargs = _extract_prefix_keys(decoder_kwargs, "smp_")
backbone_kwargs["out_channels"] = out_channels
backbone_kwargs["output_stride"] = output_stride
aux_kwargs = None if aux_kwargs == {} else aux_kwargs

dummy_encoder = make_smp_encoder()

register_custom_encoder(dummy_encoder, backbone_kwargs, None)

dummy_encoder = dummy_encoder(
depth=smp_kwargs["encoder_depth"],
output_stride=backbone_kwargs["output_stride"],
out_channels=backbone_kwargs["out_channels"],
)

model_args = {
"encoder_name": "SMPEncoderWrapperWithPFFIM",
"encoder_weights": None,
"in_channels": in_channels,
"classes": num_classes,
**smp_kwargs,
}

# Creates model with dummy encoder and decoder.
model = decoder_module(**model_args, aux_params=aux_kwargs)

smp_decoder = SMPDecoderForPrithviWrapper(
decoder=model.decoder,
num_channels=out_channels[-1],
in_index=decoder_kwargs["in_index"],
)

return smp_decoder


def _get_decoder(decoder: str | nn.Module) -> nn.Module:
if isinstance(decoder, nn.Module):
return decoder
Expand All @@ -197,7 +331,7 @@ def _extract_prefix_keys(d: dict, prefix: str) -> dict:
remaining_dict = {}
for k, v in d.items():
if k.startswith(prefix):
extracted_dict[k.split(prefix)[1]] = v
extracted_dict[k[len(prefix) :]] = v
else:
remaining_dict[k] = v

Expand Down
Loading
Loading