@@ -409,6 +409,7 @@ def __init__(
409
409
qk_norm : bool = False ,
410
410
init_values : Optional [float ] = None ,
411
411
class_token : bool = True ,
412
+ pos_embed : str = 'learn' ,
412
413
no_embed_class : bool = False ,
413
414
reg_tokens : int = 0 ,
414
415
pre_norm : bool = False ,
@@ -460,6 +461,7 @@ def __init__(
460
461
super ().__init__ ()
461
462
assert global_pool in ('' , 'avg' , 'token' , 'map' )
462
463
assert class_token or global_pool != 'token'
464
+ assert pos_embed in ('' , 'none' , 'learn' )
463
465
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
464
466
norm_layer = get_norm_layer (norm_layer ) or partial (nn .LayerNorm , eps = 1e-6 )
465
467
act_layer = get_act_layer (act_layer ) or nn .GELU
@@ -494,7 +496,10 @@ def __init__(
494
496
self .cls_token = nn .Parameter (torch .zeros (1 , 1 , embed_dim )) if class_token else None
495
497
self .reg_token = nn .Parameter (torch .zeros (1 , reg_tokens , embed_dim )) if reg_tokens else None
496
498
embed_len = num_patches if no_embed_class else num_patches + self .num_prefix_tokens
497
- self .pos_embed = nn .Parameter (torch .randn (1 , embed_len , embed_dim ) * .02 )
499
+ if not pos_embed or pos_embed == 'none' :
500
+ self .pos_embed = None
501
+ else :
502
+ self .pos_embed = nn .Parameter (torch .randn (1 , embed_len , embed_dim ) * .02 )
498
503
self .pos_drop = nn .Dropout (p = pos_drop_rate )
499
504
if patch_drop_rate > 0 :
500
505
self .patch_drop = PatchDropout (
@@ -556,7 +561,8 @@ def rescale(param, _layer_id):
556
561
def init_weights (self , mode : str = '' ) -> None :
557
562
assert mode in ('jax' , 'jax_nlhb' , 'moco' , '' )
558
563
head_bias = - math .log (self .num_classes ) if 'nlhb' in mode else 0.
559
- trunc_normal_ (self .pos_embed , std = .02 )
564
+ if self .pos_embed is not None :
565
+ trunc_normal_ (self .pos_embed , std = .02 )
560
566
if self .cls_token is not None :
561
567
nn .init .normal_ (self .cls_token , std = 1e-6 )
562
568
named_apply (get_init_weights_vit (mode , head_bias ), self )
@@ -583,6 +589,8 @@ def group_matcher(self, coarse: bool = False) -> Dict:
583
589
@torch .jit .ignore
584
590
def set_grad_checkpointing (self , enable : bool = True ) -> None :
585
591
self .grad_checkpointing = enable
592
+ if hasattr (self .patch_embed , 'set_grad_checkpointing' ):
593
+ self .patch_embed .set_grad_checkpointing (enable )
586
594
587
595
@torch .jit .ignore
588
596
def get_classifier (self ) -> nn .Module :
@@ -600,6 +608,9 @@ def reset_classifier(self, num_classes: int, global_pool = None) -> None:
600
608
self .head = nn .Linear (self .embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
601
609
602
610
def _pos_embed (self , x : torch .Tensor ) -> torch .Tensor :
611
+ if self .pos_embed is None :
612
+ return x
613
+
603
614
if self .dynamic_img_size :
604
615
B , H , W , C = x .shape
605
616
pos_embed = resample_abs_pos_embed (
@@ -1066,10 +1077,13 @@ def checkpoint_filter_fn(
1066
1077
# IJEPA, vit in an 'encoder' submodule
1067
1078
state_dict = state_dict ['encoder' ]
1068
1079
prefix = 'module.'
1069
- elif 'visual.trunk.pos_embed' in state_dict :
1080
+ elif 'visual.trunk.pos_embed' in state_dict or 'visual.trunk.blocks.0.norm1.weight' in state_dict :
1070
1081
# OpenCLIP model with timm vision encoder
1071
- # FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
1072
1082
prefix = 'visual.trunk.'
1083
+ if 'visual.head.proj.weight' in state_dict and isinstance (model .head , nn .Linear ):
1084
+ # remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
1085
+ out_dict ['head.weight' ] = state_dict ['visual.head.proj.weight' ]
1086
+ out_dict ['head.bias' ] = torch .zeros (state_dict ['visual.head.proj.weight' ].shape [0 ])
1073
1087
1074
1088
if prefix :
1075
1089
# filter on & remove prefix string from keys
0 commit comments