Skip to content

Commit d3c6532

Browse files
committed
[ADD] average CLIP embedding of input frames
1 parent 31f8c33 commit d3c6532

File tree

2 files changed

+92
-6
lines changed

2 files changed

+92
-6
lines changed

configs/example_training/seva-clipl_dl3dv.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,12 @@ model:
6464
input_key: mask
6565
target: sgm.modules.encoders.modules.IdentityEncoder
6666

67-
# - is_trainable: False
68-
# ucg_rate: 0.1
69-
# input_key: input_frames
70-
# target: sgm.modules.encoders.modules.SevaFrozenOpenCLIPImageEmbedder
71-
# params:
72-
# max_crops: 2 # same as num_images in the dataloader
67+
- is_trainable: False # crossattn
68+
ucg_rate: 0.1
69+
input_keys: ["clean_latent", "mask"]
70+
target: sgm.modules.encoders.modules.SevaFrozenOpenCLIPImageEmbedder
71+
params:
72+
unsqueeze_dim: True
7373

7474

7575
# - is_trainable: False

sgm/modules/encoders/modules.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,3 +1052,89 @@ def forward(self, vid):
10521052
vid = repeat(vid, "b t d -> (b s) t d", s=self.n_copies)
10531053

10541054
return vid
1055+
1056+
1057+
class SevaFrozenOpenCLIPImageEmbedder(AbstractEmbModel):
1058+
"""
1059+
Uses the OpenCLIP vision transformer encoder for images
1060+
"""
1061+
1062+
def __init__(
1063+
self,
1064+
arch="ViT-H-14",
1065+
version="laion2b_s32b_b79k",
1066+
device="cuda",
1067+
max_length=77,
1068+
freeze=True,
1069+
antialias=True,
1070+
ucg_rate=0.0,
1071+
unsqueeze_dim=False,
1072+
init_device=None,
1073+
):
1074+
super().__init__()
1075+
model, _, _ = open_clip.create_model_and_transforms(
1076+
arch,
1077+
device=torch.device(default(init_device, "cpu")),
1078+
pretrained=version,
1079+
)
1080+
del model.transformer
1081+
self.model = model
1082+
self.device = device
1083+
self.max_length = max_length
1084+
if freeze:
1085+
self.freeze()
1086+
1087+
self.antialias = antialias
1088+
1089+
self.register_buffer(
1090+
"mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
1091+
)
1092+
self.register_buffer(
1093+
"std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
1094+
)
1095+
self.ucg_rate = ucg_rate
1096+
self.stored_batch = None
1097+
1098+
def preprocess(self, x):
1099+
# normalize to [0,1]
1100+
x = kornia.geometry.resize(
1101+
x,
1102+
(224, 224),
1103+
interpolation="bicubic",
1104+
align_corners=True,
1105+
antialias=self.antialias,
1106+
)
1107+
x = (x + 1.0) / 2.0
1108+
# renormalize according to clip
1109+
x = kornia.enhance.normalize(x, self.mean, self.std)
1110+
return x
1111+
1112+
def freeze(self):
1113+
self.model = self.model.eval()
1114+
for param in self.parameters():
1115+
param.requires_grad = False
1116+
1117+
@autocast
1118+
def forward(self, image, mask, no_dropout=False):
1119+
batch_size = image.shape[0]
1120+
z = [self.encode_with_vision_transformer(image[b][mask[b]]).mean(0, keepdim=True)
1121+
for b in range(batch_size)]
1122+
z = torch.cat(z, dim=0)
1123+
z = z.to(image.dtype)
1124+
if self.ucg_rate > 0.0 and not no_dropout:
1125+
z = (
1126+
torch.bernoulli(
1127+
(1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
1128+
)[:, None]
1129+
* z
1130+
)
1131+
z = z[:, None]
1132+
return z
1133+
1134+
def encode_with_vision_transformer(self, img):
1135+
img = self.preprocess(img)
1136+
x = self.model.visual(img)
1137+
return x
1138+
1139+
def encode(self, text):
1140+
return self(text)

0 commit comments

Comments
 (0)