Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 4230854

Browse files
committedAug 16, 2024·
Whe tubulet_size > 1, the number of temporal patches must be properly evaluated
Signed-off-by: João Lucas de Sousa Almeida <joao.l.sa.9.3@gmail.com>
1 parent f2d1d58 commit 4230854

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed
 

‎terratorch/models/backbones/vit_encoder_decoder.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ def __init__(
186186
# MAE encoder specifics
187187
self.patch_embed = PatchEmbed(pretrain_img_size, patch_size, num_frames, tubelet_size, in_chans, embed_dim)
188188
self.patch_size = patch_size
189+
self.tubulet_size = tubulet_size
189190
self.feature_info = []
190191
self.in_chans = in_chans
191192
self.num_frames = num_frames
@@ -347,7 +348,7 @@ def forward_encoder(
347348
x = x.reshape(-1, self.in_chans, 1, *x.shape[-2:])
348349
t, h, w = x.shape[-3:]
349350
x = self.patch_embed(x)
350-
pos_embed = torch.from_numpy(get_3d_sincos_pos_embed(self.embed_dim, (t, h // self.patch_size, w // self.patch_size), cls_token=True)).to(
351+
pos_embed = torch.from_numpy(get_3d_sincos_pos_embed(self.embed_dim, (t // self.tubulet_size, h // self.patch_size, w // self.patch_size), cls_token=True)).to(
351352
x
352353
)
353354
# add pos embed w/o cls token
@@ -378,7 +379,7 @@ def forward_decoder(self, x: torch.Tensor, ids_restore: torch.Tensor, dim_info:
378379
x = self.decoder_embed(x)
379380
t, h, w = dim_info
380381
decoder_pos_embed = torch.from_numpy(
381-
get_3d_sincos_pos_embed(self.decoder_embed_dim, (t, h // self.patch_size, w // self.patch_size), cls_token=True)
382+
get_3d_sincos_pos_embed(self.decoder_embed_dim, (t // self.tubulet_size, h // self.patch_size, w // self.patch_size), cls_token=True)
382383
).to(x)
383384

384385
# append mask tokens to sequence
@@ -436,7 +437,7 @@ def forward_features(self, x) -> list[torch.Tensor]:
436437
t, h, w = x.shape[-3:]
437438
# embed patches
438439
x = self.patch_embed(x)
439-
pos_embed = torch.from_numpy(get_3d_sincos_pos_embed(self.embed_dim, (t, h // self.patch_size, w // self.patch_size), cls_token=True)).to(
440+
pos_embed = torch.from_numpy(get_3d_sincos_pos_embed(self.embed_dim, (t // self.tubulet_size, h // self.patch_size, w // self.patch_size), cls_token=True)).to(
440441
x
441442
)
442443
# add pos embed w/o cls token

0 commit comments

Comments
 (0)
Please sign in to comment.