Skip to content

Commit d001550

Browse files
author
Kye
committed
cm3leon
1 parent 3db36e5 commit d001550

File tree

6 files changed

+113
-100
lines changed

6 files changed

+113
-100
lines changed

README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,15 @@ To start with CM3Leon in a PyTorch environment:
3838
import torch
3939
from cm3.model import CM3
4040

41+
# usage
4142
img = torch.randn(1, 3, 256, 256)
42-
text = torch.randint(0, 20000, (1, 1024))
43+
caption = torch.randint(0, 20000, (1, 1024))
4344

4445
model = CM3()
45-
output = model(text, img)
46+
47+
output = model(img, caption)
48+
print(output.shape) # (1, 1024, 20000)
49+
4650

4751
```
4852

cm3/model.py

Lines changed: 39 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,16 @@
1111
ViTransformerWrapper,
1212
)
1313

14-
#logging
15-
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
14+
# logging
15+
logging.basicConfig(
16+
level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s"
17+
)
1618

1719

18-
#main model
20+
# main model
1921
class CM3(Module):
2022
"""
21-
Andromeda is a transformer-based model architecture. It initializes with
23+
Andromeda is a transformer-based model architecture. It initializes with
2224
a Transformer and AutoregressiveWrapper with default or user-specified parameters.
2325
2426
Initialize the model with specified or default parameters.
@@ -41,37 +43,33 @@ class CM3(Module):
4143
- attn_qk_norm: Attention query-key normalization
4244
- attn_qk_norm_dim_scale: Attention query-key normalization dimension scale
4345
"""
46+
4447
def __init__(
45-
self,
46-
num_tokens=50432,
47-
max_seq_len=8192,
48-
dim=2560,
49-
depth=32,
50-
dim_head=128,
51-
heads=24,
52-
use_abs_pos_emb=False,
53-
alibi_pos_bias=True,
54-
alibi_num_heads=12,
55-
rotary_xpos=True,
56-
attn_flash=True,
57-
image_size=256,
58-
patch_size=32,
59-
attn_one_kv_head=True, # multiquery attention
60-
qk_norm=True,
61-
attn_qk_norm=True,
62-
attn_qk_norm_dim_scale=True,
63-
):
48+
self,
49+
num_tokens=50432,
50+
max_seq_len=8192,
51+
dim=2560,
52+
depth=32,
53+
dim_head=128,
54+
heads=24,
55+
use_abs_pos_emb=False,
56+
alibi_pos_bias=True,
57+
alibi_num_heads=12,
58+
rotary_xpos=True,
59+
attn_flash=True,
60+
image_size=256,
61+
patch_size=32,
62+
attn_one_kv_head=True, # multiquery attention
63+
qk_norm=True,
64+
attn_qk_norm=True,
65+
attn_qk_norm_dim_scale=True,
66+
):
6467
super().__init__()
6568

6669
self.encoder = ViTransformerWrapper(
6770
image_size=image_size,
6871
patch_size=patch_size,
69-
attn_layers=Encoder(
70-
dim=dim,
71-
depth=depth,
72-
dim_head=dim_head,
73-
heads=heads
74-
)
72+
attn_layers=Encoder(dim=dim, depth=depth, dim_head=dim_head, heads=heads),
7573
)
7674

7775
self.transformer = Transformer(
@@ -91,30 +89,32 @@ def __init__(
9189
# qk_norm=qk_norm,
9290
# attn_qk_norm=attn_qk_norm,
9391
# attn_qk_norm_dim_scale=attn_qk_norm_dim_scale,
94-
cross_attend=True
95-
)
92+
cross_attend=True,
93+
),
9694
)
9795

9896
self.decoder = AutoregressiveWrapper(self.transformer)
9997

10098
def mask_and_relocate(self, text_tokens):
101-
#mask image span
102-
text_tokens = text_tokens.masked_fill(text_tokens==self.im_idx, self.mask_token)
99+
# mask image span
100+
text_tokens = text_tokens.masked_fill(
101+
text_tokens == self.im_idx, self.mask_token
102+
)
103103

104-
#relocate to end
105-
image_span = text_tokens[text_tokens==self.im_end_idx].unsqueeze(1)
104+
# relocate to end
105+
image_span = text_tokens[text_tokens == self.im_end_idx].unsqueeze(1)
106106
text_tokens = torch.cat([text_tokens, image_span], dim=1)
107107
return text_tokens
108-
108+
109109
def cm3_loss(self, log_probs, labels):
110-
#cm3 loss prediction
110+
# cm3 loss prediction
111111
loss = nn.NLLLoss()(log_probs, labels)
112112
return loss
113113

114114
# def forward(self, text_tokens, img, **kwargs):
115115
# try:
116116
# encoded_img = self.encoder(img, return_embeddings=True)
117-
117+
118118
# #mask and relocate image span in text tokens
119119
# text_tokens = self.mask_and_relocate(text_tokens)
120120

@@ -134,10 +134,9 @@ def cm3_loss(self, log_probs, labels):
134134
# raise
135135

136136
def forward(self, img, text):
137-
try:
137+
try:
138138
encoded = self.encoder(img, return_embeddings=True)
139139
return self.decoder(text, context=encoded)
140140
except Exception as error:
141141
print(f"Failed in forward method: {error}")
142142
raise
143-

cm3/tokenizer.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010

1111
class Tokenizer:
1212
"""
13-
A SentencePieceTokenizer is a tokenizer that uses a pretrained SentencePiece model
14-
to convert text into tokens and vice versa.
13+
A SentencePieceTokenizer is a tokenizer that uses a pretrained SentencePiece model
14+
to convert text into tokens and vice versa.
1515
16-
It includes the ability to add special tokens for infilling tasks and provides
16+
It includes the ability to add special tokens for infilling tasks and provides
1717
functionality to encode and decode text with or without implicit leading spaces.
1818
1919
Parameters:
@@ -32,6 +32,7 @@ class Tokenizer:
3232
3333
3434
"""
35+
3536
def __init__(self, model_path: str):
3637
# reload tokenizer
3738
assert os.path.isfile(model_path), model_path
@@ -49,28 +50,23 @@ def __init__(self, model_path: str):
4950
self.middle_id: Optional[int] = self.sp_model.piece_to_id("▁<MID>") or None
5051
self.suffix_id: Optional[int] = self.sp_model.piece_to_id("▁<SUF>") or None
5152
self.eot_id: Optional[int] = self.sp_model.piece_to_id("▁<EOT>") or None
52-
53-
#generates text until a modality break token is detected => then img is sampled
53+
54+
# generates text until a modality break token is detected => then img is sampled
5455
self.break_id: Optional[int] = self.sp_model.piece_to_id("_<BREAK>") or None
5556
self.image_id: Optional[int] = self.sp_model.piece_to_id("_<IMG>") or None
5657
self.infill_id: Optional[int] = self.sp_model.piece_to_id("_<INFILL>") or None
57-
58-
logger.info(f"BREAK ID: {self.break_id} - IMG ID: {self.image_id} - INFILL ID: {self.infill_id}")
59-
6058

59+
logger.info(
60+
f"BREAK ID: {self.break_id} - IMG ID: {self.image_id} - INFILL ID: {self.infill_id}"
61+
)
6162

6263
logger.info(
6364
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id} "
6465
f"- PRE ID: {self.prefix_id} - MID ID: {self.middle_id} - SUF ID: {self.suffix_id} - EOT ID: {self.eot_id}"
6566
)
6667
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
6768

68-
def encode(
69-
self,
70-
s: str,
71-
bos: bool,
72-
eos: bool
73-
) -> List[int]:
69+
def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
7470
assert type(s) is str
7571
t = self.sp_model.encode(s)
7672
if bos:
@@ -89,7 +85,7 @@ def encode_infilling(self, s: str) -> List[int]:
8985
def decode_infilling(self, t: List[int]) -> str:
9086
"""Decode a string without an implicit leading space."""
9187
return self.sp_model.decode([self.sp_model.piece_to_id("☺")] + t)[1:]
92-
88+
9389

9490
# class CM3LeonTokenizer(Tokenizer):
9591
# """
@@ -127,7 +123,7 @@ def decode_infilling(self, t: List[int]) -> str:
127123
# model_path=model_path,
128124
# query_text="A photo of an image segment",
129125
# )
130-
126+
131127
# def encode(
132128
# self,
133129
# s: str = None,
@@ -149,9 +145,8 @@ def decode_infilling(self, t: List[int]) -> str:
149145
# )
150146

151147
# #combine text, tokens and image embeddings
152-
# #starting with a <break> token followed by img embeds
148+
# #starting with a <break> token followed by img embeds
153149
# # and ending with a eos token
154150

155151
# seq = text + [self.break_id] + img + [self.eos_id]
156152
# return seq
157-

0 commit comments

Comments
 (0)