1
- import importlib
2
1
from collections .abc import Callable
3
- from typing import Any
4
2
5
3
import timm
4
+ import torch
6
5
from torch import nn
7
- import torch
8
6
9
7
import terratorch .models .decoders as decoder_registry
10
8
from terratorch .datasets import HLSBands
15
13
ModelFactory ,
16
14
register_factory ,
17
15
)
18
-
19
16
from terratorch .models .pixel_wise_model import PixelWiseModel
20
17
from terratorch .models .scalar_output_model import ScalarOutputModel
21
- from terratorch .models .backbones .vit_encoder_decoder import TemporalViTEncoder
22
- from terratorch .models .backbones .prithvi_vit import checkpoint_filter_fn
23
18
24
19
PIXEL_WISE_TASKS = ["segmentation" , "regression" ]
25
20
SCALAR_TASKS = ["classification" ]
29
24
class DecoderNotFoundError (Exception ):
30
25
pass
31
26
32
- class AttrInfo :
33
-
34
- def __init__ (self , channels :callable ):
35
-
36
- self .channels = channels
37
-
38
- class ModelWrapper (nn .Module ):
39
-
40
- def __init__ (self , ** kwargs ) -> None :
41
-
42
- super (ModelWrapper , self ).__init__ ()
43
-
44
- # Little hack because VIT does not support timm's features_only
45
- # so we do it ourselves
46
- encoder_only = kwargs .get ("features_only" , False )
47
- if "features_only" in kwargs :
48
- kwargs = {k : v for k , v in kwargs .items () if k != "features_only" }
49
-
50
- self .config = kwargs
51
- self .model = TemporalViTEncoder (** kwargs , encoder_only = True )
52
-
53
- # Sharing methods between model and wrapper
54
- self .forward = self .model .forward_features
55
- self .encode_decode_forward = self .model .forward
56
- self .prepare_features_for_image_model = self .model .prepare_features_for_image_model
57
-
58
- # Adaptation to allow the wrapper to acess information about
59
- # the model
60
- self .feature_info = AttrInfo (channels = self .channels )
61
-
62
- def channels (self ):
63
- return self .config ["num_heads" ]* [self .config ["embed_dim" ]]
64
-
65
-
66
27
@register_factory
67
28
class PrithviModelFactory (ModelFactory ):
68
29
def build_model (
@@ -106,14 +67,12 @@ def build_model(
106
67
aux_decoders (list[AuxiliaryHead] | None): List of AuxiliaryHead deciders to be added to the model.
107
68
These decoders take the input from the encoder as well.
108
69
rescale (bool): Whether to apply bilinear interpolation to rescale the model output if its size
109
- is different from the ground truth. Only applicable to pixel wise models (e.g. segmentation, pixel wise regression). Defaults to True.
70
+ is different from the ground truth. Only applicable to pixel wise models
71
+ (e.g. segmentation, pixel wise regression). Defaults to True.
110
72
111
- Raises:
112
- NotImplementedError: _description_
113
- DecoderNotFoundException: _description_
114
73
115
74
Returns:
116
- nn.Module: _description_
75
+ nn.Module: Full model with encoder, decoder and head.
117
76
"""
118
77
if not torch .cuda .is_available ():
119
78
self .CPU_ONLY = True
@@ -136,39 +95,15 @@ def build_model(
136
95
137
96
backbone_kwargs , kwargs = _extract_prefix_keys (kwargs , "backbone_" )
138
97
139
- if "checkpoint_path" in backbone_kwargs :
140
- checkpoint_path = backbone_kwargs .pop ("checkpoint_path" )
141
- else :
142
- checkpoint_path = None
143
-
144
- # Currently the Prithvi restoring needs access to some
145
- # registries which are not always available on all the
146
- # system. The alternative is to load from a local
147
- # checkpoint.
148
- try :
149
- backbone : nn .Module = timm .create_model (
150
- backbone ,
151
- pretrained = pretrained ,
152
- in_chans = in_channels , # this can be removed, can be derived from bands. But is a breaking change.
153
- num_frames = num_frames ,
154
- bands = bands ,
155
- features_only = True ,
156
- ** backbone_kwargs ,
157
- )
158
- except Exception :
159
-
160
- print (f"Trying to load from the path defined in the config file, { checkpoint_path } ." )
161
- backbone = ModelWrapper (** backbone_kwargs , features_only = True )
162
-
163
- if self .CPU_ONLY :
164
- model_dict = torch .load (checkpoint_path , map_location = "cpu" )
165
- else :
166
- model_dict = torch .load (checkpoint_path )
167
-
168
- model_dict = checkpoint_filter_fn (model_dict , model = backbone .model , pretrained_bands = bands , model_bands = bands )
169
-
170
- backbone .model .load_state_dict (model_dict , strict = False )
171
-
98
+ backbone : nn .Module = timm .create_model (
99
+ backbone ,
100
+ pretrained = pretrained ,
101
+ in_chans = in_channels , # this can be removed, can be derived from bands. But is a breaking change.
102
+ num_frames = num_frames ,
103
+ bands = bands ,
104
+ features_only = True ,
105
+ ** backbone_kwargs ,
106
+ )
172
107
# allow decoder to be a module passed directly
173
108
decoder_cls = _get_decoder (decoder )
174
109
@@ -192,13 +127,10 @@ def build_model(
192
127
aux_decoder_cls : nn .Module = _get_decoder (aux_decoder .decoder )
193
128
aux_decoder_kwargs , kwargs = _extract_prefix_keys (args , "decoder_" )
194
129
aux_decoder_instance = aux_decoder_cls (backbone .feature_info .channels (), ** aux_decoder_kwargs )
195
- # aux_decoder_instance = aux_decoder_cls([128, 256, 512, 1024], **decoder_kwargs)
196
130
197
131
aux_head_kwargs , kwargs = _extract_prefix_keys (args , "head_" )
198
132
if num_classes :
199
133
aux_head_kwargs ["num_classes" ] = num_classes
200
- # aux_head: nn.Module = _get_head(task, aux_decoder_instance, num_classes=num_classes, **head_kwargs)
201
- # aux_decoder.decoder = nn.Sequential(aux_decoder_instance, aux_head)
202
134
to_be_aux_decoders .append (
203
135
AuxiliaryHeadWithDecoderWithoutInstantiatedHead (aux_decoder .name , aux_decoder_instance , aux_head_kwargs )
204
136
)
0 commit comments