From bb96834b8ac0b1e734f6b57ac3c63c50508c6415 Mon Sep 17 00:00:00 2001 From: daniszw Date: Mon, 3 Mar 2025 15:59:21 -0300 Subject: [PATCH] - Fixing issue with swapped H/W in interpolate_pos_encoding - Extracting interpolate_pos_encoding and creating general function _interpolate_pos_encoding with parameters to work both for encoder and decoder. - small fix in log message in _create_prithvi. --- terratorch/models/backbones/prithvi_mae.py | 144 ++++++++++----------- terratorch/models/backbones/prithvi_vit.py | 2 +- 2 files changed, 73 insertions(+), 73 deletions(-) diff --git a/terratorch/models/backbones/prithvi_mae.py b/terratorch/models/backbones/prithvi_mae.py index 2fc3a4b2..ca95177f 100644 --- a/terratorch/models/backbones/prithvi_mae.py +++ b/terratorch/models/backbones/prithvi_mae.py @@ -28,6 +28,7 @@ logger = logging.getLogger(__name__) + def get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False): """ Create 3D sin/cos positional embeddings. @@ -125,6 +126,50 @@ def _init_weights(module): module.weight.data.fill_(1.0) +def _interpolate_pos_encoding( + pos_embed: torch.Tensor, + grid_size: tuple[int, int, int] | list[int], + patch_size: tuple[int, int, int] | list[int], + shape: tuple[int, int, int], + embed_dim: int, +): + """ + Adapted from: + - transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding, + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194 + """ + t, h, w = shape + t_patches = t // patch_size[0] + h_patches = h // patch_size[1] + w_patches = w // patch_size[2] + + if [t_patches, h_patches, w_patches] == grid_size: + # No interpolation needed + return pos_embed + if t_patches != grid_size[0]: + # Re-compute pos embedding to handle changed num_frames + new_grid_size = (t_patches, *grid_size[1:]) + new_pos_embed = get_3d_sincos_pos_embed(pos_embed.shape[-1], new_grid_size, add_cls_token=True) + new_pos_embed = torch.from_numpy(new_pos_embed).float().unsqueeze(0) + else: + new_grid_size = grid_size + new_pos_embed = pos_embed + + class_pos_embed, patch_pos_embed = new_pos_embed[:, :1], new_pos_embed[:, 1:] + + patch_pos_embed = patch_pos_embed.reshape(*new_grid_size, embed_dim).permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(h_patches, w_patches), + mode='bicubic', + align_corners=True, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, embed_dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + class PatchEmbed(nn.Module): """3D version of timm.models.vision_transformer.PatchEmbed""" def __init__( @@ -332,39 +377,16 @@ def random_masking(self, sequence, mask_ratio, noise=None): return sequence_unmasked, mask, ids_restore - def interpolate_pos_encoding(self, t, w, h): - """ - Adapted from: - - transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding, - - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194 - """ - t_patches = t // self.patch_embed.patch_size[0] - w_patches = w // self.patch_embed.patch_size[1] - h_patches = h // self.patch_embed.patch_size[2] - if [t_patches, w_patches, h_patches] == self.patch_embed.grid_size: - # No interpolation needed - return self.pos_embed - if t_patches != self.patch_embed.grid_size[0]: - # Re-compute pos embedding to handle changed num_frames - grid_size = (t_patches, *self.patch_embed.grid_size[1:]) - pos_embed = get_3d_sincos_pos_embed(self.pos_embed.shape[-1], grid_size, add_cls_token=True) - pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0) - else: - grid_size = self.patch_embed.grid_size - pos_embed = self.pos_embed - - class_pos_embed, patch_pos_embed = pos_embed[:, :1], pos_embed[:, 1:] - - patch_pos_embed = patch_pos_embed.reshape(*grid_size, self.embed_dim).permute(0, 3, 1, 2) + def interpolate_pos_encoding(self, sample_shape: tuple[int, int, int]): - patch_pos_embed = nn.functional.interpolate( - patch_pos_embed, - size=(h_patches, w_patches), - mode='bicubic', - align_corners=True, + pos_embed = _interpolate_pos_encoding( + pos_embed=self.pos_embed, + grid_size=self.patch_embed.grid_size, + patch_size=self.patch_embed.patch_size, + shape=sample_shape, + embed_dim=self.embed_dim, ) - patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, self.embed_dim) - return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + return pos_embed def forward( self, x: torch.Tensor, @@ -375,12 +397,12 @@ def forward( if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1: # add time dim x = x.unsqueeze(2) - t, h, w = x.shape[-3:] + sample_shape = x.shape[-3:] # embed patches x = self.patch_embed(x) - pos_embed = self.interpolate_pos_encoding(t, h, w) + pos_embed = self.interpolate_pos_encoding(sample_shape) # add pos embed w/o cls token x = x + pos_embed[:, 1:, :] @@ -416,12 +438,12 @@ def forward_features( if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1: # add time dim x = x.unsqueeze(2) - t, h, w = x.shape[-3:] + sample_shape = x.shape[-3:] # embed patches x = self.patch_embed(x) - pos_embed = self.interpolate_pos_encoding(t, h, w) + pos_embed = self.interpolate_pos_encoding(sample_shape) # add pos embed w/o cls token x = x + pos_embed[:, 1:, :] @@ -528,39 +550,17 @@ def initialize_weights(self): torch.nn.init.normal_(self.mask_token, std=0.02) self.apply(_init_weights) - def interpolate_pos_encoding(self, t, w, h): - """ - Adapted from: - - transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding, - - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194 - """ - t_patches = t // self.patch_size[0] - w_patches = w // self.patch_size[1] - h_patches = h // self.patch_size[2] - if [t_patches, w_patches, h_patches] == self.grid_size: - # No interpolation needed - return self.pos_embed - if t_patches != self.grid_size[0]: - # Re-compute pos embedding to handle changed num_frames - grid_size = (t_patches, *self.grid_size[1:]) - decoder_pos_embed = get_3d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], grid_size, add_cls_token=True) - decoder_pos_embed = torch.from_numpy(decoder_pos_embed).float().unsqueeze(0) - else: - grid_size = self.grid_size - decoder_pos_embed = self.decoder_pos_embed - - class_pos_embed, patch_pos_embed = decoder_pos_embed[:, :1], decoder_pos_embed[:, 1:] - - patch_pos_embed = patch_pos_embed.reshape(*grid_size, self.decoder_embed_dim).permute(0, 3, 1, 2) + def interpolate_pos_encoding(self, sample_shape: tuple[int, int, int]): - patch_pos_embed = nn.functional.interpolate( - patch_pos_embed, - size=(h_patches, w_patches), - mode='bicubic', - align_corners=True, + pos_embed = _interpolate_pos_encoding( + pos_embed=self.decoder_pos_embed, + grid_size=self.grid_size, + patch_size=self.patch_size, + shape=sample_shape, + embed_dim=self.decoder_embed_dim, ) - patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, self.decoder_embed_dim) - return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + return pos_embed def forward( self, @@ -581,8 +581,7 @@ def forward( x = torch.gather(x, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x.device)) # add pos embed - t, h, w = input_size[-3:] - decoder_pos_embed = self.interpolate_pos_encoding(t, w, h) + decoder_pos_embed = self.interpolate_pos_encoding(input_size[-3:]) cls_token = cls_token + decoder_pos_embed[:, :1, :] x = x + decoder_pos_embed[:, 1:, :] @@ -678,7 +677,8 @@ def patchify(self, pixel_values): Pixel values. Returns: - torch.FloatTensor of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: + torch.FloatTensor of shape + `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: Patchified pixel values. """ patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size @@ -688,14 +688,13 @@ def patchify(self, pixel_values): patchified_pixel_values = rearrange(pixel_values, 'b c (t s) (h p) (w q) -> b (t h w) (s p q c)', c=num_channels, s=patch_size_t, p=patch_size_h, q=patch_size_w) - return patchified_pixel_values def unpatchify(self, patchified_pixel_values, image_size: tuple[int, int] | None = None): """ Args: patchified_pixel_values (`torch.FloatTensor` of shape - `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: + `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels))`: Patchified pixel values. image_size (`tuple[int, int]`, *optional*): Original image size. @@ -721,7 +720,8 @@ def forward_loss(self, pixel_values, pred, mask): Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, time, height, width)`): Pixel values. - pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: + pred (`torch.FloatTensor` of shape + `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: Predicted pixel values. mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): Tensor indicating which patches are masked (1) and which are not (0). diff --git a/terratorch/models/backbones/prithvi_vit.py b/terratorch/models/backbones/prithvi_vit.py index 82431200..bd2fcb8b 100644 --- a/terratorch/models/backbones/prithvi_vit.py +++ b/terratorch/models/backbones/prithvi_vit.py @@ -219,7 +219,7 @@ def _create_prithvi( if loaded_keys.missing_keys: logger.warning(f"Missing keys in ckpt_path {ckpt_path}: {loaded_keys.missing_keys}") if loaded_keys.unexpected_keys: - logger.warning(f"Missing keys in ckpt_path {ckpt_path}: {loaded_keys.missing_keys}") + logger.warning(f"Unexpected keys in ckpt_path {ckpt_path}: {loaded_keys.unexpected_keys}") else: assert variant in pretrained_weights, (f"No pre-trained model found for variant {variant} " f"(pretrained models: {pretrained_weights.keys()})")