Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions backends/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,10 @@ set(_aoti_cuda_shim_sources runtime/shims/memory.cpp
runtime/shims/cuda_guard.cpp
)

# Only build int4mm shim when CUDA language/toolchain is available.
# Only build CUDA-specific shims when CUDA language/toolchain is available.
if(CMAKE_CUDA_COMPILER)
list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu)
list(APPEND _aoti_cuda_shim_sources runtime/shims/randint.cu)
endif()

add_library(aoti_cuda_shims SHARED ${_aoti_cuda_shim_sources})
Expand Down Expand Up @@ -150,7 +151,8 @@ endif()
# retention.
if(_cuda_is_msvc_toolchain)
target_link_libraries(
aoti_cuda_shims PRIVATE cuda_platform CUDA::cudart ${CMAKE_DL_LIBS}
aoti_cuda_shims PRIVATE cuda_platform CUDA::cudart CUDA::curand
${CMAKE_DL_LIBS}
)
# Link object library directly so symbols are pulled exactly once while
# avoiding duplicate static/object inclusion and interface leakage.
Expand All @@ -160,7 +162,7 @@ else()
aoti_cuda_shims
PRIVATE cuda_platform
PUBLIC -Wl,--whole-archive aoti_common_shims_slim -Wl,--no-whole-archive
CUDA::cudart ${CMAKE_DL_LIBS}
CUDA::cudart CUDA::curand ${CMAKE_DL_LIBS}
)
endif()

Expand Down
4 changes: 2 additions & 2 deletions backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def save_data_externally(cls) -> bool:
def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
return {
"at::_ops::_weight_int4pack_mm::call": None,
"aoti_torch_cuda_randint_low_out": None,
}

@classmethod
Expand All @@ -170,8 +171,7 @@ def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]
mode = spec.value.decode("utf-8").upper()
if mode not in ["ON", "OFF"]:
raise ValueError(
f"Invalid triton_kernel_mode: {mode}. "
f"Expected 'ON' or 'OFF'."
f"Invalid triton_kernel_mode: {mode}. Expected 'ON' or 'OFF'."
)
triton_kernel_mode = mode
passes = [MoveCondPredicateToCpuPass()]
Expand Down
81 changes: 52 additions & 29 deletions backends/cuda/runtime/cuda_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,16 @@ class ET_EXPERIMENTAL CudaBackend final
return (DelegateHandle*)handle; // Return the handle post-processing
}

// Once per execution
// Execute the AOTI-compiled CUDA kernel for one inference step.
//
// Currently supports both CPU and CUDA memory for IO tensors:
// - Inputs: detected via cudaPointerGetAttributes; CUDA data is wrapped
// in-place (no copy), CPU data is copied to GPU via from_etensor().
// - Outputs: either copied to ETensor's backing memory (CPU or CUDA),
// or the ETensor is rewired to point at GPU memory (skip-copy mode).
//
// TODO: Once the device tensor pipeline is fully adopted, all IO tensors
// will reside in CUDA memory. Remove the CPU fallback paths.
Error execute(
BackendExecutionContext& context,
DelegateHandle* handle_,
Expand All @@ -405,14 +414,17 @@ class ET_EXPERIMENTAL CudaBackend final
n_outputs,
args.size())

// Verify device info on all memory-planned, ET-driven IO tensors.
// All input and output tensors should have device_type = CUDA, which
// is set during serialization by PropagateDevicePass based on the
// target_device compile spec from CudaPartitioner.
// Verify device metadata on all IO tensors.
// All tensors should have device_type = CUDA, set during serialization
// by PropagateDevicePass based on the target_device compile spec from
// CudaPartitioner.
//
// Note: At this stage, the tensor memory is still on CPU. The device_type
// is metadata indicating where the tensor *should* reside. The backend
// is responsible for copying data to the actual CUDA device.
// Note: device_type is metadata — the actual memory location may be
// either CPU (legacy path with H2D copy ops) or CUDA (when device
// memory planning is enabled via enable_non_cpu_memory_planning,
// which allocates delegate IO in CUDA memory). The backend detects
// the actual location via cudaPointerGetAttributes and handles both
// cases.
for (size_t i = 0; i < n_inputs + n_outputs; i++) {
auto* tensor = &(args[i]->toTensor());
auto device_type = tensor->unsafeGetTensorImpl()->device_type();
Expand All @@ -425,34 +437,37 @@ class ET_EXPERIMENTAL CudaBackend final
static_cast<int>(device_type));
}

// NOTE: ExecuTorch tensors may be on CPU or GPU due to the skip-copy
// optimization. We need to create GPU copies for CUDA kernel execution
// using SlimTensor.
// Convert ExecuTorch tensors to SlimTensors for AOTI kernel execution.
// Input data may be in CPU or CUDA memory — the backend detects and
// handles both cases automatically (see memory model comment above).
std::vector<SlimTensor*> gpu_inputs(n_inputs);
std::vector<SlimTensor*> gpu_outputs(n_outputs);

