Skip to content

Commit 7394f7f

Browse files
Number of patches must be evaluated using the value defined in the config not a hard-coded one
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent bbef2dc commit 7394f7f

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

terratorch/models/backbones/vit_encoder_decoder.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def __init__(
185185
# --------------------------------------------------------------------------
186186
# MAE encoder specifics
187187
self.patch_embed = PatchEmbed(pretrain_img_size, patch_size, num_frames, tubelet_size, in_chans, embed_dim)
188+
self.patch_size = patch_size
188189
self.feature_info = []
189190
self.in_chans = in_chans
190191
self.num_frames = num_frames
@@ -435,7 +436,7 @@ def forward_features(self, x) -> list[torch.Tensor]:
435436
t, h, w = x.shape[-3:]
436437
# embed patches
437438
x = self.patch_embed(x)
438-
pos_embed = torch.from_numpy(get_3d_sincos_pos_embed(self.embed_dim, (t, h // 16, w // 16), cls_token=True)).to(
439+
pos_embed = torch.from_numpy(get_3d_sincos_pos_embed(self.embed_dim, (t, h // self.patch_size, w // self.patch_size), cls_token=True)).to(
439440
x
440441
)
441442
# add pos embed w/o cls token

0 commit comments

Comments
 (0)