Skip to content

Improve the calculation of location sizes for arrays and structs (revives #513) #798

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
37 changes: 36 additions & 1 deletion crates/rustc_codegen_spirv/src/codegen_cx/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ impl<'tcx> CodegenCx<'tcx> {
Decoration::Location,
std::iter::once(Operand::LiteralInt32(*location)),
);
*location += 1;
*location += self.location_count_of_type(value_spirv_type);
}

// Emit the `OpVariable` with its *Result* ID set to `var`.
Expand All @@ -565,6 +565,41 @@ impl<'tcx> CodegenCx<'tcx> {
}
}

fn location_count_of_type(&self, ty: Word) -> u32 {
match self.lookup_type(ty) {
// Arrays take up multiple locations.
SpirvType::Array { count, element } => {
self.builder
.lookup_const_u64(count)
.expect("Array type has invalid count value") as u32
* self.location_count_of_type(element)
}
// Structs take up one location per field.
SpirvType::Adt { field_types, .. } => {
let mut size = 0;

for field_type in field_types {
size += self.location_count_of_type(field_type);
}

size
}
SpirvType::Vector { element, count } => {
// 3 or 4 component vectors take up 2 locations if they have a 64-bit scalar type.
if count > 2 {
match self.lookup_type(element) {
SpirvType::Float(64) | SpirvType::Integer(64, _) => 2,
_ => 1,
}
} else {
1
}
}
SpirvType::Matrix { element, count } => count * self.location_count_of_type(element),
_ => 1,
}
}

