15
15
"""
16
16
import math
17
17
from functools import partial
18
- from typing import List , Optional , Tuple , Type , Union
18
+ from typing import Dict , List , Optional , Tuple , Type , Union
19
19
20
20
import torch
21
21
import torch .nn as nn
24
24
from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
25
25
from timm .layers import StdConv2dSame , StdConv2d , ConvNormAct , to_2tuple , to_ntuple , Format , nchw_to
26
26
27
+ from ._builder import build_model_with_cfg
27
28
from ._registry import generate_default_cfgs , register_model , register_model_deprecations
28
29
from .resnet import resnet26d , resnet50d
29
30
from .resnetv2 import ResNetV2 , create_resnetv2_stem
30
- from .vision_transformer import _create_vision_transformer , VisionTransformer
31
+ from .vision_transformer import VisionTransformer
31
32
32
33
33
34
class HybridEmbed (nn .Module ):
@@ -159,22 +160,26 @@ class HybridEmbedWithSize(nn.Module):
159
160
"""
160
161
def __init__ (
161
162
self ,
162
- backbone ,
163
- img_size = 224 ,
164
- patch_size = 1 ,
165
- feature_size = None ,
166
- in_chans = 3 ,
167
- embed_dim = 768 ,
163
+ backbone : nn .Module ,
164
+ img_size : Union [int , Tuple [int , int ]] = 224 ,
165
+ patch_size : Union [int , Tuple [int , int ]] = 1 ,
166
+ feature_size : Optional [Union [int , Tuple [int , int ]]] = None ,
167
+ feature_ratio : Optional [Union [int , Tuple [int , int ]]] = None ,
168
+ in_chans : int = 3 ,
169
+ embed_dim : int = 768 ,
168
170
bias = True ,
171
+ proj = True ,
169
172
):
170
173
super ().__init__ (
171
174
backbone = backbone ,
172
175
img_size = img_size ,
173
176
patch_size = patch_size ,
174
177
feature_size = feature_size ,
178
+ feature_ratio = feature_ratio ,
175
179
in_chans = in_chans ,
176
180
embed_dim = embed_dim ,
177
181
bias = bias ,
182
+ proj = proj ,
178
183
)
179
184
180
185
@torch .jit .ignore
@@ -206,12 +211,8 @@ def __init__(
206
211
):
207
212
super ().__init__ ()
208
213
if isinstance (channels , int ):
209
- if depth == 4 :
210
- channels = (channels // 8 , channels // 4 , channels // 2 , channels )
211
- elif depth == 3 :
212
- channels = (channels // 4 , channels // 2 , channels )
213
- else :
214
- channels = to_ntuple (depth )(channels )
214
+ # a default tiered channel strategy
215
+ channels = tuple ([channels // 2 ** i for i in range (depth )][::- 1 ])
215
216
216
217
kernel_size = to_ntuple (depth )(kernel_size )
217
218
padding = to_ntuple (depth )(padding )
@@ -235,13 +236,6 @@ def __init__(
235
236
in_chs = channels [i ]
236
237
237
238
238
- def _create_vision_transformer_hybrid (variant , backbone , embed_args = None , pretrained = False , ** kwargs ):
239
- embed_args = embed_args or {}
240
- embed_layer = partial (HybridEmbed , backbone = backbone , ** embed_args )
241
- kwargs .setdefault ('patch_size' , 1 ) # default patch size for hybrid models if not set
242
- return _create_vision_transformer (variant , pretrained = pretrained , embed_layer = embed_layer , ** kwargs )
243
-
244
-
245
239
def _resnetv2 (layers = (3 , 4 , 9 ), ** kwargs ):
246
240
""" ResNet-V2 backbone helper"""
247
241
padding_same = kwargs .get ('padding_same' , True )
@@ -257,6 +251,66 @@ def _resnetv2(layers=(3, 4, 9), **kwargs):
257
251
return backbone
258
252
259
253
254
+ def _convert_mobileclip (state_dict , model , prefix = 'image_encoder.model.' ):
255
+ out = {}
256
+ for k , v in state_dict .items ():
257
+ if not k .startswith (prefix ):
258
+ continue
259
+ k = k .replace (prefix , '' )
260
+ k = k .replace ('patch_emb.' , 'patch_embed.backbone.' )
261
+ k = k .replace ('block.conv' , 'conv' )
262
+ k = k .replace ('block.norm' , 'bn' )
263
+ k = k .replace ('post_transformer_norm.' , 'norm.' )
264
+ k = k .replace ('pre_norm_mha.0' , 'norm1' )
265
+ k = k .replace ('pre_norm_mha.1' , 'attn' )
266
+ k = k .replace ('pre_norm_ffn.0' , 'norm2' )
267
+ k = k .replace ('pre_norm_ffn.1' , 'mlp.fc1' )
268
+ k = k .replace ('pre_norm_ffn.4' , 'mlp.fc2' )
269
+ k = k .replace ('qkv_proj.' , 'qkv.' )
270
+ k = k .replace ('out_proj.' , 'proj.' )
271
+ k = k .replace ('transformer.' , 'blocks.' )
272
+ if k == 'pos_embed.pos_embed.pos_embed' :
273
+ k = 'pos_embed'
274
+ v = v .squeeze (0 )
275
+ if 'classifier.proj' in k :
276
+ bias_k = k .replace ('classifier.proj' , 'head.bias' )
277
+ k = k .replace ('classifier.proj' , 'head.weight' )
278
+ v = v .T
279
+ out [bias_k ] = torch .zeros (v .shape [0 ])
280
+ out [k ] = v
281
+ return out
282
+
283
+
284
+ def checkpoint_filter_fn (
285
+ state_dict : Dict [str , torch .Tensor ],
286
+ model : VisionTransformer ,
287
+ interpolation : str = 'bicubic' ,
288
+ antialias : bool = True ,
289
+ ) -> Dict [str , torch .Tensor ]:
290
+ from .vision_transformer import checkpoint_filter_fn as _filter_fn
291
+
292
+ if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict :
293
+ state_dict = _convert_mobileclip (state_dict , model )
294
+
295
+ return _filter_fn (state_dict , model , interpolation = interpolation , antialias = antialias )
296
+
297
+
298
+ def _create_vision_transformer_hybrid (variant , backbone , embed_args = None , pretrained = False , ** kwargs ):
299
+ out_indices = kwargs .pop ('out_indices' , 3 )
300
+ embed_args = embed_args or {}
301
+ embed_layer = partial (HybridEmbed , backbone = backbone , ** embed_args )
302
+ kwargs .setdefault ('embed_layer' , embed_layer )
303
+ kwargs .setdefault ('patch_size' , 1 ) # default patch size for hybrid models if not set
304
+ return build_model_with_cfg (
305
+ VisionTransformer ,
306
+ variant ,
307
+ pretrained ,
308
+ pretrained_filter_fn = checkpoint_filter_fn ,
309
+ feature_cfg = dict (out_indices = out_indices , feature_cls = 'getter' ),
310
+ ** kwargs ,
311
+ )
312
+
313
+
260
314
def _cfg (url = '' , ** kwargs ):
261
315
return {
262
316
'url' : url ,
@@ -331,6 +385,17 @@ def _cfg(url='', **kwargs):
331
385
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD , first_conv = 'patch_embed.backbone.conv1.0' ),
332
386
'vit_base_resnet50d_224.untrained' : _cfg (
333
387
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD , first_conv = 'patch_embed.backbone.conv1.0' ),
388
+
389
+ 'vit_base_mci_224.apple_mclip' : _cfg (
390
+ url = 'https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_b.pt' ,
391
+ num_classes = 512 ,
392
+ mean = (0. , 0. , 0. ), std = (1. , 1. , 1. ), first_conv = 'patch_embed.backbone.conv1.0' ,
393
+ ),
394
+ 'vit_base_mci_224.apple_mclip_lt' : _cfg (
395
+ url = 'https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt' ,
396
+ num_classes = 512 ,
397
+ mean = (0. , 0. , 0. ), std = (1. , 1. , 1. ), first_conv = 'patch_embed.backbone.conv1.0' ,
398
+ ),
334
399
})
335
400
336
401
@@ -491,7 +556,7 @@ def vit_base_mci_224(pretrained=False, **kwargs) -> VisionTransformer:
491
556
)
492
557
model_args = dict (embed_dim = 768 , depth = 12 , num_heads = 12 , no_embed_class = True )
493
558
model = _create_vision_transformer_hybrid (
494
- 'vit_base_resnet50d_224 ' , backbone = backbone , embed_args = dict (proj = False ),
559
+ 'vit_base_mci_224 ' , backbone = backbone , embed_args = dict (proj = False ),
495
560
pretrained = pretrained , ** dict (model_args , ** kwargs )
496
561
)
497
562
return model
0 commit comments