Skip to content

Commit 252a3b1

Browse files
mark14wufacebook-github-bot
authored andcommitted
Fix hardcoded shape in low_mem_dropout benchmark (#2475)
Summary: Pull Request resolved: #2475 Reviewed By: htyu Differential Revision: D63653081 Pulled By: xuzhao9 fbshipit-source-id: 8d840986779b6124cbccc2425c24e2b892d55ce4
1 parent 611bf70 commit 252a3b1

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torchbenchmark/operators/low_mem_dropout/operator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def triton_dropout(self, p, x):
3838
n_elements = x.numel()
3939
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
4040

41-
x_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda()
41+
x_keep = (torch.rand(size=(n_elements,)) > p).to(torch.int32).cuda()
4242

4343
def _inner():
4444
return _triton_dropout[grid](

0 commit comments

Comments
 (0)