Skip to content

Commit da82a3a

Browse files
Merge pull request #23 from IBM/extend/swin
Extend/swin
2 parents 40bdf73 + 3e1c71b commit da82a3a

File tree

3 files changed

+1382
-1
lines changed

3 files changed

+1382
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
"""This module handles registering prithvi_swin models into timm.
2+
"""
3+
4+
import logging
5+
import math
6+
import warnings
7+
from collections import OrderedDict
8+
from pathlib import Path
9+
10+
import torch
11+
from timm.models import SwinTransformer
12+
from timm.models._builder import build_model_with_cfg
13+
from timm.models._registry import generate_default_cfgs, register_model
14+
from timm.models.swin_transformer import checkpoint_filter_fn as timm_swin_checkpoint_filter_fn
15+
16+
from terratorch.datasets.utils import HLSBands
17+
from terratorch.models.backbones.prithvi_select_patch_embed_weights import prithvi_select_patch_embed_weights
18+
from terratorch.models.backbones.swin_encoder_decoder import MMSegSwinTransformer
19+
20+
PRETRAINED_BANDS = [
21+
HLSBands.BLUE,
22+
HLSBands.GREEN,
23+
HLSBands.RED,
24+
HLSBands.NIR_NARROW,
25+
HLSBands.SWIR_1,
26+
HLSBands.SWIR_2,
27+
]
28+
29+
30+
def _cfg(file: Path = "", **kwargs) -> dict:
31+
return {
32+
"file": file,
33+
"source": "file",
34+
"license": "mit",
35+
# "first_conv": "patch_embed.proj",
36+
**kwargs,
37+
}
38+
39+
default_cfgs = generate_default_cfgs(
40+
{
41+
"prithvi_swin_90_us": {
42+
"hf_hub_id": "ibm-nasa-geospatial/Prithvi-100M",
43+
"hf_hub_filename": "Prithvi_100M.pt"
44+
}
45+
}
46+
)
47+
48+
def convert_weights_swin2mmseg(ckpt):
49+
# from https://github.com/open-mmlab/mmsegmentation/blob/main/tools/model_converters/swin2mmseg.py
50+
new_ckpt = OrderedDict()
51+
52+
def correct_unfold_reduction_order(x):
53+
out_channel, in_channel = x.shape
54+
x = x.reshape(out_channel, 4, in_channel // 4)
55+
x = x[:, [0, 2, 1, 3], :].transpose(1, 2).reshape(out_channel, in_channel)
56+
return x
57+
58+
def correct_unfold_norm_order(x):
59+
in_channel = x.shape[0]
60+
x = x.reshape(4, in_channel // 4)
61+
x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
62+
return x
63+
64+
for k, v in ckpt.items():
65+
if k.startswith("head"):
66+
continue
67+
elif k.startswith("layers"):
68+
new_v = v
69+
if "attn." in k:
70+
new_k = k.replace("attn.", "attn.w_msa.")
71+
elif "mlp." in k:
72+
if "mlp.fc1." in k:
73+
new_k = k.replace("mlp.fc1.", "ffn.layers.0.0.")
74+
elif "mlp.fc2." in k:
75+
new_k = k.replace("mlp.fc2.", "ffn.layers.1.")
76+
else:
77+
new_k = k.replace("mlp.", "ffn.")
78+
elif "downsample" in k:
79+
new_k = k
80+
if "reduction." in k:
81+
new_v = correct_unfold_reduction_order(v)
82+
elif "norm." in k:
83+
new_v = correct_unfold_norm_order(v)
84+
else:
85+
new_k = k
86+
new_k = new_k.replace("layers", "stages", 1)
87+
elif k.startswith("patch_embed"):
88+
new_v = v
89+
if "proj" in k:
90+
new_k = k.replace("proj", "projection")
91+
else:
92+
new_k = k
93+
else:
94+
new_v = v
95+
new_k = k
96+
97+
new_ckpt[new_k] = new_v
98+
99+
return new_ckpt
100+
101+
102+
def weights_are_swin_implementation(state_dict: dict[str, torch.Tensor]):
103+
# if keys start with 'encoder', treat it as the swin implementation
104+
for k in state_dict.keys():
105+
if k.startswith("encoder."):
106+
return True
107+
return False
108+
109+
110+
def checkpoint_filter_fn(state_dict: dict[str, torch.Tensor], model: torch.nn.Module, pretrained_bands, model_bands):
111+
"""convert patch embedding weight from manual patchify + linear proj to conv"""
112+
if "head.fc.weight" in state_dict:
113+
return state_dict
114+
115+
if "state_dict" in state_dict:
116+
_state_dict = state_dict["state_dict"]
117+
elif "model" in state_dict:
118+
_state_dict = state_dict["model"]
119+
else:
120+
_state_dict = state_dict
121+
122+
# strip prefix of state_dict
123+
if next(iter(_state_dict.keys())).startswith("module."):
124+
_state_dict = {k[7:]: v for k, v in _state_dict.items()}
125+
126+
if weights_are_swin_implementation(_state_dict):
127+
# keep only encoder weights
128+
state_dict = OrderedDict()
129+
for k, v in _state_dict.items():
130+
if k.startswith("encoder."):
131+
state_dict[k[8:]] = v
132+
elif not k.startswith("decoder"):
133+
state_dict[k] = v
134+
state_dict = convert_weights_swin2mmseg(state_dict)
135+
else:
136+
# keep only encoder weights
137+
state_dict = OrderedDict()
138+
139+
for k, v in _state_dict.items():
140+
if k.startswith("backbone."):
141+
state_dict[k[9:]] = v
142+
else:
143+
state_dict[k] = v
144+
145+
relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
146+
for table_key in relative_position_bias_table_keys:
147+
table_pretrained = state_dict[table_key]
148+
table_current = model.state_dict()[table_key]
149+
L1, nH1 = table_pretrained.size()
150+
L2, nH2 = table_current.size()
151+
if nH1 != nH2:
152+
warnings.warn(f"Error in loading {table_key}, pass", stacklevel=1)
153+
elif L1 != L2:
154+
S1 = int(L1**0.5)
155+
S2 = int(L2**0.5)
156+
table_pretrained_resized = torch.nn.functional.interpolate(
157+
table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),
158+
size=(S2, S2),
159+
mode="bicubic",
160+
)
161+
state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0).contiguous()
162+
163+
if hasattr(model.head.fc, "weight"):
164+
state_dict["head.fc.weight"] = model.head.fc.weight.detach().clone()
165+
state_dict["head.fc.bias"] = model.head.fc.bias.detach().clone()
166+
167+
state_dict = prithvi_select_patch_embed_weights(state_dict, model, pretrained_bands, model_bands)
168+
return state_dict
169+
170+
171+
def _create_swin_mmseg_transformer(
172+
variant: str,
173+
pretrained_bands: list[HLSBands],
174+
model_bands: list[HLSBands],
175+
pretrained: bool = False, # noqa: FBT002, FBT001
176+
**kwargs,
177+
):
178+
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get("depths", (1, 1, 3, 1))))
179+
out_indices = kwargs.pop("out_indices", default_out_indices)
180+
181+
# the current swin model is not multitemporal
182+
if "num_frames" in kwargs:
183+
kwargs = {k: v for k, v in kwargs.items() if k != "num_frames"}
184+
kwargs["in_chans"] = len(model_bands)
185+
186+
def checkpoint_filter_wrapper_fn(state_dict, model):
187+
return checkpoint_filter_fn(state_dict, model, pretrained_bands, model_bands)
188+
189+
model: MMSegSwinTransformer = build_model_with_cfg(
190+
MMSegSwinTransformer,
191+
variant,
192+
pretrained,
193+
pretrained_filter_fn=checkpoint_filter_wrapper_fn,
194+
pretrained_strict=False,
195+
feature_cfg={"flatten_sequential": True, "out_indices": out_indices},
196+
**kwargs,
197+
)
198+
model.pretrained_bands = pretrained_bands
199+
model.model_bands = model_bands
200+
201+
def prepare_features_for_image_model(x):
202+
return [
203+
# layer_output.reshape(
204+
# -1,
205+
# int(math.sqrt(layer_output.shape[1])),
206+
# int(math.sqrt(layer_output.shape[1])),
207+
# layer_output.shape[2],
208+
# )
209+
layer_output.permute(0, 3, 1, 2).contiguous()
210+
for layer_output in x
211+
]
212+
213+
# add permuting here
214+
model.prepare_features_for_image_model = prepare_features_for_image_model
215+
return model
216+
217+
218+
@register_model
219+
def prithvi_swin_90_us(
220+
pretrained: bool = False, # noqa: FBT002, FBT001
221+
pretrained_bands: list[HLSBands] | None = None,
222+
bands: list[int] | None = None,
223+
**kwargs,
224+
) -> MMSegSwinTransformer:
225+
"""Prithvi Swin 90M"""
226+
if pretrained_bands is None:
227+
pretrained_bands = PRETRAINED_BANDS
228+
if bands is None:
229+
bands = pretrained_bands
230+
logging.info(
231+
f"Model bands not passed. Assuming bands are ordered in the same way as {PRETRAINED_BANDS}.\
232+
Pretrained patch_embed layer may be misaligned with current bands"
233+
)
234+
235+
model_args = {
236+
"patch_size": 4,
237+
"window_size": 7,
238+
"embed_dim": 128,
239+
"depths": (2, 2, 18, 2),
240+
"in_chans": 6,
241+
"num_heads": (4, 8, 16, 32),
242+
}
243+
transformer = _create_swin_mmseg_transformer(
244+
"prithvi_swin_90_us", pretrained_bands, bands, pretrained=pretrained, **dict(model_args, **kwargs)
245+
)
246+
return transformer
247+
248+
249+
@register_model
250+
def prithvi_swin_B(
251+
pretrained: bool = False, # noqa: FBT002, FBT001
252+
pretrained_bands: list[HLSBands] | None = None,
253+
bands: list[int] | None = None,
254+
**kwargs,
255+
) -> SwinTransformer:
256+
"""Prithvi Swin B"""
257+
if pretrained_bands is None:
258+
pretrained_bands = PRETRAINED_BANDS
259+
if bands is None:
260+
bands = pretrained_bands
261+
logging.info(
262+
f"Model bands not passed. Assuming bands are ordered in the same way as {PRETRAINED_BANDS}.\
263+
Pretrained patch_embed layer may be misaligned with current bands"
264+
)
265+
266+
model_args = {
267+
"patch_size": 4,
268+
"window_size": 7,
269+
"embed_dim": 128,
270+
"depths": (2, 2, 18, 2),
271+
"in_chans": 6,
272+
"num_heads": (4, 8, 16, 32),
273+
}
274+
transformer = _create_swin_mmseg_transformer(
275+
"prithvi_swin_B", pretrained_bands, bands, pretrained=pretrained, **dict(model_args, **kwargs)
276+
)
277+
return transformer
278+
279+
280+
@register_model
281+
def prithvi_swin_L(
282+
pretrained: bool = False, # noqa: FBT002, FBT001
283+
pretrained_bands: list[HLSBands] | None = None,
284+
bands: list[int] | None = None,
285+
**kwargs,
286+
) -> SwinTransformer:
287+
"""Prithvi Swin L"""
288+
if pretrained_bands is None:
289+
pretrained_bands = PRETRAINED_BANDS
290+
if bands is None:
291+
bands = pretrained_bands
292+
logging.info(
293+
f"Model bands not passed. Assuming bands are ordered in the same way as {PRETRAINED_BANDS}.\
294+
Pretrained patch_embed layer may be misaligned with current bands"
295+
)
296+
297+
model_args = {
298+
"patch_size": 4,
299+
"window_size": 7,
300+
"embed_dim": 192,
301+
"depths": (2, 2, 18, 2),
302+
"in_chans": 6,
303+
"num_heads": (6, 12, 24, 48),
304+
}
305+
transformer = _create_swin_mmseg_transformer(
306+
"prithvi_swin_L", pretrained_bands, bands, pretrained=pretrained, **dict(model_args, **kwargs)
307+
)
308+
return transformer

0 commit comments

Comments
 (0)