1
- from typing import Tuple
2
-
3
1
import torch
4
2
import torch .distributed as dist
5
3
import torch .nn .functional as F
@@ -211,7 +209,7 @@ def forward(self, x, context):
211
209
class ScreenAI (nn .Module ):
212
210
def __init__ (
213
211
self ,
214
- patch_size : Tuple [int , int ] = (16 , 16 ),
212
+ patch_size : int , # Tuple[int, int] = (16, 16),
215
213
image_size : int = 224 ,
216
214
dim : int = 512 ,
217
215
depth : int = 6 ,
@@ -233,12 +231,6 @@ def __init__(
233
231
self .multi_modal_encoder_depth = multi_modal_encoder_depth
234
232
self .llm_decoder_depth = llm_decoder_depth
235
233
236
- # Aspect ratio preserving gride with max 25 patches, split up the image into patches
237
- self .grid = (
238
- image_size // patch_size [0 ],
239
- image_size // patch_size [1 ],
240
- )
241
-
242
234
# Patch embedding
243
235
self .patch_embedding = nn .Conv2d (
244
236
3 , dim , patch_size , patch_size
@@ -285,14 +277,15 @@ def __init__(
285
277
for _ in range (llm_decoder_depth )
286
278
)
287
279
288
- def forward (self , img : Tensor , text : Tensor ) -> Tensor :
280
+ def forward (self , text : Tensor , img : Tensor ) -> Tensor :
289
281
# Image patch
290
282
img = rearrange (
291
283
img ,
292
284
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)" ,
293
285
p1 = self .patch_size [0 ],
294
286
p2 = self .patch_size [1 ],
295
287
)
288
+ print (f"Image patch shape: { img .shape } " )
296
289
297
290
# vit
298
291
img = self .vit (img , return_embeddings = True )
@@ -306,8 +299,8 @@ def forward(self, img: Tensor, text: Tensor) -> Tensor:
306
299
307
300
# T5 Multimodal encoder
308
301
for attn , ff in self .mme_layers :
309
- x , _ , _ = attn (x , x , x )
310
- x = ff (x )
302
+ x , _ , _ = attn (x , x , x ) + x
303
+ x = ff (x ) + x
311
304
312
305
# Pass the k, v values into the cross attention of llm
313
306
for cross_attn , attn in self .llm_layers :
0 commit comments