Skip to content

Commit f2d1d58

Browse files
More hardcoded values removed
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent 7394f7f commit f2d1d58

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

terratorch/models/backbones/vit_encoder_decoder.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def forward_encoder(
347347
x = x.reshape(-1, self.in_chans, 1, *x.shape[-2:])
348348
t, h, w = x.shape[-3:]
349349
x = self.patch_embed(x)
350-
pos_embed = torch.from_numpy(get_3d_sincos_pos_embed(self.embed_dim, (t, h // 16, w // 16), cls_token=True)).to(
350+
pos_embed = torch.from_numpy(get_3d_sincos_pos_embed(self.embed_dim, (t, h // self.patch_size, w // self.patch_size), cls_token=True)).to(
351351
x
352352
)
353353
# add pos embed w/o cls token
@@ -378,7 +378,7 @@ def forward_decoder(self, x: torch.Tensor, ids_restore: torch.Tensor, dim_info:
378378
x = self.decoder_embed(x)
379379
t, h, w = dim_info
380380
decoder_pos_embed = torch.from_numpy(
381-
get_3d_sincos_pos_embed(self.decoder_embed_dim, (t, h // 16, w // 16), cls_token=True)
381+
get_3d_sincos_pos_embed(self.decoder_embed_dim, (t, h // self.patch_size, w // self.patch_size), cls_token=True)
382382
).to(x)
383383

384384
# append mask tokens to sequence

0 commit comments

Comments
 (0)