Skip to content

[WIP][Kernels] Contiguous Group GeMM #1036

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
47b8799
first draft, not working
lessw2020 Mar 30, 2025
a4c1e7c
forward pass working, matches PyTorch ref
lessw2020 Mar 30, 2025
e394886
naming cleanup
lessw2020 Mar 30, 2025
812f06d
naming cleanup
lessw2020 Mar 30, 2025
9ed651e
remove tma version for now...
lessw2020 Mar 30, 2025
578cab4
start backwards..but forwards having numerics issues as well.
lessw2020 Mar 30, 2025
d7d2679
ensure group_index is calculated properly
lessw2020 Mar 30, 2025
a63865a
small test passing, med fails
lessw2020 Mar 31, 2025
24cd08b
all sizes passing - use debug.py
lessw2020 Mar 31, 2025
cfa5773
update pytorch reference. add alignedbench but not working
lessw2020 Mar 31, 2025
3bc93b6
add grid stride kernel but not passing on large sizes
lessw2020 Mar 31, 2025
8440a90
add demo.py to test token input prep
lessw2020 Apr 1, 2025
1646119
fixed demo input prep - now works with cg_forward
lessw2020 Apr 1, 2025
e680f0f
backward pass now working - dx, dw
lessw2020 Apr 1, 2025
e5311e9
full e2e MoE demo, triton restore not yet accurate
lessw2020 Apr 1, 2025
fc0c8ef
remove non working backwards kernels
lessw2020 Apr 1, 2025
4a5675b
get router gradients working for non triton prep path
lessw2020 Apr 2, 2025
e720e91
add first unit test suite
lessw2020 Apr 2, 2025
9f715ac
update early config prune to return min if none found, add additional…
lessw2020 Apr 2, 2025
323eebc
consolidate standard config, early config prune, cuda utils into sep …
lessw2020 Apr 2, 2025
6c4bdcd
first pass at sorting kernel, fails unit testing
lessw2020 Apr 3, 2025
07c7c8f
sort update
lessw2020 Apr 3, 2025
872638d
add cpp extension
lessw2020 Apr 4, 2025
1ff3603
remove cpu gpu sync, add unit test
lessw2020 Apr 4, 2025
21ac60a
3 unit test failures of 9
lessw2020 Apr 4, 2025
7207639
update to persistent forward with L2 caching optimization
lessw2020 Apr 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
725 changes: 725 additions & 0 deletions torchtitan/experiments/kernels/contiguous_group_gemm/backdebug.py

Large diffs are not rendered by default.

573 changes: 573 additions & 0 deletions torchtitan/experiments/kernels/contiguous_group_gemm/benchmark.py

Large diffs are not rendered by default.

665 changes: 665 additions & 0 deletions torchtitan/experiments/kernels/contiguous_group_gemm/cg_backward.py

Large diffs are not rendered by default.

450 changes: 450 additions & 0 deletions torchtitan/experiments/kernels/contiguous_group_gemm/cg_forward.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch


# Simple reference implementation for verification
def pytorch_reference(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
expert_indices: torch.Tensor,
group_size_m: int = 128,
) -> torch.Tensor:
"""
Reference implementation using PyTorch for verification.
"""
M_total, K = inputs.shape
num_experts, N, _ = expert_weights.shape

output = torch.empty((M_total, N), device=inputs.device, dtype=inputs.dtype)

# Process each group
for i in range(0, M_total, group_size_m):
end_idx = min(i + group_size_m, M_total)

# Get expert index for this group
expert_idx = expert_indices[i].item()

# Get expert weights
expert_weight = expert_weights[expert_idx]

# Compute output for this group
output[i:end_idx] = torch.matmul(inputs[i:end_idx], expert_weight.T)

return output
Loading
Loading