26
26
from jax .interpreters import mlir
27
27
from jax import tree_util
28
28
from jax ._src import util
29
+ from jax ._src .lib import xla_bridge as xb
29
30
from jax ._src .lib .mlir import ir
30
31
from jax ._src .lib .mlir .dialects import mhlo
31
32
import numpy as np
32
33
import torch
33
34
import triton
34
35
import triton .language as tl
35
36
36
- from jax_triton import custom_call
37
+ from jax_triton import triton_kernel_call
37
38
38
39
os .environ ["TRITON_CACHE_DIR" ] = ""
39
40
map , unsafe_map = util .safe_map , map
40
41
zip , unsafe_zip = util .safe_zip , zip
41
42
42
- xc .register_custom_call_target ("triton_call " , custom_call .get_custom_call (), platform = "CUDA" )
43
+ xc .register_custom_call_target ("triton_kernel_call " , triton_kernel_call .get_custom_call (), platform = "CUDA" )
43
44
44
45
def get_triton_type (obj : Any ) -> str :
45
46
type_map = {
@@ -79,8 +80,6 @@ def get_triton_python_ir(aval):
79
80
80
81
def compile (triton_function , constants , * , key , device = 0 , num_warps = 4 , num_stages = 2 ):
81
82
def lower (* args ):
82
- arg_types = [get_triton_python_ir (a ) for a in args ]
83
- attributes = {i : 16 for i in range (len (args ))}
84
83
triton_function ._warmup (arg_types = arg_types , device = device ,
85
84
attributes = attributes , constants = constants , num_warps = num_warps ,
86
85
num_stages = num_stages , key = key , is_manual_warmup = True ,
@@ -133,18 +132,24 @@ def aval_to_layout(aval):
133
132
arange = np .arange (aval .ndim , dtype = 'int64' )[::- 1 ].copy ()
134
133
return ir .DenseIntElementsAttr .get (arange , type = ir .IndexType .get ())
135
134
136
- def emit_triton_call (triton_func , avals_in , avals_out , grid , num_warps , num_stages ,
135
+ def emit_triton_call (ctx , triton_func , grid , num_warps , num_stages ,
137
136
dump_binary_path : Optional [str ], ** metaparams ):
138
137
metadata = {triton_func .arg_names .index (k ) : v for k , v in metaparams .items ()}
139
- compile (triton_func , metadata , num_warps = num_warps , num_stages = num_stages , key = "foo" )(* avals_in , * avals_out )
140
- loaded_binary = triton_func .bin_cache ["foo" ]
141
- kernel_ptr = loaded_binary .kernel
142
- shared_mem = loaded_binary .shared_mem
138
+ all_args = [* ctx .avals_in , * ctx .avals_out ]
139
+ arg_types = [get_triton_python_ir (a ) for a in all_args ]
140
+ attributes = {i : 16 for i in range (len (all_args ))}
141
+ # TODO(sharadmv): handle multiple devices, right now we assume device 0 which
142
+ # is fine when we have multiple of the same GPU but this won't work in
143
+ # general.
144
+ binary = triton_func ._compile (arg_types = arg_types , device = 0 ,
145
+ attributes = attributes , constants = metadata , num_warps = num_warps ,
146
+ num_stages = num_stages , extern_libs = {})
147
+ name , asm , shared_mem = binary .name , binary .asm , binary .shared_mem
143
148
if dump_binary_path is not None :
144
149
binary = dict (
145
- asm = loaded_binary . asm ,
150
+ asm = asm ,
146
151
shared_mem = shared_mem ,
147
- name = loaded_binary . bin . name )
152
+ name = name )
148
153
with open (dump_binary_path , "wb" ) as fp :
149
154
pickle .dump (binary , fp )
150
155
@@ -158,22 +163,24 @@ def emit_triton_call(triton_func, avals_in, avals_out, grid, num_warps, num_stag
158
163
grid_1 , grid_2 = grid_ [1 ], grid_ [2 ]
159
164
else :
160
165
assert False
161
- arity = len (avals_in ) + len (avals_out )
162
- descriptor = custom_call .make_triton_call_descriptor (kernel_ptr , shared_mem , grid_0 , grid_1 , grid_2 , num_warps , arity )
163
- return descriptor
166
+ arity = len (ctx .avals_in ) + len (ctx .avals_out )
167
+ descriptor , keepalive = triton_kernel_call .make_triton_call_descriptor (
168
+ name , asm , shared_mem , grid_0 , grid_1 , grid_2 , num_warps , arity )
169
+ return descriptor , keepalive
164
170
165
171
def triton_call_lowering (ctx , * args , kernel , out_shapes , grid , num_warps = 4 , num_stages = 2 ,
166
172
dump_binary_path : Optional [str ], ** metaparams ):
167
173
out_type = ir .TupleType .get_tuple ([
168
174
ir .RankedTensorType .get (out_shape .shape , mlir .dtype_to_ir_type (out_shape .dtype ))
169
175
for out_shape in out_shapes ])
170
176
i32_type = ir .IntegerType .get_signless (32 )
171
- descriptor = emit_triton_call (kernel , ctx .avals_in , ctx .avals_out , grid ,
172
- num_warps , num_stages , dump_binary_path ,
173
- ** metaparams )
177
+ descriptor , keepalive = emit_triton_call (ctx , kernel , grid ,
178
+ num_warps , num_stages , dump_binary_path ,
179
+ ** metaparams )
180
+ ctx .module_context .add_keepalive (keepalive )
174
181
out = mhlo .CustomCallOp (
175
182
[out_type ], args ,
176
- call_target_name = ir .StringAttr .get ("triton_call " ),
183
+ call_target_name = ir .StringAttr .get ("triton_kernel_call " ),
177
184
has_side_effect = ir .BoolAttr .get (False ),
178
185
backend_config = ir .StringAttr .get (descriptor ),
179
186
api_version = ir .IntegerAttr .get (i32_type , 1 ),
0 commit comments