Skip to content

Commit 5cf5876

Browse files
Allowing the configuration to overwrite default arguments from the model constructors
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent 3d9b5f5 commit 5cf5876

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

terratorch/models/backbones/prithvi_vit.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
from functools import partial
66
from pathlib import Path
7+
from collections import defaultdict
78

89
import torch
910
from timm.models import FeatureInfo
@@ -140,13 +141,22 @@ def create_prithvi_vit_100(
140141
"norm_layer": partial(nn.LayerNorm, eps=1e-6),
141142
"num_frames": 1,
142143
}
144+
145+
# It is possible to overwrite default parameters using
146+
# config file
147+
kwargs_ = defaultdict()
148+
kwargs_.update(model_args)
149+
kwargs_.update(kwargs)
150+
kwargs_ = dict(kwargs_)
151+
143152
model = _create_prithvi(
144153
model_name,
145154
pretrained=pretrained,
146155
model_bands=bands,
147156
pretrained_bands=pretrained_bands,
148-
**dict(model_args, **kwargs),
157+
**kwargs_,
149158
)
159+
150160
return model
151161

152162

0 commit comments

Comments
 (0)