Skip to content

Commit c9cf689

Browse files
committed
Split runtime global logic and cache kernel specific one
1 parent dfe1b8c commit c9cf689

File tree

6 files changed

+159
-84
lines changed

6 files changed

+159
-84
lines changed

compiler/rustc_codegen_llvm/src/base.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,14 @@ use rustc_middle::dep_graph;
2323
use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrs, SanitizerFnAttrs};
2424
use rustc_middle::mir::mono::Visibility;
2525
use rustc_middle::ty::TyCtxt;
26-
use rustc_session::config::DebugInfo;
26+
use rustc_session::config::{DebugInfo, Offload};
2727
use rustc_span::Symbol;
2828
use rustc_target::spec::SanitizerSet;
2929

3030
use super::ModuleLlvm;
3131
use crate::attributes;
3232
use crate::builder::Builder;
33+
use crate::builder::gpu_offload::OffloadGlobals;
3334
use crate::context::CodegenCx;
3435
use crate::llvm::{self, Value};
3536

@@ -85,6 +86,13 @@ pub(crate) fn compile_codegen_unit(
8586
let llvm_module = ModuleLlvm::new(tcx, cgu_name.as_str());
8687
{
8788
let mut cx = CodegenCx::new(tcx, cgu, &llvm_module);
89+
90+
if cx.sess().opts.unstable_opts.offload.contains(&Offload::Enable)
91+
&& !cx.sess().target.is_like_gpu
92+
{
93+
cx.offload_globals.replace(Some(OffloadGlobals::declare(&cx)));
94+
}
95+
8896
let mono_items = cx.codegen_unit.items_in_deterministic_order(cx.tcx);
8997
for &(mono_item, data) in &mono_items {
9098
mono_item.predefine::<Builder<'_, '_, '_>>(

compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

Lines changed: 126 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,76 @@ use std::ffi::CString;
22

33
use llvm::Linkage::*;
44
use rustc_abi::Align;
5-
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
5+
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
66
use rustc_middle::ty::offload_meta::OffloadMetadata;
77

8-
use crate::builder::SBuilder;
8+
use crate::builder::Builder;
9+
use crate::common::CodegenCx;
910
use crate::llvm::AttributePlace::Function;
10-
use crate::llvm::{self, BasicBlock, Linkage, Type, Value};
11+
use crate::llvm::{self, Linkage, Type, Value};
1112
use crate::{SimpleCx, attributes};
1213

14+
// LLVM kernel-independent globals required for offloading
15+
pub(crate) struct OffloadGlobals<'ll> {
16+
pub launcher_fn: &'ll llvm::Value,
17+
pub launcher_ty: &'ll llvm::Type,
18+
19+
pub bin_desc: &'ll llvm::Type,
20+
21+
pub kernel_args_ty: &'ll llvm::Type,
22+
23+
pub offload_entry_ty: &'ll llvm::Type,
24+
25+
pub begin_mapper: &'ll llvm::Value,
26+
pub end_mapper: &'ll llvm::Value,
27+
pub mapper_fn_ty: &'ll llvm::Type,
28+
29+
pub ident_t_global: &'ll llvm::Value,
30+
31+
pub register_lib: &'ll llvm::Value,
32+
pub unregister_lib: &'ll llvm::Value,
33+
pub init_rtls: &'ll llvm::Value,
34+
}
35+
36+
impl<'ll> OffloadGlobals<'ll> {
37+
pub(crate) fn declare(cx: &CodegenCx<'ll, '_>) -> Self {
38+
let (launcher_fn, launcher_ty) = generate_launcher(cx);
39+
let kernel_args_ty = KernelArgsTy::new_decl(cx);
40+
let offload_entry_ty = TgtOffloadEntry::new_decl(cx);
41+
let (begin_mapper, _, end_mapper, mapper_fn_ty) = gen_tgt_data_mappers(cx);
42+
let ident_t_global = generate_at_one(cx);
43+
44+
let tptr = cx.type_ptr();
45+
let ti32 = cx.type_i32();
46+
let tgt_bin_desc_ty = vec![ti32, tptr, tptr, tptr];
47+
let bin_desc = cx.type_named_struct("struct.__tgt_bin_desc");
48+
cx.set_struct_body(bin_desc, &tgt_bin_desc_ty, false);
49+
50+
let register_lib = declare_offload_fn(&cx, "__tgt_register_lib", mapper_fn_ty);
51+
let unregister_lib = declare_offload_fn(&cx, "__tgt_unregister_lib", mapper_fn_ty);
52+
let init_ty = cx.type_func(&[], cx.type_void());
53+
let init_rtls = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty);
54+
55+
OffloadGlobals {
56+
launcher_fn,
57+
launcher_ty,
58+
bin_desc,
59+
kernel_args_ty,
60+
offload_entry_ty,
61+
begin_mapper,
62+
end_mapper,
63+
mapper_fn_ty,
64+
ident_t_global,
65+
register_lib,
66+
unregister_lib,
67+
init_rtls,
68+
}
69+
}
70+
}
71+
1372
// ; Function Attrs: nounwind
1473
// declare i32 @__tgt_target_kernel(ptr, i64, i32, i32, ptr, ptr) #2
15-
fn generate_launcher<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Value, &'ll llvm::Type) {
74+
fn generate_launcher<'ll>(cx: &CodegenCx<'ll, '_>) -> (&'ll llvm::Value, &'ll llvm::Type) {
1675
let tptr = cx.type_ptr();
1776
let ti64 = cx.type_i64();
1877
let ti32 = cx.type_i32();
@@ -30,7 +89,7 @@ fn generate_launcher<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Value, &'ll llvm
3089
// @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
3190
// FIXME(offload): @0 should include the file name (e.g. lib.rs) in which the function to be
3291
// offloaded was defined.
33-
fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
92+
pub(crate) fn generate_at_one<'ll>(cx: &CodegenCx<'ll, '_>) -> &'ll llvm::Value {
3493
let unknown_txt = ";unknown;unknown;0;0;;";
3594
let c_entry_name = CString::new(unknown_txt).unwrap();
3695
let c_val = c_entry_name.as_bytes_with_nul();
@@ -68,7 +127,7 @@ pub(crate) struct TgtOffloadEntry {
68127
}
69128

70129
impl TgtOffloadEntry {
71-
pub(crate) fn new_decl<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
130+
pub(crate) fn new_decl<'ll>(cx: &CodegenCx<'ll, '_>) -> &'ll llvm::Type {
72131
let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry");
73132
let tptr = cx.type_ptr();
74133
let ti64 = cx.type_i64();
@@ -82,7 +141,7 @@ impl TgtOffloadEntry {
82141
}
83142

84143
fn new<'ll>(
85-
cx: &'ll SimpleCx<'_>,
144+
cx: &CodegenCx<'ll, '_>,
86145
region_id: &'ll Value,
87146
llglobal: &'ll Value,
88147
) -> [&'ll Value; 9] {
@@ -126,7 +185,7 @@ impl KernelArgsTy {
126185
const OFFLOAD_VERSION: u64 = 3;
127186
const FLAGS: u64 = 0;
128187
const TRIPCOUNT: u64 = 0;
129-
fn new_decl<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll Type {
188+
fn new_decl<'ll>(cx: &CodegenCx<'ll, '_>) -> &'ll Type {
130189
let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
131190
let tptr = cx.type_ptr();
132191
let ti64 = cx.type_i64();
@@ -140,8 +199,8 @@ impl KernelArgsTy {
140199
kernel_arguments_ty
141200
}
142201

143-
fn new<'ll>(
144-
cx: &'ll SimpleCx<'_>,
202+
fn new<'ll, 'tcx>(
203+
cx: &CodegenCx<'ll, 'tcx>,
145204
num_args: u64,
146205
memtransfer_types: &'ll Value,
147206
geps: [&'ll Value; 3],
@@ -171,15 +230,16 @@ impl KernelArgsTy {
171230
}
172231

173232
// Contains LLVM values needed to manage offloading for a single kernel.
174-
pub(crate) struct OffloadKernelData<'ll> {
233+
#[derive(Copy, Clone)]
234+
pub(crate) struct OffloadKernelGlobals<'ll> {
175235
pub offload_sizes: &'ll llvm::Value,
176236
pub memtransfer_types: &'ll llvm::Value,
177237
pub region_id: &'ll llvm::Value,
178238
pub offload_entry: &'ll llvm::Value,
179239
}
180240

181241
fn gen_tgt_data_mappers<'ll>(
182-
cx: &'ll SimpleCx<'_>,
242+
cx: &CodegenCx<'ll, '_>,
183243
) -> (&'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Type) {
184244
let tptr = cx.type_ptr();
185245
let ti64 = cx.type_i64();
@@ -241,12 +301,18 @@ pub(crate) fn add_global<'ll>(
241301
// mapped to/from the gpu. It also returns a region_id with the name of this kernel, to be
242302
// concatenated into the list of region_ids.
243303
pub(crate) fn gen_define_handling<'ll>(
244-
cx: &SimpleCx<'ll>,
245-
offload_entry_ty: &'ll llvm::Type,
304+
cx: &CodegenCx<'ll, '_>,
246305
metadata: &[OffloadMetadata],
247-
types: &[&Type],
248-
symbol: &str,
249-
) -> OffloadKernelData<'ll> {
306+
types: &[&'ll Type],
307+
symbol: String,
308+
offload_globals: &OffloadGlobals<'ll>,
309+
) -> OffloadKernelGlobals<'ll> {
310+
if let Some(entry) = cx.offload_kernel_cache.borrow().get(&symbol) {
311+
return *entry;
312+
}
313+
314+
let offload_entry_ty = offload_globals.offload_entry_ty;
315+
250316
// It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
251317
// reference) types.
252318
let ptr_meta = types.iter().zip(metadata).filter_map(|(&x, meta)| match cx.type_kind(x) {
@@ -274,7 +340,7 @@ pub(crate) fn gen_define_handling<'ll>(
274340
let initializer = cx.get_const_i8(0);
275341
let region_id = add_unnamed_global(&cx, &name, initializer, WeakAnyLinkage);
276342

277-
let c_entry_name = CString::new(symbol).unwrap();
343+
let c_entry_name = CString::new(symbol.clone()).unwrap();
278344
let c_val = c_entry_name.as_bytes_with_nul();
279345
let offload_entry_name = format!(".offloading.entry_name.{symbol}");
280346

@@ -298,11 +364,16 @@ pub(crate) fn gen_define_handling<'ll>(
298364
let c_section_name = CString::new("llvm_offload_entries").unwrap();
299365
llvm::set_section(offload_entry, &c_section_name);
300366

301-
OffloadKernelData { offload_sizes, memtransfer_types, region_id, offload_entry }
367+
let result =
368+
OffloadKernelGlobals { offload_sizes, memtransfer_types, region_id, offload_entry };
369+
370+
cx.offload_kernel_cache.borrow_mut().insert(symbol, result);
371+
372+
result
302373
}
303374

304375
fn declare_offload_fn<'ll>(
305-
cx: &'ll SimpleCx<'_>,
376+
cx: &CodegenCx<'ll, '_>,
306377
name: &str,
307378
ty: &'ll llvm::Type,
308379
) -> &'ll llvm::Value {
@@ -335,28 +406,28 @@ fn declare_offload_fn<'ll>(
335406
// 4. set insert point after kernel call.
336407
// 5. generate all the GEPS and stores, to be used in 6)
337408
// 6. generate __tgt_target_data_end calls to move data from the GPU
338-
pub(crate) fn gen_call_handling<'ll>(
339-
cx: &SimpleCx<'ll>,
340-
bb: &BasicBlock,
341-
offload_data: &OffloadKernelData<'ll>,
409+
pub(crate) fn gen_call_handling<'ll, 'tcx>(
410+
builder: &mut Builder<'_, 'll, 'tcx>,
411+
offload_data: &OffloadKernelGlobals<'ll>,
342412
args: &[&'ll Value],
343413
types: &[&Type],
344414
metadata: &[OffloadMetadata],
415+
offload_globals: &OffloadGlobals<'ll>,
345416
) {
346-
let OffloadKernelData { offload_sizes, offload_entry, memtransfer_types, region_id } =
417+
let cx = builder.cx;
418+
let OffloadKernelGlobals { offload_sizes, offload_entry, memtransfer_types, region_id } =
347419
offload_data;
348-
let (tgt_decl, tgt_target_kernel_ty) = generate_launcher(&cx);
349-
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
350-
let tptr = cx.type_ptr();
351-
let ti32 = cx.type_i32();
352-
let tgt_bin_desc_ty = vec![ti32, tptr, tptr, tptr];
353-
let tgt_bin_desc = cx.type_named_struct("struct.__tgt_bin_desc");
354-
cx.set_struct_body(tgt_bin_desc, &tgt_bin_desc_ty, false);
355420

356-
let tgt_kernel_decl = KernelArgsTy::new_decl(&cx);
357-
let (begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers(&cx);
421+
let tgt_decl = offload_globals.launcher_fn;
422+
let tgt_target_kernel_ty = offload_globals.launcher_ty;
358423

359-
let mut builder = SBuilder::build(cx, bb);
424+
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
425+
let tgt_bin_desc = offload_globals.bin_desc;
426+
427+
let tgt_kernel_decl = offload_globals.kernel_args_ty;
428+
let begin_mapper_decl = offload_globals.begin_mapper;
429+
let end_mapper_decl = offload_globals.end_mapper;
430+
let fn_ty = offload_globals.mapper_fn_ty;
360431

361432
let num_args = types.len() as u64;
362433
let ip = unsafe { llvm::LLVMRustGetInsertPoint(&builder.llbuilder) };
@@ -378,9 +449,8 @@ pub(crate) fn gen_call_handling<'ll>(
378449
// Step 0)
379450
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
380451
// %6 = alloca %struct.__tgt_bin_desc, align 8
381-
let llfn = unsafe { llvm::LLVMGetBasicBlockParent(bb) };
382452
unsafe {
383-
llvm::LLVMRustPositionBuilderPastAllocas(&builder.llbuilder, llfn);
453+
llvm::LLVMRustPositionBuilderPastAllocas(&builder.llbuilder, builder.llfn());
384454
}
385455
let tgt_bin_desc_alloca = builder.direct_alloca(tgt_bin_desc, Align::EIGHT, "EmptyDesc");
386456

@@ -413,16 +483,16 @@ pub(crate) fn gen_call_handling<'ll>(
413483
}
414484

415485
let mapper_fn_ty = cx.type_func(&[cx.type_ptr()], cx.type_void());
416-
let register_lib_decl = declare_offload_fn(&cx, "__tgt_register_lib", mapper_fn_ty);
417-
let unregister_lib_decl = declare_offload_fn(&cx, "__tgt_unregister_lib", mapper_fn_ty);
486+
let register_lib_decl = offload_globals.register_lib;
487+
let unregister_lib_decl = offload_globals.unregister_lib;
418488
let init_ty = cx.type_func(&[], cx.type_void());
419-
let init_rtls_decl = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty);
489+
let init_rtls_decl = offload_globals.init_rtls;
420490

421491
// FIXME(offload): Later we want to add them to the wrapper code, rather than our main function.
422492
// call void @__tgt_register_lib(ptr noundef %6)
423-
builder.call(mapper_fn_ty, register_lib_decl, &[tgt_bin_desc_alloca], None);
493+
builder.call(mapper_fn_ty, None, None, register_lib_decl, &[tgt_bin_desc_alloca], None, None);
424494
// call void @__tgt_init_all_rtls()
425-
builder.call(init_ty, init_rtls_decl, &[], None);
495+
builder.call(init_ty, None, None, init_rtls_decl, &[], None, None);
426496

427497
for i in 0..num_args {
428498
let idx = cx.get_const_i32(i);
@@ -437,15 +507,15 @@ pub(crate) fn gen_call_handling<'ll>(
437507

438508
// For now we have a very simplistic indexing scheme into our
439509
// offload_{baseptrs,ptrs,sizes}. We will probably improve this along with our gpu frontend pr.
440-
fn get_geps<'a, 'll>(
441-
builder: &mut SBuilder<'a, 'll>,
442-
cx: &'ll SimpleCx<'ll>,
510+
fn get_geps<'ll, 'tcx>(
511+
builder: &mut Builder<'_, 'll, 'tcx>,
443512
ty: &'ll Type,
444513
ty2: &'ll Type,
445514
a1: &'ll Value,
446515
a2: &'ll Value,
447516
a4: &'ll Value,
448517
) -> [&'ll Value; 3] {
518+
let cx = builder.cx;
449519
let i32_0 = cx.get_const_i32(0);
450520

451521
let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
@@ -454,30 +524,29 @@ pub(crate) fn gen_call_handling<'ll>(
454524
[gep1, gep2, gep3]
455525
}
456526

457-
fn generate_mapper_call<'a, 'll>(
458-
builder: &mut SBuilder<'a, 'll>,
459-
cx: &'ll SimpleCx<'ll>,
527+
fn generate_mapper_call<'ll, 'tcx>(
528+
builder: &mut Builder<'_, 'll, 'tcx>,
460529
geps: [&'ll Value; 3],
461530
o_type: &'ll Value,
462531
fn_to_call: &'ll Value,
463532
fn_ty: &'ll Type,
464533
num_args: u64,
465534
s_ident_t: &'ll Value,
466535
) {
536+
let cx = builder.cx;
467537
let nullptr = cx.const_null(cx.type_ptr());
468538
let i64_max = cx.get_const_i64(u64::MAX);
469539
let num_args = cx.get_const_i32(num_args);
470540
let args =
471541
vec![s_ident_t, i64_max, num_args, geps[0], geps[1], geps[2], o_type, nullptr, nullptr];
472-
builder.call(fn_ty, fn_to_call, &args, None);
542+
builder.call(fn_ty, None, None, fn_to_call, &args, None, None);
473543
}
474544

475545
// Step 2)
476-
let s_ident_t = generate_at_one(&cx);
477-
let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
546+
let s_ident_t = offload_globals.ident_t_global;
547+
let geps = get_geps(builder, ty, ty2, a1, a2, a4);
478548
generate_mapper_call(
479-
&mut builder,
480-
&cx,
549+
builder,
481550
geps,
482551
memtransfer_types,
483552
begin_mapper_decl,
@@ -504,14 +573,13 @@ pub(crate) fn gen_call_handling<'ll>(
504573
region_id,
505574
a5,
506575
];
507-
builder.call(tgt_target_kernel_ty, tgt_decl, &args, None);
576+
builder.call(tgt_target_kernel_ty, None, None, tgt_decl, &args, None, None);
508577
// %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args)
509578

510579
// Step 4)
511-
let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
580+
let geps = get_geps(builder, ty, ty2, a1, a2, a4);
512581
generate_mapper_call(
513-
&mut builder,
514-
&cx,
582+
builder,
515583
geps,
516584
memtransfer_types,
517585
end_mapper_decl,
@@ -520,7 +588,5 @@ pub(crate) fn gen_call_handling<'ll>(
520588
s_ident_t,
521589
);
522590

523-
builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None);
524-
525-
drop(builder);
591+
builder.call(mapper_fn_ty, None, None, unregister_lib_decl, &[tgt_bin_desc_alloca], None, None);
526592
}

0 commit comments

Comments
 (0)