Skip to content

Commit 9a7f9fe

Browse files
Correcting typo
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent 4230854 commit 9a7f9fe

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

terratorch/models/backbones/vit_encoder_decoder.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def __init__(
186186
# MAE encoder specifics
187187
self.patch_embed = PatchEmbed(pretrain_img_size, patch_size, num_frames, tubelet_size, in_chans, embed_dim)
188188
self.patch_size = patch_size
189-
self.tubulet_size = tubulet_size
189+
self.tubelet_size = tubelet_size
190190
self.feature_info = []
191191
self.in_chans = in_chans
192192
self.num_frames = num_frames
@@ -348,7 +348,7 @@ def forward_encoder(
348348
x = x.reshape(-1, self.in_chans, 1, *x.shape[-2:])
349349
t, h, w = x.shape[-3:]
350350
x = self.patch_embed(x)
351-
pos_embed = torch.from_numpy(get_3d_sincos_pos_embed(self.embed_dim, (t // self.tubulet_size, h // self.patch_size, w // self.patch_size), cls_token=True)).to(
351+
pos_embed = torch.from_numpy(get_3d_sincos_pos_embed(self.embed_dim, (t // self.tubelet_size, h // self.patch_size, w // self.patch_size), cls_token=True)).to(
352352
x
353353
)
354354
# add pos embed w/o cls token
@@ -379,7 +379,7 @@ def forward_decoder(self, x: torch.Tensor, ids_restore: torch.Tensor, dim_info:
379379
x = self.decoder_embed(x)
380380
t, h, w = dim_info
381381
decoder_pos_embed = torch.from_numpy(
382-
get_3d_sincos_pos_embed(self.decoder_embed_dim, (t // self.tubulet_size, h // self.patch_size, w // self.patch_size), cls_token=True)
382+
get_3d_sincos_pos_embed(self.decoder_embed_dim, (t // self.tubelet_size, h // self.patch_size, w // self.patch_size), cls_token=True)
383383
).to(x)
384384

385385
# append mask tokens to sequence
@@ -437,7 +437,7 @@ def forward_features(self, x) -> list[torch.Tensor]:
437437
t, h, w = x.shape[-3:]
438438
# embed patches
439439
x = self.patch_embed(x)
440-
pos_embed = torch.from_numpy(get_3d_sincos_pos_embed(self.embed_dim, (t // self.tubulet_size, h // self.patch_size, w // self.patch_size), cls_token=True)).to(
440+
pos_embed = torch.from_numpy(get_3d_sincos_pos_embed(self.embed_dim, (t // self.tubelet_size, h // self.patch_size, w // self.patch_size), cls_token=True)).to(
441441
x
442442
)
443443
# add pos embed w/o cls token

0 commit comments

Comments
 (0)