Skip to content

Commit 573c1db

Browse files
authored
Merge pull request #149 from shenoynikhil/chunk-fix
2 parents 1ccc6b4 + 5f640a3 commit 573c1db

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

src/boltz/model/layers/outer_product_mean.py

+2
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def forward(self, m: Tensor, mask: Tensor, chunk_size: int = None) -> Tensor:
8383
z_out = z.to(m) @ sliced_weight_proj_o.T
8484
else:
8585
z_out = z_out + z.to(m) @ sliced_weight_proj_o.T
86+
87+
z_out = z_out + self.proj_o.bias # add bias
8688
return z_out
8789
else:
8890
mask = mask[:, :, None, :] * mask[:, :, :, None]

tests/model/layers/test_outer_product_mean.py

+2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ def setUp(self):
2121
for name, param in self.layer.named_parameters():
2222
nn.init.normal_(param, mean=1., std=1.)
2323

24+
# Set to eval mode
25+
self.layer.eval()
2426

2527
def test_chunk(self):
2628
chunk_sizes = [16, 33, 64, 83, 100]

0 commit comments

Comments
 (0)