Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing issues with interpolate_pos_encoding in prithvi #471

Merged
merged 1 commit into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 72 additions & 72 deletions terratorch/models/backbones/prithvi_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

logger = logging.getLogger(__name__)


def get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
"""
Create 3D sin/cos positional embeddings.
Expand Down Expand Up @@ -125,6 +126,50 @@ def _init_weights(module):
module.weight.data.fill_(1.0)


def _interpolate_pos_encoding(
pos_embed: torch.Tensor,
grid_size: tuple[int, int, int] | list[int],
patch_size: tuple[int, int, int] | list[int],
shape: tuple[int, int, int],
embed_dim: int,
):
"""
Adapted from:
- transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding,
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194
"""
t, h, w = shape
t_patches = t // patch_size[0]
h_patches = h // patch_size[1]
w_patches = w // patch_size[2]

if [t_patches, h_patches, w_patches] == grid_size:
# No interpolation needed
return pos_embed
if t_patches != grid_size[0]:
# Re-compute pos embedding to handle changed num_frames
new_grid_size = (t_patches, *grid_size[1:])
new_pos_embed = get_3d_sincos_pos_embed(pos_embed.shape[-1], new_grid_size, add_cls_token=True)
new_pos_embed = torch.from_numpy(new_pos_embed).float().unsqueeze(0)
else:
new_grid_size = grid_size
new_pos_embed = pos_embed

class_pos_embed, patch_pos_embed = new_pos_embed[:, :1], new_pos_embed[:, 1:]

patch_pos_embed = patch_pos_embed.reshape(*new_grid_size, embed_dim).permute(0, 3, 1, 2)

patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(h_patches, w_patches),
mode='bicubic',
align_corners=True,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, embed_dim)

return torch.cat((class_pos_embed, patch_pos_embed), dim=1)


class PatchEmbed(nn.Module):
"""3D version of timm.models.vision_transformer.PatchEmbed"""
def __init__(
Expand Down Expand Up @@ -332,39 +377,16 @@ def random_masking(self, sequence, mask_ratio, noise=None):

return sequence_unmasked, mask, ids_restore

def interpolate_pos_encoding(self, t, w, h):
"""
Adapted from:
- transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding,
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194
"""
t_patches = t // self.patch_embed.patch_size[0]
w_patches = w // self.patch_embed.patch_size[1]
h_patches = h // self.patch_embed.patch_size[2]
if [t_patches, w_patches, h_patches] == self.patch_embed.grid_size:
# No interpolation needed
return self.pos_embed
if t_patches != self.patch_embed.grid_size[0]:
# Re-compute pos embedding to handle changed num_frames
grid_size = (t_patches, *self.patch_embed.grid_size[1:])
pos_embed = get_3d_sincos_pos_embed(self.pos_embed.shape[-1], grid_size, add_cls_token=True)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)
else:
grid_size = self.patch_embed.grid_size
pos_embed = self.pos_embed

class_pos_embed, patch_pos_embed = pos_embed[:, :1], pos_embed[:, 1:]

patch_pos_embed = patch_pos_embed.reshape(*grid_size, self.embed_dim).permute(0, 3, 1, 2)
def interpolate_pos_encoding(self, sample_shape: tuple[int, int, int]):

patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(h_patches, w_patches),
mode='bicubic',
align_corners=True,
pos_embed = _interpolate_pos_encoding(
pos_embed=self.pos_embed,
grid_size=self.patch_embed.grid_size,
patch_size=self.patch_embed.patch_size,
shape=sample_shape,
embed_dim=self.embed_dim,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, self.embed_dim)
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
return pos_embed

def forward(
self, x: torch.Tensor,
Expand All @@ -375,12 +397,12 @@ def forward(
if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
# add time dim
x = x.unsqueeze(2)
t, h, w = x.shape[-3:]
sample_shape = x.shape[-3:]

# embed patches
x = self.patch_embed(x)

pos_embed = self.interpolate_pos_encoding(t, h, w)
pos_embed = self.interpolate_pos_encoding(sample_shape)
# add pos embed w/o cls token
x = x + pos_embed[:, 1:, :]

Expand Down Expand Up @@ -416,12 +438,12 @@ def forward_features(
if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
# add time dim
x = x.unsqueeze(2)
t, h, w = x.shape[-3:]
sample_shape = x.shape[-3:]

# embed patches
x = self.patch_embed(x)

pos_embed = self.interpolate_pos_encoding(t, h, w)
pos_embed = self.interpolate_pos_encoding(sample_shape)
# add pos embed w/o cls token
x = x + pos_embed[:, 1:, :]

Expand Down Expand Up @@ -528,39 +550,17 @@ def initialize_weights(self):
torch.nn.init.normal_(self.mask_token, std=0.02)
self.apply(_init_weights)

def interpolate_pos_encoding(self, t, w, h):
"""
Adapted from:
- transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding,
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194
"""
t_patches = t // self.patch_size[0]
w_patches = w // self.patch_size[1]
h_patches = h // self.patch_size[2]
if [t_patches, w_patches, h_patches] == self.grid_size:
# No interpolation needed
return self.pos_embed
if t_patches != self.grid_size[0]:
# Re-compute pos embedding to handle changed num_frames
grid_size = (t_patches, *self.grid_size[1:])
decoder_pos_embed = get_3d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], grid_size, add_cls_token=True)
decoder_pos_embed = torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)
else:
grid_size = self.grid_size
decoder_pos_embed = self.decoder_pos_embed

class_pos_embed, patch_pos_embed = decoder_pos_embed[:, :1], decoder_pos_embed[:, 1:]

patch_pos_embed = patch_pos_embed.reshape(*grid_size, self.decoder_embed_dim).permute(0, 3, 1, 2)
def interpolate_pos_encoding(self, sample_shape: tuple[int, int, int]):

patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(h_patches, w_patches),
mode='bicubic',
align_corners=True,
pos_embed = _interpolate_pos_encoding(
pos_embed=self.decoder_pos_embed,
grid_size=self.grid_size,
patch_size=self.patch_size,
shape=sample_shape,
embed_dim=self.decoder_embed_dim,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, self.decoder_embed_dim)
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)

return pos_embed

def forward(
self,
Expand All @@ -581,8 +581,7 @@ def forward(
x = torch.gather(x, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x.device))

# add pos embed
t, h, w = input_size[-3:]
decoder_pos_embed = self.interpolate_pos_encoding(t, w, h)
decoder_pos_embed = self.interpolate_pos_encoding(input_size[-3:])
cls_token = cls_token + decoder_pos_embed[:, :1, :]
x = x + decoder_pos_embed[:, 1:, :]

Expand Down Expand Up @@ -678,7 +677,8 @@ def patchify(self, pixel_values):
Pixel values.

Returns:
torch.FloatTensor of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
torch.FloatTensor of shape
`(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
Patchified pixel values.
"""
patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
Expand All @@ -688,14 +688,13 @@ def patchify(self, pixel_values):
patchified_pixel_values = rearrange(pixel_values, 'b c (t s) (h p) (w q) -> b (t h w) (s p q c)',
c=num_channels, s=patch_size_t, p=patch_size_h, q=patch_size_w)


