We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents f2901a9 + 3ad510a commit 2ae8d3aCopy full SHA for 2ae8d3a
src/boltz/model/layers/outer_product_mean.py
@@ -85,6 +85,7 @@ def forward(self, m: Tensor, mask: Tensor, chunk_size: int = None) -> Tensor:
85
z_out = z_out + z.to(m) @ sliced_weight_proj_o.T
86
return z_out
87
else:
88
+ mask = mask[:, :, None, :] * mask[:, :, :, None]
89
num_mask = mask.sum(1).clamp(min=1)
90
z = torch.einsum("bsic,bsjd->bijcd", a.float(), b.float())
91
z = z.reshape(*z.shape[:3], -1)
0 commit comments