@@ -389,12 +389,12 @@ def _cfg(url='', **kwargs):
389
389
'vit_base_mci_224.apple_mclip' : _cfg (
390
390
url = 'https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_b.pt' ,
391
391
num_classes = 512 ,
392
- mean = (0. , 0. , 0. ), std = (1. , 1. , 1. ), first_conv = 'patch_embed.backbone.0.conv.weight ' ,
392
+ mean = (0. , 0. , 0. ), std = (1. , 1. , 1. ), first_conv = 'patch_embed.backbone.0.conv' ,
393
393
),
394
394
'vit_base_mci_224.apple_mclip_lt' : _cfg (
395
395
url = 'https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt' ,
396
396
num_classes = 512 ,
397
- mean = (0. , 0. , 0. ), std = (1. , 1. , 1. ), first_conv = 'patch_embed.backbone.0.conv.weight ' ,
397
+ mean = (0. , 0. , 0. ), std = (1. , 1. , 1. ), first_conv = 'patch_embed.backbone.0.conv' ,
398
398
),
399
399
})
400
400
@@ -552,6 +552,7 @@ def vit_base_mci_224(pretrained=False, **kwargs) -> VisionTransformer:
552
552
stride = (4 , 2 , 2 ),
553
553
kernel_size = (4 , 2 , 2 ),
554
554
padding = 0 ,
555
+ in_chans = kwargs .get ('in_chans' , 3 ),
555
556
act_layer = nn .GELU ,
556
557
)
557
558
model_args = dict (embed_dim = 768 , depth = 12 , num_heads = 12 , no_embed_class = True )
0 commit comments