diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt index 2befd78b41b..827283a9f5c 100644 --- a/backends/cuda/CMakeLists.txt +++ b/backends/cuda/CMakeLists.txt @@ -109,7 +109,9 @@ set(_aoti_cuda_shim_sources runtime/shims/memory.cpp # Only build int4mm shim when CUDA language/toolchain is available. if(CMAKE_CUDA_COMPILER) - list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu) + list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu + runtime/shims/sort.cu + ) endif() add_library(aoti_cuda_shims SHARED ${_aoti_cuda_shim_sources}) diff --git a/backends/cuda/benchmarks/benchmark_moe.py b/backends/cuda/benchmarks/benchmark_moe.py new file mode 100644 index 00000000000..3b3c672dc50 --- /dev/null +++ b/backends/cuda/benchmarks/benchmark_moe.py @@ -0,0 +1,336 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Benchmark the Triton fused MoE kernel against eager and torch.compile baselines. + +Measures latency across prompt lengths matching the Qwen3.5-35B-A3B model +(hidden_size=2048, num_experts=256, top_k=8, intermediate_size=512, +INT4 weight-only quantization with group_size=128). + +Usage: + python benchmark_moe.py + python benchmark_moe.py --prompt-lengths 1,8,64,512 --num_iters 200 +""" + +import argparse +from functools import partial + +import torch +from triton.testing import do_bench + +import executorch.backends.cuda.triton.kernels # noqa: F401 — registers triton ops + + +# -- Qwen3.5-35B-A3B defaults ------------------------------------------------ + +DEFAULTS = dict( + num_experts=256, + top_k=8, + hidden_size=2048, + intermediate_size=512, + group_size=128, +) + +PROMPT_LENGTHS = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4095] + + +# -- Weight / input generation ----------------------------------------------- + + +def _make_int4_weights(E, N, K, group_size, device="cuda"): + """Generate random packed INT4 weights and per-group scales. + + Returns: + w: [E, N, K//2] int8 — two INT4 values packed per byte + scale: [E, N, K//group_size] bf16 + """ + vals = torch.randint(0, 16, (E, N, K), dtype=torch.uint8, device=device) + low = vals[:, :, 0::2] + high = vals[:, :, 1::2] + packed = (high << 4) | low + w = packed.to(torch.int8) + + scale = torch.randn(E, N, K // group_size, device=device, dtype=torch.bfloat16) * 0.01 + return w, scale + + +# -- Dequantization ---------------------------------------------------------- + + +def _dequant_int4(w_packed, scale, group_size): + """Unpack INT4 weights and dequantize. + + w_packed: [E, N, K//2] int8 + scale: [E, N, K//group_size] bf16 + Returns: [E, N, K] bf16 + """ + w_uint8 = w_packed.to(torch.uint8) + low = (w_uint8 & 0xF).to(torch.float32) + high = ((w_uint8 >> 4) & 0xF).to(torch.float32) + E, N, Khalf = w_packed.shape + K = Khalf * 2 + vals = torch.empty(E, N, K, device=w_packed.device, dtype=torch.float32) + vals[:, :, 0::2] = low + vals[:, :, 1::2] = high + vals = vals - 8.0 + scale_expanded = scale.float().repeat_interleave(group_size, dim=2)[:, :, :K] + return (vals * scale_expanded).to(torch.bfloat16) + + +# -- Backends ----------------------------------------------------------------- + + +def _run_eager(hidden_states, w1, w1_scale, w2, w2_scale, + topk_weights, topk_ids, top_k, num_experts, group_size): + """Loop-based eager MoE — correctness reference only (not benchmarked).""" + M, K = hidden_states.shape + inter = w2.shape[2] * 2 + + w1_deq = _dequant_int4(w1, w1_scale, group_size) + w2_deq = _dequant_int4(w2, w2_scale, group_size) + + output = torch.zeros(M, K, device=hidden_states.device, dtype=torch.bfloat16) + for i in range(M): + for j in range(top_k): + expert_id = topk_ids[i, j].item() + weight = topk_weights[i, j] + x = hidden_states[i:i+1] @ w1_deq[expert_id].T + gate = x[:, :inter] + up = x[:, inter:] + x = torch.nn.functional.silu(gate) * up + x = x @ w2_deq[expert_id].T + output[i] += weight * x.squeeze(0) + return output + + +def _run_eager_vectorized(hidden_states, w1, w1_scale, w2, w2_scale, + topk_weights, topk_ids, top_k, num_experts, group_size): + """Vectorized eager — gather + bmm, no Python loops.""" + M, K = hidden_states.shape + inter = w2.shape[2] * 2 + + w1_deq = _dequant_int4(w1, w1_scale, group_size) + w2_deq = _dequant_int4(w2, w2_scale, group_size) + + flat_ids = topk_ids.reshape(-1) + hs_rep = hidden_states.unsqueeze(1).expand(-1, top_k, -1).reshape(M * top_k, K) + gemm1_out = torch.bmm( + hs_rep.unsqueeze(1), w1_deq[flat_ids].transpose(1, 2) + ).squeeze(1) + + gate = gemm1_out[:, :inter] + up = gemm1_out[:, inter:] + act = torch.nn.functional.silu(gate) * up + + gemm2_out = torch.bmm( + act.unsqueeze(1), w2_deq[flat_ids].transpose(1, 2) + ).squeeze(1) + + return (gemm2_out.view(M, top_k, K) * topk_weights.unsqueeze(-1)).sum(dim=1) + + +_compiled_fn = None + + +def _run_compiled(hidden_states, w1, w1_scale, w2, w2_scale, + topk_weights, topk_ids, top_k, num_experts, group_size): + global _compiled_fn + if _compiled_fn is None: + _compiled_fn = torch.compile(_run_eager_vectorized) + return _compiled_fn( + hidden_states, w1, w1_scale, w2, w2_scale, + topk_weights, topk_ids, top_k, num_experts, group_size, + ) + + +def _run_triton(hidden_states, w1, w1_scale, w2, w2_scale, + topk_weights, topk_ids, top_k, num_experts, group_size): + return torch.ops.triton.fused_moe( + hidden_states, w1, w1_scale, w2, w2_scale, + topk_weights, topk_ids, + top_k=top_k, num_experts=num_experts, group_size=group_size, + ) + + +def _run_triton_batched(hidden_states, w1, w1_scale, w2, w2_scale, + topk_weights, topk_ids, top_k, num_experts, group_size): + from executorch.backends.cuda.triton.kernels.fused_moe import fused_moe_batched + return fused_moe_batched( + hidden_states, w1, w1_scale, w2, w2_scale, + topk_weights, topk_ids, + top_k=top_k, num_experts=num_experts, group_size=group_size, + ) + + +BACKENDS = { + "eager_vec": ("Eager (vec)", _run_eager_vectorized), + "compile": ("Compile", _run_compiled), + "triton": ("Triton fused", _run_triton), + "triton_batched": ("Triton batched", _run_triton_batched), +} + +# Backends that dequantize all experts (OOM at large M with 256 experts) +_MAY_OOM = {"eager_vec", "compile"} + + +# -- Helpers ------------------------------------------------------------------ + + +def _max_abs_error(out, ref): + return (out.float() - ref.float()).abs().max().item() + + +def _bench_ms(fn, num_warmup, num_iters): + return do_bench(fn, warmup=num_warmup, rep=num_iters, return_mode="median") + + +def _try_bench(run_fn, args, num_warmup, num_iters): + fn = partial(run_fn, **args) + try: + fn() + return _bench_ms(fn, num_warmup, num_iters) + except torch.OutOfMemoryError: + torch.cuda.empty_cache() + return None + + +# -- Main --------------------------------------------------------------------- + + +@torch.inference_mode() +def run_benchmark( + prompt_lengths, + num_experts, top_k, hidden_size, intermediate_size, group_size, + num_warmup, num_iters, +): + backends = [(name, *BACKENDS[name]) for name in BACKENDS] + + device_name = torch.cuda.get_device_name() + print() + print("=" * 100) + print(f"Fused MoE Benchmark — Qwen3.5-35B-A3B (W4A16)") + print(f" Device: {device_name}") + print(f" Experts: {num_experts}, Top-K: {top_k}, Hidden: {hidden_size}, " + f"Intermediate: {intermediate_size}, Group: {group_size}") + print(f" Warmup: {num_warmup}, Iters: {num_iters}") + print(f" Backends: {', '.join(label for _, label, _ in backends)}") + print("=" * 100) + + # Generate weights once (shared across prompt lengths) + w1, w1_scale = _make_int4_weights( + num_experts, 2 * intermediate_size, hidden_size, group_size + ) + w2, w2_scale = _make_int4_weights( + num_experts, hidden_size, intermediate_size, group_size + ) + + # Column layout: Shape | backend1 | backend2 | ... (dynamic widths) + col_specs = [("M (tokens)", "", 10)] + for _, label, _ in backends: + col_specs.append((label, "(ms)", max(8, len(label)))) + + col_widths = [max(len(h), len(u), mw) for h, u, mw in col_specs] + + header = " | ".join( + f"{h:<{w}}" if i == 0 else f"{h:>{w}}" + for i, ((h, _, _), w) in enumerate(zip(col_specs, col_widths)) + ) + units = " | ".join( + f"{'':>{w}}" if i == 0 else f"{u:>{w}}" + for i, ((_, u, _), w) in enumerate(zip(col_specs, col_widths)) + ) + print(header) + print(units) + print("-" * len(header)) + + for M in prompt_lengths: + hidden_states = torch.randn(M, hidden_size, device="cuda", dtype=torch.bfloat16) + router_logits = torch.randn(M, num_experts, device="cuda", dtype=torch.float32) + topk_w, topk_i = torch.topk(router_logits, top_k, dim=-1) + topk_w = torch.softmax(topk_w, dim=-1) + topk_i = topk_i.to(torch.int64) + + common_args = dict( + hidden_states=hidden_states, + w1=w1, w1_scale=w1_scale, + w2=w2, w2_scale=w2_scale, + topk_weights=topk_w, topk_ids=topk_i, + top_k=top_k, num_experts=num_experts, group_size=group_size, + ) + + # Correctness: triton vs loop-based eager reference. + # Only check at small M to avoid slow eager loop + OOM on large M. + if M <= 64: + ref_out = _run_eager(**common_args) + tri_out = _run_triton(**common_args) + err = _max_abs_error(tri_out, ref_out) + assert err < 1.0e-1, ( + f"Triton vs eager mismatch at M={M}: " + f"max abs error {err:.3e} >= 1.0e-1" + ) + del ref_out, tri_out + + # Benchmark + times = {} + for name, _label, run_fn in backends: + times[name] = _try_bench(run_fn, common_args, num_warmup, num_iters) + + ci = 0 + row_parts = [f"{f'M={M}':<{col_widths[ci]}}"] + ci += 1 + for name, _, _ in backends: + t = times[name] + w = col_widths[ci] + row_parts.append(f"{t:>{w}.3f}" if t is not None else f"{'OOM':>{w}}") + ci += 1 + print(" | ".join(row_parts)) + + del hidden_states, topk_w, topk_i + torch.cuda.empty_cache() + + print("-" * len(header)) + print() + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark Triton fused MoE vs eager/compile baselines" + ) + parser.add_argument("--num-experts", type=int, default=DEFAULTS["num_experts"]) + parser.add_argument("--top-k", type=int, default=DEFAULTS["top_k"]) + parser.add_argument("--hidden-size", type=int, default=DEFAULTS["hidden_size"]) + parser.add_argument("--intermediate-size", type=int, default=DEFAULTS["intermediate_size"]) + parser.add_argument("--group-size", type=int, default=DEFAULTS["group_size"]) + parser.add_argument("--num_warmup", type=int, default=25) + parser.add_argument("--num_iters", type=int, default=100) + parser.add_argument( + "--prompt-lengths", + type=str, + default=None, + help="Comma-separated list of prompt lengths (default: standard sweep)", + ) + args = parser.parse_args() + + prompt_lengths = PROMPT_LENGTHS + if args.prompt_lengths: + prompt_lengths = [int(x.strip()) for x in args.prompt_lengths.split(",")] + + run_benchmark( + prompt_lengths=prompt_lengths, + num_experts=args.num_experts, + top_k=args.top_k, + hidden_size=args.hidden_size, + intermediate_size=args.intermediate_size, + group_size=args.group_size, + num_warmup=args.num_warmup, + num_iters=args.num_iters, + ) + + +if __name__ == "__main__": + main() diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index 661b4f2b960..061b0d6a29a 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -145,6 +145,7 @@ def save_data_externally(cls) -> bool: def get_supported_fallback_kernels(cls) -> Dict[str, Any]: return { "at::_ops::_weight_int4pack_mm::call": None, + "at::_ops::sort_stable::call": None, } @classmethod diff --git a/backends/cuda/runtime/TARGETS b/backends/cuda/runtime/TARGETS index 678cc1d6932..726f89c8125 100644 --- a/backends/cuda/runtime/TARGETS +++ b/backends/cuda/runtime/TARGETS @@ -33,6 +33,7 @@ runtime.cxx_library( "shims/cuda_guard.cpp", "shims/int4mm.cu", "shims/memory.cpp", + "shims/sort.cu", "shims/tensor_attribute.cpp", ], headers = [ @@ -40,6 +41,7 @@ runtime.cxx_library( "shims/int4mm.cuh", "shims/int4mm.h", "shims/memory.h", + "shims/sort.h", "shims/tensor_attribute.h", "utils.h", ], diff --git a/backends/cuda/runtime/shims/sort.cu b/backends/cuda/runtime/shims/sort.cu new file mode 100644 index 00000000000..4979d94c0cb --- /dev/null +++ b/backends/cuda/runtime/shims/sort.cu @@ -0,0 +1,245 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace executorch::backends::cuda { + +namespace c10_slim = executorch::backends::aoti::slim::c10; + +namespace { + +__global__ void init_indices_kernel( + int64_t* data, + int64_t slice_size, + int64_t total) { + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx < total) { + data[idx] = idx % slice_size; + } +} + +template +void sort_slice_impl( + T* keys, + int64_t* values, + int64_t n, + bool descending, + bool stable, + cudaStream_t stream) { + auto k = thrust::device_pointer_cast(keys); + auto v = thrust::device_pointer_cast(values); + if (stable && descending) { + thrust::stable_sort_by_key( + thrust::cuda::par.on(stream), k, k + n, v, thrust::greater()); + } else if (stable) { + thrust::stable_sort_by_key( + thrust::cuda::par.on(stream), k, k + n, v); + } else if (descending) { + thrust::sort_by_key( + thrust::cuda::par.on(stream), k, k + n, v, thrust::greater()); + } else { + thrust::sort_by_key(thrust::cuda::par.on(stream), k, k + n, v); + } +} + +} // namespace + +#ifdef __cplusplus +extern "C" { +#endif + +AOTITorchError aoti_torch_cuda_sort_stable( + Tensor* self, + int32_t* stable, + int64_t dim, + int32_t descending, + Tensor** ret0, + Tensor** ret1) { + ET_CHECK_OR_RETURN_ERROR( + self != nullptr, + InvalidArgument, + "aoti_torch_cuda_sort_stable: self is null"); + ET_CHECK_OR_RETURN_ERROR( + ret0 != nullptr, + InvalidArgument, + "aoti_torch_cuda_sort_stable: ret0 is null"); + ET_CHECK_OR_RETURN_ERROR( + ret1 != nullptr, + InvalidArgument, + "aoti_torch_cuda_sort_stable: ret1 is null"); + + int64_t ndim = static_cast(self->dim()); + + if (dim < 0) + dim += ndim; + ET_CHECK_OR_RETURN_ERROR( + dim >= 0 && dim < ndim, + InvalidArgument, + "aoti_torch_cuda_sort_stable: dim out of range"); + + ET_CHECK_OR_RETURN_ERROR( + self->is_contiguous(), + NotSupported, + "aoti_torch_cuda_sort_stable: non-contiguous input not supported"); + + int64_t sort_size = self->size(dim); + int64_t total_elements = static_cast(self->numel()); + int64_t num_slices = (sort_size > 0) ? total_elements / sort_size : 0; + + auto stream_result = getCurrentCUDAStream(0); + ET_CHECK_OR_RETURN_ERROR( + stream_result.ok(), + Internal, + "aoti_torch_cuda_sort_stable: failed to get CUDA stream"); + cudaStream_t stream = stream_result.get(); + + // Contiguous strides for output tensors + auto input_sizes = self->sizes(); + std::vector contig_strides(ndim); + if (ndim > 0) { + contig_strides[ndim - 1] = 1; + for (int64_t i = ndim - 2; i >= 0; --i) { + contig_strides[i] = contig_strides[i + 1] * input_sizes[i + 1]; + } + } + + int32_t dtype_val = static_cast(self->dtype()); + + // Allocate output values (same shape/dtype as input) + *ret0 = nullptr; + aoti_torch_empty_strided( + ndim, + input_sizes.data(), + contig_strides.data(), + dtype_val, + static_cast(c10_slim::DeviceType::CUDA), + 0, + ret0); + ET_CHECK_OR_RETURN_ERROR( + *ret0 != nullptr, + Internal, + "aoti_torch_cuda_sort_stable: failed to allocate values tensor"); + + // Copy input data to output values + if (total_elements > 0) { + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpyAsync( + (*ret0)->data_ptr(), + self->data_ptr(), + self->nbytes(), + cudaMemcpyDeviceToDevice, + stream)); + } + + // Allocate output indices (same shape, int64 dtype) + *ret1 = nullptr; + aoti_torch_empty_strided( + ndim, + input_sizes.data(), + contig_strides.data(), + static_cast(c10_slim::ScalarType::Long), + static_cast(c10_slim::DeviceType::CUDA), + 0, + ret1); + ET_CHECK_OR_RETURN_ERROR( + *ret1 != nullptr, + Internal, + "aoti_torch_cuda_sort_stable: failed to allocate indices tensor"); + + // Initialize indices: each slice gets 0, 1, ..., sort_size-1 + if (total_elements > 0) { + int threads = 256; + int blocks = static_cast((total_elements + threads - 1) / threads); + init_indices_kernel<<>>( + static_cast((*ret1)->data_ptr()), + sort_size, + total_elements); + ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR(); + } + + if (sort_size <= 1) { + return Error::Ok; + } + + bool is_stable = (stable != nullptr && *stable != 0); + bool desc = (descending != 0); + + // Require sorting along a contiguous dimension (stride == 1) + int64_t dim_stride = self->stride(dim); + ET_CHECK_OR_RETURN_ERROR( + dim_stride == 1 || ndim == 1, + NotSupported, + "aoti_torch_cuda_sort_stable: sort along non-innermost dim " + "of multi-D tensor not yet supported"); + + auto self_dtype = self->dtype(); + + for (int64_t s = 0; s < num_slices; ++s) { + int64_t offset = s * sort_size; + int64_t* idx_ptr = + static_cast((*ret1)->data_ptr()) + offset; + + switch (self_dtype) { + case c10_slim::ScalarType::Long: { + sort_slice_impl( + static_cast((*ret0)->data_ptr()) + offset, + idx_ptr, + sort_size, + desc, + is_stable, + stream); + break; + } + case c10_slim::ScalarType::Int: { + sort_slice_impl( + static_cast((*ret0)->data_ptr()) + offset, + idx_ptr, + sort_size, + desc, + is_stable, + stream); + break; + } + case c10_slim::ScalarType::Float: { + sort_slice_impl( + static_cast((*ret0)->data_ptr()) + offset, + idx_ptr, + sort_size, + desc, + is_stable, + stream); + break; + } + default: + ET_LOG( + Error, + "aoti_torch_cuda_sort_stable: unsupported dtype %d", + static_cast(self_dtype)); + return Error::InvalidArgument; + } + } + + ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR(); + return Error::Ok; +} + +#ifdef __cplusplus +} +#endif + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/sort.h b/backends/cuda/runtime/shims/sort.h new file mode 100644 index 00000000000..88efee834ca --- /dev/null +++ b/backends/cuda/runtime/shims/sort.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace executorch::backends::cuda { + +using executorch::backends::aoti::AOTITorchError; +using executorch::backends::aoti::Tensor; + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * Sorts a tensor along a given dimension. + * + * @param self Input tensor to sort (any numeric dtype, CUDA device) + * @param stable Pointer to bool — if non-null and *stable != 0, uses stable sort + * @param dim Dimension along which to sort + * @param descending If non-zero, sort in descending order + * @param ret0 Output: sorted values tensor (same shape/dtype as self) + * @param ret1 Output: indices tensor (int64, same shape as self) + * @return AOTITorchError + */ +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_cuda_sort_stable( + Tensor* self, + int32_t* stable, + int64_t dim, + int32_t descending, + Tensor** ret0, + Tensor** ret1); + +#ifdef __cplusplus +} +#endif + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/tests/test_fused_moe.py b/backends/cuda/tests/test_fused_moe.py index ee96a9d06e8..a84d48f0cac 100644 --- a/backends/cuda/tests/test_fused_moe.py +++ b/backends/cuda/tests/test_fused_moe.py @@ -30,6 +30,7 @@ from executorch.backends.cuda.cuda_partitioner import CudaPartitioner from executorch.backends.cuda.triton.kernels.fused_moe import ( fused_moe as triton_fused_moe, + fused_moe_batched as triton_fused_moe_batched, ) from executorch.exir import ( EdgeCompileConfig, @@ -332,6 +333,75 @@ def test_single_expert(self): rel = diff / (ref.float().abs().max().item() + 1e-10) self.assertLess(rel, 0.05, f"token {t}: relative diff {rel:.4f}") + def test_batched_correctness(self): + """Batched kernel matches reference across M values.""" + test_cases = [ + (42, 8, 64, 32, 4, 2, 32, "8tok_small"), + (7, 16, 64, 32, 8, 4, 32, "16tok_8exp_top4"), + (13, 32, 128, 64, 8, 2, 64, "32tok_gs64"), + (55, 64, 64, 32, 4, 2, 32, "64tok"), + (99, 128, 128, 64, 8, 2, 32, "128tok"), + ] + for seed, M, hidden, intermediate, num_experts, top_k, gs, desc in test_cases: + with self.subTest(desc=desc): + torch.manual_seed(seed) + x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda") + w1_weight = torch.randn( + num_experts, 2 * intermediate, hidden, + dtype=torch.bfloat16, device="cuda", + ) + w2_weight = torch.randn( + num_experts, hidden, intermediate, + dtype=torch.bfloat16, device="cuda", + ) + w1, w1s = _quantize_weights_int4(w1_weight.cpu(), gs) + w2, w2s = _quantize_weights_int4(w2_weight.cpu(), gs) + w1, w1s, w2, w2s = w1.cuda(), w1s.cuda(), w2.cuda(), w2s.cuda() + + scores = torch.randn(M, num_experts, device="cuda") + topk_weights, topk_ids = torch.topk(scores, top_k, dim=-1) + topk_weights = topk_weights.softmax(dim=-1).float() + + out = triton_fused_moe_batched( + x, w1, w1s, w2, w2s, topk_weights, topk_ids, + top_k, num_experts, gs, + ) + + w1_dq = _dequantize_int4(w1.cpu(), w1s.cpu(), gs).cuda() + w2_dq = _dequantize_int4(w2.cpu(), w2s.cpu(), gs).cuda() + ref = _reference_moe(x, w1_dq, w2_dq, topk_weights, topk_ids, top_k) + + diff = (out.float() - ref.float()).abs().max().item() + rel = diff / (ref.float().abs().max().item() + 1e-10) + self.assertLess( + rel, 0.05, f"{desc}: relative diff {rel:.4f} (abs {diff:.6f})", + ) + + def test_batched_matches_fused(self): + """Batched kernel matches the existing fused_moe kernel at Qwen-scale dims.""" + E, top_k, K, inter, gs = 256, 8, 2048, 512, 128 + torch.manual_seed(42) + vals = torch.randint(0, 16, (E, 2 * inter, K), dtype=torch.uint8, device="cuda") + w1 = ((vals[:, :, 1::2] << 4) | vals[:, :, 0::2]).to(torch.int8) + w1s = torch.randn(E, 2 * inter, K // gs, device="cuda", dtype=torch.bfloat16) * 0.01 + vals = torch.randint(0, 16, (E, K, inter), dtype=torch.uint8, device="cuda") + w2 = ((vals[:, :, 1::2] << 4) | vals[:, :, 0::2]).to(torch.int8) + w2s = torch.randn(E, K, inter // gs, device="cuda", dtype=torch.bfloat16) * 0.01 + + for M in [16, 64, 256]: + with self.subTest(M=M): + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + logits = torch.randn(M, E, device="cuda", dtype=torch.float32) + tw, ti = torch.topk(logits, top_k, dim=-1) + tw = tw.softmax(dim=-1) + ti = ti.to(torch.int64) + + out_fused = triton_fused_moe(x, w1, w1s, w2, w2s, tw, ti, top_k, E, gs) + out_batched = triton_fused_moe_batched(x, w1, w1s, w2, w2s, tw, ti, top_k, E, gs) + + err = (out_fused.float() - out_batched.float()).abs().max().item() + self.assertLess(err, 0.5, f"M={M}: max abs error {err:.4e}") + def test_export_cuda(self): """Export succeeds and produces non-empty .pte.""" with tempfile.TemporaryDirectory() as tmpdir: diff --git a/backends/cuda/tests/test_sort_shim.py b/backends/cuda/tests/test_sort_shim.py new file mode 100644 index 00000000000..ce95ed326e3 --- /dev/null +++ b/backends/cuda/tests/test_sort_shim.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Test sort CUDA shim for AOTI export. + +The sort shim (sort.cu) provides aoti_torch_cuda_sort_stable, a thrust-based +fallback for aten::sort.stable that Inductor emits when it can't natively lower +sort. This is needed for ops like argsort that decompose to sort_stable. + +Usage: + python -m pytest backends/cuda/tests/test_sort_shim.py -v +""" + +import os +import tempfile +import unittest + +import torch +import torch.nn as nn + +from executorch.backends.cuda.cuda_backend import CudaBackend +from executorch.backends.cuda.cuda_partitioner import CudaPartitioner +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, +) +from executorch.exir.passes import MemoryPlanningPass +from torch.export import export + + +class SortModel(nn.Module): + """Model that uses sort (via argsort) for export testing.""" + + def forward(self, x): + # argsort decomposes to sort_stable in Inductor + return x.argsort(dim=-1) + + +class SortStableModel(nn.Module): + """Model that uses torch.sort directly.""" + + def forward(self, x): + values, indices = torch.sort(x, dim=-1, stable=True) + return values, indices + + +class TestSortShim(unittest.TestCase): + def setUp(self): + if not torch.cuda.is_available(): + self.skipTest("CUDA is not available") + + def test_argsort_export(self): + """argsort exports and produces .pte via AOTI with sort shim.""" + model = SortModel().eval() + x = torch.randn(4, 8, dtype=torch.float32, device="cuda") + + with torch.no_grad(): + ep = export(model, (x,), strict=True) + + with tempfile.TemporaryDirectory() as tmpdir: + specs = [CudaBackend.generate_method_name_compile_spec("forward")] + et_prog = to_edge_transform_and_lower( + ep, + partitioner=[CudaPartitioner(specs)], + compile_config=EdgeCompileConfig( + _check_ir_validity=False, _skip_dim_order=True + ), + ) + et_program = et_prog.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + do_quant_fusion_and_const_prop=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ), + ) + + pte_path = os.path.join(tmpdir, "sort_model.pte") + with open(pte_path, "wb") as f: + et_program.write_to_file(f) + + self.assertTrue(os.path.exists(pte_path)) + self.assertGreater(os.path.getsize(pte_path), 0) + + def test_sort_stable_export(self): + """torch.sort(stable=True) exports and produces .pte via AOTI with sort shim.""" + model = SortStableModel().eval() + x = torch.randn(4, 8, dtype=torch.float32, device="cuda") + + with torch.no_grad(): + ep = export(model, (x,), strict=True) + + with tempfile.TemporaryDirectory() as tmpdir: + specs = [CudaBackend.generate_method_name_compile_spec("forward")] + et_prog = to_edge_transform_and_lower( + ep, + partitioner=[CudaPartitioner(specs)], + compile_config=EdgeCompileConfig( + _check_ir_validity=False, _skip_dim_order=True + ), + ) + et_program = et_prog.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + do_quant_fusion_and_const_prop=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ), + ) + + pte_path = os.path.join(tmpdir, "sort_stable_model.pte") + with open(pte_path, "wb") as f: + et_program.write_to_file(f) + + self.assertTrue(os.path.exists(pte_path)) + self.assertGreater(os.path.getsize(pte_path), 0) + + def test_sort_fallback_registered(self): + """sort_stable is registered as a supported fallback kernel.""" + fallbacks = CudaBackend.get_supported_fallback_kernels() + self.assertIn("at::_ops::sort_stable::call", fallbacks) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/cuda/triton/kernels/__init__.py b/backends/cuda/triton/kernels/__init__.py index e7af2bdaf84..61a56f9de50 100644 --- a/backends/cuda/triton/kernels/__init__.py +++ b/backends/cuda/triton/kernels/__init__.py @@ -4,12 +4,20 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from executorch.backends.cuda.triton.kernels.fused_moe import fused_moe +from executorch.backends.cuda.triton.kernels.fused_moe import ( + fused_moe, + fused_moe_batched, + fused_moe_batched_gemm, + moe_align_block_size, +) from executorch.backends.cuda.triton.kernels.sdpa import sdpa from executorch.backends.cuda.triton.kernels.topk import topk __all__ = [ "fused_moe", + "fused_moe_batched", + "fused_moe_batched_gemm", + "moe_align_block_size", "sdpa", "topk", ] diff --git a/backends/cuda/triton/kernels/fused_moe.py b/backends/cuda/triton/kernels/fused_moe.py index 98a86698bc4..8eee02f03d2 100644 --- a/backends/cuda/triton/kernels/fused_moe.py +++ b/backends/cuda/triton/kernels/fused_moe.py @@ -144,16 +144,29 @@ def _fused_moe_kernel( b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0) b = (b >> b_shifter) & 0xF - # Load per-group scales [BLOCK_SIZE_K, BLOCK_SIZE_N] - scale_ptrs = ( - B_scale - + expert_id * stride_bse - + offs_n[None, :] * stride_bsn - + ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size) * stride_bsk - ) - b_scale = tl.load( - scale_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0 - ).to(tl.float32) + # Load per-group scales and dequantize + if BLOCK_SIZE_K <= group_size: + # All K values in this tile share one scale group — load [1, N] + group_idx = (BLOCK_SIZE_K * k_step) // group_size + scale_ptrs = ( + B_scale + + expert_id * stride_bse + + offs_n[None, :] * stride_bsn + + group_idx * stride_bsk + ) + b_scale = tl.load( + scale_ptrs, mask=n_mask[None, :], other=0.0 + ).to(tl.float32) + else: + scale_ptrs = ( + B_scale + + expert_id * stride_bse + + offs_n[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size) * stride_bsk + ) + b_scale = tl.load( + scale_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0 + ).to(tl.float32) # Dequantize and accumulate: vector-matrix multiply b_dequant = ((b.to(tl.float32) - 8.0) * b_scale).to(compute_type) @@ -252,15 +265,27 @@ def _fused_moe_silu_kernel( b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0) b = (b >> b_shifter) & 0xF - scale_ptrs = ( - B_scale - + expert_id * stride_bse - + offs_n[None, :] * stride_bsn - + ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size) * stride_bsk - ) - b_scale = tl.load( - scale_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0 - ).to(tl.float32) + if BLOCK_SIZE_K <= group_size: + group_idx = (BLOCK_SIZE_K * k_step) // group_size + scale_ptrs = ( + B_scale + + expert_id * stride_bse + + offs_n[None, :] * stride_bsn + + group_idx * stride_bsk + ) + b_scale = tl.load( + scale_ptrs, mask=n_mask[None, :], other=0.0 + ).to(tl.float32) + else: + scale_ptrs = ( + B_scale + + expert_id * stride_bse + + offs_n[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size) * stride_bsk + ) + b_scale = tl.load( + scale_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0 + ).to(tl.float32) b_dequant = ((b.to(tl.float32) - 8.0) * b_scale).to(compute_type) acc += tl.sum(a[:, None].to(compute_type) * b_dequant, axis=0) @@ -334,6 +359,7 @@ def fused_moe( def grid1(meta): return (num_pairs * triton.cdiv(N1, meta["BLOCK_SIZE_N"]),) + # Weight layout: [E, N, K//2]. wrap_triton(_fused_moe_kernel)[grid1]( hidden_states, w1, @@ -410,3 +436,544 @@ def _fused_moe_fake( group_size: int, ) -> torch.Tensor: return torch.empty_like(hidden_states) + + +# --------------------------------------------------------------------------- +# Batched prefill MoE — token sorting + tl.dot (tensor cores) +# --------------------------------------------------------------------------- + +# Fixed BLOCK_M for the batched kernel. Not autotuned because the token +# sorting layout depends on it. 16 is the minimum for tl.dot and wastes +# the least padding with typical Qwen3.5 expert load (~30 tokens/expert). +_BATCHED_BLOCK_M = 16 + + +def moe_align_block_size( + topk_ids: torch.Tensor, + block_size: int, + num_experts: int, +) -> tuple[torch.Tensor, torch.Tensor, int]: + """Sort token-expert pairs by expert and pad to block_size boundaries. + + Given router output topk_ids [M, top_k], produces a flat array of pair + indices grouped by expert with each expert's block padded to a multiple + of block_size. Padding slots use sentinel value M*top_k which maps to + a zero-row appended by the caller. + + All output shapes depend only on (M, top_k, num_experts, block_size) — + no data-dependent shapes — so this is compatible with torch.export / + symbolic tracing. + + Returns: + sorted_token_ids: [max_num_tokens_padded] int64 + expert_ids: [max_num_expert_blocks] int64 + num_tokens_post_padded: scalar int64 tensor + """ + M, top_k = topk_ids.shape + num_pairs = M * top_k + device = topk_ids.device + sentinel = num_pairs # out-of-bounds index -> zero padding row + + # Worst-case output size: every expert gets at least 1 token → + # block_size padding each. With top_k routing, at most min(num_pairs, + # num_experts) experts are active. Worst-case total slots: + max_num_tokens_padded = num_pairs + num_experts * block_size + max_num_expert_blocks = max_num_tokens_padded // block_size + + flat_ids = topk_ids.reshape(-1) # [num_pairs] expert id per pair + + # Per-expert token counts via one_hot+sum (bincount lacks AOTI c-shim) + tokens_per_expert = torch.nn.functional.one_hot( + flat_ids, num_classes=num_experts + ).sum(0) + padded_per_expert = ( + (tokens_per_expert + block_size - 1) // block_size + ) * block_size + + # Prefix sum for expert offsets in the output array + expert_offsets = torch.zeros( + num_experts + 1, dtype=torch.int64, device=device + ) + expert_offsets[1:] = padded_per_expert.cumsum(0) + num_tokens_post_padded = expert_offsets[num_experts] # scalar tensor + + # Pre-allocate at max size, filled with sentinel + sorted_token_ids = torch.full( + (max_num_tokens_padded,), sentinel, dtype=torch.int64, device=device + ) + + # Place each pair at its destination using counting sort. + # For pair i with expert e = flat_ids[i], compute: + # dest = expert_offsets[e] + within_expert_rank[i] + # within_expert_rank[i] = number of pairs j < i with flat_ids[j] == e. + # + # We compute this via exclusive prefix sum within each expert group: + # 1) Create a key that sorts by (expert, pair_index): + # sort_key = flat_ids * num_pairs + arange(num_pairs) + # 2) argsort gives indices sorted by expert then by original order. + # To avoid argsort (needs sort_stable fallback in AOTI), we use a + # scatter-based approach: + # For each pair i, within_expert_rank[i] = sum_{j token for activation lookup + + # N offsets + offs_n = n_block * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + n_mask = offs_n < N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # A pointers: gathered rows [BLOCK_M, K] + a_ptrs = A + token_ids[:, None] * stride_am + offs_k[None, :] * stride_ak + + # B pointers: [expert_id, offs_n, offs_k//2] + b_ptrs = ( + B + + expert_id * stride_be + + (offs_k[:, None] // 2) * stride_bk + + offs_n[None, :] * stride_bn + ) + b_shifter = (offs_k[:, None] % 2) * 4 + + # 2D accumulator [BLOCK_M, BLOCK_N] + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k_step in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + k_remaining = K - k_step * BLOCK_SIZE_K + k_mask = offs_k < k_remaining + + # Load A tile [BLOCK_M, BLOCK_K] — gathered via token_ids + a = tl.load(a_ptrs, mask=k_mask[None, :], other=0.0) + + # Load B tile [BLOCK_K, BLOCK_N] and unpack INT4 + b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0) + b = (b >> b_shifter) & 0xF + + # Per-group scales + if BLOCK_SIZE_K <= group_size: + group_idx = (BLOCK_SIZE_K * k_step) // group_size + scale_ptrs = ( + B_scale + + expert_id * stride_bse + + offs_n[None, :] * stride_bsn + + group_idx * stride_bsk + ) + b_scale = tl.load( + scale_ptrs, mask=n_mask[None, :], other=0.0 + ).to(tl.float32) + else: + scale_ptrs = ( + B_scale + + expert_id * stride_bse + + offs_n[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size) * stride_bsk + ) + b_scale = tl.load( + scale_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0 + ).to(tl.float32) + + # Dequantize: (uint4 - 8) * scale + b_dequant = ((b.to(tl.float32) - 8.0) * b_scale).to(compute_type) + + # Tensor core matmul: [BLOCK_M, BLOCK_K] @ [BLOCK_K, BLOCK_N] + acc += tl.dot(a.to(compute_type), b_dequant) + + # Advance K pointers + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + + # Write output in sorted order [BLOCK_M, BLOCK_N] + c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + tl.store(c_ptrs, acc.to(compute_type), mask=n_mask[None, :]) + + +@triton.autotune(configs=_BATCHED_GEMM2_CONFIGS, key=["N", "K"]) +@triton.jit +def _fused_moe_silu_batched_kernel( + # Pointers + A, # [num_tokens_post_padded, 2*inter] bf16 GEMM1 output (sorted order) + B, # [E, N, K//2] int8 packed INT4 weights + C, # [M*top_k + 1, N] bf16 output (scatter to original pair order) + B_scale, # [E, N, K//group_size] bf16 scales + sorted_token_ids, # [num_tokens_post_padded] int64 pair indices + expert_ids, # [num_expert_blocks] int64 + topk_weights, # [M*top_k] float32 router weights (flat) + # Dimensions + N: tl.constexpr, + K: tl.constexpr, # intermediate_size + num_pairs, # M * top_k (for clamping sentinel weight lookups) + # Strides + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk, + stride_bsn, + # Config + top_k: tl.constexpr, + group_size: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + compute_type: tl.constexpr, +): + """Batched GEMM2 with fused SiLU and scatter-back. + + Reads gate+up from GEMM1 output (sorted order), applies SiLU(gate)*up, + multiplies by INT4 w2 weights, applies router weights, and scatters + output to original pair positions. + """ + pid = tl.program_id(0) + num_n_blocks = tl.cdiv(N, BLOCK_SIZE_N) + expert_block_idx = pid // num_n_blocks + n_block = pid % num_n_blocks + + expert_id = tl.load(expert_ids + expert_block_idx).to(tl.int64) + + # M-block in sorted order + offs_m = expert_block_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + pair_ids = tl.load(sorted_token_ids + offs_m) + + # N offsets + offs_n = n_block * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + n_mask = offs_n < N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # A pointers: gate at [0, K), up at [K, 2K) — contiguous in sorted order + a_gate_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + a_up_ptrs = a_gate_ptrs + K * stride_ak + + # B pointers: [expert_id, offs_n, offs_k//2] + b_ptrs = ( + B + + expert_id * stride_be + + (offs_k[:, None] // 2) * stride_bk + + offs_n[None, :] * stride_bn + ) + b_shifter = (offs_k[:, None] % 2) * 4 + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k_step in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + k_remaining = K - k_step * BLOCK_SIZE_K + k_mask = offs_k < k_remaining + + # Load gate and up tiles [BLOCK_M, BLOCK_K], apply SiLU + gate = tl.load(a_gate_ptrs, mask=k_mask[None, :], other=0.0).to(tl.float32) + up = tl.load(a_up_ptrs, mask=k_mask[None, :], other=0.0) + a = (gate * tl.sigmoid(gate) * up).to(compute_type) + + # Load and dequantize INT4 weights [BLOCK_K, BLOCK_N] + b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0) + b = (b >> b_shifter) & 0xF + + if BLOCK_SIZE_K <= group_size: + group_idx = (BLOCK_SIZE_K * k_step) // group_size + scale_ptrs = ( + B_scale + + expert_id * stride_bse + + offs_n[None, :] * stride_bsn + + group_idx * stride_bsk + ) + b_scale = tl.load( + scale_ptrs, mask=n_mask[None, :], other=0.0 + ).to(tl.float32) + else: + scale_ptrs = ( + B_scale + + expert_id * stride_bse + + offs_n[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size) * stride_bsk + ) + b_scale = tl.load( + scale_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0 + ).to(tl.float32) + + b_dequant = ((b.to(tl.float32) - 8.0) * b_scale).to(compute_type) + + # Tensor core matmul: [BLOCK_M, BLOCK_K] @ [BLOCK_K, BLOCK_N] + acc += tl.dot(a, b_dequant) + + a_gate_ptrs += BLOCK_SIZE_K * stride_ak + a_up_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + + # Apply router weights per row + # Clamp sentinel pair_ids to a valid index for the weight load + safe_pair_ids = tl.minimum(pair_ids, num_pairs - 1) + weights = tl.load(topk_weights + safe_pair_ids) + # Zero out sentinel rows (pair_ids >= num_pairs means padding) + is_valid = pair_ids < num_pairs + weights = tl.where(is_valid, weights, 0.0) + acc = acc * weights[:, None] + + # Scatter to original pair order: write at pair_ids positions + # Sentinel pair_ids write to the extra row at end (ignored) + scatter_ids = tl.where(is_valid, pair_ids, num_pairs) + c_ptrs = C + scatter_ids[:, None] * stride_cm + offs_n[None, :] * stride_cn + tl.store(c_ptrs, acc.to(compute_type), mask=n_mask[None, :]) + + +# --------------------------------------------------------------------------- +# Batched triton_op wrapper +# --------------------------------------------------------------------------- + + +@triton_op("triton::fused_moe_batched_gemm", mutates_args={}) +def fused_moe_batched_gemm( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + num_experts: int, + group_size: int, +) -> torch.Tensor: + """Batched GEMM1 + GEMM2+SiLU with token sorting + tensor-core GEMMs.""" + M, K = hidden_states.shape + N1 = w1.shape[1] # 2 * intermediate_size + intermediate = N1 // 2 + N2 = w2.shape[1] # hidden_size + num_pairs = M * top_k + BLOCK_M = _BATCHED_BLOCK_M + + sorted_token_ids, expert_ids, _ = moe_align_block_size( + topk_ids, BLOCK_M, num_experts + ) + max_padded = sorted_token_ids.shape[0] + num_expert_blocks = expert_ids.shape[0] + + hidden_padded = torch.cat( + [ + hidden_states, + torch.zeros(1, K, dtype=hidden_states.dtype, device=hidden_states.device), + ], + dim=0, + ) + + topk_weights_flat = topk_weights.reshape(-1) + + cache1 = torch.empty( + max_padded, N1, + dtype=hidden_states.dtype, device=hidden_states.device, + ) + + def grid1(meta): + return (num_expert_blocks * triton.cdiv(N1, meta["BLOCK_SIZE_N"]),) + + wrap_triton(_fused_moe_batched_kernel)[grid1]( + hidden_padded, + w1, + cache1, + w1_scale, + sorted_token_ids, + expert_ids, + N=N1, + K=K, + stride_am=hidden_padded.stride(0), + stride_ak=hidden_padded.stride(1), + stride_be=w1.stride(0), + stride_bk=w1.stride(2), + stride_bn=w1.stride(1), + stride_cm=cache1.stride(0), + stride_cn=cache1.stride(1), + stride_bse=w1_scale.stride(0), + stride_bsk=w1_scale.stride(2), + stride_bsn=w1_scale.stride(1), + top_k=top_k, + group_size=group_size, + BLOCK_SIZE_M=BLOCK_M, + compute_type=tl.bfloat16, + ) + + out_buf = torch.zeros( + num_pairs + 1, N2, + dtype=hidden_states.dtype, device=hidden_states.device, + ) + + def grid2(meta): + return (num_expert_blocks * triton.cdiv(N2, meta["BLOCK_SIZE_N"]),) + + wrap_triton(_fused_moe_silu_batched_kernel)[grid2]( + cache1, + w2, + out_buf, + w2_scale, + sorted_token_ids, + expert_ids, + topk_weights_flat, + N=N2, + K=intermediate, + num_pairs=num_pairs, + stride_am=cache1.stride(0), + stride_ak=cache1.stride(1), + stride_be=w2.stride(0), + stride_bk=w2.stride(2), + stride_bn=w2.stride(1), + stride_cm=out_buf.stride(0), + stride_cn=out_buf.stride(1), + stride_bse=w2_scale.stride(0), + stride_bsk=w2_scale.stride(2), + stride_bsn=w2_scale.stride(1), + top_k=top_k, + group_size=group_size, + BLOCK_SIZE_M=BLOCK_M, + compute_type=tl.bfloat16, + ) + + return out_buf[:num_pairs].view(M, top_k, N2).sum(dim=1) + + +@fused_moe_batched_gemm.register_fake +def _fused_moe_batched_gemm_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + num_experts: int, + group_size: int, +) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +def fused_moe_batched( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + num_experts: int, + group_size: int, +) -> torch.Tensor: + """Convenience wrapper for benchmarking (same as fused_moe_batched_gemm).""" + return fused_moe_batched_gemm( + hidden_states, w1, w1_scale, w2, w2_scale, + topk_weights, topk_ids, + top_k, num_experts, group_size, + ) diff --git a/examples/models/qwen3_5_moe/export.py b/examples/models/qwen3_5_moe/export.py index 19a720a2e79..c47b889128f 100644 --- a/examples/models/qwen3_5_moe/export.py +++ b/examples/models/qwen3_5_moe/export.py @@ -380,14 +380,20 @@ def _apply_turboquant(model, config): # --------------------------------------------------------------------------- +def _set_batched_moe(model, enabled): + """Toggle batched tensor-core MoE kernel for all MoE layers.""" + for layer in model.layers: + if hasattr(layer, "mlp") and hasattr(layer.mlp, "experts"): + layer.mlp.experts.use_batched_moe = enabled + + def export_and_lower(model, config, args): """Export model to .pte via torch.export + CUDA backend. Exports two methods: - - "decode": decode path (T=1), uses native PyTorch recurrent FLA - so AOTI can fuse with surrounding ops for maximum decode throughput. - - "prefill": prefill path (T>=2), uses chunked FLA triton_op with - dynamic sequence length. + - "decode": decode path (T=1), vec-mat MoE kernel via fused_moe. + - "prefill": prefill path (T>=2), batched tensor-core MoE kernel + via fused_moe_batched_gemm, with dynamic sequence length. Both methods share mutable state buffers (KV cache, conv_state, recurrent_state) via share_mutable_buffers=True. The model uses @@ -412,7 +418,8 @@ def export_and_lower(model, config, args): # -O0 compiles ~8x faster than -O1 with no measurable runtime impact. inductor_config.aot_inductor.compile_wrapper_opt_level = "O0" - # --- Decode method (T=1, static shape) --- + # --- Decode method (T=1, static shape, vec-mat MoE kernel) --- + _set_batched_moe(model, False) print("Exporting decode method...") decode_tokens = torch.tensor([[0]], dtype=torch.long) decode_pos = torch.tensor([0], dtype=torch.long) @@ -424,10 +431,14 @@ def export_and_lower(model, config, args): ) print("Decode export successful!") - # --- Prefill method (T>=2, dynamic shape) --- + # --- Prefill method (T>=2, dynamic shape, batched tensor-core MoE kernel) --- + # Example T must equal max_seq_len-1 so AOTI compiles kernels for the + # full range of sequence lengths. + _set_batched_moe(model, True) print("Exporting prefill method...") - prefill_tokens = torch.tensor([[0, 1]], dtype=torch.long) - prefill_pos = torch.tensor([0, 1], dtype=torch.long) + example_seq_len = config.max_seq_len - 1 + prefill_tokens = torch.zeros((1, example_seq_len), dtype=torch.long) + prefill_pos = torch.arange(example_seq_len, dtype=torch.long) seq_dim = Dim("seq_len", min=2, max=config.max_seq_len - 1) prefill_dynamic_shapes = ( {1: seq_dim}, # tokens diff --git a/examples/models/qwen3_5_moe/model.py b/examples/models/qwen3_5_moe/model.py index 751915fb123..560b5fcfb49 100644 --- a/examples/models/qwen3_5_moe/model.py +++ b/examples/models/qwen3_5_moe/model.py @@ -479,6 +479,7 @@ def __init__(self, config): self.intermediate_size = config.moe_intermediate_size self.hidden_size = config.hidden_size self.group_size = 32 + self.use_batched_moe = False self.w1_weight = nn.Parameter( torch.empty( @@ -496,6 +497,19 @@ def __init__(self, config): ) def forward(self, x, expert_weights, expert_indices, top_k): + if self.use_batched_moe: + return torch.ops.triton.fused_moe_batched_gemm( + x, + self.w1, + self.w1_scale, + self.w2, + self.w2_scale, + expert_weights, + expert_indices, + top_k, + self.num_experts, + self.group_size, + ) return torch.ops.triton.fused_moe( x, self.w1,