Skip to content

Commit 78738f4

Browse files
authored
Merge pull request #6 from jax-ml/multidevice
Lazily load CUDA modules before launch
2 parents f4e589d + 47498b3 commit 78738f4

File tree

7 files changed

+255
-121
lines changed

7 files changed

+255
-121
lines changed

examples/matrix_multiplication.py

-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
import jax
2222
import jax.numpy as jnp
23-
import math
2423

2524
m=512
2625
n=512

jax_triton/triton_call.py

+25-18
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,21 @@
2626
from jax.interpreters import mlir
2727
from jax import tree_util
2828
from jax._src import util
29+
from jax._src.lib import xla_bridge as xb
2930
from jax._src.lib.mlir import ir
3031
from jax._src.lib.mlir.dialects import mhlo
3132
import numpy as np
3233
import torch
3334
import triton
3435
import triton.language as tl
3536

36-
from jax_triton import custom_call
37+
from jax_triton import triton_kernel_call
3738

3839
os.environ["TRITON_CACHE_DIR"] = ""
3940
map, unsafe_map = util.safe_map, map
4041
zip, unsafe_zip = util.safe_zip, zip
4142

42-
xc.register_custom_call_target("triton_call", custom_call.get_custom_call(), platform="CUDA")
43+
xc.register_custom_call_target("triton_kernel_call", triton_kernel_call.get_custom_call(), platform="CUDA")
4344