return patchified_pixel_values

def unpatchify(self, patchified_pixel_values, image_size: tuple[int, int] | None = None):
"""
Args:
patchified_pixel_values (`torch.FloatTensor` of shape
`(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
`(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels))`:
Patchified pixel values.
image_size (`tuple[int, int]`, *optional*):
Original image size.
Expand All @@ -721,7 +720,8 @@ def forward_loss(self, pixel_values, pred, mask):
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, time, height, width)`):
Pixel values.
pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
pred (`torch.FloatTensor` of shape
`(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
Predicted pixel values.
mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Tensor indicating which patches are masked (1) and which are not (0).
Expand Down
2 changes: 1 addition & 1 deletion terratorch/models/backbones/prithvi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _create_prithvi(
if loaded_keys.missing_keys:
logger.warning(f"Missing keys in ckpt_path {ckpt_path}: {loaded_keys.missing_keys}")
if loaded_keys.unexpected_keys:
logger.warning(f"Missing keys in ckpt_path {ckpt_path}: {loaded_keys.missing_keys}")
logger.warning(f"Unexpected keys in ckpt_path {ckpt_path}: {loaded_keys.unexpected_keys}")
else:
assert variant in pretrained_weights, (f"No pre-trained model found for variant {variant} "
f"(pretrained models: {pretrained_weights.keys()})")
Expand Down
Loading