Skip to content

Commit c0cf729

Browse files
author
Carlos Gomes
committed
simplify model loading from local ckpt
1 parent 0ef8d9a commit c0cf729

File tree

2 files changed

+15
-83
lines changed

2 files changed

+15
-83
lines changed

examples/confs/sen1floods11_vit_local_ckpt.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ model:
8787
backbone_patch_size: 16
8888
backbone_in_chans: 6
8989
backbone_num_frames: 1
90-
backbone_checkpoint_path: examples/Prithvi_100M.pt
91-
# backbone_window_size: 8
90+
backbone_pretrained_cfg_overlay:
91+
file: examples/Prithvi_100M.pt
9292
decoder_channels: 256
9393
in_channels: 6
9494
bands:

src/terratorch/models/prithvi_model_factory.py

+13-81
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
import importlib
21
from collections.abc import Callable
3-
from typing import Any
42

53
import timm
4+
import torch
65
from torch import nn
7-
import torch
86

97
import terratorch.models.decoders as decoder_registry
108
from terratorch.datasets import HLSBands
@@ -15,11 +13,8 @@
1513
ModelFactory,
1614
register_factory,
1715
)
18-
1916
from terratorch.models.pixel_wise_model import PixelWiseModel
2017
from terratorch.models.scalar_output_model import ScalarOutputModel
21-
from terratorch.models.backbones.vit_encoder_decoder import TemporalViTEncoder
22-
from terratorch.models.backbones.prithvi_vit import checkpoint_filter_fn
2318

2419
PIXEL_WISE_TASKS = ["segmentation", "regression"]
2520
SCALAR_TASKS = ["classification"]
@@ -29,40 +24,6 @@
2924
class DecoderNotFoundError(Exception):
3025
pass
3126

32-
class AttrInfo:
33-
34-
def __init__(self, channels:callable):
35-
36-
self.channels = channels
37-
38-
class ModelWrapper(nn.Module):
39-
40-
def __init__(self, **kwargs) -> None:
41-
42-
super(ModelWrapper, self).__init__()
43-
44-
# Little hack because VIT does not support timm's features_only
45-
# so we do it ourselves
46-
encoder_only = kwargs.get("features_only", False)
47-
if "features_only" in kwargs:
48-
kwargs = {k: v for k, v in kwargs.items() if k != "features_only"}
49-
50-
self.config = kwargs
51-
self.model = TemporalViTEncoder(**kwargs, encoder_only=True)
52-
53-
# Sharing methods between model and wrapper
54-
self.forward = self.model.forward_features
55-
self.encode_decode_forward = self.model.forward
56-
self.prepare_features_for_image_model = self.model.prepare_features_for_image_model
57-
58-
# Adaptation to allow the wrapper to acess information about
59-
# the model
60-
self.feature_info = AttrInfo(channels=self.channels)
61-
62-
def channels(self):
63-
return self.config["num_heads"]*[self.config["embed_dim"]]
64-
65-
6627
@register_factory
6728
class PrithviModelFactory(ModelFactory):
6829
def build_model(
@@ -106,14 +67,12 @@ def build_model(
10667
aux_decoders (list[AuxiliaryHead] | None): List of AuxiliaryHead deciders to be added to the model.
10768
These decoders take the input from the encoder as well.
10869
rescale (bool): Whether to apply bilinear interpolation to rescale the model output if its size
109-
is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True.
70+
is different from the ground truth. Only applicable to pixel wise models
71+
(e.g. segmentation, pixel wise regression). Defaults to True.
11072
111-
Raises:
112-
NotImplementedError: _description_
113-
DecoderNotFoundException: _description_
11473
11574
Returns:
116-
nn.Module: _description_
75+
nn.Module: Full model with encoder, decoder and head.
11776
"""
11877
if not torch.cuda.is_available():
11978
self.CPU_ONLY = True
@@ -136,39 +95,15 @@ def build_model(
13695

13796
backbone_kwargs, kwargs = _extract_prefix_keys(kwargs, "backbone_")
13897

139-
if "checkpoint_path" in backbone_kwargs:
140-
checkpoint_path = backbone_kwargs.pop("checkpoint_path")
141-
else:
142-
checkpoint_path = None
143-
144-
# Currently the Prithvi restoring needs access to some
145-
# registries which are not always available on all the
146-
# system. The alternative is to load from a local
147-
# checkpoint.
148-
try:
149-
backbone: nn.Module = timm.create_model(
150-
backbone,
151-
pretrained=pretrained,
152-
in_chans=in_channels, # this can be removed, can be derived from bands. But is a breaking change.
153-
num_frames=num_frames,
154-
bands=bands,
155-
features_only=True,
156-
**backbone_kwargs,
157-
)
158-
except Exception:
159-
160-
print(f"Trying to load from the path defined in the config file, {checkpoint_path}.")
161-
backbone = ModelWrapper(**backbone_kwargs, features_only=True)
162-
163-
if self.CPU_ONLY:
164-
model_dict = torch.load(checkpoint_path, map_location="cpu")
165-
else:
166-
model_dict = torch.load(checkpoint_path)
167-
168-
model_dict = checkpoint_filter_fn(model_dict, model=backbone.model, pretrained_bands=bands, model_bands=bands)
169-
170-
backbone.model.load_state_dict(model_dict, strict=False)
171-
98+
backbone: nn.Module = timm.create_model(
99+
backbone,
100+
pretrained=pretrained,
101+
in_chans=in_channels, # this can be removed, can be derived from bands. But is a breaking change.
102+
num_frames=num_frames,
103+
bands=bands,
104+
features_only=True,
105+
**backbone_kwargs,
106+
)
172107
# allow decoder to be a module passed directly
173108
decoder_cls = _get_decoder(decoder)
174109

@@ -192,13 +127,10 @@ def build_model(
192127
aux_decoder_cls: nn.Module = _get_decoder(aux_decoder.decoder)
193128
aux_decoder_kwargs, kwargs = _extract_prefix_keys(args, "decoder_")
194129
aux_decoder_instance = aux_decoder_cls(backbone.feature_info.channels(), **aux_decoder_kwargs)
195-
# aux_decoder_instance = aux_decoder_cls([128, 256, 512, 1024], **decoder_kwargs)
196130

197131
aux_head_kwargs, kwargs = _extract_prefix_keys(args, "head_")
198132
if num_classes:
199133
aux_head_kwargs["num_classes"] = num_classes
200-
# aux_head: nn.Module = _get_head(task, aux_decoder_instance, num_classes=num_classes, **head_kwargs)
201-
# aux_decoder.decoder = nn.Sequential(aux_decoder_instance, aux_head)
202134
to_be_aux_decoders.append(
203135
AuxiliaryHeadWithDecoderWithoutInstantiatedHead(aux_decoder.name, aux_decoder_instance, aux_head_kwargs)
204136
)

0 commit comments

Comments
 (0)