@@ -1052,3 +1052,89 @@ def forward(self, vid):
10521052 vid = repeat (vid , "b t d -> (b s) t d" , s = self .n_copies )
10531053
10541054 return vid
1055+
1056+
1057+ class SevaFrozenOpenCLIPImageEmbedder (AbstractEmbModel ):
1058+ """
1059+ Uses the OpenCLIP vision transformer encoder for images
1060+ """
1061+
1062+ def __init__ (
1063+ self ,
1064+ arch = "ViT-H-14" ,
1065+ version = "laion2b_s32b_b79k" ,
1066+ device = "cuda" ,
1067+ max_length = 77 ,
1068+ freeze = True ,
1069+ antialias = True ,
1070+ ucg_rate = 0.0 ,
1071+ unsqueeze_dim = False ,
1072+ init_device = None ,
1073+ ):
1074+ super ().__init__ ()
1075+ model , _ , _ = open_clip .create_model_and_transforms (
1076+ arch ,
1077+ device = torch .device (default (init_device , "cpu" )),
1078+ pretrained = version ,
1079+ )
1080+ del model .transformer
1081+ self .model = model
1082+ self .device = device
1083+ self .max_length = max_length
1084+ if freeze :
1085+ self .freeze ()
1086+
1087+ self .antialias = antialias
1088+
1089+ self .register_buffer (
1090+ "mean" , torch .Tensor ([0.48145466 , 0.4578275 , 0.40821073 ]), persistent = False
1091+ )
1092+ self .register_buffer (
1093+ "std" , torch .Tensor ([0.26862954 , 0.26130258 , 0.27577711 ]), persistent = False
1094+ )
1095+ self .ucg_rate = ucg_rate
1096+ self .stored_batch = None
1097+
1098+ def preprocess (self , x ):
1099+ # normalize to [0,1]
1100+ x = kornia .geometry .resize (
1101+ x ,
1102+ (224 , 224 ),
1103+ interpolation = "bicubic" ,
1104+ align_corners = True ,
1105+ antialias = self .antialias ,
1106+ )
1107+ x = (x + 1.0 ) / 2.0
1108+ # renormalize according to clip
1109+ x = kornia .enhance .normalize (x , self .mean , self .std )
1110+ return x
1111+
1112+ def freeze (self ):
1113+ self .model = self .model .eval ()
1114+ for param in self .parameters ():
1115+ param .requires_grad = False
1116+
1117+ @autocast
1118+ def forward (self , image , mask , no_dropout = False ):
1119+ batch_size = image .shape [0 ]
1120+ z = [self .encode_with_vision_transformer (image [b ][mask [b ]]).mean (0 , keepdim = True )
1121+ for b in range (batch_size )]
1122+ z = torch .cat (z , dim = 0 )
1123+ z = z .to (image .dtype )
1124+ if self .ucg_rate > 0.0 and not no_dropout :
1125+ z = (
1126+ torch .bernoulli (
1127+ (1.0 - self .ucg_rate ) * torch .ones (z .shape [0 ], device = z .device )
1128+ )[:, None ]
1129+ * z
1130+ )
1131+ z = z [:, None ]
1132+ return z
1133+
1134+ def encode_with_vision_transformer (self , img ):
1135+ img = self .preprocess (img )
1136+ x = self .model .visual (img )
1137+ return x
1138+
1139+ def encode (self , text ):
1140+ return self (text )
0 commit comments