// Booleans are only allowed in some storage classes. Error if they're in others.
// Integers and f64s must be decorated with `#[spirv(flat)]`.
fn check_for_bad_types(
Expand Down
31 changes: 31 additions & 0 deletions tests/ui/dis/array_location_calculation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// build-pass
// compile-flags: -C llvm-args=--disassemble-globals
// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> ""
// normalize-stderr-test "OpExtension .SPV_KHR_vulkan_memory_model.\n" -> ""
// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple"
// normalize-stderr-test "OpMemberName %12 0 .0.\n" -> ""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's up with this rewrite rule? It seems like it's probably hiding a bug that should be fixed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is here because the ordering of the OpMemberNames gets altered between runs:

running 1 test
diff of stderr:

 OpCapability Float64
 OpCapability Int16
 OpCapability Int64
 OpCapability Int8
 OpCapability ShaderClockKHR
 OpCapability Shader
 OpExtension "SPV_KHR_shader_clock"
 OpMemoryModel Logical Simple
 OpEntryPoint Fragment %1 "main" %2 %3 %4 %5 %6 %7 %8 %9
 OpExecutionMode %1 OriginUpperLeft
 %10 = OpString "$OPSTRING_FILENAME/array_location_calculation.rs"
 OpMemberName %11 0 "x_axis"
 OpMemberName %11 1 "y_axis"
 OpMemberName %11 2 "z_axis"
 OpName %11 "spirv_std::glam::core::storage::Columns3<spirv_std::glam::XYZ<f32>>"
 OpMemberName %12 0 "0"
 OpName %12 "spirv_std::glam::Mat3"
 OpName %13 "array_location_calculation::main"
 OpName %2 "one"
 OpName %3 "two"
 OpName %4 "three"
 OpName %5 "four"
 OpName %6 "five"
 OpName %7 "six"
 OpName %8 "seven"
 OpName %9 "eight"
 OpMemberName %11 0 "x_axis"
 OpMemberName %11 1 "y_axis"
 OpMemberName %11 2 "z_axis"
-OpMemberName %12 0 "0"
-OpMemberName %12 0 "0"
 OpMemberName %11 0 "x_axis"
 OpMemberName %11 1 "y_axis"
 OpMemberName %11 2 "z_axis"
+OpMemberName %12 0 "0"
+OpMemberName %12 0 "0"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, that seems like a bug.


use spirv_std::{
self as _,
glam::{DVec3, IVec4, Mat3, Vec4},
};

#[spirv(matrix)]
pub struct Mat4x3 {
pub col_0: Vec4,
pub col_1: Vec4,
pub col_2: Vec4,
}

#[spirv(fragment)]
pub fn main(
one: [f32; 7],
two: [f32; 3],
three: Mat3,
four: DVec3,
five: IVec4,
six: f32,
seven: Mat4x3,
eight: u32,
) {
}
79 changes: 79 additions & 0 deletions tests/ui/dis/array_location_calculation.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
OpCapability Float64
OpCapability Int16
OpCapability Int64
OpCapability Int8
OpCapability ShaderClockKHR
OpCapability Shader
OpExtension "SPV_KHR_shader_clock"
OpMemoryModel Logical Simple
OpEntryPoint Fragment %1 "main" %2 %3 %4 %5 %6 %7 %8 %9
OpExecutionMode %1 OriginUpperLeft
%10 = OpString "$OPSTRING_FILENAME/array_location_calculation.rs"
OpMemberName %11 0 "x_axis"
OpMemberName %11 1 "y_axis"
OpMemberName %11 2 "z_axis"
OpName %11 "spirv_std::glam::core::storage::Columns3<spirv_std::glam::XYZ<f32>>"
OpName %12 "spirv_std::glam::Mat3"
OpName %13 "array_location_calculation::main"
OpName %2 "one"
OpName %3 "two"
OpName %4 "three"
OpName %5 "four"
OpName %6 "five"
OpName %7 "six"
OpName %8 "seven"
OpName %9 "eight"
OpMemberName %11 0 "x_axis"
OpMemberName %11 1 "y_axis"
OpMemberName %11 2 "z_axis"
OpMemberName %11 0 "x_axis"
OpMemberName %11 1 "y_axis"
OpMemberName %11 2 "z_axis"
OpDecorate %14 ArrayStride 4
OpDecorate %15 ArrayStride 4
OpMemberDecorate %11 0 Offset 0
OpMemberDecorate %11 1 Offset 16
OpMemberDecorate %11 2 Offset 32
OpMemberDecorate %12 0 Offset 0
OpDecorate %2 Location 0
OpDecorate %3 Location 7
OpDecorate %4 Location 10
OpDecorate %5 Location 13
OpDecorate %6 Location 15
OpDecorate %7 Location 16
OpDecorate %8 Location 17
OpDecorate %9 Location 20
%16 = OpTypeVoid
%17 = OpTypeFloat 32
%18 = OpTypeInt 32 0
%19 = OpConstant %18 7
%14 = OpTypeArray %17 %19
%20 = OpConstant %18 3
%15 = OpTypeArray %17 %20
%21 = OpTypeVector %17 3
%11 = OpTypeStruct %21 %21 %21
%12 = OpTypeStruct %11
%22 = OpTypeFloat 64
%23 = OpTypeVector %22 3
%24 = OpTypeInt 32 1
%25 = OpTypeVector %24 4
%26 = OpTypeVector %17 4
%27 = OpTypeMatrix %26 3
%28 = OpTypeFunction %16 %14 %15 %12 %23 %25 %17 %27 %18
%29 = OpTypeFunction %16
%30 = OpTypePointer Input %14
%2 = OpVariable %30 Input
%31 = OpTypePointer Input %15
%3 = OpVariable %31 Input
%32 = OpTypePointer Input %12
%4 = OpVariable %32 Input
%33 = OpTypePointer Input %23
%5 = OpVariable %33 Input
%34 = OpTypePointer Input %25
%6 = OpVariable %34 Input
%35 = OpTypePointer Input %17
%7 = OpVariable %35 Input
%36 = OpTypePointer Input %27
%8 = OpVariable %36 Input
%37 = OpTypePointer Input %18
%9 = OpVariable %37 Input