4445
def get_triton_type(obj: Any) -> str:
4546
type_map = {
@@ -79,8 +80,6 @@ def get_triton_python_ir(aval):
7980

8081
def compile(triton_function, constants, *, key, device=0, num_warps=4, num_stages=2):
8182
def lower(*args):
82-
arg_types = [get_triton_python_ir(a) for a in args]
83-
attributes = {i: 16 for i in range(len(args))}
8483
triton_function._warmup(arg_types=arg_types, device=device,
8584
attributes=attributes, constants=constants, num_warps=num_warps,
8685
num_stages=num_stages, key=key, is_manual_warmup=True,
@@ -133,18 +132,24 @@ def aval_to_layout(aval):
133132
arange = np.arange(aval.ndim, dtype='int64')[::-1].copy()
134133
return ir.DenseIntElementsAttr.get(arange, type=ir.IndexType.get())
135134

136-
def emit_triton_call(triton_func, avals_in, avals_out, grid, num_warps, num_stages,
135+
def emit_triton_call(ctx, triton_func, grid, num_warps, num_stages,
137136
dump_binary_path: Optional[str], **metaparams):
138137
metadata = {triton_func.arg_names.index(k) : v for k, v in metaparams.items()}
139-
compile(triton_func, metadata, num_warps=num_warps, num_stages=num_stages, key="foo")(*avals_in, *avals_out)
140-
loaded_binary = triton_func.bin_cache["foo"]
141-
kernel_ptr = loaded_binary.kernel
142-
shared_mem = loaded_binary.shared_mem
138+
all_args = [*ctx.avals_in, *ctx.avals_out]
139+
arg_types = [get_triton_python_ir(a) for a in all_args]
140+
attributes = {i: 16 for i in range(len(all_args))}
141+
# TODO(sharadmv): handle multiple devices, right now we assume device 0 which
142+
# is fine when we have multiple of the same GPU but this won't work in
143+
# general.
144+
binary = triton_func._compile(arg_types=arg_types, device=0,
145+
attributes=attributes, constants=metadata, num_warps=num_warps,
146+
num_stages=num_stages, extern_libs={})
147+
name, asm, shared_mem = binary.name, binary.asm, binary.shared_mem
143148
if dump_binary_path is not None:
144149
binary = dict(
145-
asm=loaded_binary.asm,
150+
asm=asm,
146151
shared_mem=shared_mem,
147-
name=loaded_binary.bin.name)
152+
name=name)
148153
with open(dump_binary_path, "wb") as fp:
149154
pickle.dump(binary, fp)
150155

@@ -158,22 +163,24 @@ def emit_triton_call(triton_func, avals_in, avals_out, grid, num_warps, num_stag
158163
grid_1, grid_2 = grid_[1], grid_[2]
159164
else:
160165
assert False
161-
arity = len(avals_in) + len(avals_out)
162-
descriptor = custom_call.make_triton_call_descriptor(kernel_ptr, shared_mem, grid_0, grid_1, grid_2, num_warps, arity)
163-
return descriptor
166+
arity = len(ctx.avals_in) + len(ctx.avals_out)
167+
descriptor, keepalive = triton_kernel_call.make_triton_call_descriptor(
168+
name, asm, shared_mem, grid_0, grid_1, grid_2, num_warps, arity)
169+
return descriptor, keepalive
164170

165171
def triton_call_lowering(ctx, *args, kernel, out_shapes, grid, num_warps=4, num_stages=2,
166172
dump_binary_path: Optional[str], **metaparams):
167173
out_type = ir.TupleType.get_tuple([
168174
ir.RankedTensorType.get(out_shape.shape, mlir.dtype_to_ir_type(out_shape.dtype))
169175
for out_shape in out_shapes])
170176
i32_type = ir.IntegerType.get_signless(32)
171-
descriptor = emit_triton_call(kernel, ctx.avals_in, ctx.avals_out, grid,
172-
num_warps, num_stages, dump_binary_path,
173-
**metaparams)
177+
descriptor, keepalive = emit_triton_call(ctx, kernel, grid,
178+
num_warps, num_stages, dump_binary_path,
179+
**metaparams)
180+
ctx.module_context.add_keepalive(keepalive)
174181
out = mhlo.CustomCallOp(
175182
[out_type], args,
176-
call_target_name=ir.StringAttr.get("triton_call"),
183+
call_target_name=ir.StringAttr.get("triton_kernel_call"),
177184
has_side_effect=ir.BoolAttr.get(False),
178185
backend_config=ir.StringAttr.get(descriptor),
179186
api_version=ir.IntegerAttr.get(i32_type, 1),

lib/custom_call.cc

-100
This file was deleted.

lib/triton_kernel_call.cc

+135
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
/* Copyright 2022 Google LLC
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#include "triton_kernel_call.h"
16+
17+
#include <iostream>
18+
#include <cassert>
19+
#include <string>
20+
21+
#include <pybind11/pybind11.h>
22+
#include "cuda.h"
23+
24+
namespace py = pybind11;
25+
26+
namespace jax_triton {
27+
28+
const int TRITON_MAX_N_SHARED_BYTES = 49152;
29+
const int TRITON_MAX_SHARED_OPTIN = 49152;
30+
31+
32+
void TritonExecutable::launch(CUstream stream, void** buffers) {
33+
CUdevice dev;
34+
CUcontext ctx;
35+
// Set the current context to the stream context so we can query the stream
36+
// device
37+
cuStreamGetCtx(stream, &ctx);
38+
cuCtxSetCurrent(ctx);
39+
/// Only load the kernel if it hasn't already been loaded for this device
40+
cuCtxGetDevice(&dev);
41+
CUfunction kernel = load(dev);
42+
std::string params;
43+
params.resize(8 * arity);
44+
char* params_ptr = &params[0];
45+
for (uint32_t i = 0; i < arity; i++) {
46+
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
47+
std::memcpy(params_ptr, &buffers[i], 8);
48+
params_ptr += 8;
49+
}
50+
size_t params_size = static_cast<size_t>(params_ptr - &params[0]);
51+
void* config[] = {
52+
CU_LAUNCH_PARAM_BUFFER_POINTER,
53+
static_cast<void*>(const_cast<char*>(params.data())),
54+
CU_LAUNCH_PARAM_BUFFER_SIZE, &params_size,
55+
CU_LAUNCH_PARAM_END
56+
};
57+
CUresult result = cuLaunchKernel(kernel, grid_0, grid_1, grid_2, num_warps * 32, 1, 1, shared_mem, stream, nullptr, config);
58+
if (result != 0) {
59+
std::cout << "Failed launch: " << result << std::endl;
60+
}
61+
};
62+
63+
CUfunction TritonExecutable::load(CUdevice device) {
64+
const std::lock_guard<std::mutex> lock(mut);
65+
if (is_loaded(device)) {
66+
return kernels[device];
67+
}
68+
// Mimics Triton kernel loading
69+
std::string assembly;
70+
auto iter = asm_map.find("cubin");
71+
if (iter != asm_map.end())
72+
assembly = py::cast<std::string>(asm_map["cubin"]);
73+
else {
74+
assert(asm_map.contains("ptx"));
75+
assembly = py::cast<std::string>(asm_map["ptx"]);
76+
}
77+
CUfunction fun;
78+
CUmodule mod;
79+
cuModuleLoadData(&mod, assembly.c_str());
80+
cuModuleGetFunction(&fun, mod, name.c_str());
81+
int n_regs = 0;
82+
int n_spills = 0;
83+
cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun);
84+
cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
85+
n_spills /= 4;
86+
int shared_optin;
87+
cuDeviceGetAttribute(&shared_optin,
88+
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
89+
device);
90+
if (shared_mem > TRITON_MAX_N_SHARED_BYTES &&
91+
shared_optin > TRITON_MAX_SHARED_OPTIN) {
92+
cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED);
93+
int shared_total, shared_static;
94+
cuDeviceGetAttribute(
95+
&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
96+
device);
97+
cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES,
98+
fun);
99+
cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
100+
shared_optin - shared_static);
101+
}
102+
kernels[device] = fun;
103+
return fun;
104+
};
105+
106+
void do_custom_call(CUstream stream, void** buffers,
107+
char* opaque, size_t opaque_len) {
108+
uint64_t descriptor = std::strtoull(opaque, NULL, 0);
109+
TritonExecutable* executable = TritonExecutable::from_descriptor(descriptor);
110+
executable->launch(stream, buffers);
111+
}
112+
113+
std::pair<std::string, py::object> MakeTritonExecutable(std::string name, asm_map_t asm_map, uint32_t shared_mem, uint32_t grid_0, uint32_t grid_1, uint32_t grid_2, uint32_t num_warps, uint32_t arity) {
114+
auto triton_call = std::make_unique<TritonExecutable>(
115+
name, asm_map, shared_mem, grid_0, grid_1, grid_2, num_warps, arity);
116+
std::string descriptor = std::to_string(reinterpret_cast<uint64_t>(triton_call.get()));
117+
py::capsule callback_capsule(triton_call.release(), [](void* ptr) {
118+
delete reinterpret_cast<TritonExecutable*>(ptr);
119+
});
120+
return std::make_pair(descriptor, py::object(std::move(callback_capsule)));
121+
}
122+
123+
template <typename T>
124+
pybind11::capsule EncapsulateFunction(T* fn) {
125+
return pybind11::capsule(reinterpret_cast<void*>(fn), "xla._CUSTOM_CALL_TARGET");
126+
}
127+
128+
PYBIND11_MODULE(triton_kernel_call, m) {
129+
m.def("make_triton_call_descriptor", &MakeTritonExecutable);
130+
m.def("get_custom_call", [](){
131+
return EncapsulateFunction(do_custom_call);
132+
});
133+
}
134+
135+
} // namespace jax_triton

0 commit comments

Comments
 (0)