Skip to content

Commit 4cbf229

Browse files
Merge pull request #25 from IBM/fix/swin_instantiation
Fix/swin instantiation
2 parents da82a3a + 313b9f9 commit 4cbf229

File tree

5 files changed

+9
-51
lines changed

5 files changed

+9
-51
lines changed

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ dependencies = [
3939
"lightly>=1.4.25",
4040
"h5py>=3.10.0",
4141
"geobench>=1.0.0",
42-
"mlflow>=2.12.1"
42+
"mlflow>=2.12.1",
43+
"lightning<=2.2.5"
4344
]
4445

4546
[project.optional-dependencies]
+2-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
11
# import so they get registered
2-
from terratorch.models.backbones.prithvi_vit import TemporalViTEncoder
3-
4-
__all__ = ["TemporalViTEncoder"]
5-
__all__ = ["TemporalViTEncoder"]
6-
__all__ = ["TemporalViTEncoder"]
2+
import terratorch.models.backbones.prithvi_vit
3+
import terratorch.models.backbones.prithvi_swin

src/terratorch/models/backbones/prithvi_swin.py

-40
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,6 @@ def _cfg(file: Path = "", **kwargs) -> dict:
3636
**kwargs,
3737
}
3838

39-
default_cfgs = generate_default_cfgs(
40-
{
41-
"prithvi_swin_90_us": {
42-
"hf_hub_id": "ibm-nasa-geospatial/Prithvi-100M",
43-
"hf_hub_filename": "Prithvi_100M.pt"
44-
}
45-
}
46-
)
47-
4839
def convert_weights_swin2mmseg(ckpt):
4940
# from https://github.com/open-mmlab/mmsegmentation/blob/main/tools/model_converters/swin2mmseg.py
5041
new_ckpt = OrderedDict()
@@ -215,37 +206,6 @@ def prepare_features_for_image_model(x):
215206
return model
216207

217208

218-
@register_model
219-
def prithvi_swin_90_us(
220-
pretrained: bool = False, # noqa: FBT002, FBT001
221-
pretrained_bands: list[HLSBands] | None = None,
222-
bands: list[int] | None = None,
223-
**kwargs,
224-
) -> MMSegSwinTransformer:
225-
"""Prithvi Swin 90M"""
226-
if pretrained_bands is None:
227-
pretrained_bands = PRETRAINED_BANDS
228-
if bands is None:
229-
bands = pretrained_bands
230-
logging.info(
231-
f"Model bands not passed. Assuming bands are ordered in the same way as {PRETRAINED_BANDS}.\
232-
Pretrained patch_embed layer may be misaligned with current bands"
233-
)
234-
235-
model_args = {
236-
"patch_size": 4,
237-
"window_size": 7,
238-
"embed_dim": 128,
239-
"depths": (2, 2, 18, 2),
240-
"in_chans": 6,
241-
"num_heads": (4, 8, 16, 32),
242-
}
243-
transformer = _create_swin_mmseg_transformer(
244-
"prithvi_swin_90_us", pretrained_bands, bands, pretrained=pretrained, **dict(model_args, **kwargs)
245-
)
246-
return transformer
247-
248-
249209
@register_model
250210
def prithvi_swin_B(
251211
pretrained: bool = False, # noqa: FBT002, FBT001

tests/test_backbones.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,15 @@ def input_386():
2828
return torch.ones((1, NUM_CHANNELS, 386, 386))
2929

3030

31-
@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300"]) #["prithvi_swin_90_us", "prithvi_vit_100", "prithvi_vit_300"])
31+
@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"])
3232
@pytest.mark.parametrize("test_input", ["input_224", "input_512"])
3333
def test_can_create_backbones_from_timm(model_name, test_input, request):
3434
backbone = timm.create_model(model_name, pretrained=False)
3535
input_tensor = request.getfixturevalue(test_input)
3636
backbone(input_tensor)
3737

3838

39-
@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300"])
39+
@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"])
4040
@pytest.mark.parametrize("test_input", ["input_224", "input_512"])
4141
def test_can_create_backbones_from_timm_features_only(model_name, test_input, request):
4242
backbone = timm.create_model(model_name, pretrained=False, features_only=True)

tests/test_prithvi_tasks.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def model_input() -> torch.Tensor:
2020
return torch.ones((1, NUM_CHANNELS, 224, 224))
2121

2222

23-
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
23+
@pytest.mark.parametrize("backbone",["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"])
2424
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
2525
@pytest.mark.parametrize("loss", ["ce", "jaccard", "focal", "dice"])
2626
def test_create_segmentation_task(backbone, decoder, loss, model_factory: PrithviModelFactory):
@@ -38,7 +38,7 @@ def test_create_segmentation_task(backbone, decoder, loss, model_factory: Prithv
3838
)
3939

4040

41-
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
41+
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"])
4242
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
4343
@pytest.mark.parametrize("loss", ["mae", "rmse", "huber"])
4444
def test_create_regression_task(backbone, decoder, loss, model_factory: PrithviModelFactory):
@@ -55,7 +55,7 @@ def test_create_regression_task(backbone, decoder, loss, model_factory: PrithviM
5555
)
5656

5757

58-
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
58+
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"])
5959
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
6060
@pytest.mark.parametrize("loss", ["ce", "bce", "jaccard", "focal"])
6161
def test_create_classification_task(backbone, decoder, loss, model_factory: PrithviModelFactory):

0 commit comments

Comments
 (0)