@@ -186,6 +186,7 @@ def __init__(
186
186
# MAE encoder specifics
187
187
self .patch_embed = PatchEmbed (pretrain_img_size , patch_size , num_frames , tubelet_size , in_chans , embed_dim )
188
188
self .patch_size = patch_size
189
+ self .tubulet_size = tubulet_size
189
190
self .feature_info = []
190
191
self .in_chans = in_chans
191
192
self .num_frames = num_frames
@@ -347,7 +348,7 @@ def forward_encoder(
347
348
x = x .reshape (- 1 , self .in_chans , 1 , * x .shape [- 2 :])
348
349
t , h , w = x .shape [- 3 :]
349
350
x = self .patch_embed (x )
350
- pos_embed = torch .from_numpy (get_3d_sincos_pos_embed (self .embed_dim , (t , h // self .patch_size , w // self .patch_size ), cls_token = True )).to (
351
+ pos_embed = torch .from_numpy (get_3d_sincos_pos_embed (self .embed_dim , (t // self . tubulet_size , h // self .patch_size , w // self .patch_size ), cls_token = True )).to (
351
352
x
352
353
)
353
354
# add pos embed w/o cls token
@@ -378,7 +379,7 @@ def forward_decoder(self, x: torch.Tensor, ids_restore: torch.Tensor, dim_info:
378
379
x = self .decoder_embed (x )
379
380
t , h , w = dim_info
380
381
decoder_pos_embed = torch .from_numpy (
381
- get_3d_sincos_pos_embed (self .decoder_embed_dim , (t , h // self .patch_size , w // self .patch_size ), cls_token = True )
382
+ get_3d_sincos_pos_embed (self .decoder_embed_dim , (t // self . tubulet_size , h // self .patch_size , w // self .patch_size ), cls_token = True )
382
383
).to (x )
383
384
384
385
# append mask tokens to sequence
@@ -436,7 +437,7 @@ def forward_features(self, x) -> list[torch.Tensor]:
436
437
t , h , w = x .shape [- 3 :]
437
438
# embed patches
438
439
x = self .patch_embed (x )
439
- pos_embed = torch .from_numpy (get_3d_sincos_pos_embed (self .embed_dim , (t , h // self .patch_size , w // self .patch_size ), cls_token = True )).to (
440
+ pos_embed = torch .from_numpy (get_3d_sincos_pos_embed (self .embed_dim , (t // self . tubulet_size , h // self .patch_size , w // self .patch_size ), cls_token = True )).to (
440
441
x
441
442
)
442
443
# add pos embed w/o cls token
0 commit comments