Skip to content

Commit 5756a81

Browse files
committed
Merge remote-tracking branch 'origin/Beckschen-vitamin' into fastvit_mobileclip
2 parents 7f96538 + 0e77c95 commit 5756a81

File tree

5 files changed

+578
-14
lines changed

5 files changed

+578
-14
lines changed

tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
6161
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
6262
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*',
63-
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*'
63+
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*'
6464
]
6565
NUM_NON_STD = len(NON_STD_FILTERS)
6666

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
from .vision_transformer_hybrid import *
7272
from .vision_transformer_relpos import *
7373
from .vision_transformer_sam import *
74+
from .vitamin import *
7475
from .volo import *
7576
from .vovnet import *
7677
from .xception import *

timm/models/vision_transformer.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ def __init__(
409409
qk_norm: bool = False,
410410
init_values: Optional[float] = None,
411411
class_token: bool = True,
412+
pos_embed: str = 'learn',
412413
no_embed_class: bool = False,
413414
reg_tokens: int = 0,
414415
pre_norm: bool = False,
@@ -460,6 +461,7 @@ def __init__(
460461
super().__init__()
461462
assert global_pool in ('', 'avg', 'token', 'map')
462463
assert class_token or global_pool != 'token'
464+
assert pos_embed in ('', 'none', 'learn')
463465
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
464466
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
465467
act_layer = get_act_layer(act_layer) or nn.GELU
@@ -494,7 +496,10 @@ def __init__(
494496
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
495497
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
496498
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
497-
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
499+
if not pos_embed or pos_embed == 'none':
500+
self.pos_embed = None
501+
else:
502+
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
498503
self.pos_drop = nn.Dropout(p=pos_drop_rate)
499504
if patch_drop_rate > 0:
500505
self.patch_drop = PatchDropout(
@@ -556,7 +561,8 @@ def rescale(param, _layer_id):
556561
def init_weights(self, mode: str = '') -> None:
557562
assert mode in ('jax', 'jax_nlhb', 'moco', '')
558563
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
559-
trunc_normal_(self.pos_embed, std=.02)
564+
if self.pos_embed is not None:
565+
trunc_normal_(self.pos_embed, std=.02)
560566
if self.cls_token is not None:
561567
nn.init.normal_(self.cls_token, std=1e-6)
562568
named_apply(get_init_weights_vit(mode, head_bias), self)
@@ -583,6 +589,8 @@ def group_matcher(self, coarse: bool = False) -> Dict:
583589
@torch.jit.ignore
584590
def set_grad_checkpointing(self, enable: bool = True) -> None:
585591
self.grad_checkpointing = enable
592+
if hasattr(self.patch_embed, 'set_grad_checkpointing'):
593+
self.patch_embed.set_grad_checkpointing(enable)
586594

587595
@torch.jit.ignore
588596
def get_classifier(self) -> nn.Module:
@@ -600,6 +608,9 @@ def reset_classifier(self, num_classes: int, global_pool = None) -> None:
600608
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
601609

602610
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
611+
if self.pos_embed is None:
612+
return x
613+
603614
if self.dynamic_img_size:
604615
B, H, W, C = x.shape
605616
pos_embed = resample_abs_pos_embed(
@@ -1066,10 +1077,13 @@ def checkpoint_filter_fn(
10661077
# IJEPA, vit in an 'encoder' submodule
10671078
state_dict = state_dict['encoder']
10681079
prefix = 'module.'
1069-
elif 'visual.trunk.pos_embed' in state_dict:
1080+
elif 'visual.trunk.pos_embed' in state_dict or 'visual.trunk.blocks.0.norm1.weight' in state_dict:
10701081
# OpenCLIP model with timm vision encoder
1071-
# FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
10721082
prefix = 'visual.trunk.'
1083+
if 'visual.head.proj.weight' in state_dict and isinstance(model.head, nn.Linear):
1084+
# remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
1085+
out_dict['head.weight'] = state_dict['visual.head.proj.weight']
1086+
out_dict['head.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
10731087

10741088
if prefix:
10751089
# filter on & remove prefix string from keys

timm/models/vision_transformer_hybrid.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,15 @@ class HybridEmbed(nn.Module):
3838

3939
def __init__(
4040
self,
41-
backbone,
42-
img_size=224,
43-
patch_size=1,
44-
feature_size=None,
45-
feature_ratio=None,
46-
in_chans=3,
47-
embed_dim=768,
48-
bias=True,
41+
backbone: nn.Module,
42+
img_size: Union[int, Tuple[int, int]] = 224,
43+
patch_size: Union[int, Tuple[int, int]] = 1,
44+
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
45+
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
46+
in_chans: int = 3,
47+
embed_dim: int = 768,
48+
bias: bool = True,
49+
proj: bool = True,
4950
flatten: bool = True,
5051
output_fmt: Optional[str] = None,
5152
strict_img_size: bool = True,
@@ -95,7 +96,18 @@ def __init__(
9596
self.strict_img_size = strict_img_size
9697
self.dynamic_img_pad = dynamic_img_pad
9798

98-
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
99+
if proj:
100+
self.proj = nn.Conv2d(
101+
feature_dim,
102+
embed_dim,
103+
kernel_size=patch_size,
104+
stride=patch_size,
105+
bias=bias,
106+
)
107+
else:
108+
assert feature_dim == embed_dim,\
109+
f'The feature dim ({feature_dim} must match embed dim ({embed_dim}) when projection disabled.'
110+
self.proj = nn.Identity()
99111

100112
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
101113
total_reduction = (
@@ -116,6 +128,13 @@ def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
116128
else:
117129
return feat_size[0] // self.patch_size[0], feat_size[1] // self.patch_size[1]
118130

131+
@torch.jit.ignore
132+
def set_grad_checkpointing(self, enable: bool = True):
133+
if hasattr(self.backbone, 'set_grad_checkpointing'):
134+
self.backbone.set_grad_checkpointing(enable=enable)
135+
elif hasattr(self.backbone, 'grad_checkpointing'):
136+
self.backbone.grad_checkpointing = enable
137+
119138
def forward(self, x):
120139
x = self.backbone(x)
121140
if isinstance(x, (list, tuple)):
@@ -157,6 +176,13 @@ def __init__(
157176
bias=bias,
158177
)
159178

179+
@torch.jit.ignore
180+
def set_grad_checkpointing(self, enable: bool = True):
181+
if hasattr(self.backbone, 'set_grad_checkpointing'):
182+
self.backbone.set_grad_checkpointing(enable=enable)
183+
elif hasattr(self.backbone, 'grad_checkpointing'):
184+
self.backbone.grad_checkpointing = enable
185+
160186
def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
161187
x = self.backbone(x)
162188
if isinstance(x, (list, tuple)):

0 commit comments

Comments
 (0)