@@ -2,17 +2,76 @@ use std::ffi::CString;
22
33use llvm:: Linkage :: * ;
44use rustc_abi:: Align ;
5- use rustc_codegen_ssa:: traits:: BaseTypeCodegenMethods ;
5+ use rustc_codegen_ssa:: traits:: { BaseTypeCodegenMethods , BuilderMethods } ;
66use rustc_middle:: ty:: offload_meta:: OffloadMetadata ;
77
8- use crate :: builder:: SBuilder ;
8+ use crate :: builder:: Builder ;
9+ use crate :: common:: CodegenCx ;
910use crate :: llvm:: AttributePlace :: Function ;
10- use crate :: llvm:: { self , BasicBlock , Linkage , Type , Value } ;
11+ use crate :: llvm:: { self , Linkage , Type , Value } ;
1112use crate :: { SimpleCx , attributes} ;
1213
14+ // LLVM kernel-independent globals required for offloading
15+ pub ( crate ) struct OffloadGlobals < ' ll > {
16+ pub launcher_fn : & ' ll llvm:: Value ,
17+ pub launcher_ty : & ' ll llvm:: Type ,
18+
19+ pub bin_desc : & ' ll llvm:: Type ,
20+
21+ pub kernel_args_ty : & ' ll llvm:: Type ,
22+
23+ pub offload_entry_ty : & ' ll llvm:: Type ,
24+
25+ pub begin_mapper : & ' ll llvm:: Value ,
26+ pub end_mapper : & ' ll llvm:: Value ,
27+ pub mapper_fn_ty : & ' ll llvm:: Type ,
28+
29+ pub ident_t_global : & ' ll llvm:: Value ,
30+
31+ pub register_lib : & ' ll llvm:: Value ,
32+ pub unregister_lib : & ' ll llvm:: Value ,
33+ pub init_rtls : & ' ll llvm:: Value ,
34+ }
35+
36+ impl < ' ll > OffloadGlobals < ' ll > {
37+ pub ( crate ) fn declare ( cx : & CodegenCx < ' ll , ' _ > ) -> Self {
38+ let ( launcher_fn, launcher_ty) = generate_launcher ( cx) ;
39+ let kernel_args_ty = KernelArgsTy :: new_decl ( cx) ;
40+ let offload_entry_ty = TgtOffloadEntry :: new_decl ( cx) ;
41+ let ( begin_mapper, _, end_mapper, mapper_fn_ty) = gen_tgt_data_mappers ( cx) ;
42+ let ident_t_global = generate_at_one ( cx) ;
43+
44+ let tptr = cx. type_ptr ( ) ;
45+ let ti32 = cx. type_i32 ( ) ;
46+ let tgt_bin_desc_ty = vec ! [ ti32, tptr, tptr, tptr] ;
47+ let bin_desc = cx. type_named_struct ( "struct.__tgt_bin_desc" ) ;
48+ cx. set_struct_body ( bin_desc, & tgt_bin_desc_ty, false ) ;
49+
50+ let register_lib = declare_offload_fn ( & cx, "__tgt_register_lib" , mapper_fn_ty) ;
51+ let unregister_lib = declare_offload_fn ( & cx, "__tgt_unregister_lib" , mapper_fn_ty) ;
52+ let init_ty = cx. type_func ( & [ ] , cx. type_void ( ) ) ;
53+ let init_rtls = declare_offload_fn ( cx, "__tgt_init_all_rtls" , init_ty) ;
54+
55+ OffloadGlobals {
56+ launcher_fn,
57+ launcher_ty,
58+ bin_desc,
59+ kernel_args_ty,
60+ offload_entry_ty,
61+ begin_mapper,
62+ end_mapper,
63+ mapper_fn_ty,
64+ ident_t_global,
65+ register_lib,
66+ unregister_lib,
67+ init_rtls,
68+ }
69+ }
70+ }
71+
1372// ; Function Attrs: nounwind
1473// declare i32 @__tgt_target_kernel(ptr, i64, i32, i32, ptr, ptr) #2
15- fn generate_launcher < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> ( & ' ll llvm:: Value , & ' ll llvm:: Type ) {
74+ fn generate_launcher < ' ll > ( cx : & CodegenCx < ' ll , ' _ > ) -> ( & ' ll llvm:: Value , & ' ll llvm:: Type ) {
1675 let tptr = cx. type_ptr ( ) ;
1776 let ti64 = cx. type_i64 ( ) ;
1877 let ti32 = cx. type_i32 ( ) ;
@@ -30,7 +89,7 @@ fn generate_launcher<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Value, &'ll llvm
3089// @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
3190// FIXME(offload): @0 should include the file name (e.g. lib.rs) in which the function to be
3291// offloaded was defined.
33- fn generate_at_one < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> & ' ll llvm:: Value {
92+ pub ( crate ) fn generate_at_one < ' ll > ( cx : & CodegenCx < ' ll , ' _ > ) -> & ' ll llvm:: Value {
3493 let unknown_txt = ";unknown;unknown;0;0;;" ;
3594 let c_entry_name = CString :: new ( unknown_txt) . unwrap ( ) ;
3695 let c_val = c_entry_name. as_bytes_with_nul ( ) ;
@@ -68,7 +127,7 @@ pub(crate) struct TgtOffloadEntry {
68127}
69128
70129impl TgtOffloadEntry {
71- pub ( crate ) fn new_decl < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> & ' ll llvm:: Type {
130+ pub ( crate ) fn new_decl < ' ll > ( cx : & CodegenCx < ' ll , ' _ > ) -> & ' ll llvm:: Type {
72131 let offload_entry_ty = cx. type_named_struct ( "struct.__tgt_offload_entry" ) ;
73132 let tptr = cx. type_ptr ( ) ;
74133 let ti64 = cx. type_i64 ( ) ;
@@ -82,7 +141,7 @@ impl TgtOffloadEntry {
82141 }
83142
84143 fn new < ' ll > (
85- cx : & ' ll SimpleCx < ' _ > ,
144+ cx : & CodegenCx < ' ll , ' _ > ,
86145 region_id : & ' ll Value ,
87146 llglobal : & ' ll Value ,
88147 ) -> [ & ' ll Value ; 9 ] {
@@ -126,7 +185,7 @@ impl KernelArgsTy {
126185 const OFFLOAD_VERSION : u64 = 3 ;
127186 const FLAGS : u64 = 0 ;
128187 const TRIPCOUNT : u64 = 0 ;
129- fn new_decl < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> & ' ll Type {
188+ fn new_decl < ' ll > ( cx : & CodegenCx < ' ll , ' _ > ) -> & ' ll Type {
130189 let kernel_arguments_ty = cx. type_named_struct ( "struct.__tgt_kernel_arguments" ) ;
131190 let tptr = cx. type_ptr ( ) ;
132191 let ti64 = cx. type_i64 ( ) ;
@@ -140,8 +199,8 @@ impl KernelArgsTy {
140199 kernel_arguments_ty
141200 }
142201
143- fn new < ' ll > (
144- cx : & ' ll SimpleCx < ' _ > ,
202+ fn new < ' ll , ' tcx > (
203+ cx : & CodegenCx < ' ll , ' tcx > ,
145204 num_args : u64 ,
146205 memtransfer_types : & ' ll Value ,
147206 geps : [ & ' ll Value ; 3 ] ,
@@ -171,15 +230,16 @@ impl KernelArgsTy {
171230}
172231
173232// Contains LLVM values needed to manage offloading for a single kernel.
174- pub ( crate ) struct OffloadKernelData < ' ll > {
233+ #[ derive( Copy , Clone ) ]
234+ pub ( crate ) struct OffloadKernelGlobals < ' ll > {
175235 pub offload_sizes : & ' ll llvm:: Value ,
176236 pub memtransfer_types : & ' ll llvm:: Value ,
177237 pub region_id : & ' ll llvm:: Value ,
178238 pub offload_entry : & ' ll llvm:: Value ,
179239}
180240
181241fn gen_tgt_data_mappers < ' ll > (
182- cx : & ' ll SimpleCx < ' _ > ,
242+ cx : & CodegenCx < ' ll , ' _ > ,
183243) -> ( & ' ll llvm:: Value , & ' ll llvm:: Value , & ' ll llvm:: Value , & ' ll llvm:: Type ) {
184244 let tptr = cx. type_ptr ( ) ;
185245 let ti64 = cx. type_i64 ( ) ;
@@ -241,12 +301,18 @@ pub(crate) fn add_global<'ll>(
241301// mapped to/from the gpu. It also returns a region_id with the name of this kernel, to be
242302// concatenated into the list of region_ids.
243303pub ( crate ) fn gen_define_handling < ' ll > (
244- cx : & SimpleCx < ' ll > ,
245- offload_entry_ty : & ' ll llvm:: Type ,
304+ cx : & CodegenCx < ' ll , ' _ > ,
246305 metadata : & [ OffloadMetadata ] ,
247- types : & [ & Type ] ,
248- symbol : & str ,
249- ) -> OffloadKernelData < ' ll > {
306+ types : & [ & ' ll Type ] ,
307+ symbol : String ,
308+ offload_globals : & OffloadGlobals < ' ll > ,
309+ ) -> OffloadKernelGlobals < ' ll > {
310+ if let Some ( entry) = cx. offload_kernel_cache . borrow ( ) . get ( & symbol) {
311+ return * entry;
312+ }
313+
314+ let offload_entry_ty = offload_globals. offload_entry_ty ;
315+
250316 // It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
251317 // reference) types.
252318 let ptr_meta = types. iter ( ) . zip ( metadata) . filter_map ( |( & x, meta) | match cx. type_kind ( x) {
@@ -274,7 +340,7 @@ pub(crate) fn gen_define_handling<'ll>(
274340 let initializer = cx. get_const_i8 ( 0 ) ;
275341 let region_id = add_unnamed_global ( & cx, & name, initializer, WeakAnyLinkage ) ;
276342
277- let c_entry_name = CString :: new ( symbol) . unwrap ( ) ;
343+ let c_entry_name = CString :: new ( symbol. clone ( ) ) . unwrap ( ) ;
278344 let c_val = c_entry_name. as_bytes_with_nul ( ) ;
279345 let offload_entry_name = format ! ( ".offloading.entry_name.{symbol}" ) ;
280346
@@ -298,11 +364,16 @@ pub(crate) fn gen_define_handling<'ll>(
298364 let c_section_name = CString :: new ( "llvm_offload_entries" ) . unwrap ( ) ;
299365 llvm:: set_section ( offload_entry, & c_section_name) ;
300366
301- OffloadKernelData { offload_sizes, memtransfer_types, region_id, offload_entry }
367+ let result =
368+ OffloadKernelGlobals { offload_sizes, memtransfer_types, region_id, offload_entry } ;
369+
370+ cx. offload_kernel_cache . borrow_mut ( ) . insert ( symbol, result) ;
371+
372+ result
302373}
303374
304375fn declare_offload_fn < ' ll > (
305- cx : & ' ll SimpleCx < ' _ > ,
376+ cx : & CodegenCx < ' ll , ' _ > ,
306377 name : & str ,
307378 ty : & ' ll llvm:: Type ,
308379) -> & ' ll llvm:: Value {
@@ -335,28 +406,28 @@ fn declare_offload_fn<'ll>(
335406// 4. set insert point after kernel call.
336407// 5. generate all the GEPS and stores, to be used in 6)
337408// 6. generate __tgt_target_data_end calls to move data from the GPU
338- pub ( crate ) fn gen_call_handling < ' ll > (
339- cx : & SimpleCx < ' ll > ,
340- bb : & BasicBlock ,
341- offload_data : & OffloadKernelData < ' ll > ,
409+ pub ( crate ) fn gen_call_handling < ' ll , ' tcx > (
410+ builder : & mut Builder < ' _ , ' ll , ' tcx > ,
411+ offload_data : & OffloadKernelGlobals < ' ll > ,
342412 args : & [ & ' ll Value ] ,
343413 types : & [ & Type ] ,
344414 metadata : & [ OffloadMetadata ] ,
415+ offload_globals : & OffloadGlobals < ' ll > ,
345416) {
346- let OffloadKernelData { offload_sizes, offload_entry, memtransfer_types, region_id } =
417+ let cx = builder. cx ;
418+ let OffloadKernelGlobals { offload_sizes, offload_entry, memtransfer_types, region_id } =
347419 offload_data;
348- let ( tgt_decl, tgt_target_kernel_ty) = generate_launcher ( & cx) ;
349- // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
350- let tptr = cx. type_ptr ( ) ;
351- let ti32 = cx. type_i32 ( ) ;
352- let tgt_bin_desc_ty = vec ! [ ti32, tptr, tptr, tptr] ;
353- let tgt_bin_desc = cx. type_named_struct ( "struct.__tgt_bin_desc" ) ;
354- cx. set_struct_body ( tgt_bin_desc, & tgt_bin_desc_ty, false ) ;
355420
356- let tgt_kernel_decl = KernelArgsTy :: new_decl ( & cx ) ;
357- let ( begin_mapper_decl , _ , end_mapper_decl , fn_ty ) = gen_tgt_data_mappers ( & cx ) ;
421+ let tgt_decl = offload_globals . launcher_fn ;
422+ let tgt_target_kernel_ty = offload_globals . launcher_ty ;
358423
359- let mut builder = SBuilder :: build ( cx, bb) ;
424+ // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
425+ let tgt_bin_desc = offload_globals. bin_desc ;
426+
427+ let tgt_kernel_decl = offload_globals. kernel_args_ty ;
428+ let begin_mapper_decl = offload_globals. begin_mapper ;
429+ let end_mapper_decl = offload_globals. end_mapper ;
430+ let fn_ty = offload_globals. mapper_fn_ty ;
360431
361432 let num_args = types. len ( ) as u64 ;
362433 let ip = unsafe { llvm:: LLVMRustGetInsertPoint ( & builder. llbuilder ) } ;
@@ -378,9 +449,8 @@ pub(crate) fn gen_call_handling<'ll>(
378449 // Step 0)
379450 // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
380451 // %6 = alloca %struct.__tgt_bin_desc, align 8
381- let llfn = unsafe { llvm:: LLVMGetBasicBlockParent ( bb) } ;
382452 unsafe {
383- llvm:: LLVMRustPositionBuilderPastAllocas ( & builder. llbuilder , llfn) ;
453+ llvm:: LLVMRustPositionBuilderPastAllocas ( & builder. llbuilder , builder . llfn ( ) ) ;
384454 }
385455 let tgt_bin_desc_alloca = builder. direct_alloca ( tgt_bin_desc, Align :: EIGHT , "EmptyDesc" ) ;
386456
@@ -413,16 +483,16 @@ pub(crate) fn gen_call_handling<'ll>(
413483 }
414484
415485 let mapper_fn_ty = cx. type_func ( & [ cx. type_ptr ( ) ] , cx. type_void ( ) ) ;
416- let register_lib_decl = declare_offload_fn ( & cx , "__tgt_register_lib" , mapper_fn_ty ) ;
417- let unregister_lib_decl = declare_offload_fn ( & cx , "__tgt_unregister_lib" , mapper_fn_ty ) ;
486+ let register_lib_decl = offload_globals . register_lib ;
487+ let unregister_lib_decl = offload_globals . unregister_lib ;
418488 let init_ty = cx. type_func ( & [ ] , cx. type_void ( ) ) ;
419- let init_rtls_decl = declare_offload_fn ( cx , "__tgt_init_all_rtls" , init_ty ) ;
489+ let init_rtls_decl = offload_globals . init_rtls ;
420490
421491 // FIXME(offload): Later we want to add them to the wrapper code, rather than our main function.
422492 // call void @__tgt_register_lib(ptr noundef %6)
423- builder. call ( mapper_fn_ty, register_lib_decl, & [ tgt_bin_desc_alloca] , None ) ;
493+ builder. call ( mapper_fn_ty, None , None , register_lib_decl, & [ tgt_bin_desc_alloca] , None , None ) ;
424494 // call void @__tgt_init_all_rtls()
425- builder. call ( init_ty, init_rtls_decl, & [ ] , None ) ;
495+ builder. call ( init_ty, None , None , init_rtls_decl, & [ ] , None , None ) ;
426496
427497 for i in 0 ..num_args {
428498 let idx = cx. get_const_i32 ( i) ;
@@ -437,15 +507,15 @@ pub(crate) fn gen_call_handling<'ll>(
437507
438508 // For now we have a very simplistic indexing scheme into our
439509 // offload_{baseptrs,ptrs,sizes}. We will probably improve this along with our gpu frontend pr.
440- fn get_geps < ' a , ' ll > (
441- builder : & mut SBuilder < ' a , ' ll > ,
442- cx : & ' ll SimpleCx < ' ll > ,
510+ fn get_geps < ' ll , ' tcx > (
511+ builder : & mut Builder < ' _ , ' ll , ' tcx > ,
443512 ty : & ' ll Type ,
444513 ty2 : & ' ll Type ,
445514 a1 : & ' ll Value ,
446515 a2 : & ' ll Value ,
447516 a4 : & ' ll Value ,
448517 ) -> [ & ' ll Value ; 3 ] {
518+ let cx = builder. cx ;
449519 let i32_0 = cx. get_const_i32 ( 0 ) ;
450520
451521 let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, i32_0] ) ;
@@ -454,30 +524,29 @@ pub(crate) fn gen_call_handling<'ll>(
454524 [ gep1, gep2, gep3]
455525 }
456526
457- fn generate_mapper_call < ' a , ' ll > (
458- builder : & mut SBuilder < ' a , ' ll > ,
459- cx : & ' ll SimpleCx < ' ll > ,
527+ fn generate_mapper_call < ' ll , ' tcx > (
528+ builder : & mut Builder < ' _ , ' ll , ' tcx > ,
460529 geps : [ & ' ll Value ; 3 ] ,
461530 o_type : & ' ll Value ,
462531 fn_to_call : & ' ll Value ,
463532 fn_ty : & ' ll Type ,
464533 num_args : u64 ,
465534 s_ident_t : & ' ll Value ,
466535 ) {
536+ let cx = builder. cx ;
467537 let nullptr = cx. const_null ( cx. type_ptr ( ) ) ;
468538 let i64_max = cx. get_const_i64 ( u64:: MAX ) ;
469539 let num_args = cx. get_const_i32 ( num_args) ;
470540 let args =
471541 vec ! [ s_ident_t, i64_max, num_args, geps[ 0 ] , geps[ 1 ] , geps[ 2 ] , o_type, nullptr, nullptr] ;
472- builder. call ( fn_ty, fn_to_call, & args, None ) ;
542+ builder. call ( fn_ty, None , None , fn_to_call, & args, None , None ) ;
473543 }
474544
475545 // Step 2)
476- let s_ident_t = generate_at_one ( & cx ) ;
477- let geps = get_geps ( & mut builder, & cx , ty, ty2, a1, a2, a4) ;
546+ let s_ident_t = offload_globals . ident_t_global ;
547+ let geps = get_geps ( builder, ty, ty2, a1, a2, a4) ;
478548 generate_mapper_call (
479- & mut builder,
480- & cx,
549+ builder,
481550 geps,
482551 memtransfer_types,
483552 begin_mapper_decl,
@@ -504,14 +573,13 @@ pub(crate) fn gen_call_handling<'ll>(
504573 region_id,
505574 a5,
506575 ] ;
507- builder. call ( tgt_target_kernel_ty, tgt_decl, & args, None ) ;
576+ builder. call ( tgt_target_kernel_ty, None , None , tgt_decl, & args, None , None ) ;
508577 // %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args)
509578
510579 // Step 4)
511- let geps = get_geps ( & mut builder, & cx , ty, ty2, a1, a2, a4) ;
580+ let geps = get_geps ( builder, ty, ty2, a1, a2, a4) ;
512581 generate_mapper_call (
513- & mut builder,
514- & cx,
582+ builder,
515583 geps,
516584 memtransfer_types,
517585 end_mapper_decl,
@@ -520,7 +588,5 @@ pub(crate) fn gen_call_handling<'ll>(
520588 s_ident_t,
521589 ) ;
522590
523- builder. call ( mapper_fn_ty, unregister_lib_decl, & [ tgt_bin_desc_alloca] , None ) ;
524-
525- drop ( builder) ;
591+ builder. call ( mapper_fn_ty, None , None , unregister_lib_decl, & [ tgt_bin_desc_alloca] , None , None ) ;
526592}
0 commit comments