Skip to content

Commit 5de7e68

Browse files
Merge pull request #3 from IBM/feature/load_from_local_ckpt
simplify model loading from local ckpt
2 parents 0ef8d9a + 80b3b6c commit 5de7e68

File tree

4 files changed

+33
-83
lines changed

4 files changed

+33
-83
lines changed

docs/models.md

+7
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ We also provide a model factory that can build a task specific model for a downs
3434

3535
By passing a list of bands being used to the constructor, we automatically filter out unused bands, and randomly initialize weights for new bands that were not pretrained on.
3636

37+
!!! info
38+
39+
To pass your own path from where to load the weights with the PrithviModelFactory, you can make use of timm's `pretrained_cfg_overlay`.
40+
E.g. to pass a local path, you can pass the parameter `backbone_pretrained_cfg_overlay = {"file": "<local_path>"}` to the model factory.
41+
42+
Besides `file`, you can also pass `url`, `hf_hub_id`, amongst others. Check timm's documentation for full details.
43+
3744
:::terratorch.models.backbones.prithvi_select_patch_embed_weights
3845

3946
## Decoders

docs/quick_start.md

+11
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,17 @@ task = PixelwiseRegressionTask(
105105

106106
At this level of abstraction, you can also provide a configuration file (see [LightningCLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html#lightning-cli)) with all the details of the training. See an example for semantic segmentation below:
107107

108+
!!! info
109+
110+
To pass your own path from where to load the weights with the PrithviModelFactory, you can make use of timm's `pretrained_cfg_overlay`.
111+
E.g. to pass a local path, you can add, under model_args:
112+
113+
```yaml
114+
backbone_pretrained_cfg_overlay:
115+
file: <local_path>
116+
```
117+
Besides `file`, you can also pass `url`, `hf_hub_id`, amongst others. Check timm's documentation for full details.
118+
108119
```yaml title="Configuration file for a Semantic Segmentation Task"
109120
# lightning.pytorch==2.1.1
110121
seed_everything: 0

examples/confs/sen1floods11_vit_local_ckpt.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ model:
8787
backbone_patch_size: 16
8888
backbone_in_chans: 6
8989
backbone_num_frames: 1
90-
backbone_checkpoint_path: examples/Prithvi_100M.pt
91-
# backbone_window_size: 8
90+
backbone_pretrained_cfg_overlay:
91+
file: examples/Prithvi_100M.pt
9292
decoder_channels: 256
9393
in_channels: 6
9494
bands:

src/terratorch/models/prithvi_model_factory.py

+13-81
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
import importlib
21
from collections.abc import Callable
3-
from typing import Any
42

53
import timm
4+
import torch
65
from torch import nn
7-
import torch
86

97
import terratorch.models.decoders as decoder_registry
108
from terratorch.datasets import HLSBands
@@ -15,11 +13,8 @@
1513
ModelFactory,
1614
register_factory,
1715
)
18-
1916
from terratorch.models.pixel_wise_model import PixelWiseModel
2017
from terratorch.models.scalar_output_model import ScalarOutputModel
21-
from terratorch.models.backbones.vit_encoder_decoder import TemporalViTEncoder
22-
from terratorch.models.backbones.prithvi_vit import checkpoint_filter_fn
2318

2419
PIXEL_WISE_TASKS = ["segmentation", "regression"]
2520
SCALAR_TASKS = ["classification"]
@@ -29,40 +24,6 @@
2924
class DecoderNotFoundError(Exception):
3025
pass
3126

32-
class AttrInfo:
33-
34-
def __init__(self, channels:callable):
35-
36-
self.channels = channels
37-
38-
class ModelWrapper(nn.Module):
39-
40-
def __init__(self, **kwargs) -> None:
41-
42-
super(ModelWrapper, self).__init__()
43-
44-
# Little hack because VIT does not support timm's features_only
45-
# so we do it ourselves
46-
encoder_only = kwargs.get("features_only", False)
47-
if "features_only" in kwargs:
48-
kwargs = {k: v for k, v in kwargs.items() if k != "features_only"}
49-
50-
self.config = kwargs
51-
self.model = TemporalViTEncoder(**kwargs, encoder_only=True)
52-
53-
# Sharing methods between model and wrapper
54-
self.forward = self.model.forward_features
55-
self.encode_decode_forward = self.model.forward
56-
self.prepare_features_for_image_model = self.model.prepare_features_for_image_model
57-
58-
# Adaptation to allow the wrapper to acess information about
59-
# the model
60-
self.feature_info = AttrInfo(channels=self.channels)
61-
62-
def channels(self):
63-
return self.config["num_heads"]*[self.config["embed_dim"]]
64-
65-
6627
@register_factory
6728
class PrithviModelFactory(ModelFactory):
6829
def build_model(
@@ -106,14 +67,12 @@ def build_model(
10667
aux_decoders (list[AuxiliaryHead] | None): List of AuxiliaryHead deciders to be added to the model.
10768
These decoders take the input from the encoder as well.
10869
rescale (bool): Whether to apply bilinear interpolation to rescale the model output if its size
109-
is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True.
70+
is different from the ground truth. Only applicable to pixel wise models
71+
(e.g. segmentation, pixel wise regression). Defaults to True.
11072
111-
Raises:
112-
NotImplementedError: _description_
113-
DecoderNotFoundException: _description_
11473
11574
Returns:
116-
nn.Module: _description_
75+
nn.Module: Full model with encoder, decoder and head.
11776
"""
11877
if not torch.cuda.is_available():
11978
self.CPU_ONLY = True
@@ -136,39 +95,15 @@ def build_model(
13695

13796
backbone_kwargs, kwargs = _extract_prefix_keys(kwargs, "backbone_")
13897

139-
if "checkpoint_path" in backbone_kwargs:
140-
checkpoint_path = backbone_kwargs.pop("checkpoint_path")
141-
else:
142-
checkpoint_path = None
143-
144-
# Currently the Prithvi restoring needs access to some
145-
# registries which are not always available on all the
146-
# system. The alternative is to load from a local
147-
# checkpoint.
148-
try:
149-
backbone: nn.Module = timm.create_model(
150-
backbone,
151-
pretrained=pretrained,
152-
in_chans=in_channels, # this can be removed, can be derived from bands. But is a breaking change.
153-
num_frames=num_frames,
154-
bands=bands,
155-
features_only=True,
156-
**backbone_kwargs,
157-
)
158-
except Exception:
159-
160-
print(f"Trying to load from the path defined in the config file, {checkpoint_path}.")
161-
backbone = ModelWrapper(**backbone_kwargs, features_only=True)
162-
163-
if self.CPU_ONLY:
164-
model_dict = torch.load(checkpoint_path, map_location="cpu")
165-
else:
166-
model_dict = torch.load(checkpoint_path)
167-
168-
model_dict = checkpoint_filter_fn(model_dict, model=backbone.model, pretrained_bands=bands, model_bands=bands)
169-
170-
backbone.model.load_state_dict(model_dict, strict=False)
171-
98+
backbone: nn.Module = timm.create_model(
99+
backbone,
100+
pretrained=pretrained,
101+
in_chans=in_channels, # this can be removed, can be derived from bands. But is a breaking change.
102+
num_frames=num_frames,
103+
bands=bands,
104+
features_only=True,
105+
**backbone_kwargs,
106+
)
172107
# allow decoder to be a module passed directly
173108
decoder_cls = _get_decoder(decoder)
174109

@@ -192,13 +127,10 @@ def build_model(
192127
aux_decoder_cls: nn.Module = _get_decoder(aux_decoder.decoder)
193128
aux_decoder_kwargs, kwargs = _extract_prefix_keys(args, "decoder_")
194129
aux_decoder_instance = aux_decoder_cls(backbone.feature_info.channels(), **aux_decoder_kwargs)
195-
# aux_decoder_instance = aux_decoder_cls([128, 256, 512, 1024], **decoder_kwargs)
196130

197131
aux_head_kwargs, kwargs = _extract_prefix_keys(args, "head_")
198132
if num_classes:
199133
aux_head_kwargs["num_classes"] = num_classes
200-
# aux_head: nn.Module = _get_head(task, aux_decoder_instance, num_classes=num_classes, **head_kwargs)
201-
# aux_decoder.decoder = nn.Sequential(aux_decoder_instance, aux_head)
202134
to_be_aux_decoders.append(
203135
AuxiliaryHeadWithDecoderWithoutInstantiatedHead(aux_decoder.name, aux_decoder_instance, aux_head_kwargs)
204136
)

0 commit comments

Comments
 (0)