Skip to content

Commit 2ae8d3a

Browse files
committed
Merge branch 'main' of github.com:jwohlwend/boltz
2 parents f2901a9 + 3ad510a commit 2ae8d3a

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

src/boltz/model/layers/outer_product_mean.py

+1
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def forward(self, m: Tensor, mask: Tensor, chunk_size: int = None) -> Tensor:
8585
z_out = z_out + z.to(m) @ sliced_weight_proj_o.T
8686
return z_out
8787
else:
88+
mask = mask[:, :, None, :] * mask[:, :, :, None]
8889
num_mask = mask.sum(1).clamp(min=1)
8990
z = torch.einsum("bsic,bsjd->bijcd", a.float(), b.float())
9091
z = z.reshape(*z.shape[:3], -1)

0 commit comments

Comments
 (0)