|
31 | 31 |
|
32 | 32 | @triton.jit
|
33 | 33 | def add_kernel(x_ptr, y_ptr, output_ptr,
|
34 |
| - block_size: tl.constexpr, n_elements: tl.constexpr): |
| 34 | + block_size: tl.constexpr, n_elements: tl.constexpr): |
35 | 35 | pid = tl.program_id(axis=0) # we use a 1d launch grid so axis is 0
|
36 | 36 | block_start = pid * block_size
|
37 | 37 | offsets = block_start + tl.arange(0, block_size)
|
@@ -98,15 +98,15 @@ class TritonKernelCallTest(parameterized.TestCase):
|
98 | 98 | ])
|
99 | 99 | def test_add_vectors(self, size, dtype, block_size):
|
100 | 100 |
|
101 |
| - grid = lambda meta: (size // meta["BLOCK_SIZE"] + 1,) |
| 101 | + grid = lambda meta: (size // meta["block_size"] + 1,) |
102 | 102 | k1, k2 = random.split(random.PRNGKey(0), 2)
|
103 | 103 | if dtype in {"float32", "float16", "float64"}:
|
104 | 104 | x, y = random.normal(k1, [size], dtype=dtype), random.normal(k2, [size], dtype=dtype)
|
105 | 105 | elif dtype in {"int32", "int64"}:
|
106 | 106 | x, y = random.randint(k1, [size], -100, 100, dtype=dtype), random.randint(k2, [size], -100, 100, dtype=dtype)
|
107 | 107 |
|
108 | 108 | out = triton_call(x, y, kernel=add_kernel, out_shape=x,
|
109 |
| - grid=grid, BLOCK_SIZE=block_size, n_elements=size) |
| 109 | + grid=grid, block_size=block_size, n_elements=size) |
110 | 110 | expected = x + y
|
111 | 111 | np.testing.assert_allclose(out, expected)
|
112 | 112 |
|
|
0 commit comments