Skip to content

Commit 39aacd0

Browse files
author
Kye
committed
{MODULE LAYERS]
1 parent e15b72d commit 39aacd0

File tree

2 files changed

+9
-16
lines changed

2 files changed

+9
-16
lines changed

example.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import torch
1+
import torch
22
from screenai.main import ScreenAI
33

44
# Create a tensor
@@ -7,20 +7,20 @@
77

88
# Model
99
model = ScreenAI(
10-
patch_size=(4, 6),
10+
patch_size=16,
1111
image_size=224,
1212
dim=512,
1313
depth=6,
1414
heads=8,
1515
vit_depth=4,
1616
multi_modal_encoder_depth=4,
1717
llm_decoder_depth=4,
18-
mm_encoder_ff_mult=4
18+
mm_encoder_ff_mult=4,
1919
)
2020

2121

2222
# Forward
23-
out = model(image, text)
23+
out = model(text, image)
2424

2525
# Print the output shape
2626
print(out.shape)

screenai/main.py

+5-12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Tuple
2-
31
import torch
42
import torch.distributed as dist
53
import torch.nn.functional as F
@@ -211,7 +209,7 @@ def forward(self, x, context):
211209
class ScreenAI(nn.Module):
212210
def __init__(
213211
self,
214-
patch_size: Tuple[int, int] = (16, 16),
212+
patch_size: int, # Tuple[int, int] = (16, 16),
215213
image_size: int = 224,
216214
dim: int = 512,
217215
depth: int = 6,
@@ -233,12 +231,6 @@ def __init__(
233231
self.multi_modal_encoder_depth = multi_modal_encoder_depth
234232
self.llm_decoder_depth = llm_decoder_depth
235233

236-
# Aspect ratio preserving gride with max 25 patches, split up the image into patches
237-
self.grid = (
238-
image_size // patch_size[0],
239-
image_size // patch_size[1],
240-
)
241-
242234
# Patch embedding
243235
self.patch_embedding = nn.Conv2d(
244236
3, dim, patch_size, patch_size
@@ -285,14 +277,15 @@ def __init__(
285277
for _ in range(llm_decoder_depth)
286278
)
287279

288-
def forward(self, img: Tensor, text: Tensor) -> Tensor:
280+
def forward(self, text: Tensor, img: Tensor) -> Tensor:
289281
# Image patch
290282
img = rearrange(
291283
img,
292284
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
293285
p1=self.patch_size[0],
294286
p2=self.patch_size[1],
295287
)
288+
print(f"Image patch shape: {img.shape}")
296289

297290
# vit
298291
img = self.vit(img, return_embeddings=True)
@@ -306,8 +299,8 @@ def forward(self, img: Tensor, text: Tensor) -> Tensor:
306299

307300
# T5 Multimodal encoder
308301
for attn, ff in self.mme_layers:
309-
x, _, _ = attn(x, x, x)
310-
x = ff(x)
302+
x, _, _ = attn(x, x, x) + x
303+
x = ff(x) + x
311304

312305
# Pass the k, v values into the cross attention of llm
313306
for cross_attn, attn in self.llm_layers:

0 commit comments

Comments
 (0)