Skip to content

Commit 233a5c2

Browse files
committedAug 9, 2024·
The number of channels can be estimated in different ways according to the kind of bands definition
Signed-off-by: João Lucas de Sousa Almeida <joao.l.sa.9.3@gmail.com>
1 parent ee17c96 commit 233a5c2

File tree

3 files changed

+5
-3
lines changed

3 files changed

+5
-3
lines changed
 

‎terratorch/models/backbones/prithvi_swin.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from terratorch.datasets.utils import HLSBands
1919
from terratorch.models.backbones.prithvi_select_patch_embed_weights import prithvi_select_patch_embed_weights
2020
from terratorch.models.backbones.swin_encoder_decoder import MMSegSwinTransformer
21+
from terratorch.models.backbones.utils import _estimate_in_chans
2122

2223
PRETRAINED_BANDS = [
2324
HLSBands.BLUE,
@@ -174,7 +175,8 @@ def _create_swin_mmseg_transformer(
174175
# the current swin model is not multitemporal
175176
if "num_frames" in kwargs:
176177
kwargs = {k: v for k, v in kwargs.items() if k != "num_frames"}
177-
kwargs["in_chans"] = len(model_bands)
178+
179+
kwargs["in_chans"] = _estimate_in_chans(model_bands=model_bands)
178180

179181
def checkpoint_filter_wrapper_fn(state_dict, model):
180182
return checkpoint_filter_fn(state_dict, model, pretrained_bands, model_bands)

‎terratorch/models/backbones/prithvi_vit.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from terratorch.datasets import HLSBands
1515
from terratorch.models.backbones.prithvi_select_patch_embed_weights import prithvi_select_patch_embed_weights
1616
from terratorch.models.backbones.vit_encoder_decoder import TemporalViTEncoder
17+
from terratorch.models.backbones.utils import _estimate_in_chans
1718

1819
PRETRAINED_BANDS = [
1920
HLSBands.BLUE,
@@ -81,7 +82,7 @@ def _create_prithvi(
8182
if "features_only" in kwargs:
8283
kwargs = {k: v for k, v in kwargs.items() if k != "features_only"}
8384

84-
kwargs["in_chans"] = len(model_bands)
85+
kwargs["in_chans"] = _estimate_in_chans(model_bands=model_bands)
8586

8687
def checkpoint_filter_wrapper_fn(state_dict, model):
8788
return checkpoint_filter_fn(state_dict, model, pretrained_bands, model_bands)

‎terratorch/tasks/segmentation_tasks.py

-1
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,6 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -
254254
x = batch["image"]
255255
y = batch["mask"]
256256
model_output: ModelOutput = self(x)
257-
print(x.shape, model_output.output.shape, y.shape)
258257
loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
259258
self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
260259
y_hat_hard = to_segmentation_prediction(model_output)

0 commit comments

Comments
 (0)
Please sign in to comment.