Skip to content

Commit cefb8ef

Browse files
author
Carlos Gomes
committed
add swin import, remove swin_90_us
Signed-off-by: Carlos Gomes <[email protected]>
1 parent da82a3a commit cefb8ef

File tree

2 files changed

+2
-45
lines changed

2 files changed

+2
-45
lines changed
+2-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
11
# import so they get registered
2-
from terratorch.models.backbones.prithvi_vit import TemporalViTEncoder
3-
4-
__all__ = ["TemporalViTEncoder"]
5-
__all__ = ["TemporalViTEncoder"]
6-
__all__ = ["TemporalViTEncoder"]
2+
import terratorch.models.backbones.prithvi_vit
3+
import terratorch.models.backbones.prithvi_swin

src/terratorch/models/backbones/prithvi_swin.py

-40
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,6 @@ def _cfg(file: Path = "", **kwargs) -> dict:
3636
**kwargs,
3737
}
3838

39-
default_cfgs = generate_default_cfgs(
40-
{
41-
"prithvi_swin_90_us": {
42-
"hf_hub_id": "ibm-nasa-geospatial/Prithvi-100M",
43-
"hf_hub_filename": "Prithvi_100M.pt"
44-
}
45-
}
46-
)
47-
4839
def convert_weights_swin2mmseg(ckpt):
4940
# from https://github.com/open-mmlab/mmsegmentation/blob/main/tools/model_converters/swin2mmseg.py
5041
new_ckpt = OrderedDict()
@@ -215,37 +206,6 @@ def prepare_features_for_image_model(x):
215206
return model
216207

217208

218-
@register_model
219-
def prithvi_swin_90_us(
220-
pretrained: bool = False, # noqa: FBT002, FBT001
221-
pretrained_bands: list[HLSBands] | None = None,
222-
bands: list[int] | None = None,
223-
**kwargs,
224-
) -> MMSegSwinTransformer:
225-
"""Prithvi Swin 90M"""
226-
if pretrained_bands is None:
227-
pretrained_bands = PRETRAINED_BANDS
228-
if bands is None:
229-
bands = pretrained_bands
230-
logging.info(
231-
f"Model bands not passed. Assuming bands are ordered in the same way as {PRETRAINED_BANDS}.\
232-
Pretrained patch_embed layer may be misaligned with current bands"
233-
)
234-
235-
model_args = {
236-
"patch_size": 4,
237-
"window_size": 7,
238-
"embed_dim": 128,
239-
"depths": (2, 2, 18, 2),
240-
"in_chans": 6,
241-
"num_heads": (4, 8, 16, 32),
242-
}
243-
transformer = _create_swin_mmseg_transformer(
244-
"prithvi_swin_90_us", pretrained_bands, bands, pretrained=pretrained, **dict(model_args, **kwargs)
245-
)
246-
return transformer
247-
248-
249209
@register_model
250210
def prithvi_swin_B(
251211
pretrained: bool = False, # noqa: FBT002, FBT001

0 commit comments

Comments
 (0)