Skip to content

Commit f4e589d

Browse files
committed
Fix broken tests
1 parent f00e291 commit f4e589d

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/triton_call_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
@triton.jit
3333
def add_kernel(x_ptr, y_ptr, output_ptr,
34-
block_size: tl.constexpr, n_elements: tl.constexpr):
34+
block_size: tl.constexpr, n_elements: tl.constexpr):
3535
pid = tl.program_id(axis=0) # we use a 1d launch grid so axis is 0
3636
block_start = pid * block_size
3737
offsets = block_start + tl.arange(0, block_size)
@@ -98,15 +98,15 @@ class TritonKernelCallTest(parameterized.TestCase):
9898
])
9999
def test_add_vectors(self, size, dtype, block_size):
100100

101-
grid = lambda meta: (size // meta["BLOCK_SIZE"] + 1,)
101+
grid = lambda meta: (size // meta["block_size"] + 1,)
102102
k1, k2 = random.split(random.PRNGKey(0), 2)
103103
if dtype in {"float32", "float16", "float64"}:
104104
x, y = random.normal(k1, [size], dtype=dtype), random.normal(k2, [size], dtype=dtype)
105105
elif dtype in {"int32", "int64"}:
106106
x, y = random.randint(k1, [size], -100, 100, dtype=dtype), random.randint(k2, [size], -100, 100, dtype=dtype)
107107

108108
out = triton_call(x, y, kernel=add_kernel, out_shape=x,
109-
grid=grid, BLOCK_SIZE=block_size, n_elements=size)
109+
grid=grid, block_size=block_size, n_elements=size)
110110
expected = x + y
111111
np.testing.assert_allclose(out, expected)
112112

0 commit comments

Comments
 (0)