// Process input tensors: convert ETensor (CPU) to SlimTensor (GPU)
for (size_t i = 0; i < n_inputs; i++) {
auto* cpu_tensor = &(args[i]->toTensor());
auto* input_tensor = &(args[i]->toTensor());

// Check if input data is already on GPU (skip-copy optimization for
// inputs) This can happen when the caller has pre-staged data on GPU
// Detect if input data is already in CUDA memory. This occurs when:
// - Device memory planning is enabled (enable_non_cpu_memory_planning),
// which allocates delegate IO in CUDA memory
// - The input is a skip-copy output from a previous method execution
// When detected, the data is wrapped directly — no H2D copy needed.
cudaPointerAttributes attributes{};
const void* data_ptr = cpu_tensor->const_data_ptr();
const void* data_ptr = input_tensor->const_data_ptr();
if (data_ptr != nullptr) {
cudaError_t err = cudaPointerGetAttributes(&attributes, data_ptr);
if (err == cudaSuccess && attributes.type == cudaMemoryTypeDevice) {
// Data is already on GPU - wrap it directly without copy
auto sizes = cpu_tensor->sizes();
auto strides = cpu_tensor->strides();
auto sizes = input_tensor->sizes();
auto strides = input_tensor->strides();
std::vector<int64_t> sizes_vec(sizes.begin(), sizes.end());
std::vector<int64_t> strides_vec(strides.begin(), strides.end());

gpu_inputs[i] = new SlimTensor(slim::from_blob(
const_cast<void*>(data_ptr),
slim::makeArrayRef(sizes_vec),
slim::makeArrayRef(strides_vec),
static_cast<slim::c10::ScalarType>(cpu_tensor->scalar_type()),
static_cast<slim::c10::ScalarType>(input_tensor->scalar_type()),
DEFAULT_CUDA_DEVICE,
0 // storage_offset
));
Expand All @@ -461,19 +476,22 @@ class ET_EXPERIMENTAL CudaBackend final
}
}

// Data is on CPU - use from_etensor to copy to GPU
// Data is in CPU memory (legacy path) — copy to GPU via from_etensor.
// TODO: Remove this path once all callers use the device tensor pipeline.
gpu_inputs[i] = new SlimTensor(
from_etensor(*cpu_tensor, CPU_DEVICE, DEFAULT_CUDA_DEVICE));
from_etensor(*input_tensor, CPU_DEVICE, DEFAULT_CUDA_DEVICE));
}

