diff --git a/CHANGELOG.md b/CHANGELOG.md index 951238009ae..187b89e43dd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -64,6 +64,7 @@ Naga now infers the correct binding layout when a resource appears only in an as - Apply necessary automatic conversions to the `value` argument of `textureStore`. By @jimblandy in [#7567](https://github.com/gfx-rs/wgpu/pull/7567). - Properly apply WGSL's automatic conversions to the arguments to texture sampling functions. By @jimblandy in [#7548](https://github.com/gfx-rs/wgpu/pull/7548). - Properly evaluate `abs(most negative abstract int)`. By @jimblandy in [#7507](https://github.com/gfx-rs/wgpu/pull/7507). +- Generate vectorized code for `[un]pack4x{I,U}8[Clamp]` on SPIR-V and MSL 2.1+. By @robamler in [#7664](https://github.com/gfx-rs/wgpu/pull/7664). #### DX12 diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index f05e5c233aa..2810884d299 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -1497,6 +1497,58 @@ impl Writer { Ok(()) } + /// Emit code for the WGSL functions `pack4x{I, U}8[Clamp]`. + fn put_pack4x8( + &mut self, + arg: Handle, + context: &ExpressionContext<'_>, + was_signed: bool, + clamp_bounds: Option<(&str, &str)>, + ) -> Result<(), Error> { + let write_arg = |this: &mut Self| -> BackendResult { + if let Some((min, max)) = clamp_bounds { + // Clamping with scalar bounds works (component-wise) even for packed_[u]char4. + write!(this.out, "{NAMESPACE}::clamp(")?; + this.put_expression(arg, context, true)?; + write!(this.out, ", {min}, {max})")?; + } else { + this.put_expression(arg, context, true)?; + } + Ok(()) + }; + + if context.lang_version >= (2, 1) { + let packed_type = if was_signed { + "packed_char4" + } else { + "packed_uchar4" + }; + // Metal uses little endian byte order, which matches what WGSL expects here. + write!(self.out, "as_type({packed_type}(")?; + write_arg(self)?; + write!(self.out, "))")?; + } else { + // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars. + if was_signed { + write!(self.out, "uint(")?; + } + write!(self.out, "(")?; + write_arg(self)?; + write!(self.out, "[0] & 0xFF) | ((")?; + write_arg(self)?; + write!(self.out, "[1] & 0xFF) << 8) | ((")?; + write_arg(self)?; + write!(self.out, "[2] & 0xFF) << 16) | ((")?; + write_arg(self)?; + write!(self.out, "[3] & 0xFF) << 24)")?; + if was_signed { + write!(self.out, ")")?; + } + } + + Ok(()) + } + /// Emit code for the isign expression. /// fn put_isign( @@ -2437,53 +2489,41 @@ impl Writer { write!(self.out, "{fun_name}")?; self.put_call_parameters(iter::once(arg), context)?; } - fun @ (Mf::Pack4xI8 | Mf::Pack4xU8 | Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp) => { - let was_signed = matches!(fun, Mf::Pack4xI8 | Mf::Pack4xI8Clamp); - let clamp_bounds = match fun { - Mf::Pack4xI8Clamp => Some(("-128", "127")), - Mf::Pack4xU8Clamp => Some(("0", "255")), - _ => None, - }; - if was_signed { - write!(self.out, "uint(")?; - } - let write_arg = |this: &mut Self| -> BackendResult { - if let Some((min, max)) = clamp_bounds { - write!(this.out, "{NAMESPACE}::clamp(")?; - this.put_expression(arg, context, true)?; - write!(this.out, ", {min}, {max})")?; - } else { - this.put_expression(arg, context, true)?; - } - Ok(()) - }; - write!(self.out, "(")?; - write_arg(self)?; - write!(self.out, "[0] & 0xFF) | ((")?; - write_arg(self)?; - write!(self.out, "[1] & 0xFF) << 8) | ((")?; - write_arg(self)?; - write!(self.out, "[2] & 0xFF) << 16) | ((")?; - write_arg(self)?; - write!(self.out, "[3] & 0xFF) << 24)")?; - if was_signed { - write!(self.out, ")")?; - } + Mf::Pack4xI8 => self.put_pack4x8(arg, context, true, None)?, + Mf::Pack4xU8 => self.put_pack4x8(arg, context, false, None)?, + Mf::Pack4xI8Clamp => { + self.put_pack4x8(arg, context, true, Some(("-128", "127")))? + } + Mf::Pack4xU8Clamp => { + self.put_pack4x8(arg, context, false, Some(("0", "255")))? } fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => { - write!(self.out, "(")?; - if matches!(fun, Mf::Unpack4xU8) { - write!(self.out, "u")?; + let sign_prefix = if matches!(fun, Mf::Unpack4xU8) { + "u" + } else { + "" + }; + + if context.lang_version >= (2, 1) { + // Metal uses little endian byte order, which matches what WGSL expects here. + write!( + self.out, + "{sign_prefix}int4(as_type(" + )?; + self.put_expression(arg, context, true)?; + write!(self.out, "))")?; + } else { + // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars. + write!(self.out, "({sign_prefix}int4(")?; + self.put_expression(arg, context, true)?; + write!(self.out, ", ")?; + self.put_expression(arg, context, true)?; + write!(self.out, " >> 8, ")?; + self.put_expression(arg, context, true)?; + write!(self.out, " >> 16, ")?; + self.put_expression(arg, context, true)?; + write!(self.out, " >> 24) << 24 >> 24)")?; } - write!(self.out, "int4(")?; - self.put_expression(arg, context, true)?; - write!(self.out, ", ")?; - self.put_expression(arg, context, true)?; - write!(self.out, " >> 8, ")?; - self.put_expression(arg, context, true)?; - write!(self.out, " >> 16, ")?; - self.put_expression(arg, context, true)?; - write!(self.out, " >> 24) << 24 >> 24)")?; } Mf::QuantizeToF16 => { match *context.resolve_type(arg) { @@ -3226,14 +3266,20 @@ impl Writer { self.need_bake_expressions.insert(arg); self.need_bake_expressions.insert(arg1.unwrap()); } - crate::MathFunction::FirstLeadingBit - | crate::MathFunction::Pack4xI8 + crate::MathFunction::FirstLeadingBit => { + self.need_bake_expressions.insert(arg); + } + crate::MathFunction::Pack4xI8 | crate::MathFunction::Pack4xU8 | crate::MathFunction::Pack4xI8Clamp | crate::MathFunction::Pack4xU8Clamp | crate::MathFunction::Unpack4xI8 | crate::MathFunction::Unpack4xU8 => { - self.need_bake_expressions.insert(arg); + // On MSL < 2.1, we emit a polyfill for these functions that uses the + // argument multiple times. This is no longer necessary on MSL >= 2.1. + if context.lang_version < (2, 1) { + self.need_bake_expressions.insert(arg); + } } crate::MathFunction::ExtractBits => { // Only argument 1 is re-used. diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 92edbcb05c4..8c1c8c4caa2 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -1552,105 +1552,29 @@ impl BlockContext<'_> { Mf::Pack2x16unorm => MathOp::Ext(spirv::GLOp::PackUnorm2x16), Mf::Pack2x16snorm => MathOp::Ext(spirv::GLOp::PackSnorm2x16), fun @ (Mf::Pack4xI8 | Mf::Pack4xU8 | Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp) => { - let (int_type, is_signed) = match fun { - Mf::Pack4xI8 | Mf::Pack4xI8Clamp => (crate::ScalarKind::Sint, true), - Mf::Pack4xU8 | Mf::Pack4xU8Clamp => (crate::ScalarKind::Uint, false), - _ => unreachable!(), - }; + let is_signed = matches!(fun, Mf::Pack4xI8 | Mf::Pack4xI8Clamp); let should_clamp = matches!(fun, Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp); - let uint_type_id = - self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32)); - - let int_type_id = - self.get_numeric_type_id(NumericType::Scalar(crate::Scalar { - kind: int_type, - width: 4, - })); - - let mut last_instruction = Instruction::new(spirv::Op::Nop); - - let zero = self.writer.get_constant_scalar(crate::Literal::U32(0)); - let mut preresult = zero; - block - .body - .reserve(usize::from(VEC_LENGTH) * (2 + usize::from(is_signed))); - - let eight = self.writer.get_constant_scalar(crate::Literal::U32(8)); - const VEC_LENGTH: u8 = 4; - for i in 0..u32::from(VEC_LENGTH) { - let offset = - self.writer.get_constant_scalar(crate::Literal::U32(i * 8)); - let mut extracted = self.gen_id(); - block.body.push(Instruction::binary( - spirv::Op::CompositeExtract, - int_type_id, - extracted, - arg0_id, - i, - )); - if is_signed { - let casted = self.gen_id(); - block.body.push(Instruction::unary( - spirv::Op::Bitcast, - uint_type_id, - casted, - extracted, - )); - extracted = casted; - } - if should_clamp { - let (min, max, clamp_op) = if is_signed { - ( - crate::Literal::I32(-128), - crate::Literal::I32(127), - spirv::GLOp::SClamp, - ) - } else { - ( - crate::Literal::U32(0), - crate::Literal::U32(255), - spirv::GLOp::UClamp, - ) - }; - let [min, max] = - [min, max].map(|lit| self.writer.get_constant_scalar(lit)); - let clamp_id = self.gen_id(); - block.body.push(Instruction::ext_inst( - self.writer.gl450_ext_inst_id, - clamp_op, - result_type_id, - clamp_id, - &[extracted, min, max], - )); - - extracted = clamp_id; - } - let is_last = i == u32::from(VEC_LENGTH - 1); - if is_last { - last_instruction = Instruction::quaternary( - spirv::Op::BitFieldInsert, + let last_instruction = + if self.writer.require_all(&[spirv::Capability::Int8]).is_ok() { + self.write_pack4x8_optimized( + block, result_type_id, + arg0_id, id, - preresult, - extracted, - offset, - eight, + is_signed, + should_clamp, ) } else { - let new_preresult = self.gen_id(); - block.body.push(Instruction::quaternary( - spirv::Op::BitFieldInsert, + self.write_pack4x8_polyfill( + block, result_type_id, - new_preresult, - preresult, - extracted, - offset, - eight, - )); - preresult = new_preresult; - } - } + arg0_id, + id, + is_signed, + should_clamp, + ) + }; MathOp::Custom(last_instruction) } @@ -1660,59 +1584,28 @@ impl BlockContext<'_> { Mf::Unpack2x16unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm2x16), Mf::Unpack2x16snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm2x16), fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => { - let (int_type, extract_op, is_signed) = match fun { - Mf::Unpack4xI8 => { - (crate::ScalarKind::Sint, spirv::Op::BitFieldSExtract, true) - } - Mf::Unpack4xU8 => { - (crate::ScalarKind::Uint, spirv::Op::BitFieldUExtract, false) - } - _ => unreachable!(), - }; + let is_signed = matches!(fun, Mf::Unpack4xI8); - let sint_type_id = - self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::I32)); - - let eight = self.writer.get_constant_scalar(crate::Literal::U32(8)); - let int_type_id = - self.get_numeric_type_id(NumericType::Scalar(crate::Scalar { - kind: int_type, - width: 4, - })); - block - .body - .reserve(usize::from(VEC_LENGTH) * 2 + usize::from(is_signed)); - let arg_id = if is_signed { - let new_arg_id = self.gen_id(); - block.body.push(Instruction::unary( - spirv::Op::Bitcast, - sint_type_id, - new_arg_id, - arg0_id, - )); - new_arg_id - } else { - arg0_id - }; - - const VEC_LENGTH: u8 = 4; - let parts: [_; VEC_LENGTH as usize] = - core::array::from_fn(|_| self.gen_id()); - for (i, part_id) in parts.into_iter().enumerate() { - let index = self - .writer - .get_constant_scalar(crate::Literal::U32(i as u32 * 8)); - block.body.push(Instruction::ternary( - extract_op, - int_type_id, - part_id, - arg_id, - index, - eight, - )); - } + let last_instruction = + if self.writer.require_all(&[spirv::Capability::Int8]).is_ok() { + self.write_unpack4x8_optimized( + block, + result_type_id, + arg0_id, + id, + is_signed, + ) + } else { + self.write_unpack4x8_polyfill( + block, + result_type_id, + arg0_id, + id, + is_signed, + ) + }; - MathOp::Custom(Instruction::composite_construct(result_type_id, id, &parts)) + MathOp::Custom(last_instruction) } }; @@ -2721,6 +2614,288 @@ impl BlockContext<'_> { } } + /// Emit code for `pack4x{I,U}8[Clamp]` if capability "Int8" is available. + fn write_pack4x8_optimized( + &mut self, + block: &mut Block, + result_type_id: u32, + arg0_id: u32, + id: u32, + is_signed: bool, + should_clamp: bool, + ) -> Instruction { + let int_type = if is_signed { + crate::ScalarKind::Sint + } else { + crate::ScalarKind::Uint + }; + let wide_vector_type = NumericType::Vector { + size: crate::VectorSize::Quad, + scalar: crate::Scalar { + kind: int_type, + width: 4, + }, + }; + let wide_vector_type_id = self.get_numeric_type_id(wide_vector_type); + let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector { + size: crate::VectorSize::Quad, + scalar: crate::Scalar { + kind: crate::ScalarKind::Uint, + width: 1, + }, + }); + + let mut wide_vector = arg0_id; + if should_clamp { + let (min, max, clamp_op) = if is_signed { + ( + crate::Literal::I32(-128), + crate::Literal::I32(127), + spirv::GLOp::SClamp, + ) + } else { + ( + crate::Literal::U32(0), + crate::Literal::U32(255), + spirv::GLOp::UClamp, + ) + }; + let [min, max] = [min, max].map(|lit| { + let scalar = self.writer.get_constant_scalar(lit); + self.writer.get_constant_composite( + LookupType::Local(LocalType::Numeric(wide_vector_type)), + &[scalar; 4], + ) + }); + + let clamp_id = self.gen_id(); + block.body.push(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + clamp_op, + wide_vector_type_id, + clamp_id, + &[wide_vector, min, max], + )); + + wide_vector = clamp_id; + } + + let packed_vector = self.gen_id(); + block.body.push(Instruction::unary( + spirv::Op::UConvert, // We truncate, so `UConvert` and `SConvert` behave identically. + packed_vector_type_id, + packed_vector, + wide_vector, + )); + + // The SPIR-V spec [1] defines the bit order for bit casting between a vector + // and a scalar precisely as required by the WGSL spec [2]. + // [1]: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast + // [2]: https://www.w3.org/TR/WGSL/#pack4xI8-builtin + Instruction::unary(spirv::Op::Bitcast, result_type_id, id, packed_vector) + } + + /// Emit code for `pack4x{I,U}8[Clamp]` if capability "Int8" is not available. + fn write_pack4x8_polyfill( + &mut self, + block: &mut Block, + result_type_id: u32, + arg0_id: u32, + id: u32, + is_signed: bool, + should_clamp: bool, + ) -> Instruction { + let int_type = if is_signed { + crate::ScalarKind::Sint + } else { + crate::ScalarKind::Uint + }; + let uint_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32)); + let int_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar { + kind: int_type, + width: 4, + })); + + let mut last_instruction = Instruction::new(spirv::Op::Nop); + + let zero = self.writer.get_constant_scalar(crate::Literal::U32(0)); + let mut preresult = zero; + block + .body + .reserve(usize::from(VEC_LENGTH) * (2 + usize::from(is_signed))); + + let eight = self.writer.get_constant_scalar(crate::Literal::U32(8)); + const VEC_LENGTH: u8 = 4; + for i in 0..u32::from(VEC_LENGTH) { + let offset = self.writer.get_constant_scalar(crate::Literal::U32(i * 8)); + let mut extracted = self.gen_id(); + block.body.push(Instruction::binary( + spirv::Op::CompositeExtract, + int_type_id, + extracted, + arg0_id, + i, + )); + if is_signed { + let casted = self.gen_id(); + block.body.push(Instruction::unary( + spirv::Op::Bitcast, + uint_type_id, + casted, + extracted, + )); + extracted = casted; + } + if should_clamp { + let (min, max, clamp_op) = if is_signed { + ( + crate::Literal::I32(-128), + crate::Literal::I32(127), + spirv::GLOp::SClamp, + ) + } else { + ( + crate::Literal::U32(0), + crate::Literal::U32(255), + spirv::GLOp::UClamp, + ) + }; + let [min, max] = [min, max].map(|lit| self.writer.get_constant_scalar(lit)); + + let clamp_id = self.gen_id(); + block.body.push(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + clamp_op, + result_type_id, + clamp_id, + &[extracted, min, max], + )); + + extracted = clamp_id; + } + let is_last = i == u32::from(VEC_LENGTH - 1); + if is_last { + last_instruction = Instruction::quaternary( + spirv::Op::BitFieldInsert, + result_type_id, + id, + preresult, + extracted, + offset, + eight, + ) + } else { + let new_preresult = self.gen_id(); + block.body.push(Instruction::quaternary( + spirv::Op::BitFieldInsert, + result_type_id, + new_preresult, + preresult, + extracted, + offset, + eight, + )); + preresult = new_preresult; + } + } + last_instruction + } + + /// Emit code for `unpack4x{I,U}8` if capability "Int8" is available. + fn write_unpack4x8_optimized( + &mut self, + block: &mut Block, + result_type_id: u32, + arg0_id: u32, + id: u32, + is_signed: bool, + ) -> Instruction { + let (int_type, convert_op) = if is_signed { + (crate::ScalarKind::Sint, spirv::Op::SConvert) + } else { + (crate::ScalarKind::Uint, spirv::Op::UConvert) + }; + + let packed_vector_type_id = self.get_numeric_type_id(NumericType::Vector { + size: crate::VectorSize::Quad, + scalar: crate::Scalar { + kind: int_type, + width: 1, + }, + }); + + // The SPIR-V spec [1] defines the bit order for bit casting between a vector + // and a scalar precisely as required by the WGSL spec [2]. + // [1]: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast + // [2]: https://www.w3.org/TR/WGSL/#pack4xI8-builtin + let packed_vector = self.gen_id(); + block.body.push(Instruction::unary( + spirv::Op::Bitcast, + packed_vector_type_id, + packed_vector, + arg0_id, + )); + + Instruction::unary(convert_op, result_type_id, id, packed_vector) + } + + /// Emit code for `unpack4x{I,U}8` if capability "Int8" is not available. + fn write_unpack4x8_polyfill( + &mut self, + block: &mut Block, + result_type_id: u32, + arg0_id: u32, + id: u32, + is_signed: bool, + ) -> Instruction { + let (int_type, extract_op) = if is_signed { + (crate::ScalarKind::Sint, spirv::Op::BitFieldSExtract) + } else { + (crate::ScalarKind::Uint, spirv::Op::BitFieldUExtract) + }; + + let sint_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::I32)); + + let eight = self.writer.get_constant_scalar(crate::Literal::U32(8)); + let int_type_id = self.get_numeric_type_id(NumericType::Scalar(crate::Scalar { + kind: int_type, + width: 4, + })); + block + .body + .reserve(usize::from(VEC_LENGTH) * 2 + usize::from(is_signed)); + let arg_id = if is_signed { + let new_arg_id = self.gen_id(); + block.body.push(Instruction::unary( + spirv::Op::Bitcast, + sint_type_id, + new_arg_id, + arg0_id, + )); + new_arg_id + } else { + arg0_id + }; + + const VEC_LENGTH: u8 = 4; + let parts: [_; VEC_LENGTH as usize] = core::array::from_fn(|_| self.gen_id()); + for (i, part_id) in parts.into_iter().enumerate() { + let index = self + .writer + .get_constant_scalar(crate::Literal::U32(i as u32 * 8)); + block.body.push(Instruction::ternary( + extract_op, + int_type_id, + part_id, + arg_id, + index, + eight, + )); + } + + Instruction::composite_construct(result_type_id, id, &parts) + } + /// Generate one or more SPIR-V blocks for `naga_block`. /// /// Use `label_id` as the label for the SPIR-V entry point block. diff --git a/naga/tests/in/wgsl/bits-optimized-msl.toml b/naga/tests/in/wgsl/bits-optimized-msl.toml new file mode 100644 index 00000000000..9409d2ac77c --- /dev/null +++ b/naga/tests/in/wgsl/bits-optimized-msl.toml @@ -0,0 +1,4 @@ +targets = "METAL" + +[msl] +lang_version = [2, 1] diff --git a/naga/tests/in/wgsl/bits-optimized-msl.wgsl b/naga/tests/in/wgsl/bits-optimized-msl.wgsl new file mode 100644 index 00000000000..a77266ad343 --- /dev/null +++ b/naga/tests/in/wgsl/bits-optimized-msl.wgsl @@ -0,0 +1,69 @@ +// Keep in sync with `bits_downlevel` and `bits_downlevel_webgl` + +@compute @workgroup_size(1) +fn main() { + var i = 0; + var i2 = vec2(0); + var i3 = vec3(0); + var i4 = vec4(0); + var u = 0u; + var u2 = vec2(0u); + var u3 = vec3(0u); + var u4 = vec4(0u); + var f2 = vec2(0.0); + var f4 = vec4(0.0); + u = pack4x8snorm(f4); + u = pack4x8unorm(f4); + u = pack2x16snorm(f2); + u = pack2x16unorm(f2); + u = pack2x16float(f2); + u = pack4xI8(i4); + u = pack4xU8(u4); + u = pack4xI8Clamp(i4); + u = pack4xU8Clamp(u4); + f4 = unpack4x8snorm(u); + f4 = unpack4x8unorm(u); + f2 = unpack2x16snorm(u); + f2 = unpack2x16unorm(u); + f2 = unpack2x16float(u); + i4 = unpack4xI8(u); + u4 = unpack4xU8(u); + i = insertBits(i, i, 5u, 10u); + i2 = insertBits(i2, i2, 5u, 10u); + i3 = insertBits(i3, i3, 5u, 10u); + i4 = insertBits(i4, i4, 5u, 10u); + u = insertBits(u, u, 5u, 10u); + u2 = insertBits(u2, u2, 5u, 10u); + u3 = insertBits(u3, u3, 5u, 10u); + u4 = insertBits(u4, u4, 5u, 10u); + i = extractBits(i, 5u, 10u); + i2 = extractBits(i2, 5u, 10u); + i3 = extractBits(i3, 5u, 10u); + i4 = extractBits(i4, 5u, 10u); + u = extractBits(u, 5u, 10u); + u2 = extractBits(u2, 5u, 10u); + u3 = extractBits(u3, 5u, 10u); + u4 = extractBits(u4, 5u, 10u); + i = firstTrailingBit(i); + u2 = firstTrailingBit(u2); + i3 = firstLeadingBit(i3); + u3 = firstLeadingBit(u3); + i = firstLeadingBit(i); + u = firstLeadingBit(u); + i = countOneBits(i); + i2 = countOneBits(i2); + i3 = countOneBits(i3); + i4 = countOneBits(i4); + u = countOneBits(u); + u2 = countOneBits(u2); + u3 = countOneBits(u3); + u4 = countOneBits(u4); + i = reverseBits(i); + i2 = reverseBits(i2); + i3 = reverseBits(i3); + i4 = reverseBits(i4); + u = reverseBits(u); + u2 = reverseBits(u2); + u3 = reverseBits(u3); + u4 = reverseBits(u4); +} diff --git a/naga/tests/out/msl/wgsl-bits-optimized-msl.msl b/naga/tests/out/msl/wgsl-bits-optimized-msl.msl new file mode 100644 index 00000000000..e33ed65f463 --- /dev/null +++ b/naga/tests/out/msl/wgsl-bits-optimized-msl.msl @@ -0,0 +1,137 @@ +// language: metal2.1 +#include +#include + +using metal::uint; + + +kernel void main_( +) { + int i = 0; + metal::int2 i2_ = metal::int2(0); + metal::int3 i3_ = metal::int3(0); + metal::int4 i4_ = metal::int4(0); + uint u = 0u; + metal::uint2 u2_ = metal::uint2(0u); + metal::uint3 u3_ = metal::uint3(0u); + metal::uint4 u4_ = metal::uint4(0u); + metal::float2 f2_ = metal::float2(0.0); + metal::float4 f4_ = metal::float4(0.0); + metal::float4 _e28 = f4_; + u = metal::pack_float_to_snorm4x8(_e28); + metal::float4 _e30 = f4_; + u = metal::pack_float_to_unorm4x8(_e30); + metal::float2 _e32 = f2_; + u = metal::pack_float_to_snorm2x16(_e32); + metal::float2 _e34 = f2_; + u = metal::pack_float_to_unorm2x16(_e34); + metal::float2 _e36 = f2_; + u = as_type(half2(_e36)); + metal::int4 _e38 = i4_; + u = as_type(packed_char4(_e38)); + metal::uint4 _e40 = u4_; + u = as_type(packed_uchar4(_e40)); + metal::int4 _e42 = i4_; + u = as_type(packed_char4(metal::clamp(_e42, -128, 127))); + metal::uint4 _e44 = u4_; + u = as_type(packed_uchar4(metal::clamp(_e44, 0, 255))); + uint _e46 = u; + f4_ = metal::unpack_snorm4x8_to_float(_e46); + uint _e48 = u; + f4_ = metal::unpack_unorm4x8_to_float(_e48); + uint _e50 = u; + f2_ = metal::unpack_snorm2x16_to_float(_e50); + uint _e52 = u; + f2_ = metal::unpack_unorm2x16_to_float(_e52); + uint _e54 = u; + f2_ = float2(as_type(_e54)); + uint _e56 = u; + i4_ = int4(as_type(_e56)); + uint _e58 = u; + u4_ = uint4(as_type(_e58)); + int _e60 = i; + int _e61 = i; + i = metal::insert_bits(_e60, _e61, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::int2 _e65 = i2_; + metal::int2 _e66 = i2_; + i2_ = metal::insert_bits(_e65, _e66, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::int3 _e70 = i3_; + metal::int3 _e71 = i3_; + i3_ = metal::insert_bits(_e70, _e71, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::int4 _e75 = i4_; + metal::int4 _e76 = i4_; + i4_ = metal::insert_bits(_e75, _e76, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + uint _e80 = u; + uint _e81 = u; + u = metal::insert_bits(_e80, _e81, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::uint2 _e85 = u2_; + metal::uint2 _e86 = u2_; + u2_ = metal::insert_bits(_e85, _e86, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::uint3 _e90 = u3_; + metal::uint3 _e91 = u3_; + u3_ = metal::insert_bits(_e90, _e91, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::uint4 _e95 = u4_; + metal::uint4 _e96 = u4_; + u4_ = metal::insert_bits(_e95, _e96, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + int _e100 = i; + i = metal::extract_bits(_e100, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::int2 _e104 = i2_; + i2_ = metal::extract_bits(_e104, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::int3 _e108 = i3_; + i3_ = metal::extract_bits(_e108, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::int4 _e112 = i4_; + i4_ = metal::extract_bits(_e112, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + uint _e116 = u; + u = metal::extract_bits(_e116, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::uint2 _e120 = u2_; + u2_ = metal::extract_bits(_e120, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::uint3 _e124 = u3_; + u3_ = metal::extract_bits(_e124, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::uint4 _e128 = u4_; + u4_ = metal::extract_bits(_e128, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + int _e132 = i; + i = (((metal::ctz(_e132) + 1) % 33) - 1); + metal::uint2 _e134 = u2_; + u2_ = (((metal::ctz(_e134) + 1) % 33) - 1); + metal::int3 _e136 = i3_; + i3_ = metal::select(31 - metal::clz(metal::select(_e136, ~_e136, _e136 < 0)), int3(-1), _e136 == 0 || _e136 == -1); + metal::uint3 _e138 = u3_; + u3_ = metal::select(31 - metal::clz(_e138), uint3(-1), _e138 == 0 || _e138 == -1); + int _e140 = i; + i = metal::select(31 - metal::clz(metal::select(_e140, ~_e140, _e140 < 0)), int(-1), _e140 == 0 || _e140 == -1); + uint _e142 = u; + u = metal::select(31 - metal::clz(_e142), uint(-1), _e142 == 0 || _e142 == -1); + int _e144 = i; + i = metal::popcount(_e144); + metal::int2 _e146 = i2_; + i2_ = metal::popcount(_e146); + metal::int3 _e148 = i3_; + i3_ = metal::popcount(_e148); + metal::int4 _e150 = i4_; + i4_ = metal::popcount(_e150); + uint _e152 = u; + u = metal::popcount(_e152); + metal::uint2 _e154 = u2_; + u2_ = metal::popcount(_e154); + metal::uint3 _e156 = u3_; + u3_ = metal::popcount(_e156); + metal::uint4 _e158 = u4_; + u4_ = metal::popcount(_e158); + int _e160 = i; + i = metal::reverse_bits(_e160); + metal::int2 _e162 = i2_; + i2_ = metal::reverse_bits(_e162); + metal::int3 _e164 = i3_; + i3_ = metal::reverse_bits(_e164); + metal::int4 _e166 = i4_; + i4_ = metal::reverse_bits(_e166); + uint _e168 = u; + u = metal::reverse_bits(_e168); + metal::uint2 _e170 = u2_; + u2_ = metal::reverse_bits(_e170); + metal::uint3 _e172 = u3_; + u3_ = metal::reverse_bits(_e172); + metal::uint4 _e174 = u4_; + u4_ = metal::reverse_bits(_e174); + return; +} diff --git a/naga/tests/out/spv/wgsl-6772-unpack-expr-accesses.spvasm b/naga/tests/out/spv/wgsl-6772-unpack-expr-accesses.spvasm index 973557789e3..eb3edc3b36f 100644 --- a/naga/tests/out/spv/wgsl-6772-unpack-expr-accesses.spvasm +++ b/naga/tests/out/spv/wgsl-6772-unpack-expr-accesses.spvasm @@ -1,8 +1,9 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 30 +; Bound: 23 OpCapability Shader +OpCapability Int8 %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 OpEntryPoint GLCompute %4 "main" @@ -14,27 +15,20 @@ OpExecutionMode %4 LocalSize 1 1 1 %8 = OpTypeInt 32 0 %9 = OpConstant %8 12 %11 = OpTypeVector %6 4 -%13 = OpConstant %8 8 -%19 = OpConstant %8 0 -%20 = OpConstant %8 16 -%21 = OpConstant %8 24 -%23 = OpTypeVector %8 4 +%14 = OpTypeInt 8 1 +%13 = OpTypeVector %14 4 +%17 = OpTypeVector %8 4 +%20 = OpTypeInt 8 0 +%19 = OpTypeVector %20 4 %4 = OpFunction %2 None %5 %3 = OpLabel OpBranch %10 %10 = OpLabel -%14 = OpBitcast %6 %9 -%15 = OpBitFieldSExtract %6 %14 %19 %13 -%16 = OpBitFieldSExtract %6 %14 %13 %13 -%17 = OpBitFieldSExtract %6 %14 %20 %13 -%18 = OpBitFieldSExtract %6 %14 %21 %13 -%12 = OpCompositeConstruct %11 %15 %16 %17 %18 -%22 = OpCompositeExtract %6 %12 2 -%25 = OpBitFieldUExtract %8 %9 %19 %13 -%26 = OpBitFieldUExtract %8 %9 %13 %13 -%27 = OpBitFieldUExtract %8 %9 %20 %13 -%28 = OpBitFieldUExtract %8 %9 %21 %13 -%24 = OpCompositeConstruct %23 %25 %26 %27 %28 -%29 = OpCompositeExtract %8 %24 1 +%15 = OpBitcast %13 %9 +%12 = OpSConvert %11 %15 +%16 = OpCompositeExtract %6 %12 2 +%21 = OpBitcast %19 %9 +%18 = OpUConvert %17 %21 +%22 = OpCompositeExtract %8 %18 1 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/spv/wgsl-bits.spvasm b/naga/tests/out/spv/wgsl-bits.spvasm index 76e221aea16..dec26768b3a 100644 --- a/naga/tests/out/spv/wgsl-bits.spvasm +++ b/naga/tests/out/spv/wgsl-bits.spvasm @@ -1,8 +1,9 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 275 +; Bound: 234 OpCapability Shader +OpCapability Int8 %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 OpEntryPoint GLCompute %15 "main" @@ -43,13 +44,17 @@ OpExecutionMode %15 LocalSize 1 1 1 %45 = OpTypePointer Function %10 %47 = OpTypePointer Function %11 %49 = OpTypePointer Function %13 -%63 = OpConstant %7 8 -%70 = OpConstant %7 16 -%74 = OpConstant %7 24 -%90 = OpConstant %3 -128 -%91 = OpConstant %3 127 -%108 = OpConstant %7 255 -%145 = OpConstant %7 32 +%64 = OpTypeInt 8 0 +%63 = OpTypeVector %64 4 +%71 = OpConstant %3 -128 +%72 = OpConstantComposite %6 %71 %71 %71 %71 +%73 = OpConstant %3 127 +%74 = OpConstantComposite %6 %73 %73 %73 %73 +%79 = OpConstant %7 255 +%80 = OpConstantComposite %10 %79 %79 %79 %79 +%96 = OpTypeInt 8 1 +%95 = OpTypeVector %96 4 +%104 = OpConstant %7 32 %15 = OpFunction %2 None %16 %14 = OpLabel %48 = OpVariable %49 Function %27 @@ -80,260 +85,215 @@ OpStore %38 %58 %60 = OpExtInst %7 %1 PackHalf2x16 %59 OpStore %38 %60 %61 = OpLoad %6 %36 -%64 = OpCompositeExtract %3 %61 0 -%65 = OpBitcast %7 %64 -%66 = OpBitFieldInsert %7 %21 %65 %21 %63 -%67 = OpCompositeExtract %3 %61 1 -%68 = OpBitcast %7 %67 -%69 = OpBitFieldInsert %7 %66 %68 %63 %63 -%71 = OpCompositeExtract %3 %61 2 -%72 = OpBitcast %7 %71 -%73 = OpBitFieldInsert %7 %69 %72 %70 %63 -%75 = OpCompositeExtract %3 %61 3 -%76 = OpBitcast %7 %75 -%62 = OpBitFieldInsert %7 %73 %76 %74 %63 +%65 = OpUConvert %63 %61 +%62 = OpBitcast %7 %65 OpStore %38 %62 +%66 = OpLoad %10 %44 +%68 = OpUConvert %63 %66 +%67 = OpBitcast %7 %68 +OpStore %38 %67 +%69 = OpLoad %6 %36 +%75 = OpExtInst %6 %1 SClamp %69 %72 %74 +%76 = OpUConvert %63 %75 +%70 = OpBitcast %7 %76 +OpStore %38 %70 %77 = OpLoad %10 %44 -%79 = OpCompositeExtract %7 %77 0 -%80 = OpBitFieldInsert %7 %21 %79 %21 %63 -%81 = OpCompositeExtract %7 %77 1 -%82 = OpBitFieldInsert %7 %80 %81 %63 %63 -%83 = OpCompositeExtract %7 %77 2 -%84 = OpBitFieldInsert %7 %82 %83 %70 %63 -%85 = OpCompositeExtract %7 %77 3 -%78 = OpBitFieldInsert %7 %84 %85 %74 %63 +%81 = OpExtInst %10 %1 UClamp %77 %24 %80 +%82 = OpUConvert %63 %81 +%78 = OpBitcast %7 %82 OpStore %38 %78 -%86 = OpLoad %6 %36 -%88 = OpCompositeExtract %3 %86 0 -%89 = OpBitcast %7 %88 -%92 = OpExtInst %7 %1 SClamp %89 %90 %91 -%93 = OpBitFieldInsert %7 %21 %92 %21 %63 -%94 = OpCompositeExtract %3 %86 1 -%95 = OpBitcast %7 %94 -%96 = OpExtInst %7 %1 SClamp %95 %90 %91 -%97 = OpBitFieldInsert %7 %93 %96 %63 %63 -%98 = OpCompositeExtract %3 %86 2 -%99 = OpBitcast %7 %98 -%100 = OpExtInst %7 %1 SClamp %99 %90 %91 -%101 = OpBitFieldInsert %7 %97 %100 %70 %63 -%102 = OpCompositeExtract %3 %86 3 -%103 = OpBitcast %7 %102 -%104 = OpExtInst %7 %1 SClamp %103 %90 %91 -%87 = OpBitFieldInsert %7 %101 %104 %74 %63 -OpStore %38 %87 -%105 = OpLoad %10 %44 -%107 = OpCompositeExtract %7 %105 0 -%109 = OpExtInst %7 %1 UClamp %107 %21 %108 -%110 = OpBitFieldInsert %7 %21 %109 %21 %63 -%111 = OpCompositeExtract %7 %105 1 -%112 = OpExtInst %7 %1 UClamp %111 %21 %108 -%113 = OpBitFieldInsert %7 %110 %112 %63 %63 -%114 = OpCompositeExtract %7 %105 2 -%115 = OpExtInst %7 %1 UClamp %114 %21 %108 -%116 = OpBitFieldInsert %7 %113 %115 %70 %63 -%117 = OpCompositeExtract %7 %105 3 -%118 = OpExtInst %7 %1 UClamp %117 %21 %108 -%106 = OpBitFieldInsert %7 %116 %118 %74 %63 -OpStore %38 %106 -%119 = OpLoad %7 %38 -%120 = OpExtInst %13 %1 UnpackSnorm4x8 %119 -OpStore %48 %120 -%121 = OpLoad %7 %38 -%122 = OpExtInst %13 %1 UnpackUnorm4x8 %121 -OpStore %48 %122 -%123 = OpLoad %7 %38 -%124 = OpExtInst %11 %1 UnpackSnorm2x16 %123 -OpStore %46 %124 -%125 = OpLoad %7 %38 -%126 = OpExtInst %11 %1 UnpackUnorm2x16 %125 -OpStore %46 %126 +%83 = OpLoad %7 %38 +%84 = OpExtInst %13 %1 UnpackSnorm4x8 %83 +OpStore %48 %84 +%85 = OpLoad %7 %38 +%86 = OpExtInst %13 %1 UnpackUnorm4x8 %85 +OpStore %48 %86 +%87 = OpLoad %7 %38 +%88 = OpExtInst %11 %1 UnpackSnorm2x16 %87 +OpStore %46 %88 +%89 = OpLoad %7 %38 +%90 = OpExtInst %11 %1 UnpackUnorm2x16 %89 +OpStore %46 %90 +%91 = OpLoad %7 %38 +%92 = OpExtInst %11 %1 UnpackHalf2x16 %91 +OpStore %46 %92 +%93 = OpLoad %7 %38 +%97 = OpBitcast %95 %93 +%94 = OpSConvert %6 %97 +OpStore %36 %94 +%98 = OpLoad %7 %38 +%100 = OpBitcast %63 %98 +%99 = OpUConvert %10 %100 +OpStore %44 %99 +%101 = OpLoad %3 %30 +%102 = OpLoad %3 %30 +%105 = OpExtInst %7 %1 UMin %28 %104 +%106 = OpISub %7 %104 %105 +%107 = OpExtInst %7 %1 UMin %29 %106 +%103 = OpBitFieldInsert %3 %101 %102 %105 %107 +OpStore %30 %103 +%108 = OpLoad %4 %32 +%109 = OpLoad %4 %32 +%111 = OpExtInst %7 %1 UMin %28 %104 +%112 = OpISub %7 %104 %111 +%113 = OpExtInst %7 %1 UMin %29 %112 +%110 = OpBitFieldInsert %4 %108 %109 %111 %113 +OpStore %32 %110 +%114 = OpLoad %5 %34 +%115 = OpLoad %5 %34 +%117 = OpExtInst %7 %1 UMin %28 %104 +%118 = OpISub %7 %104 %117 +%119 = OpExtInst %7 %1 UMin %29 %118 +%116 = OpBitFieldInsert %5 %114 %115 %117 %119 +OpStore %34 %116 +%120 = OpLoad %6 %36 +%121 = OpLoad %6 %36 +%123 = OpExtInst %7 %1 UMin %28 %104 +%124 = OpISub %7 %104 %123 +%125 = OpExtInst %7 %1 UMin %29 %124 +%122 = OpBitFieldInsert %6 %120 %121 %123 %125 +OpStore %36 %122 +%126 = OpLoad %7 %38 %127 = OpLoad %7 %38 -%128 = OpExtInst %11 %1 UnpackHalf2x16 %127 -OpStore %46 %128 -%129 = OpLoad %7 %38 -%131 = OpBitcast %3 %129 -%132 = OpBitFieldSExtract %3 %131 %21 %63 -%133 = OpBitFieldSExtract %3 %131 %63 %63 -%134 = OpBitFieldSExtract %3 %131 %70 %63 -%135 = OpBitFieldSExtract %3 %131 %74 %63 -%130 = OpCompositeConstruct %6 %132 %133 %134 %135 -OpStore %36 %130 -%136 = OpLoad %7 %38 -%138 = OpBitFieldUExtract %7 %136 %21 %63 -%139 = OpBitFieldUExtract %7 %136 %63 %63 -%140 = OpBitFieldUExtract %7 %136 %70 %63 -%141 = OpBitFieldUExtract %7 %136 %74 %63 -%137 = OpCompositeConstruct %10 %138 %139 %140 %141 -OpStore %44 %137 -%142 = OpLoad %3 %30 -%143 = OpLoad %3 %30 -%146 = OpExtInst %7 %1 UMin %28 %145 -%147 = OpISub %7 %145 %146 -%148 = OpExtInst %7 %1 UMin %29 %147 -%144 = OpBitFieldInsert %3 %142 %143 %146 %148 -OpStore %30 %144 -%149 = OpLoad %4 %32 -%150 = OpLoad %4 %32 -%152 = OpExtInst %7 %1 UMin %28 %145 -%153 = OpISub %7 %145 %152 +%129 = OpExtInst %7 %1 UMin %28 %104 +%130 = OpISub %7 %104 %129 +%131 = OpExtInst %7 %1 UMin %29 %130 +%128 = OpBitFieldInsert %7 %126 %127 %129 %131 +OpStore %38 %128 +%132 = OpLoad %8 %40 +%133 = OpLoad %8 %40 +%135 = OpExtInst %7 %1 UMin %28 %104 +%136 = OpISub %7 %104 %135 +%137 = OpExtInst %7 %1 UMin %29 %136 +%134 = OpBitFieldInsert %8 %132 %133 %135 %137 +OpStore %40 %134 +%138 = OpLoad %9 %42 +%139 = OpLoad %9 %42 +%141 = OpExtInst %7 %1 UMin %28 %104 +%142 = OpISub %7 %104 %141 +%143 = OpExtInst %7 %1 UMin %29 %142 +%140 = OpBitFieldInsert %9 %138 %139 %141 %143 +OpStore %42 %140 +%144 = OpLoad %10 %44 +%145 = OpLoad %10 %44 +%147 = OpExtInst %7 %1 UMin %28 %104 +%148 = OpISub %7 %104 %147 +%149 = OpExtInst %7 %1 UMin %29 %148 +%146 = OpBitFieldInsert %10 %144 %145 %147 %149 +OpStore %44 %146 +%150 = OpLoad %3 %30 +%152 = OpExtInst %7 %1 UMin %28 %104 +%153 = OpISub %7 %104 %152 %154 = OpExtInst %7 %1 UMin %29 %153 -%151 = OpBitFieldInsert %4 %149 %150 %152 %154 -OpStore %32 %151 -%155 = OpLoad %5 %34 -%156 = OpLoad %5 %34 -%158 = OpExtInst %7 %1 UMin %28 %145 -%159 = OpISub %7 %145 %158 -%160 = OpExtInst %7 %1 UMin %29 %159 -%157 = OpBitFieldInsert %5 %155 %156 %158 %160 -OpStore %34 %157 -%161 = OpLoad %6 %36 -%162 = OpLoad %6 %36 -%164 = OpExtInst %7 %1 UMin %28 %145 -%165 = OpISub %7 %145 %164 -%166 = OpExtInst %7 %1 UMin %29 %165 -%163 = OpBitFieldInsert %6 %161 %162 %164 %166 -OpStore %36 %163 -%167 = OpLoad %7 %38 -%168 = OpLoad %7 %38 -%170 = OpExtInst %7 %1 UMin %28 %145 -%171 = OpISub %7 %145 %170 -%172 = OpExtInst %7 %1 UMin %29 %171 -%169 = OpBitFieldInsert %7 %167 %168 %170 %172 -OpStore %38 %169 -%173 = OpLoad %8 %40 -%174 = OpLoad %8 %40 -%176 = OpExtInst %7 %1 UMin %28 %145 -%177 = OpISub %7 %145 %176 -%178 = OpExtInst %7 %1 UMin %29 %177 -%175 = OpBitFieldInsert %8 %173 %174 %176 %178 -OpStore %40 %175 -%179 = OpLoad %9 %42 +%151 = OpBitFieldSExtract %3 %150 %152 %154 +OpStore %30 %151 +%155 = OpLoad %4 %32 +%157 = OpExtInst %7 %1 UMin %28 %104 +%158 = OpISub %7 %104 %157 +%159 = OpExtInst %7 %1 UMin %29 %158 +%156 = OpBitFieldSExtract %4 %155 %157 %159 +OpStore %32 %156 +%160 = OpLoad %5 %34 +%162 = OpExtInst %7 %1 UMin %28 %104 +%163 = OpISub %7 %104 %162 +%164 = OpExtInst %7 %1 UMin %29 %163 +%161 = OpBitFieldSExtract %5 %160 %162 %164 +OpStore %34 %161 +%165 = OpLoad %6 %36 +%167 = OpExtInst %7 %1 UMin %28 %104 +%168 = OpISub %7 %104 %167 +%169 = OpExtInst %7 %1 UMin %29 %168 +%166 = OpBitFieldSExtract %6 %165 %167 %169 +OpStore %36 %166 +%170 = OpLoad %7 %38 +%172 = OpExtInst %7 %1 UMin %28 %104 +%173 = OpISub %7 %104 %172 +%174 = OpExtInst %7 %1 UMin %29 %173 +%171 = OpBitFieldUExtract %7 %170 %172 %174 +OpStore %38 %171 +%175 = OpLoad %8 %40 +%177 = OpExtInst %7 %1 UMin %28 %104 +%178 = OpISub %7 %104 %177 +%179 = OpExtInst %7 %1 UMin %29 %178 +%176 = OpBitFieldUExtract %8 %175 %177 %179 +OpStore %40 %176 %180 = OpLoad %9 %42 -%182 = OpExtInst %7 %1 UMin %28 %145 -%183 = OpISub %7 %145 %182 +%182 = OpExtInst %7 %1 UMin %28 %104 +%183 = OpISub %7 %104 %182 %184 = OpExtInst %7 %1 UMin %29 %183 -%181 = OpBitFieldInsert %9 %179 %180 %182 %184 +%181 = OpBitFieldUExtract %9 %180 %182 %184 OpStore %42 %181 %185 = OpLoad %10 %44 -%186 = OpLoad %10 %44 -%188 = OpExtInst %7 %1 UMin %28 %145 -%189 = OpISub %7 %145 %188 -%190 = OpExtInst %7 %1 UMin %29 %189 -%187 = OpBitFieldInsert %10 %185 %186 %188 %190 -OpStore %44 %187 -%191 = OpLoad %3 %30 -%193 = OpExtInst %7 %1 UMin %28 %145 -%194 = OpISub %7 %145 %193 -%195 = OpExtInst %7 %1 UMin %29 %194 -%192 = OpBitFieldSExtract %3 %191 %193 %195 -OpStore %30 %192 -%196 = OpLoad %4 %32 -%198 = OpExtInst %7 %1 UMin %28 %145 -%199 = OpISub %7 %145 %198 -%200 = OpExtInst %7 %1 UMin %29 %199 -%197 = OpBitFieldSExtract %4 %196 %198 %200 -OpStore %32 %197 -%201 = OpLoad %5 %34 -%203 = OpExtInst %7 %1 UMin %28 %145 -%204 = OpISub %7 %145 %203 -%205 = OpExtInst %7 %1 UMin %29 %204 -%202 = OpBitFieldSExtract %5 %201 %203 %205 -OpStore %34 %202 -%206 = OpLoad %6 %36 -%208 = OpExtInst %7 %1 UMin %28 %145 -%209 = OpISub %7 %145 %208 -%210 = OpExtInst %7 %1 UMin %29 %209 -%207 = OpBitFieldSExtract %6 %206 %208 %210 -OpStore %36 %207 -%211 = OpLoad %7 %38 -%213 = OpExtInst %7 %1 UMin %28 %145 -%214 = OpISub %7 %145 %213 -%215 = OpExtInst %7 %1 UMin %29 %214 -%212 = OpBitFieldUExtract %7 %211 %213 %215 -OpStore %38 %212 -%216 = OpLoad %8 %40 -%218 = OpExtInst %7 %1 UMin %28 %145 -%219 = OpISub %7 %145 %218 -%220 = OpExtInst %7 %1 UMin %29 %219 -%217 = OpBitFieldUExtract %8 %216 %218 %220 -OpStore %40 %217 -%221 = OpLoad %9 %42 -%223 = OpExtInst %7 %1 UMin %28 %145 -%224 = OpISub %7 %145 %223 -%225 = OpExtInst %7 %1 UMin %29 %224 -%222 = OpBitFieldUExtract %9 %221 %223 %225 -OpStore %42 %222 -%226 = OpLoad %10 %44 -%228 = OpExtInst %7 %1 UMin %28 %145 -%229 = OpISub %7 %145 %228 -%230 = OpExtInst %7 %1 UMin %29 %229 -%227 = OpBitFieldUExtract %10 %226 %228 %230 -OpStore %44 %227 -%231 = OpLoad %3 %30 -%232 = OpExtInst %3 %1 FindILsb %231 -OpStore %30 %232 -%233 = OpLoad %8 %40 -%234 = OpExtInst %8 %1 FindILsb %233 -OpStore %40 %234 -%235 = OpLoad %5 %34 -%236 = OpExtInst %5 %1 FindSMsb %235 -OpStore %34 %236 -%237 = OpLoad %9 %42 -%238 = OpExtInst %9 %1 FindUMsb %237 -OpStore %42 %238 -%239 = OpLoad %3 %30 -%240 = OpExtInst %3 %1 FindSMsb %239 -OpStore %30 %240 -%241 = OpLoad %7 %38 -%242 = OpExtInst %7 %1 FindUMsb %241 -OpStore %38 %242 -%243 = OpLoad %3 %30 -%244 = OpBitCount %3 %243 -OpStore %30 %244 -%245 = OpLoad %4 %32 -%246 = OpBitCount %4 %245 -OpStore %32 %246 -%247 = OpLoad %5 %34 -%248 = OpBitCount %5 %247 -OpStore %34 %248 -%249 = OpLoad %6 %36 -%250 = OpBitCount %6 %249 -OpStore %36 %250 -%251 = OpLoad %7 %38 -%252 = OpBitCount %7 %251 -OpStore %38 %252 -%253 = OpLoad %8 %40 -%254 = OpBitCount %8 %253 -OpStore %40 %254 -%255 = OpLoad %9 %42 -%256 = OpBitCount %9 %255 -OpStore %42 %256 -%257 = OpLoad %10 %44 -%258 = OpBitCount %10 %257 -OpStore %44 %258 -%259 = OpLoad %3 %30 -%260 = OpBitReverse %3 %259 -OpStore %30 %260 -%261 = OpLoad %4 %32 -%262 = OpBitReverse %4 %261 -OpStore %32 %262 -%263 = OpLoad %5 %34 -%264 = OpBitReverse %5 %263 -OpStore %34 %264 -%265 = OpLoad %6 %36 -%266 = OpBitReverse %6 %265 -OpStore %36 %266 -%267 = OpLoad %7 %38 -%268 = OpBitReverse %7 %267 -OpStore %38 %268 -%269 = OpLoad %8 %40 -%270 = OpBitReverse %8 %269 -OpStore %40 %270 -%271 = OpLoad %9 %42 -%272 = OpBitReverse %9 %271 -OpStore %42 %272 -%273 = OpLoad %10 %44 -%274 = OpBitReverse %10 %273 -OpStore %44 %274 +%187 = OpExtInst %7 %1 UMin %28 %104 +%188 = OpISub %7 %104 %187 +%189 = OpExtInst %7 %1 UMin %29 %188 +%186 = OpBitFieldUExtract %10 %185 %187 %189 +OpStore %44 %186 +%190 = OpLoad %3 %30 +%191 = OpExtInst %3 %1 FindILsb %190 +OpStore %30 %191 +%192 = OpLoad %8 %40 +%193 = OpExtInst %8 %1 FindILsb %192 +OpStore %40 %193 +%194 = OpLoad %5 %34 +%195 = OpExtInst %5 %1 FindSMsb %194 +OpStore %34 %195 +%196 = OpLoad %9 %42 +%197 = OpExtInst %9 %1 FindUMsb %196 +OpStore %42 %197 +%198 = OpLoad %3 %30 +%199 = OpExtInst %3 %1 FindSMsb %198 +OpStore %30 %199 +%200 = OpLoad %7 %38 +%201 = OpExtInst %7 %1 FindUMsb %200 +OpStore %38 %201 +%202 = OpLoad %3 %30 +%203 = OpBitCount %3 %202 +OpStore %30 %203 +%204 = OpLoad %4 %32 +%205 = OpBitCount %4 %204 +OpStore %32 %205 +%206 = OpLoad %5 %34 +%207 = OpBitCount %5 %206 +OpStore %34 %207 +%208 = OpLoad %6 %36 +%209 = OpBitCount %6 %208 +OpStore %36 %209 +%210 = OpLoad %7 %38 +%211 = OpBitCount %7 %210 +OpStore %38 %211 +%212 = OpLoad %8 %40 +%213 = OpBitCount %8 %212 +OpStore %40 %213 +%214 = OpLoad %9 %42 +%215 = OpBitCount %9 %214 +OpStore %42 %215 +%216 = OpLoad %10 %44 +%217 = OpBitCount %10 %216 +OpStore %44 %217 +%218 = OpLoad %3 %30 +%219 = OpBitReverse %3 %218 +OpStore %30 %219 +%220 = OpLoad %4 %32 +%221 = OpBitReverse %4 %220 +OpStore %32 %221 +%222 = OpLoad %5 %34 +%223 = OpBitReverse %5 %222 +OpStore %34 %223 +%224 = OpLoad %6 %36 +%225 = OpBitReverse %6 %224 +OpStore %36 %225 +%226 = OpLoad %7 %38 +%227 = OpBitReverse %7 %226 +OpStore %38 %227 +%228 = OpLoad %8 %40 +%229 = OpBitReverse %8 %228 +OpStore %40 %229 +%230 = OpLoad %9 %42 +%231 = OpBitReverse %9 %230 +OpStore %42 %231 +%232 = OpLoad %10 %44 +%233 = OpBitReverse %10 %232 +OpStore %44 %233 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/wgpu-hal/src/vulkan/adapter.rs b/wgpu-hal/src/vulkan/adapter.rs index e9ae1e597a2..7aaabde01d1 100644 --- a/wgpu-hal/src/vulkan/adapter.rs +++ b/wgpu-hal/src/vulkan/adapter.rs @@ -62,13 +62,11 @@ pub struct PhysicalDeviceFeatures { /// Features provided by `VK_EXT_texture_compression_astc_hdr`, promoted to Vulkan 1.3. astc_hdr: Option>, - /// Features provided by `VK_KHR_shader_float16_int8` (promoted to Vulkan - /// 1.2) and `VK_KHR_16bit_storage` (promoted to Vulkan 1.1). We use these - /// features together, or not at all. - shader_float16: Option<( - vk::PhysicalDeviceShaderFloat16Int8Features<'static>, - vk::PhysicalDevice16BitStorageFeatures<'static>, - )>, + /// Features provided by `VK_KHR_shader_float16_int8`, promoted to Vulkan 1.2 + shader_float16_int8: Option>, + + /// Features provided by `VK_KHR_16bit_storage`, promoted to Vulkan 1.1 + _16bit_storage: Option>, /// Features provided by `VK_KHR_acceleration_structure`. acceleration_structure: Option>, @@ -154,9 +152,11 @@ impl PhysicalDeviceFeatures { if let Some(ref mut feature) = self.astc_hdr { info = info.push_next(feature); } - if let Some((ref mut f16_i8_feature, ref mut _16bit_feature)) = self.shader_float16 { - info = info.push_next(f16_i8_feature); - info = info.push_next(_16bit_feature); + if let Some(ref mut feature) = self.shader_float16_int8 { + info = info.push_next(feature); + } + if let Some(ref mut feature) = self._16bit_storage { + info = info.push_next(feature); } if let Some(ref mut feature) = self.zero_initialize_workgroup_memory { info = info.push_next(feature); @@ -386,14 +386,21 @@ impl PhysicalDeviceFeatures { } else { None }, - shader_float16: if requested_features.contains(wgt::Features::SHADER_F16) { - Some(( - vk::PhysicalDeviceShaderFloat16Int8Features::default().shader_float16(true), + shader_float16_int8: match requested_features.contains(wgt::Features::SHADER_F16) { + shader_float16 if shader_float16 || private_caps.shader_int8 => Some( + vk::PhysicalDeviceShaderFloat16Int8Features::default() + .shader_float16(shader_float16) + .shader_int8(private_caps.shader_int8), + ), + _ => None, + }, + _16bit_storage: if requested_features.contains(wgt::Features::SHADER_F16) { + Some( vk::PhysicalDevice16BitStorageFeatures::default() .storage_buffer16_bit_access(true) .storage_input_output16(true) .uniform_and_storage_buffer16_bit_access(true), - )) + ) } else { None }, @@ -724,7 +731,8 @@ impl PhysicalDeviceFeatures { ); } - if let Some((ref f16_i8, ref bit16)) = self.shader_float16 { + if let (Some(ref f16_i8), Some(ref bit16)) = (self.shader_float16_int8, self._16bit_storage) + { features.set( F::SHADER_F16, f16_i8.shader_float16 != 0 @@ -976,6 +984,15 @@ impl PhysicalDeviceProperties { if requested_features.contains(wgt::Features::TEXTURE_FORMAT_NV12) { extensions.push(khr::sampler_ycbcr_conversion::NAME); } + + // Require `VK_KHR_16bit_storage` if the feature `SHADER_F16` was requested + if requested_features.contains(wgt::Features::SHADER_F16) { + // - Feature `SHADER_F16` also requires `VK_KHR_shader_float16_int8`, but we always + // require that anyway (if it is available) below. + // - `VK_KHR_16bit_storage` requires `VK_KHR_storage_buffer_storage_class`, however + // we require that one already. + extensions.push(khr::_16bit_storage::NAME); + } } if self.device_api_version < vk::API_VERSION_1_2 { @@ -999,13 +1016,13 @@ impl PhysicalDeviceProperties { extensions.push(ext::descriptor_indexing::NAME); } - // Require `VK_KHR_shader_float16_int8` and `VK_KHR_16bit_storage` if the associated feature was requested - if requested_features.contains(wgt::Features::SHADER_F16) { + // Always require `VK_KHR_shader_float16_int8` if available as it enables + // Int8 optimizations. Also require it even if it's not available but + // requested so that we get a corresponding error message. + if requested_features.contains(wgt::Features::SHADER_F16) + || self.supports_extension(khr::shader_float16_int8::NAME) + { extensions.push(khr::shader_float16_int8::NAME); - // `VK_KHR_16bit_storage` requires `VK_KHR_storage_buffer_storage_class`, however we require that one already - if self.device_api_version < vk::API_VERSION_1_1 { - extensions.push(khr::_16bit_storage::NAME); - } } if requested_features.intersects(wgt::Features::EXPERIMENTAL_MESH_SHADER) { @@ -1474,15 +1491,22 @@ impl super::InstanceShared { .insert(vk::PhysicalDeviceTextureCompressionASTCHDRFeaturesEXT::default()); features2 = features2.push_next(next); } - if capabilities.supports_extension(khr::shader_float16_int8::NAME) - && capabilities.supports_extension(khr::_16bit_storage::NAME) + + // `VK_KHR_shader_float16_int8` is promoted to 1.2 + if capabilities.device_api_version >= vk::API_VERSION_1_2 + || capabilities.supports_extension(khr::shader_float16_int8::NAME) { - let next = features.shader_float16.insert(( - vk::PhysicalDeviceShaderFloat16Int8FeaturesKHR::default(), - vk::PhysicalDevice16BitStorageFeaturesKHR::default(), - )); - features2 = features2.push_next(&mut next.0); - features2 = features2.push_next(&mut next.1); + let next = features + .shader_float16_int8 + .insert(vk::PhysicalDeviceShaderFloat16Int8FeaturesKHR::default()); + features2 = features2.push_next(next); + } + + if capabilities.supports_extension(khr::_16bit_storage::NAME) { + let next = features + ._16bit_storage + .insert(vk::PhysicalDevice16BitStorageFeaturesKHR::default()); + features2 = features2.push_next(next); } if capabilities.supports_extension(khr::acceleration_structure::NAME) { let next = features @@ -1721,6 +1745,9 @@ impl super::Instance { shader_integer_dot_product: phd_features .shader_integer_dot_product .is_some_and(|ext| ext.shader_integer_dot_product != 0), + shader_int8: phd_features + .shader_float16_int8 + .is_some_and(|features| features.shader_int8 != 0), }; let capabilities = crate::Capabilities { limits: phd_capabilities.to_wgpu_limits(), @@ -2022,6 +2049,10 @@ impl super::Adapter { spv::Capability::DotProductKHR, ]); } + if self.private_caps.shader_int8 { + // See . + capabilities.extend(&[spv::Capability::Int8]); + } spv::Options { lang_version: match self.phd_capabilities.device_api_version { // Use maximum supported SPIR-V version according to diff --git a/wgpu-hal/src/vulkan/mod.rs b/wgpu-hal/src/vulkan/mod.rs index b492f339878..47c91e1d1b6 100644 --- a/wgpu-hal/src/vulkan/mod.rs +++ b/wgpu-hal/src/vulkan/mod.rs @@ -536,6 +536,21 @@ struct PrivateCapabilities { /// /// [`VK_KHR_shader_integer_dot_product`]: https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_KHR_shader_integer_dot_product.html shader_integer_dot_product: bool, + + /// True if this adapter supports 8-bit integers provided by the + /// [`VK_KHR_shader_float16_int8`] extension (promoted to Vulkan 1.2). + /// + /// Allows shaders to declare the "Int8" capability. Note, however, that this + /// feature alone allows the use of 8-bit integers "only in the `Private`, + /// `Workgroup` (for non-Block variables), and `Function` storage classes" + /// ([see spec]). To use 8-bit integers in the interface storage classes (e.g., + /// `StorageBuffer`), you also need to enable the corresponding feature in + /// `VkPhysicalDevice8BitStorageFeatures` and declare the corresponding SPIR-V + /// capability (e.g., `StorageBuffer8BitAccess`). + /// + /// [`VK_KHR_shader_float16_int8`]: https://registry.khronos.org/vulkan/specs/latest/man/html/VK_KHR_shader_float16_int8.html + /// [see spec]: https://registry.khronos.org/vulkan/specs/latest/man/html/VkPhysicalDeviceShaderFloat16Int8Features.html#extension-features-shaderInt8 + shader_int8: bool, } bitflags::bitflags!(