Skip to content

Commit 06d883d

Browse files
Minor improvements
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent aefcf02 commit 06d883d

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

terratorch/models/generic_unet_model_factory.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99

1010
from terratorch.models.model import Model, ModelFactory, ModelOutput, register_factory
1111

12-
from mmseg.models.decode_heads import ASPPHead
13-
1412
import importlib
1513

1614
@register_factory
@@ -42,8 +40,13 @@ def build_model(
4240
if task not in ["segmentation", "regression"]:
4341
msg = f"SMP models can only perform pixel wise tasks, but got task {task}"
4442
raise Exception(msg)
43+
44+
try:
45+
mmseg = importlib.import_module("mmseg.models.decode_heads")
46+
except:
47+
raise Exception("The module 'mmseg' is not installed or not accessible via PYTHONPATH.")
4548

46-
model_class = getattr(mmseg.models.decode_heads, model)
49+
model_class = getattr(mmseg, model)
4750

4851
model = model_class(
4952
dilations=dilations

tests/manufactured-finetune_unet.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ model:
9494
init_args:
9595
model_args:
9696
decoder: "unet"
97+
decoder_model: "ASPPHead"
9798
decoder_dilations: [1, 6, 12, 18]
9899
backbone_drop_path_rate: 0.3
99100
# backbone_window_size: 8

0 commit comments

Comments
 (0)