Skip to content

Commit 9b2522b

Browse files
committed
fix(abi): preserve array alignment and apply 16-byte workaround generically
This fixes an IllegalAddress error when passing 16-byte aligned arrays (e.g. [u128; 4]) or tuples as kernel parameters. Previously, arrays indiscriminately forced PassMode::Direct(ArgAttributes::new()), discarding any alignment metadata computed by rustc. Furthermore, the 16-byte PassMode::Cast workaround was only applied to ADTs. This change abstracts the 16-byte workaround to apply to any aggregate type (arrays, ADTs, tuples) with >=16-byte alignment, and properly preserves alignment metadata for arrays when PassMode::Direct is used.
1 parent 103a8d5 commit 9b2522b

1 file changed

Lines changed: 30 additions & 32 deletions

File tree

  • crates/rustc_codegen_nvvm/src

crates/rustc_codegen_nvvm/src/abi.rs

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -53,38 +53,36 @@ pub(crate) fn readjust_fn_abi<'tcx>(
5353
arg.mode = PassMode::Pair(ptr_attrs, ArgAttributes::new());
5454
}
5555

56-
if arg.layout.ty.is_array() && !matches!(arg.mode, PassMode::Direct { .. }) {
57-
arg.mode = PassMode::Direct(ArgAttributes::new());
58-
}
59-
60-
// pass all adts directly as values, ptx wants them to be passed all by value, but rustc's
61-
// ptx-kernel abi seems to be wrong, and it's unstable.
62-
if arg.layout.ty.is_adt() {
63-
let align = arg.layout.align.abi.bytes();
64-
if align >= 16 && align.is_power_of_two() {
65-
let unit = Reg {
66-
kind: RegKind::Integer,
67-
size: Size::from_bytes(16),
68-
};
69-
let cast = CastTarget {
70-
prefix: Default::default(),
71-
rest: rustc_target::callconv::Uniform {
72-
unit,
73-
total: arg.layout.size,
74-
is_consecutive: false,
75-
},
76-
rest_offset: Some(Size::ZERO),
77-
attrs: ArgAttributes::new(),
78-
};
79-
arg.mode = PassMode::Cast {
80-
cast: Box::new(cast),
81-
pad_i32: false,
82-
};
83-
} else if !matches!(arg.mode, PassMode::Direct { .. }) {
84-
let mut attrs = ArgAttributes::new();
85-
attrs.pointee_align = Some(arg.layout.align.abi);
86-
arg.mode = PassMode::Direct(attrs);
87-
}
56+
let is_aggregate = arg.layout.ty.is_array()
57+
|| arg.layout.ty.is_adt()
58+
|| matches!(arg.layout.ty.kind(), TyKind::Tuple(_));
59+
let align = arg.layout.align.abi.bytes();
60+
61+
if is_aggregate && align >= 16 && align.is_power_of_two() {
62+
let unit = Reg {
63+
kind: RegKind::Integer,
64+
size: Size::from_bytes(16),
65+
};
66+
let cast = CastTarget {
67+
prefix: Default::default(),
68+
rest: rustc_target::callconv::Uniform {
69+
unit,
70+
total: arg.layout.size,
71+
is_consecutive: false,
72+
},
73+
rest_offset: Some(Size::ZERO),
74+
attrs: ArgAttributes::new(),
75+
};
76+
arg.mode = PassMode::Cast {
77+
cast: Box::new(cast),
78+
pad_i32: false,
79+
};
80+
} else if (arg.layout.ty.is_array() || arg.layout.ty.is_adt())
81+
&& !matches!(arg.mode, PassMode::Direct { .. })
82+
{
83+
let mut attrs = ArgAttributes::new();
84+
attrs.pointee_align = Some(arg.layout.align.abi);
85+
arg.mode = PassMode::Direct(attrs);
8886
}
8987
arg
9088
};

0 commit comments

Comments
 (0)