diff --git a/terratorch/models/necks.py b/terratorch/models/necks.py index 0365fc6b..9c81adf0 100644 --- a/terratorch/models/necks.py +++ b/terratorch/models/necks.py @@ -141,6 +141,18 @@ def __init__(self, channel_list: list[int], remove_cls_token=True, effective_tim self.remove_cls_token = remove_cls_token self.effective_time_dim = effective_time_dim + def collapse_dims(self, x): + """ + When the encoder output has more than 3 dimensions, is necessary to + reshape it. + """ + shape = x.shape + batch = x.shape[0] + e = x.shape[-1] + collapsed_dim = np.prod(x.shape[1:-1]) + + return x.reshape(batch, collapsed_dim, e) + def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]: out = [] for x in features: @@ -148,9 +160,11 @@ def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]: x_no_token = x[:, 1:, :] else: x_no_token = x + x_no_token = self.collapse_dims(x_no_token) number_of_tokens = x_no_token.shape[1] tokens_per_timestep = number_of_tokens // self.effective_time_dim h = int(np.sqrt(tokens_per_timestep)) + encoded = rearrange( x_no_token, "batch (t h w) e -> batch (t e) h w", diff --git a/tests/resources/configs/manufactured-finetune_satlas_upernet.yaml b/tests/resources/configs/manufactured-finetune_satlas_upernet.yaml new file mode 100644 index 00000000..b881a70a --- /dev/null +++ b/tests/resources/configs/manufactured-finetune_satlas_upernet.yaml @@ -0,0 +1,147 @@ +# lightning.pytorch==2.1.1 +seed_everything: 42 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + # precision: 16-mixed + logger: + class_path: TensorBoardLogger + init_args: + save_dir: tests/ + name: all_ecos_random + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 100 + max_epochs: 1 + check_val_every_n_epoch: 1 + log_every_n_steps: 20 + enable_checkpointing: true + default_root_dir: tests/ +data: + class_path: GenericNonGeoSegmentationDataModule + init_args: + batch_size: 2 + num_workers: 4 + train_transform: + #- class_path: albumentations.HorizontalFlip + # init_args: + # p: 0.5 + #- class_path: albumentations.Rotate + # init_args: + # limit: 30 + # border_mode: 0 # cv2.BORDER_CONSTANT + # value: 0 + # # mask_value: 1 + # p: 0.5 + - class_path: ToTensorV2 + dataset_bands: + - 0 + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + - 1 + - 2 + - 3 + - 4 + output_bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + rgb_indices: + - 2 + - 1 + - 0 + check_stackability: false + train_data_root: tests/resources/inputs + train_label_data_root: tests/resources/inputs + val_data_root: tests/resources/inputs + val_label_data_root: tests/resources/inputs + test_data_root: tests/resources/inputs + test_label_data_root: tests/resources/inputs + img_grep: "segmentation*input*.tif" + label_grep: "segmentation*label*.tif" + means: + - 547.36707 + - 898.5121 + - 1020.9082 + - 2665.5352 + - 2340.584 + - 1610.1407 + stds: + - 411.4701 + - 558.54065 + - 815.94025 + - 812.4403 + - 1113.7145 + - 1067.641 + no_label_replace: -1 + no_data_replace: 0 + num_classes: 2 +model: + class_path: terratorch.tasks.SemanticSegmentationTask + init_args: + model_args: + decoder: UperNetDecoder + backbone_pretrained: True + backbone: satlas_swin_b_sentinel2_si_ms + # backbone: ssl4eol_resnet18_landsat_oli_tirs_toa_moco + # backbone_pretrain_img_size: 512 + # decoder_scale_modules: True + # decoder_in_channels: 1024 + decoder_channels: 256 + # backbone_in_channels: 6 + backbone_model_bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + backbone_out_indices: + - 1 + - 3 + - 5 + - 7 + # num_frames: 1 + num_classes: 2 + head_dropout: 0.1 + head_channel_list: + - 256 + necks: + - name: ReshapeTokensToImage + loss: ce + #class_weights: + #- 0.55 + #- 0.45 + #- 11.106234096692113 + #- 0.7220430107526882 + #- 0.6870916961826052 + #- 0.47476477946375156 + ignore_index: -1 + freeze_backbone: true + freeze_decoder: false + model_factory: EncoderDecoderFactory +optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 1.5e-5 + weight_decay: 0.05 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss +