Skip to content

Commit 88a1006

Browse files
committed
checkpoint filter fns with consistent name, add mobileclip-b pretrained cfgs
1 parent 7d4ada6 commit 88a1006

File tree

5 files changed

+95
-30
lines changed

5 files changed

+95
-30
lines changed

timm/models/beit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ def _cfg(url='', **kwargs):
591591
})
592592

593593

594-
def _beit_checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=True):
594+
def checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=True):
595595
state_dict = state_dict.get('model', state_dict)
596596
state_dict = state_dict.get('module', state_dict)
597597
# beit v2 didn't strip module
@@ -637,7 +637,7 @@ def _create_beit(variant, pretrained=False, **kwargs):
637637
out_indices = kwargs.pop('out_indices', 3)
638638
model = build_model_with_cfg(
639639
Beit, variant, pretrained,
640-
pretrained_filter_fn=_beit_checkpoint_filter_fn,
640+
pretrained_filter_fn=checkpoint_filter_fn,
641641
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
642642
**kwargs,
643643
)

timm/models/efficientformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ def forward(self, x):
556556
return x
557557

558558

559-
def _checkpoint_filter_fn(state_dict, model):
559+
def checkpoint_filter_fn(state_dict, model):
560560
""" Remap original checkpoints -> timm """
561561
if 'stem.0.weight' in state_dict:
562562
return state_dict # non-original checkpoint, no remapping needed
@@ -611,7 +611,7 @@ def _create_efficientformer(variant, pretrained=False, **kwargs):
611611
out_indices = kwargs.pop('out_indices', 4)
612612
model = build_model_with_cfg(
613613
EfficientFormer, variant, pretrained,
614-
pretrained_filter_fn=_checkpoint_filter_fn,
614+
pretrained_filter_fn=checkpoint_filter_fn,
615615
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
616616
**kwargs,
617617
)

timm/models/fastvit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,7 +1414,7 @@ def _cfg(url="", **kwargs):
14141414
})
14151415

14161416

1417-
def _checkpoint_filter_fn(state_dict, model):
1417+
def checkpoint_filter_fn(state_dict, model):
14181418
""" Remap original checkpoints -> timm """
14191419
if 'stem.0.conv_kxk.0.conv.weight' in state_dict:
14201420
return state_dict # non-original checkpoint, no remapping needed
@@ -1493,7 +1493,7 @@ def _create_fastvit(variant, pretrained=False, **kwargs):
14931493
FastVit,
14941494
variant,
14951495
pretrained,
1496-
pretrained_filter_fn=_checkpoint_filter_fn,
1496+
pretrained_filter_fn=checkpoint_filter_fn,
14971497
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
14981498
**kwargs
14991499
)

timm/models/pvt_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def forward(self, x):
403403
return x
404404

405405

406-
def _checkpoint_filter_fn(state_dict, model):
406+
def checkpoint_filter_fn(state_dict, model):
407407
""" Remap original checkpoints -> timm """
408408
if 'patch_embed.proj.weight' in state_dict:
409409
return state_dict # non-original checkpoint, no remapping needed
@@ -430,7 +430,7 @@ def _create_pvt2(variant, pretrained=False, **kwargs):
430430
PyramidVisionTransformerV2,
431431
variant,
432432
pretrained,
433-
pretrained_filter_fn=_checkpoint_filter_fn,
433+
pretrained_filter_fn=checkpoint_filter_fn,
434434
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
435435
**kwargs,
436436
)

timm/models/vision_transformer_hybrid.py

Lines changed: 87 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""
1616
import math
1717
from functools import partial
18-
from typing import List, Optional, Tuple, Type, Union
18+
from typing import Dict, List, Optional, Tuple, Type, Union
1919

2020
import torch
2121
import torch.nn as nn
@@ -24,10 +24,11 @@
2424
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2525
from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_2tuple, to_ntuple, Format, nchw_to
2626

27+
from ._builder import build_model_with_cfg
2728
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
2829
from .resnet import resnet26d, resnet50d
2930
from .resnetv2 import ResNetV2, create_resnetv2_stem
30-
from .vision_transformer import _create_vision_transformer, VisionTransformer
31+
from .vision_transformer import VisionTransformer
3132

3233

