11
11
ViTransformerWrapper ,
12
12
)
13
13
14
- #logging
15
- logging .basicConfig (level = logging .DEBUG , format = '%(asctime)s - %(levelname)s - %(message)s' )
14
+ # logging
15
+ logging .basicConfig (
16
+ level = logging .DEBUG , format = "%(asctime)s - %(levelname)s - %(message)s"
17
+ )
16
18
17
19
18
- #main model
20
+ # main model
19
21
class CM3 (Module ):
20
22
"""
21
- Andromeda is a transformer-based model architecture. It initializes with
23
+ Andromeda is a transformer-based model architecture. It initializes with
22
24
a Transformer and AutoregressiveWrapper with default or user-specified parameters.
23
25
24
26
Initialize the model with specified or default parameters.
@@ -41,37 +43,33 @@ class CM3(Module):
41
43
- attn_qk_norm: Attention query-key normalization
42
44
- attn_qk_norm_dim_scale: Attention query-key normalization dimension scale
43
45
"""
46
+
44
47
def __init__ (
45
- self ,
46
- num_tokens = 50432 ,
47
- max_seq_len = 8192 ,
48
- dim = 2560 ,
49
- depth = 32 ,
50
- dim_head = 128 ,
51
- heads = 24 ,
52
- use_abs_pos_emb = False ,
53
- alibi_pos_bias = True ,
54
- alibi_num_heads = 12 ,
55
- rotary_xpos = True ,
56
- attn_flash = True ,
57
- image_size = 256 ,
58
- patch_size = 32 ,
59
- attn_one_kv_head = True , # multiquery attention
60
- qk_norm = True ,
61
- attn_qk_norm = True ,
62
- attn_qk_norm_dim_scale = True ,
63
- ):
48
+ self ,
49
+ num_tokens = 50432 ,
50
+ max_seq_len = 8192 ,
51
+ dim = 2560 ,
52
+ depth = 32 ,
53
+ dim_head = 128 ,
54
+ heads = 24 ,
55
+ use_abs_pos_emb = False ,
56
+ alibi_pos_bias = True ,
57
+ alibi_num_heads = 12 ,
58
+ rotary_xpos = True ,
59
+ attn_flash = True ,
60
+ image_size = 256 ,
61
+ patch_size = 32 ,
62
+ attn_one_kv_head = True , # multiquery attention
63
+ qk_norm = True ,
64
+ attn_qk_norm = True ,
65
+ attn_qk_norm_dim_scale = True ,
66
+ ):
64
67
super ().__init__ ()
65
68
66
69
self .encoder = ViTransformerWrapper (
67
70
image_size = image_size ,
68
71
patch_size = patch_size ,
69
- attn_layers = Encoder (
70
- dim = dim ,
71
- depth = depth ,
72
- dim_head = dim_head ,
73
- heads = heads
74
- )
72
+ attn_layers = Encoder (dim = dim , depth = depth , dim_head = dim_head , heads = heads ),
75
73
)
76
74
77
75
self .transformer = Transformer (
@@ -91,30 +89,32 @@ def __init__(
91
89
# qk_norm=qk_norm,
92
90
# attn_qk_norm=attn_qk_norm,
93
91
# attn_qk_norm_dim_scale=attn_qk_norm_dim_scale,
94
- cross_attend = True
95
- )
92
+ cross_attend = True ,
93
+ ),
96
94
)
97
95
98
96
self .decoder = AutoregressiveWrapper (self .transformer )
99
97
100
98
def mask_and_relocate (self , text_tokens ):
101
- #mask image span
102
- text_tokens = text_tokens .masked_fill (text_tokens == self .im_idx , self .mask_token )
99
+ # mask image span
100
+ text_tokens = text_tokens .masked_fill (
101
+ text_tokens == self .im_idx , self .mask_token
102
+ )
103
103
104
- #relocate to end
105
- image_span = text_tokens [text_tokens == self .im_end_idx ].unsqueeze (1 )
104
+ # relocate to end
105
+ image_span = text_tokens [text_tokens == self .im_end_idx ].unsqueeze (1 )
106
106
text_tokens = torch .cat ([text_tokens , image_span ], dim = 1 )
107
107
return text_tokens
108
-
108
+
109
109
def cm3_loss (self , log_probs , labels ):
110
- #cm3 loss prediction
110
+ # cm3 loss prediction
111
111
loss = nn .NLLLoss ()(log_probs , labels )
112
112
return loss
113
113
114
114
# def forward(self, text_tokens, img, **kwargs):
115
115
# try:
116
116
# encoded_img = self.encoder(img, return_embeddings=True)
117
-
117
+
118
118
# #mask and relocate image span in text tokens
119
119
# text_tokens = self.mask_and_relocate(text_tokens)
120
120
@@ -134,10 +134,9 @@ def cm3_loss(self, log_probs, labels):
134
134
# raise
135
135
136
136
def forward (self , img , text ):
137
- try :
137
+ try :
138
138
encoded = self .encoder (img , return_embeddings = True )
139
139
return self .decoder (text , context = encoded )
140
140
except Exception as error :
141
141
print (f"Failed in forward method: { error } " )
142
142
raise
143
-
0 commit comments