// Process output tensors: create GPU SlimTensors for kernel output.
// Save pre-run handles to detect orphans after run().
// Allocate GPU SlimTensors for kernel outputs. These are always
// freshly allocated on GPU regardless of the input memory mode.
// Save pre-run handles to detect orphans after run() (the AOTI
// runtime may replace output handles with its own allocations).
std::vector<SlimTensor*> pre_run_outputs(n_outputs, nullptr);
for (size_t i = 0; i < n_outputs; i++) {
auto* cpu_output_tensor = &(args[i + n_inputs]->toTensor());
auto sizes = cpu_output_tensor->sizes();
auto strides = cpu_output_tensor->strides();
auto scalar_type = cpu_output_tensor->scalar_type();
auto* output_tensor = &(args[i + n_inputs]->toTensor());
auto sizes = output_tensor->sizes();
auto strides = output_tensor->strides();
auto scalar_type = output_tensor->scalar_type();

std::vector<int64_t> sizes_vec(sizes.begin(), sizes.end());
std::vector<int64_t> strides_vec(strides.begin(), strides.end());
Expand Down Expand Up @@ -536,13 +554,18 @@ class ET_EXPERIMENTAL CudaBackend final

const bool copy_outputs = !should_skip_copy_for_method(handle->method_name);

// Output disposition: copy to ETensor backing memory or keep on GPU.
// When copy_outputs is true (default), results are copied to the
// ETensor's memory (which may be CPU or CUDA planned memory).
// When false (skip-copy optimization), the ETensor is rewired to
// point at the GPU SlimTensor's memory directly.
if (copy_outputs) {
for (size_t i = 0; i < n_outputs; i++) {
auto* cpu_output_tensor = &(args[i + n_inputs]->toTensor());
auto* output_tensor = &(args[i + n_inputs]->toTensor());
ET_CHECK_OK_OR_RETURN_ERROR(
copy_slimtensor_to_etensor_async(
gpu_outputs[i], cpu_output_tensor, cuda_stream),
"Failed to copy GPU output %zu back to CPU ETensor",
gpu_outputs[i], output_tensor, cuda_stream),
"Failed to copy GPU output %zu back to ETensor",
i);
delete gpu_outputs[i];
gpu_outputs[i] = nullptr;
Expand Down
108 changes: 108 additions & 0 deletions backends/cuda/runtime/shims/randint.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <cuda_runtime.h>
#include <curand.h>

#include <executorch/backends/cuda/runtime/shims/randint.h>
#include <executorch/runtime/platform/assert.h>
#include <executorch/runtime/platform/log.h>

#include <cstdint>
#include <ctime>

namespace executorch::backends::cuda {

using executorch::runtime::Error;

namespace {

// Transform cuRAND uniform doubles (0, 1] to int64 values in [low, high).
__global__ void uniform_to_randint_kernel(
int64_t* out,
const double* uniform,
int64_t numel,
int64_t low,
int64_t range) {
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (idx < numel) {
// uniform is in (0, 1], so (uniform * range) is in (0, range].
// Subtract 1 and clamp to get [0, range-1], then add low for [low, high-1].
int64_t val = static_cast<int64_t>(uniform[idx] * range);
out[idx] = low + (val >= range ? range - 1 : val);
}
}

curandGenerator_t get_or_create_generator() {
static curandGenerator_t gen = nullptr;
if (gen == nullptr) {
curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT);
curandSetPseudoRandomGeneratorSeed(
gen, static_cast<unsigned long long>(time(nullptr)));
}
return gen;
}

} // anonymous namespace

extern "C" {

AOTITorchError aoti_torch_cuda_randint_low_out(
SlimTensor* out,
int64_t low,
int64_t high,
const int64_t* size,
int64_t size_len_) {
ET_CHECK_OR_RETURN_ERROR(
out != nullptr,
InvalidArgument,
"aoti_torch_cuda_randint_low_out: out tensor is null");

ET_CHECK_OR_RETURN_ERROR(
high > low,
InvalidArgument,
"aoti_torch_cuda_randint_low_out: requires high > low");

int64_t numel = 1;
for (int64_t i = 0; i < size_len_; i++) {
numel *= size[i];
}
if (numel == 0) {
return Error::Ok;
}

int64_t range = high - low;
int64_t* out_data = static_cast<int64_t*>(out->data_ptr());

// Allocate temporary buffer for uniform doubles on device.
double* d_uniform = nullptr;
auto alloc_err = cudaMalloc(&d_uniform, numel * sizeof(double));
ET_CHECK_OR_RETURN_ERROR(
alloc_err == cudaSuccess,
Internal,
"aoti_torch_cuda_randint_low_out: cudaMalloc failed (%d)",
static_cast<int>(alloc_err));

// Generate uniform doubles in (0, 1].
auto gen = get_or_create_generator();
curandGenerateUniformDouble(gen, d_uniform, numel);

// Transform to integers in [low, high).
constexpr int kThreads = 256;
int blocks = static_cast<int>((numel + kThreads - 1) / kThreads);
uniform_to_randint_kernel<<<blocks, kThreads>>>(
out_data, d_uniform, numel, low, range);

cudaFree(d_uniform);

return Error::Ok;
}

} // extern "C"

} // namespace executorch::backends::cuda
43 changes: 43 additions & 0 deletions backends/cuda/runtime/shims/randint.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/backends/aoti/common_shims_slim.h>
#include <executorch/backends/aoti/export.h>

namespace executorch::backends::cuda {

using executorch::backends::aoti::AOTITorchError;
using SlimTensor = executorch::backends::aoti::slim::SlimTensor;

extern "C" {

/**
* Fills a pre-allocated CUDA tensor with random integers in [low, high).
*
* Used by AOTI-generated code when the model calls torch.randint or ops
* that decompose into randint (e.g. torch.rand_like on some dtypes).
*
* @param out Pre-allocated output tensor on CUDA (must not be null).
* @param low Lower bound (inclusive) of the random range.
* @param high Upper bound (exclusive) of the random range.
* @param size Pointer to array of output dimension sizes.
* @param size_len_ Number of dimensions.
* @return AOTITorchError error code (Error::Ok on success).
*/
AOTI_SHIM_EXPORT AOTITorchError aoti_torch_cuda_randint_low_out(
SlimTensor* out,
int64_t low,
int64_t high,
const int64_t* size,
int64_t size_len_);

} // extern "C"

} // namespace executorch::backends::cuda
19 changes: 9 additions & 10 deletions examples/models/qwen3_5_moe/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,25 +32,24 @@ list(APPEND link_libraries optimized_native_cpu_ops_lib cpublas eigen_blas)
executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib)

# Extensions
list(
APPEND
link_libraries
extension_llm_runner
extension_module
extension_data_loader
extension_tensor
extension_flat_tensor
list(APPEND link_libraries extension_module extension_data_loader
extension_tensor extension_flat_tensor
)

# CUDA backend (required)
find_package(CUDAToolkit REQUIRED)
list(APPEND link_libraries aoti_cuda_backend)
list(APPEND link_libraries aoti_cuda_backend CUDA::cudart)
executorch_target_link_options_shared_lib(aoti_cuda_backend)

# Tokenizer
list(APPEND link_libraries tokenizers::tokenizers)

add_executable(qwen3_5_moe_runner main.cpp)
add_executable(
qwen3_5_moe_runner
main.cpp ${EXECUTORCH_ROOT}/runtime/core/device_allocator.cpp
${EXECUTORCH_ROOT}/runtime/core/device_memory_buffer.cpp
${EXECUTORCH_ROOT}/backends/cuda/runtime/cuda_allocator.cpp
)
target_include_directories(
qwen3_5_moe_runner PUBLIC ${_common_include_directories}
)
Expand Down
Loading
Loading