3334
class HybridEmbed(nn.Module):
@@ -159,22 +160,26 @@ class HybridEmbedWithSize(nn.Module):
159160
"""
160161
def __init__(
161162
self,
162-
backbone,
163-
img_size=224,
164-
patch_size=1,
165-
feature_size=None,
166-
in_chans=3,
167-
embed_dim=768,
163+
backbone: nn.Module,
164+
img_size: Union[int, Tuple[int, int]] = 224,
165+
patch_size: Union[int, Tuple[int, int]] = 1,
166+
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
167+
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
168+
in_chans: int = 3,
169+
embed_dim: int = 768,
168170
bias=True,
171+
proj=True,
169172
):
170173
super().__init__(
171174
backbone=backbone,
172175
img_size=img_size,
173176
patch_size=patch_size,
174177
feature_size=feature_size,
178+
feature_ratio=feature_ratio,
175179
in_chans=in_chans,
176180
embed_dim=embed_dim,
177181
bias=bias,
182+
proj=proj,
178183
)
179184

180185
@torch.jit.ignore
@@ -206,12 +211,8 @@ def __init__(
206211
):
207212
super().__init__()
208213
if isinstance(channels, int):
209-
if depth == 4:
210-
channels = (channels // 8, channels // 4, channels // 2, channels)
211-
elif depth == 3:
212-
channels = (channels // 4, channels // 2, channels)
213-
else:
214-
channels = to_ntuple(depth)(channels)
214+
# a default tiered channel strategy
215+
channels = tuple([channels // 2**i for i in range(depth)][::-1])
215216

216217
kernel_size = to_ntuple(depth)(kernel_size)
217218
padding = to_ntuple(depth)(padding)
@@ -235,13 +236,6 @@ def __init__(
235236
in_chs = channels[i]
236237

237238

238-
def _create_vision_transformer_hybrid(variant, backbone, embed_args=None, pretrained=False, **kwargs):
239-
embed_args = embed_args or {}
240-
embed_layer = partial(HybridEmbed, backbone=backbone, **embed_args)
241-
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
242-
return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs)
243-
244-
245239
def _resnetv2(layers=(3, 4, 9), **kwargs):
246240
""" ResNet-V2 backbone helper"""
247241
padding_same = kwargs.get('padding_same', True)
@@ -257,6 +251,66 @@ def _resnetv2(layers=(3, 4, 9), **kwargs):
257251
return backbone
258252

259253

254+
def _convert_mobileclip(state_dict, model, prefix='image_encoder.model.'):
255+
out = {}
256+
for k, v in state_dict.items():
257+
if not k.startswith(prefix):
258+
continue
259+
k = k.replace(prefix, '')
260+
k = k.replace('patch_emb.', 'patch_embed.backbone.')
261+
k = k.replace('block.conv', 'conv')
262+
k = k.replace('block.norm', 'bn')
263+
k = k.replace('post_transformer_norm.', 'norm.')
264+
k = k.replace('pre_norm_mha.0', 'norm1')
265+
k = k.replace('pre_norm_mha.1', 'attn')
266+
k = k.replace('pre_norm_ffn.0', 'norm2')
267+
k = k.replace('pre_norm_ffn.1', 'mlp.fc1')
268+
k = k.replace('pre_norm_ffn.4', 'mlp.fc2')
269+
k = k.replace('qkv_proj.', 'qkv.')
270+
k = k.replace('out_proj.', 'proj.')
271+
k = k.replace('transformer.', 'blocks.')
272+
if k == 'pos_embed.pos_embed.pos_embed':
273+
k = 'pos_embed'
274+
v = v.squeeze(0)
275+
if 'classifier.proj' in k:
276+
bias_k = k.replace('classifier.proj', 'head.bias')
277+
k = k.replace('classifier.proj', 'head.weight')
278+
v = v.T
279+
out[bias_k] = torch.zeros(v.shape[0])
280+
out[k] = v
281+
return out
282+
283+
284+
def checkpoint_filter_fn(
285+
state_dict: Dict[str, torch.Tensor],
286+
model: VisionTransformer,
287+
interpolation: str = 'bicubic',
288+
antialias: bool = True,
289+
) -> Dict[str, torch.Tensor]:
290+
from .vision_transformer import checkpoint_filter_fn as _filter_fn
291+
292+
if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict:
293+
state_dict = _convert_mobileclip(state_dict, model)
294+
295+
return _filter_fn(state_dict, model, interpolation=interpolation, antialias=antialias)
296+
297+
298+
def _create_vision_transformer_hybrid(variant, backbone, embed_args=None, pretrained=False, **kwargs):
299+
out_indices = kwargs.pop('out_indices', 3)
300+
embed_args = embed_args or {}
301+
embed_layer = partial(HybridEmbed, backbone=backbone, **embed_args)
302+
kwargs.setdefault('embed_layer', embed_layer)
303+
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
304+
return build_model_with_cfg(
305+
VisionTransformer,
306+
variant,
307+
pretrained,
308+
pretrained_filter_fn=checkpoint_filter_fn,
309+
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
310+
**kwargs,
311+
)
312+
313+
260314
def _cfg(url='', **kwargs):
261315
return {
262316
'url': url,
@@ -331,6 +385,17 @@ def _cfg(url='', **kwargs):
331385
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
332386
'vit_base_resnet50d_224.untrained': _cfg(
333387
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
388+
389+
'vit_base_mci_224.apple_mclip': _cfg(
390+
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_b.pt',
391+
num_classes=512,
392+
mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.conv1.0',
393+
),
394+
'vit_base_mci_224.apple_mclip_lt': _cfg(
395+
url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt',
396+
num_classes=512,
397+
mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.conv1.0',
398+
),
334399
})
335400

336401

@@ -491,7 +556,7 @@ def vit_base_mci_224(pretrained=False, **kwargs) -> VisionTransformer:
491556
)
492557
model_args = dict(embed_dim=768, depth=12, num_heads=12, no_embed_class=True)
493558
model = _create_vision_transformer_hybrid(
494-
'vit_base_resnet50d_224', backbone=backbone, embed_args=dict(proj=False),
559+
'vit_base_mci_224', backbone=backbone, embed_args=dict(proj=False),
495560
pretrained=pretrained, **dict(model_args, **kwargs)
496561
)
497562
return model

0 commit comments

Comments
 (0)