Skip to content

Conversation

@tyb0807
Copy link
Contributor

@tyb0807 tyb0807 commented Jan 3, 2026

Fixes #600. Requires #667.

@tyb0807 tyb0807 requested review from ftynse and tgymnich January 3, 2026 00:11
Comment on lines +167 to +183

@tkw.wave(constraints)
def matmul(
a: tkl.Memory[M, K, ADDRESS_SPACE, dtype],
b: tkl.Memory[N, K, ADDRESS_SPACE, dtype],
c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32],
):
c_reg = tkl.Register[M, N, tkl.f32](0.0)

@tkw.iterate(K, init_args=[c_reg])
def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
a_reg = tkw.read(a)
b_reg = tkw.read(b)
acc = tkw.mma(a_reg, b_reg, acc)
return acc

tkw.write(repeat, c)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use one from templates?

# Apply Water PassManager lowering
lowered_mlir = apply_water_middle_end_passes(wave_dialect_mlir)

print(lowered_mlir)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have FileCheck comments here so it doesn't look like forgotten debug output.

compile_to_mlir=True,
location_capture_config=LocationCaptureConfig(level=LocationCaptureLevel.NONE),
enforce_locations=False,
print_mlir=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to print mlir?

b_tensor = device_randn(n, k, dtype=torch.float16) # Note: transposed in matmul
c_tensor = device_zeros(m, n, dtype=torch.float32)

# Expected result (CPU computation)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code above creates tensors on device, why is this called CPU computation?

c_tensor = device_zeros(m, n, dtype=torch.float32)

# Expected result (CPU computation)
expected = torch.matmul(a_tensor.float(), b_tensor.T.float())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a bad idea to compute expected values with a higher precision than actual values.


compiled_e2e(a_tensor, b_tensor, c_tensor)

assert_close(c_tensor, expected, rtol=1e-3, atol=1e-3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1e-3 looks a bit too lax, do we really need it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Make matmul kernel lowered through water and run by IREE

2 participants