From 54a3ae0d41da339932078092882172b125f7e971 Mon Sep 17 00:00:00 2001 From: valaphee <32491319+valaphee@users.noreply.github.com> Date: Thu, 9 May 2024 12:19:32 +0200 Subject: [PATCH 01/10] Rudimentary impl of quad ops, impl quad ops for spirv --- naga/src/back/dot/mod.rs | 17 ++++- naga/src/back/glsl/mod.rs | 28 ++++++++- naga/src/back/hlsl/writer.rs | 2 + naga/src/back/msl/writer.rs | 9 ++- naga/src/back/pipeline_constants.rs | 11 +++- naga/src/back/spv/block.rs | 34 ++++++++++ naga/src/back/spv/instructions.rs | 32 ++++++++++ naga/src/back/spv/subgroup.rs | 14 +++-- naga/src/back/wgsl/writer.rs | 7 ++- naga/src/compact/statements.rs | 16 ++++- naga/src/front/spv/mod.rs | 53 +++++++++++++++- naga/src/front/wgsl/lower/mod.rs | 70 ++++++++++++++++++++- naga/src/ir/mod.rs | 22 +++++++ naga/src/proc/terminator.rs | 1 + naga/src/valid/analyzer.rs | 11 +++- naga/src/valid/function.rs | 4 +- naga/src/valid/handles.rs | 12 +++- naga/src/valid/mod.rs | 5 +- naga/tests/in/wgsl/subgroup-operations.wgsl | 5 ++ 19 files changed, 331 insertions(+), 22 deletions(-) diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index dd8246f90d1..b3ca50e6b53 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -379,7 +379,8 @@ impl StatementGraph { | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) - | crate::GatherMode::ShuffleXor(index) => { + | crate::GatherMode::ShuffleXor(index) + | crate::GatherMode::QuadBroadcast(index) => { self.dependencies.push((id, index, "index")) } } @@ -392,6 +393,20 @@ impl StatementGraph { crate::GatherMode::ShuffleDown(_) => "SubgroupShuffleDown", crate::GatherMode::ShuffleUp(_) => "SubgroupShuffleUp", crate::GatherMode::ShuffleXor(_) => "SubgroupShuffleXor", + crate::GatherMode::QuadBroadcast(_) => "SubgroupQuadBroadcast" + } + } + S::SubgroupQuadSwap { + direction, + argument, + result + } => { + self.dependencies.push((id, argument, "arg")); + self.emits.push((id, result)); + match direction { + crate::Direction::X => "SubgroupQuadSwapX", + crate::Direction::Y => "SubgroupQuadSwapY", + crate::Direction::Diagonal => "SubgroupQuadSwapDiagonal", } } }; diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 52a47487ea2..f6701f61d12 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -2716,6 +2716,9 @@ impl<'a, W: Write> Writer<'a, W> { crate::GatherMode::ShuffleXor(_) => { write!(self.out, "subgroupShuffleXor(")?; } + crate::GatherMode::QuadBroadcast(_) => { + write!(self.out, "subgroupQuadBroadcast(")?; + } } self.write_expr(argument, ctx)?; match mode { @@ -2724,13 +2727,36 @@ impl<'a, W: Write> Writer<'a, W> { | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) - | crate::GatherMode::ShuffleXor(index) => { + | crate::GatherMode::ShuffleXor(index) + | crate::GatherMode::QuadBroadcast(index) => { write!(self.out, ", ")?; self.write_expr(index, ctx)?; } } writeln!(self.out, ");")?; } + Statement::SubgroupQuadSwap { direction, argument, result } => { + write!(self.out, "{level}")?; + let res_name = Baked(result).to_string(); + let res_ty = ctx.info[result].ty.inner_with(&self.module.types); + self.write_value_type(res_ty)?; + write!(self.out, " {res_name} = ")?; + self.named_expressions.insert(result, res_name); + + match direction { + crate::Direction::X => { + write!(self.out, "subgroupQuadSwapHorizontal(")?; + } + crate::Direction::Y => { + write!(self.out, "subgroupQuadSwapVertical(")?; + } + crate::Direction::Diagonal => { + write!(self.out, "subgroupQuadSwapDiagonal(")?; + } + } + self.write_expr(argument, ctx)?; + writeln!(self.out, ");")?; + } } Ok(()) diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index bb90db78593..2d634b010b8 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2635,10 +2635,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, "WaveGetLaneIndex() ^ ")?; self.write_expr(module, index, func_ctx)?; } + crate::GatherMode::QuadBroadcast(_) => unreachable!() } } writeln!(self.out, ");")?; } + Statement::SubgroupQuadSwap { direction, argument, result } => {} } Ok(()) diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index f05e5c233aa..5d4b6c2c962 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -3931,6 +3931,9 @@ impl Writer { crate::GatherMode::ShuffleXor(_) => { write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?; } + crate::GatherMode::QuadBroadcast(_) => { + write!(self.out, "{NAMESPACE}::quad_broadcast(")?; + } } self.put_expression(argument, &context.expression, true)?; match mode { @@ -3939,13 +3942,15 @@ impl Writer { | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) - | crate::GatherMode::ShuffleXor(index) => { + | crate::GatherMode::ShuffleXor(index) + | crate::GatherMode::QuadBroadcast(index) => { write!(self.out, ", ")?; self.put_expression(index, &context.expression, true)?; } } writeln!(self.out, ");")?; - } + }, + crate::Statement::SubgroupQuadSwap { direction, argument, result } => {} } } diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 1cf1c805249..86390a4fed9 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -759,13 +759,22 @@ fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut S | crate::GatherMode::Shuffle(ref mut index) | crate::GatherMode::ShuffleDown(ref mut index) | crate::GatherMode::ShuffleUp(ref mut index) - | crate::GatherMode::ShuffleXor(ref mut index) => { + | crate::GatherMode::ShuffleXor(ref mut index) + | crate::GatherMode::QuadBroadcast(ref mut index) => { adjust(index); } } adjust(argument); adjust(result) } + Statement::SubgroupQuadSwap { + ref mut argument, + ref mut result, + .. + } => { + adjust(argument); + adjust(result); + } Statement::Call { ref mut arguments, ref mut result, diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 92edbcb05c4..537e40f4f67 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -3467,6 +3467,40 @@ impl BlockContext<'_> { } => { self.write_subgroup_gather(mode, argument, result, &mut block)?; } + Statement::SubgroupQuadSwap { + ref direction, + argument, + result + } => { + self.writer.require_any( + "GroupNonUniformQuad", + &[spirv::Capability::GroupNonUniformQuad], + )?; + + let id = self.gen_id(); + let result_ty = &self.fun_info[result].ty; + let result_type_id = self.get_expression_type_id(result_ty); + + let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + + let arg_id = self.cached[argument]; + + let direction = self.get_index_constant(match direction { + crate::Direction::X => 0, + crate::Direction::Y => 1, + crate::Direction::Diagonal => 2, + }); + + block + .body + .push(Instruction::group_non_uniform_quad_swap( + result_type_id, + id, + exec_scope_id, + arg_id, + direction + )); + } } } diff --git a/naga/src/back/spv/instructions.rs b/naga/src/back/spv/instructions.rs index a34790614e7..f6f25693cee 100644 --- a/naga/src/back/spv/instructions.rs +++ b/naga/src/back/spv/instructions.rs @@ -1203,6 +1203,38 @@ impl super::Instruction { } instruction.add_operand(value); + instruction + } + pub(super) fn group_non_uniform_quad_broadcast( + result_type_id: Word, + id: Word, + exec_scope_id: Word, + value: Word, + index: Word, + ) -> Self { + let mut instruction = Self::new(Op::GroupNonUniformQuadBroadcast); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + instruction.add_operand(value); + instruction.add_operand(index); + + instruction + } + pub(super) fn group_non_uniform_quad_swap( + result_type_id: Word, + id: Word, + exec_scope_id: Word, + value: Word, + direction: Word, + ) -> Self { + let mut instruction = Self::new(Op::GroupNonUniformQuadSwap); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + instruction.add_operand(value); + instruction.add_operand(direction); + instruction } } diff --git a/naga/src/back/spv/subgroup.rs b/naga/src/back/spv/subgroup.rs index 6f4dc9aba6f..dba7072532e 100644 --- a/naga/src/back/spv/subgroup.rs +++ b/naga/src/back/spv/subgroup.rs @@ -125,10 +125,6 @@ impl BlockContext<'_> { result: Handle, block: &mut Block, ) -> Result<(), Error> { - self.writer.require_any( - "GroupNonUniformBallot", - &[spirv::Capability::GroupNonUniformBallot], - )?; match *mode { crate::GatherMode::BroadcastFirst => { self.writer.require_any( @@ -150,6 +146,12 @@ impl BlockContext<'_> { &[spirv::Capability::GroupNonUniformShuffleRelative], )?; } + crate::GatherMode::QuadBroadcast(_) => { + self.writer.require_any( + "GroupNonUniformQuad", + &[spirv::Capability::GroupNonUniformQuad], + )?; + } } let id = self.gen_id(); @@ -174,7 +176,8 @@ impl BlockContext<'_> { | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) - | crate::GatherMode::ShuffleXor(index) => { + | crate::GatherMode::ShuffleXor(index) + | crate::GatherMode::QuadBroadcast(index) => { let index_id = self.cached[index]; let op = match *mode { crate::GatherMode::BroadcastFirst => unreachable!(), @@ -187,6 +190,7 @@ impl BlockContext<'_> { crate::GatherMode::ShuffleDown(_) => spirv::Op::GroupNonUniformShuffleDown, crate::GatherMode::ShuffleUp(_) => spirv::Op::GroupNonUniformShuffleUp, crate::GatherMode::ShuffleXor(_) => spirv::Op::GroupNonUniformShuffleXor, + crate::GatherMode::QuadBroadcast(_) => spirv::Op::GroupNonUniformQuadBroadcast, }; block.body.push(Instruction::group_non_uniform_gather( op, diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index a63217098ed..2401d2807f6 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -945,6 +945,9 @@ impl Writer { crate::GatherMode::ShuffleXor(_) => { write!(self.out, "subgroupShuffleXor(")?; } + crate::GatherMode::QuadBroadcast(_) => { + write!(self.out, "quadBroadcast(")?; + } } self.write_expr(module, argument, func_ctx)?; match mode { @@ -953,13 +956,15 @@ impl Writer { | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) - | crate::GatherMode::ShuffleXor(index) => { + | crate::GatherMode::ShuffleXor(index) + | crate::GatherMode::QuadBroadcast(index) => { write!(self.out, ", ")?; self.write_expr(module, index, func_ctx)?; } } writeln!(self.out, ");")?; } + Statement::SubgroupQuadSwap { direction, argument, result } => {} } Ok(()) diff --git a/naga/src/compact/statements.rs b/naga/src/compact/statements.rs index 32a8130fd32..fe201039e96 100644 --- a/naga/src/compact/statements.rs +++ b/naga/src/compact/statements.rs @@ -141,13 +141,18 @@ impl FunctionTracer<'_> { | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) - | crate::GatherMode::ShuffleXor(index) => { + | crate::GatherMode::ShuffleXor(index) + | crate::GatherMode::QuadBroadcast(index) => { self.expressions_used.insert(index); } } self.expressions_used.insert(argument); self.expressions_used.insert(result); - } + }, + St::SubgroupQuadSwap { direction, argument, result } => { + self.expressions_used.insert(argument); + self.expressions_used.insert(result); + }, // Trivial statements. St::Break @@ -350,11 +355,16 @@ impl FunctionMap { | crate::GatherMode::Shuffle(ref mut index) | crate::GatherMode::ShuffleDown(ref mut index) | crate::GatherMode::ShuffleUp(ref mut index) - | crate::GatherMode::ShuffleXor(ref mut index) => adjust(index), + | crate::GatherMode::ShuffleXor(ref mut index) + | crate::GatherMode::QuadBroadcast(ref mut index) => adjust(index), } adjust(argument); adjust(result); } + St::SubgroupQuadSwap { direction: _, ref mut argument, ref mut result } => { + adjust(argument); + adjust(result); + } // Trivial statements. St::Break diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index 3b1b8fe5cac..f1a8acc98ad 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -4064,7 +4064,8 @@ impl> Frontend { | Op::GroupNonUniformShuffle | Op::GroupNonUniformShuffleDown | Op::GroupNonUniformShuffleUp - | Op::GroupNonUniformShuffleXor => { + | Op::GroupNonUniformShuffleXor + | Op::GroupNonUniformQuadBroadcast => { inst.expect(if matches!(inst.op, Op::GroupNonUniformBroadcastFirst) { 5 } else { @@ -4104,6 +4105,9 @@ impl> Frontend { Op::GroupNonUniformShuffleXor => { crate::GatherMode::ShuffleXor(index_handle) } + Op::GroupNonUniformQuadBroadcast => { + crate::GatherMode::QuadBroadcast(index_handle) + } _ => unreachable!(), } }; @@ -4135,6 +4139,50 @@ impl> Frontend { ); emitter.start(ctx.expressions); } + Op::GroupNonUniformQuadSwap => { + inst.expect(6)?; + block.extend(emitter.finish(ctx.expressions)); + let result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let argument_id = self.next()?; + let direction = self.next()?; + + let argument_lookup = self.lookup_expression.lookup(argument_id)?; + let argument_handle = get_expr_handle!(argument_id, argument_lookup); + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; + + let result_type = self.lookup_type.lookup(result_type_id)?; + + let result_handle = ctx.expressions.append( + crate::Expression::SubgroupOperationResult { + ty: result_type.handle, + }, + span, + ); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: result_handle, + type_id: result_type_id, + block_id, + }, + ); + + block.push( + crate::Statement::SubgroupQuadSwap { + direction: crate::Direction::X, + result: result_handle, + argument: argument_handle, + }, + span, + ); + emitter.start(ctx.expressions); + } Op::AtomicLoad => { inst.expect(6)?; let start = self.data_offset; @@ -4516,7 +4564,8 @@ impl> Frontend { | S::RayQuery { .. } | S::SubgroupBallot { .. } | S::SubgroupCollectiveOperation { .. } - | S::SubgroupGather { .. } => {} + | S::SubgroupGather { .. } + | S::SubgroupQuadSwap { .. } => {} S::Call { function: ref mut callee, ref arguments, diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 93ccb7143ca..96369d13935 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -1078,6 +1078,7 @@ enum SubgroupGather { ShuffleDown, ShuffleUp, ShuffleXor, + QuadBroadcast, } impl SubgroupGather { @@ -1089,6 +1090,7 @@ impl SubgroupGather { "subgroupShuffleDown" => Self::ShuffleDown, "subgroupShuffleUp" => Self::ShuffleUp, "subgroupShuffleXor" => Self::ShuffleXor, + "quadBroadcast" => Self::QuadBroadcast, _ => return None, }) } @@ -2940,6 +2942,71 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .push(ir::Statement::SubgroupBallot { result, predicate }, span); return Ok(Some(result)); } + "quadSwapX" => { + let mut args = ctx.prepare_args(arguments, 1, span); + + let argument = self.expression(args.next()?, ctx)?; + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = + ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupQuadSwap { + direction: crate::Direction::X, + argument, + result, + }, + span, + ); + return Ok(Some(result)) + } + + "quadSwapY" => { + let mut args = ctx.prepare_args(arguments, 1, span); + + let argument = self.expression(args.next()?, ctx)?; + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = + ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupQuadSwap { + direction: crate::Direction::Y, + argument, + result, + }, + span, + ); + return Ok(Some(result)) + } + + "quadSwapDiagonal" => { + let mut args = ctx.prepare_args(arguments, 1, span); + + let argument = self.expression(args.next()?, ctx)?; + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = + ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupQuadSwap { + direction: crate::Direction::Diagonal, + argument, + result, + }, + span, + ); + return Ok(Some(result)) + } _ => { return Err(Box::new(Error::UnknownIdent(function.span, function.name))) } @@ -3471,12 +3538,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } else { let index = self.expression(args.next()?, ctx)?; match mode { + Sg::BroadcastFirst => unreachable!(), Sg::Broadcast => ir::GatherMode::Broadcast(index), Sg::Shuffle => ir::GatherMode::Shuffle(index), Sg::ShuffleDown => ir::GatherMode::ShuffleDown(index), Sg::ShuffleUp => ir::GatherMode::ShuffleUp(index), Sg::ShuffleXor => ir::GatherMode::ShuffleXor(index), - Sg::BroadcastFirst => unreachable!(), + Sg::QuadBroadcast => ir::GatherMode::QuadBroadcast(index), } }; diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index 5f0b19b7dc2..24185c0b507 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -1279,6 +1279,8 @@ pub enum GatherMode { ShuffleUp(Handle), /// Each gathers from their lane xored with the given by the expression ShuffleXor(Handle), + /// All gather from the same lane at the index given by the expression + QuadBroadcast(Handle), } #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] @@ -2099,6 +2101,26 @@ pub enum Statement { /// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult result: Handle, }, + SubgroupQuadSwap { + /// In which direction to swap + direction: Direction, + /// The value to swap over + argument: Handle, + /// The [`SubgroupOperationResult`] expression representing this load's result. + /// + /// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult + result: Handle, + } +} + +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum Direction { + X = 0, + Y = 1, + Diagonal = 2, } /// A function argument. diff --git a/naga/src/proc/terminator.rs b/naga/src/proc/terminator.rs index f22e61e6a6d..06120df835c 100644 --- a/naga/src/proc/terminator.rs +++ b/naga/src/proc/terminator.rs @@ -42,6 +42,7 @@ pub fn ensure_block_returns(block: &mut crate::Block) { | S::SubgroupBallot { .. } | S::SubgroupCollectiveOperation { .. } | S::SubgroupGather { .. } + | S::SubgroupQuadSwap { .. } | S::Barrier(_)), ) | None => block.push(S::Return { value: None }, Default::default()), diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index cf7fec2060e..ba8b7b80507 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -1142,12 +1142,21 @@ impl FunctionInfo { | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) - | crate::GatherMode::ShuffleXor(index) => { + | crate::GatherMode::ShuffleXor(index) + | crate::GatherMode::QuadBroadcast(index) => { let _ = self.add_ref(index); } } FunctionUniformity::new() } + S::SubgroupQuadSwap { + direction: _, + argument, + result: _, + } => { + let _ = self.add_ref(argument); + FunctionUniformity::new() + } }; disruptor = disruptor.or(uniformity.exit_disruptor()); diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index 7865f1fc42e..f1513732683 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -705,7 +705,8 @@ impl super::Validator { | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) - | crate::GatherMode::ShuffleXor(index) => { + | crate::GatherMode::ShuffleXor(index) + | crate::GatherMode::QuadBroadcast(index) => { let index_ty = context.resolve_type_inner(index, &self.valid_expression_set)?; match *index_ty { crate::TypeInner::Scalar(crate::Scalar::U32) => {} @@ -1616,6 +1617,7 @@ impl super::Validator { } self.validate_subgroup_gather(mode, argument, result, context)?; } + S::SubgroupQuadSwap { direction, argument, result } => {} } } Ok(BlockInfo { stages, finished }) diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index 86285c2818b..a84b6a3d86b 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -740,11 +740,21 @@ impl super::Validator { | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) - | crate::GatherMode::ShuffleXor(index) => validate_expr(index)?, + | crate::GatherMode::ShuffleXor(index) + | crate::GatherMode::QuadBroadcast(index) => validate_expr(index)?, } validate_expr(result)?; Ok(()) } + crate::Statement::SubgroupQuadSwap { + direction: _, + argument, + result, + } => { + validate_expr(argument)?; + validate_expr(result)?; + Ok(()) + } crate::Statement::Break | crate::Statement::Continue | crate::Statement::Kill diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index c8a02db1afa..9870e9b3d53 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -195,8 +195,8 @@ bitflags::bitflags! { // We don't support these operations yet // /// Clustered // const CLUSTERED = 1 << 6; - // /// Quad supported - // const QUAD_FRAGMENT_COMPUTE = 1 << 7; + /// Quad supported + const QUAD_FRAGMENT_COMPUTE = 1 << 7; // /// Quad supported in all stages // const QUAD_ALL_STAGES = 1 << 8; } @@ -221,6 +221,7 @@ impl super::GatherMode { Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT, Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE, Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE, + Self::QuadBroadcast(_) => S::QUAD_FRAGMENT_COMPUTE, } } } diff --git a/naga/tests/in/wgsl/subgroup-operations.wgsl b/naga/tests/in/wgsl/subgroup-operations.wgsl index bb6eb47fb51..26b3d98e84a 100644 --- a/naga/tests/in/wgsl/subgroup-operations.wgsl +++ b/naga/tests/in/wgsl/subgroup-operations.wgsl @@ -34,4 +34,9 @@ fn main( subgroupShuffleDown(subgroup_invocation_id, 1u); subgroupShuffleUp(subgroup_invocation_id, 1u); subgroupShuffleXor(subgroup_invocation_id, sizes.subgroup_size - 1u); + + quadBroadcast(subgroup_invocation_id, 4u); + quadSwapX(subgroup_invocation_id); + quadSwapY(subgroup_invocation_id); + quadSwapDiagonal(subgroup_invocation_id); } From 8690ad473b1021c9abb4019b83cd06993c5666d7 Mon Sep 17 00:00:00 2001 From: valaphee <32491319+valaphee@users.noreply.github.com> Date: Thu, 16 May 2024 15:20:55 +0200 Subject: [PATCH 02/10] Impl quad swap for hlsl, msl and wgsl, finish spv front --- naga/src/back/hlsl/writer.rs | 28 +++++++++++++++++++++++++++- naga/src/back/msl/writer.rs | 22 +++++++++++++++++++++- naga/src/back/wgsl/writer.rs | 21 ++++++++++++++++++++- naga/src/front/spv/mod.rs | 14 ++++++++++++-- 4 files changed, 80 insertions(+), 5 deletions(-) diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 2d634b010b8..88e87f067ea 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2640,7 +2640,33 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } writeln!(self.out, ");")?; } - Statement::SubgroupQuadSwap { direction, argument, result } => {} + Statement::SubgroupQuadSwap { direction, argument, result } => { + write!(self.out, "{level}")?; + write!(self.out, "const ")?; + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + match func_ctx.info[result].ty { + proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, + proc::TypeResolution::Value(ref value) => { + self.write_value_type(module, value)? + } + }; + write!(self.out, " {name} = ")?; + self.named_expressions.insert(result, name); + + match direction { + crate::Direction::X => { + write!(self.out, "QuadReadAcrossX(")?; + }, + crate::Direction::Y => { + write!(self.out, "QuadReadAcrossY(")?; + }, + crate::Direction::Diagonal => { + write!(self.out, "QuadReadAcrossDiagonal(")?; + }, + } + self.write_expr(module, argument, func_ctx)?; + writeln!(self.out, ");")?; + } } Ok(()) diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 5d4b6c2c962..cd529b7b742 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -3950,7 +3950,27 @@ impl Writer { } writeln!(self.out, ");")?; }, - crate::Statement::SubgroupQuadSwap { direction, argument, result } => {} + crate::Statement::SubgroupQuadSwap { direction, argument, result } => { + write!(self.out, "{level}")?; + let name = self.namer.call(""); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?; + self.put_expression(argument, &context.expression, true)?; + write!(self.out, ", ")?; + match direction { + crate::Direction::X => { + write!(self.out, "0x01")?; + }, + crate::Direction::Y => { + write!(self.out, "0x10")?; + }, + crate::Direction::Diagonal => { + write!(self.out, "0x11")?; + }, + } + writeln!(self.out, ");")?; + } } } diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 2401d2807f6..e3e32032618 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -964,7 +964,26 @@ impl Writer { } writeln!(self.out, ");")?; } - Statement::SubgroupQuadSwap { direction, argument, result } => {} + Statement::SubgroupQuadSwap { direction, argument, result } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + match direction { + crate::Direction::X => { + write!(self.out, "quadSwapX(")?; + }, + crate::Direction::Y => { + write!(self.out, "quadSwapY(")?; + }, + crate::Direction::Diagonal => { + write!(self.out, "quadSwapDiagonal(")?; + }, + } + self.write_expr(module, argument, func_ctx)?; + writeln!(self.out, ");")?; + } } Ok(()) diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index f1a8acc98ad..f109332a060 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -4146,7 +4146,7 @@ impl> Frontend { let result_id = self.next()?; let exec_scope_id = self.next()?; let argument_id = self.next()?; - let direction = self.next()?; + let direction_id = self.next()?; let argument_lookup = self.lookup_expression.lookup(argument_id)?; let argument_handle = get_expr_handle!(argument_id, argument_lookup); @@ -4156,6 +4156,16 @@ impl> Frontend { .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; + let direction_const = self.lookup_constant.lookup(direction_id)?; + let direction_const = resolve_constant(ctx.gctx(), &direction_const.inner) + .ok_or(Error::InvalidOperand)?; + let direction = match direction_const { + 0 => crate::Direction::X, + 1 => crate::Direction::Y, + 2 => crate::Direction::Diagonal, + _ => unreachable!() + }; + let result_type = self.lookup_type.lookup(result_type_id)?; let result_handle = ctx.expressions.append( @@ -4175,7 +4185,7 @@ impl> Frontend { block.push( crate::Statement::SubgroupQuadSwap { - direction: crate::Direction::X, + direction, result: result_handle, argument: argument_handle, }, From e1e92da646b0c3c6b32e57232b3e7d088617c81e Mon Sep 17 00:00:00 2001 From: valaphee <32491319+valaphee@users.noreply.github.com> Date: Thu, 16 May 2024 15:35:40 +0200 Subject: [PATCH 03/10] Cargo clippy & cargo fmt, impl valid for quad ops --- naga/src/back/dot/mod.rs | 6 ++-- naga/src/back/glsl/mod.rs | 6 +++- naga/src/back/hlsl/writer.rs | 14 +++++--- naga/src/back/msl/writer.rs | 14 +++++--- naga/src/back/spv/block.rs | 20 ++++++------ naga/src/back/wgsl/writer.rs | 14 +++++--- naga/src/compact/statements.rs | 16 +++++++--- naga/src/front/spv/mod.rs | 2 +- naga/src/front/wgsl/lower/mod.rs | 24 ++++++++------ naga/src/valid/function.rs | 55 +++++++++++++++++++++++++++++++- 10 files changed, 126 insertions(+), 45 deletions(-) diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index b3ca50e6b53..df2470c259b 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -379,7 +379,7 @@ impl StatementGraph { | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) - | crate::GatherMode::ShuffleXor(index) + | crate::GatherMode::ShuffleXor(index) | crate::GatherMode::QuadBroadcast(index) => { self.dependencies.push((id, index, "index")) } @@ -393,13 +393,13 @@ impl StatementGraph { crate::GatherMode::ShuffleDown(_) => "SubgroupShuffleDown", crate::GatherMode::ShuffleUp(_) => "SubgroupShuffleUp", crate::GatherMode::ShuffleXor(_) => "SubgroupShuffleXor", - crate::GatherMode::QuadBroadcast(_) => "SubgroupQuadBroadcast" + crate::GatherMode::QuadBroadcast(_) => "SubgroupQuadBroadcast", } } S::SubgroupQuadSwap { direction, argument, - result + result, } => { self.dependencies.push((id, argument, "arg")); self.emits.push((id, result)); diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index f6701f61d12..3836d2e206e 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -2735,7 +2735,11 @@ impl<'a, W: Write> Writer<'a, W> { } writeln!(self.out, ");")?; } - Statement::SubgroupQuadSwap { direction, argument, result } => { + Statement::SubgroupQuadSwap { + direction, + argument, + result, + } => { write!(self.out, "{level}")?; let res_name = Baked(result).to_string(); let res_ty = ctx.info[result].ty.inner_with(&self.module.types); diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 88e87f067ea..9c02e38d267 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2640,10 +2640,14 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } writeln!(self.out, ");")?; } - Statement::SubgroupQuadSwap { direction, argument, result } => { + Statement::SubgroupQuadSwap { + direction, + argument, + result, + } => { write!(self.out, "{level}")?; write!(self.out, "const ")?; - let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let name = Baked(result).to_string(); match func_ctx.info[result].ty { proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, proc::TypeResolution::Value(ref value) => { @@ -2656,13 +2660,13 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { match direction { crate::Direction::X => { write!(self.out, "QuadReadAcrossX(")?; - }, + } crate::Direction::Y => { write!(self.out, "QuadReadAcrossY(")?; - }, + } crate::Direction::Diagonal => { write!(self.out, "QuadReadAcrossDiagonal(")?; - }, + } } self.write_expr(module, argument, func_ctx)?; writeln!(self.out, ");")?; diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index cd529b7b742..1e474f0faf5 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -3949,8 +3949,12 @@ impl Writer { } } writeln!(self.out, ");")?; - }, - crate::Statement::SubgroupQuadSwap { direction, argument, result } => { + } + crate::Statement::SubgroupQuadSwap { + direction, + argument, + result, + } => { write!(self.out, "{level}")?; let name = self.namer.call(""); self.start_baking_expression(result, &context.expression, &name)?; @@ -3961,13 +3965,13 @@ impl Writer { match direction { crate::Direction::X => { write!(self.out, "0x01")?; - }, + } crate::Direction::Y => { write!(self.out, "0x10")?; - }, + } crate::Direction::Diagonal => { write!(self.out, "0x11")?; - }, + } } writeln!(self.out, ");")?; } diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 537e40f4f67..81fb85e89e4 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -3470,7 +3470,7 @@ impl BlockContext<'_> { Statement::SubgroupQuadSwap { ref direction, argument, - result + result, } => { self.writer.require_any( "GroupNonUniformQuad", @@ -3485,21 +3485,19 @@ impl BlockContext<'_> { let arg_id = self.cached[argument]; - let direction = self.get_index_constant(match direction { + let direction = self.get_index_constant(match *direction { crate::Direction::X => 0, crate::Direction::Y => 1, crate::Direction::Diagonal => 2, }); - block - .body - .push(Instruction::group_non_uniform_quad_swap( - result_type_id, - id, - exec_scope_id, - arg_id, - direction - )); + block.body.push(Instruction::group_non_uniform_quad_swap( + result_type_id, + id, + exec_scope_id, + arg_id, + direction, + )); } } } diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index e3e32032618..7b54e6b092c 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -964,22 +964,26 @@ impl Writer { } writeln!(self.out, ");")?; } - Statement::SubgroupQuadSwap { direction, argument, result } => { + Statement::SubgroupQuadSwap { + direction, + argument, + result, + } => { write!(self.out, "{level}")?; - let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let res_name = Baked(result).to_string(); self.start_named_expr(module, result, func_ctx, &res_name)?; self.named_expressions.insert(result, res_name); match direction { crate::Direction::X => { write!(self.out, "quadSwapX(")?; - }, + } crate::Direction::Y => { write!(self.out, "quadSwapY(")?; - }, + } crate::Direction::Diagonal => { write!(self.out, "quadSwapDiagonal(")?; - }, + } } self.write_expr(module, argument, func_ctx)?; writeln!(self.out, ");")?; diff --git a/naga/src/compact/statements.rs b/naga/src/compact/statements.rs index fe201039e96..cf39621927b 100644 --- a/naga/src/compact/statements.rs +++ b/naga/src/compact/statements.rs @@ -148,11 +148,15 @@ impl FunctionTracer<'_> { } self.expressions_used.insert(argument); self.expressions_used.insert(result); - }, - St::SubgroupQuadSwap { direction, argument, result } => { + } + St::SubgroupQuadSwap { + direction: _, + argument, + result, + } => { self.expressions_used.insert(argument); self.expressions_used.insert(result); - }, + } // Trivial statements. St::Break @@ -361,7 +365,11 @@ impl FunctionMap { adjust(argument); adjust(result); } - St::SubgroupQuadSwap { direction: _, ref mut argument, ref mut result } => { + St::SubgroupQuadSwap { + direction: _, + ref mut argument, + ref mut result, + } => { adjust(argument); adjust(result); } diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index f109332a060..a25947b3796 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -4163,7 +4163,7 @@ impl> Frontend { 0 => crate::Direction::X, 1 => crate::Direction::Y, 2 => crate::Direction::Diagonal, - _ => unreachable!() + _ => unreachable!(), }; let result_type = self.lookup_type.lookup(result_type_id)?; diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 96369d13935..9027610b679 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -2950,8 +2950,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let ty = ctx.register_type(argument)?; - let result = - ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?; + let result = ctx.interrupt_emitter( + crate::Expression::SubgroupOperationResult { ty }, + span, + )?; let rctx = ctx.runtime_expression_ctx(span)?; rctx.block.push( crate::Statement::SubgroupQuadSwap { @@ -2961,7 +2963,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }, span, ); - return Ok(Some(result)) + return Ok(Some(result)); } "quadSwapY" => { @@ -2972,8 +2974,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let ty = ctx.register_type(argument)?; - let result = - ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?; + let result = ctx.interrupt_emitter( + crate::Expression::SubgroupOperationResult { ty }, + span, + )?; let rctx = ctx.runtime_expression_ctx(span)?; rctx.block.push( crate::Statement::SubgroupQuadSwap { @@ -2983,7 +2987,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }, span, ); - return Ok(Some(result)) + return Ok(Some(result)); } "quadSwapDiagonal" => { @@ -2994,8 +2998,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let ty = ctx.register_type(argument)?; - let result = - ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?; + let result = ctx.interrupt_emitter( + crate::Expression::SubgroupOperationResult { ty }, + span, + )?; let rctx = ctx.runtime_expression_ctx(span)?; rctx.block.push( crate::Statement::SubgroupQuadSwap { @@ -3005,7 +3011,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }, span, ); - return Ok(Some(result)) + return Ok(Some(result)); } _ => { return Err(Box::new(Error::UnknownIdent(function.span, function.name))) diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index f1513732683..e03bb940d6d 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -745,6 +745,35 @@ impl super::Validator { } Ok(()) } + fn validate_subgroup_quad_swap( + &mut self, + argument: Handle, + result: Handle, + context: &BlockContext, + ) -> Result<(), WithSpan> { + let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; + if !matches!(*argument_inner, + crate::TypeInner::Scalar ( scalar, .. ) | crate::TypeInner::Vector { scalar, .. } + if matches!(scalar.kind, crate::ScalarKind::Uint | crate::ScalarKind::Sint | crate::ScalarKind::Float) + ) { + log::error!("Subgroup quad swap operand type {:?}", argument_inner); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(argument, context.expressions) + .into_other()); + } + + self.emit_expression(result, context)?; + match context.expressions[result] { + crate::Expression::SubgroupOperationResult { ty } + if { &context.types[ty].inner == argument_inner } => {} + _ => { + return Err(SubgroupError::ResultTypeMismatch(result) + .with_span_handle(result, context.expressions) + .into_other()) + } + } + Ok(()) + } fn validate_block_impl( &mut self, @@ -1617,7 +1646,31 @@ impl super::Validator { } self.validate_subgroup_gather(mode, argument, result, context)?; } - S::SubgroupQuadSwap { direction, argument, result } => {} + S::SubgroupQuadSwap { + direction: _, + argument, + result, + } => { + stages &= self.subgroup_stages; + if !self.capabilities.contains(super::Capabilities::SUBGROUP) { + return Err(FunctionError::MissingCapability( + super::Capabilities::SUBGROUP, + ) + .with_span_static(span, "missing capability for this operation")); + } + if !self + .subgroup_operations + .contains(super::SubgroupOperationSet::QUAD_FRAGMENT_COMPUTE) + { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation( + super::SubgroupOperationSet::QUAD_FRAGMENT_COMPUTE, + ), + ) + .with_span_static(span, "support for this operation is not present")); + } + self.validate_subgroup_quad_swap(argument, result, context)?; + } } } Ok(BlockInfo { stages, finished }) From f1492f015c5f82dbd3f938659cec177adef9872c Mon Sep 17 00:00:00 2001 From: valaphee <32491319+valaphee@users.noreply.github.com> Date: Thu, 16 May 2024 15:43:45 +0200 Subject: [PATCH 04/10] Enable quad feature --- naga/src/back/hlsl/writer.rs | 59 +++++++++++-------- naga/src/back/msl/writer.rs | 6 +- naga/src/valid/mod.rs | 8 ++- ...wgsl-subgroup-operations.main.Compute.glsl | 4 ++ .../out/hlsl/wgsl-subgroup-operations.hlsl | 4 ++ .../out/msl/wgsl-subgroup-operations.msl | 4 ++ .../out/spv/wgsl-subgroup-operations.spvasm | 7 ++- .../out/wgsl/wgsl-subgroup-operations.wgsl | 4 ++ wgpu-hal/src/vulkan/adapter.rs | 3 +- 9 files changed, 68 insertions(+), 31 deletions(-) diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 9c02e38d267..6e1ddaf71c5 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2610,32 +2610,41 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { }; write!(self.out, " {name} = ")?; self.named_expressions.insert(result, name); - - if matches!(mode, crate::GatherMode::BroadcastFirst) { - write!(self.out, "WaveReadLaneFirst(")?; - self.write_expr(module, argument, func_ctx)?; - } else { - write!(self.out, "WaveReadLaneAt(")?; - self.write_expr(module, argument, func_ctx)?; - write!(self.out, ", ")?; - match mode { - crate::GatherMode::BroadcastFirst => unreachable!(), - crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) => { - self.write_expr(module, index, func_ctx)?; - } - crate::GatherMode::ShuffleDown(index) => { - write!(self.out, "WaveGetLaneIndex() + ")?; - self.write_expr(module, index, func_ctx)?; - } - crate::GatherMode::ShuffleUp(index) => { - write!(self.out, "WaveGetLaneIndex() - ")?; - self.write_expr(module, index, func_ctx)?; - } - crate::GatherMode::ShuffleXor(index) => { - write!(self.out, "WaveGetLaneIndex() ^ ")?; - self.write_expr(module, index, func_ctx)?; + match mode { + crate::GatherMode::BroadcastFirst => { + write!(self.out, "WaveReadLaneFirst(")?; + self.write_expr(module, argument, func_ctx)?; + } + crate::GatherMode::QuadBroadcast(index) => { + write!(self.out, "QuadReadLaneAt(")?; + self.write_expr(module, argument, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, index, func_ctx)?; + } + _ => { + write!(self.out, "WaveReadLaneAt(")?; + self.write_expr(module, argument, func_ctx)?; + write!(self.out, ", ")?; + match mode { + crate::GatherMode::BroadcastFirst => unreachable!(), + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) => { + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleDown(index) => { + write!(self.out, "WaveGetLaneIndex() + ")?; + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleUp(index) => { + write!(self.out, "WaveGetLaneIndex() - ")?; + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleXor(index) => { + write!(self.out, "WaveGetLaneIndex() ^ ")?; + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::QuadBroadcast(_) => unreachable!(), } - crate::GatherMode::QuadBroadcast(_) => unreachable!() } } writeln!(self.out, ");")?; diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 1e474f0faf5..15ca3dee1ee 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -3964,13 +3964,13 @@ impl Writer { write!(self.out, ", ")?; match direction { crate::Direction::X => { - write!(self.out, "0x01")?; + write!(self.out, "1u")?; } crate::Direction::Y => { - write!(self.out, "0x10")?; + write!(self.out, "2u")?; } crate::Direction::Diagonal => { - write!(self.out, "0x11")?; + write!(self.out, "3u")?; } } writeln!(self.out, ");")?; diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index 9870e9b3d53..460980e4fe0 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -458,7 +458,13 @@ impl Validator { pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self { let subgroup_operations = if capabilities.contains(Capabilities::SUBGROUP) { use SubgroupOperationSet as S; - S::BASIC | S::VOTE | S::ARITHMETIC | S::BALLOT | S::SHUFFLE | S::SHUFFLE_RELATIVE + S::BASIC + | S::VOTE + | S::ARITHMETIC + | S::BALLOT + | S::SHUFFLE + | S::SHUFFLE_RELATIVE + | S::QUAD_FRAGMENT_COMPUTE } else { SubgroupOperationSet::empty() }; diff --git a/naga/tests/out/glsl/wgsl-subgroup-operations.main.Compute.glsl b/naga/tests/out/glsl/wgsl-subgroup-operations.main.Compute.glsl index 05ab403565e..ef1f0bd7821 100644 --- a/naga/tests/out/glsl/wgsl-subgroup-operations.main.Compute.glsl +++ b/naga/tests/out/glsl/wgsl-subgroup-operations.main.Compute.glsl @@ -40,6 +40,10 @@ void main() { uint _e35 = subgroupShuffleDown(subgroup_invocation_id, 1u); uint _e37 = subgroupShuffleUp(subgroup_invocation_id, 1u); uint _e41 = subgroupShuffleXor(subgroup_invocation_id, (sizes.subgroup_size - 1u)); + uint _e43 = subgroupQuadBroadcast(subgroup_invocation_id, 4u); + uint _e44 = subgroupQuadSwapHorizontal(subgroup_invocation_id); + uint _e45 = subgroupQuadSwapVertical(subgroup_invocation_id); + uint _e46 = subgroupQuadSwapDiagonal(subgroup_invocation_id); return; } diff --git a/naga/tests/out/hlsl/wgsl-subgroup-operations.hlsl b/naga/tests/out/hlsl/wgsl-subgroup-operations.hlsl index 839b1fa6b29..ef9a65b02e7 100644 --- a/naga/tests/out/hlsl/wgsl-subgroup-operations.hlsl +++ b/naga/tests/out/hlsl/wgsl-subgroup-operations.hlsl @@ -34,5 +34,9 @@ void main(ComputeInput_main computeinput_main) const uint _e35 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() + 1u); const uint _e37 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() - 1u); const uint _e41 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() ^ (sizes.subgroup_size - 1u)); + const uint _e43 = QuadReadLaneAt(subgroup_invocation_id, 4u); + const uint _e44 = QuadReadAcrossX(subgroup_invocation_id); + const uint _e45 = QuadReadAcrossY(subgroup_invocation_id); + const uint _e46 = QuadReadAcrossDiagonal(subgroup_invocation_id); return; } diff --git a/naga/tests/out/msl/wgsl-subgroup-operations.msl b/naga/tests/out/msl/wgsl-subgroup-operations.msl index 20a550a715a..4245c20b000 100644 --- a/naga/tests/out/msl/wgsl-subgroup-operations.msl +++ b/naga/tests/out/msl/wgsl-subgroup-operations.msl @@ -40,5 +40,9 @@ kernel void main_( uint unnamed_18 = metal::simd_shuffle_down(subgroup_invocation_id, 1u); uint unnamed_19 = metal::simd_shuffle_up(subgroup_invocation_id, 1u); uint unnamed_20 = metal::simd_shuffle_xor(subgroup_invocation_id, sizes.subgroup_size - 1u); + uint unnamed_21 = metal::quad_broadcast(subgroup_invocation_id, 4u); + uint unnamed_22 = metal::quad_shuffle_xor(subgroup_invocation_id, 1u); + uint unnamed_23 = metal::quad_shuffle_xor(subgroup_invocation_id, 2u); + uint unnamed_24 = metal::quad_shuffle_xor(subgroup_invocation_id, 3u); return; } diff --git a/naga/tests/out/spv/wgsl-subgroup-operations.spvasm b/naga/tests/out/spv/wgsl-subgroup-operations.spvasm index fb60aae5bcf..f4f257ad0ec 100644 --- a/naga/tests/out/spv/wgsl-subgroup-operations.spvasm +++ b/naga/tests/out/spv/wgsl-subgroup-operations.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.3 ; Generator: rspirv -; Bound: 58 +; Bound: 62 OpCapability Shader OpCapability GroupNonUniform OpCapability GroupNonUniformBallot @@ -9,6 +9,7 @@ OpCapability GroupNonUniformVote OpCapability GroupNonUniformArithmetic OpCapability GroupNonUniformShuffle OpCapability GroupNonUniformShuffleRelative +OpCapability GroupNonUniformQuad %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 OpEntryPoint GLCompute %17 "main" %8 %11 %13 %15 @@ -77,5 +78,9 @@ OpControlBarrier %23 %24 %25 %55 = OpCompositeExtract %3 %7 1 %56 = OpISub %3 %55 %19 %57 = OpGroupNonUniformShuffleXor %3 %23 %16 %56 +%58 = OpGroupNonUniformQuadBroadcast %3 %23 %16 %21 +%59 = OpGroupNonUniformQuadSwap %3 %23 %16 %20 +%60 = OpGroupNonUniformQuadSwap %3 %23 %16 %19 +%61 = OpGroupNonUniformQuadSwap %3 %23 %16 %24 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/wgsl/wgsl-subgroup-operations.wgsl b/naga/tests/out/wgsl/wgsl-subgroup-operations.wgsl index 25f713b3578..d7bfbc1d735 100644 --- a/naga/tests/out/wgsl/wgsl-subgroup-operations.wgsl +++ b/naga/tests/out/wgsl/wgsl-subgroup-operations.wgsl @@ -27,5 +27,9 @@ fn main(sizes: Structure, @builtin(subgroup_id) subgroup_id: u32, @builtin(subgr let _e35 = subgroupShuffleDown(subgroup_invocation_id, 1u); let _e37 = subgroupShuffleUp(subgroup_invocation_id, 1u); let _e41 = subgroupShuffleXor(subgroup_invocation_id, (sizes.subgroup_size - 1u)); + let _e43 = quadBroadcast(subgroup_invocation_id, 4u); + let _e44 = quadSwapX(subgroup_invocation_id); + let _e45 = quadSwapY(subgroup_invocation_id); + let _e46 = quadSwapDiagonal(subgroup_invocation_id); return; } diff --git a/wgpu-hal/src/vulkan/adapter.rs b/wgpu-hal/src/vulkan/adapter.rs index e9ae1e597a2..750985b4ca3 100644 --- a/wgpu-hal/src/vulkan/adapter.rs +++ b/wgpu-hal/src/vulkan/adapter.rs @@ -743,7 +743,8 @@ impl PhysicalDeviceFeatures { | vk::SubgroupFeatureFlags::ARITHMETIC | vk::SubgroupFeatureFlags::BALLOT | vk::SubgroupFeatureFlags::SHUFFLE - | vk::SubgroupFeatureFlags::SHUFFLE_RELATIVE, + | vk::SubgroupFeatureFlags::SHUFFLE_RELATIVE + | vk::SubgroupFeatureFlags::QUAD, ) { features.set( From b525a606e22f076d19aa69822a3149a0588f270c Mon Sep 17 00:00:00 2001 From: valaphee <32491319+valaphee@users.noreply.github.com> Date: Thu, 16 May 2024 15:50:32 +0200 Subject: [PATCH 05/10] Add missing feature to glsl --- naga/src/back/glsl/features.rs | 1 + naga/tests/out/glsl/spv-subgroup-operations-s.main.Compute.glsl | 1 + naga/tests/out/glsl/wgsl-subgroup-operations.main.Compute.glsl | 1 + 3 files changed, 3 insertions(+) diff --git a/naga/src/back/glsl/features.rs b/naga/src/back/glsl/features.rs index 0a083e85985..e67e09c6ca5 100644 --- a/naga/src/back/glsl/features.rs +++ b/naga/src/back/glsl/features.rs @@ -280,6 +280,7 @@ impl FeaturesManager { out, "#extension GL_KHR_shader_subgroup_shuffle_relative : require" )?; + writeln!(out, "#extension GL_KHR_shader_subgroup_quad : require")?; } if self.0.contains(Features::TEXTURE_ATOMICS) { diff --git a/naga/tests/out/glsl/spv-subgroup-operations-s.main.Compute.glsl b/naga/tests/out/glsl/spv-subgroup-operations-s.main.Compute.glsl index 067112c2d01..67389282e65 100644 --- a/naga/tests/out/glsl/spv-subgroup-operations-s.main.Compute.glsl +++ b/naga/tests/out/glsl/spv-subgroup-operations-s.main.Compute.glsl @@ -6,6 +6,7 @@ #extension GL_KHR_shader_subgroup_ballot : require #extension GL_KHR_shader_subgroup_shuffle : require #extension GL_KHR_shader_subgroup_shuffle_relative : require +#extension GL_KHR_shader_subgroup_quad : require layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; uint global = 0u; diff --git a/naga/tests/out/glsl/wgsl-subgroup-operations.main.Compute.glsl b/naga/tests/out/glsl/wgsl-subgroup-operations.main.Compute.glsl index ef1f0bd7821..6eba8cf01a3 100644 --- a/naga/tests/out/glsl/wgsl-subgroup-operations.main.Compute.glsl +++ b/naga/tests/out/glsl/wgsl-subgroup-operations.main.Compute.glsl @@ -6,6 +6,7 @@ #extension GL_KHR_shader_subgroup_ballot : require #extension GL_KHR_shader_subgroup_shuffle : require #extension GL_KHR_shader_subgroup_shuffle_relative : require +#extension GL_KHR_shader_subgroup_quad : require layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; struct Structure { From 63ee3cac3f18fe7c6e4394b867a3681e2305c7b9 Mon Sep 17 00:00:00 2001 From: Dmitry Zamkov Date: Sun, 11 May 2025 00:37:17 -0500 Subject: [PATCH 06/10] Simplifying code by making `SubgroupQuadSwap` an instance of `SubgroupGather` --- naga/src/back/dot/mod.rs | 19 ++++------ naga/src/back/glsl/mod.rs | 38 +++++++------------- naga/src/back/hlsl/writer.rs | 46 ++++++++---------------- naga/src/back/msl/writer.rs | 40 +++++++++------------ naga/src/back/pipeline_constants.rs | 9 +---- naga/src/back/spv/block.rs | 32 ----------------- naga/src/back/spv/instructions.rs | 16 --------- naga/src/back/spv/subgroup.rs | 17 ++++++++- naga/src/back/wgsl/writer.rs | 36 +++++++------------ naga/src/compact/statements.rs | 18 ++-------- naga/src/front/spv/mod.rs | 7 ++-- naga/src/front/wgsl/lower/mod.rs | 12 +++---- naga/src/ir/mod.rs | 34 +++++++----------- naga/src/proc/terminator.rs | 1 - naga/src/valid/analyzer.rs | 9 +---- naga/src/valid/function.rs | 55 +---------------------------- naga/src/valid/handles.rs | 10 +----- naga/src/valid/mod.rs | 2 +- 18 files changed, 106 insertions(+), 295 deletions(-) diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index df2470c259b..13d0857f742 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -383,6 +383,7 @@ impl StatementGraph { | crate::GatherMode::QuadBroadcast(index) => { self.dependencies.push((id, index, "index")) } + crate::GatherMode::QuadSwap(_) => {} } self.dependencies.push((id, argument, "arg")); self.emits.push((id, result)); @@ -394,19 +395,11 @@ impl StatementGraph { crate::GatherMode::ShuffleUp(_) => "SubgroupShuffleUp", crate::GatherMode::ShuffleXor(_) => "SubgroupShuffleXor", crate::GatherMode::QuadBroadcast(_) => "SubgroupQuadBroadcast", - } - } - S::SubgroupQuadSwap { - direction, - argument, - result, - } => { - self.dependencies.push((id, argument, "arg")); - self.emits.push((id, result)); - match direction { - crate::Direction::X => "SubgroupQuadSwapX", - crate::Direction::Y => "SubgroupQuadSwapY", - crate::Direction::Diagonal => "SubgroupQuadSwapDiagonal", + crate::GatherMode::QuadSwap(direction) => match direction { + crate::Direction::X => "SubgroupQuadSwapX", + crate::Direction::Y => "SubgroupQuadSwapY", + crate::Direction::Diagonal => "SubgroupQuadSwapDiagonal", + }, } } }; diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 3836d2e206e..112af70edab 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -2719,6 +2719,17 @@ impl<'a, W: Write> Writer<'a, W> { crate::GatherMode::QuadBroadcast(_) => { write!(self.out, "subgroupQuadBroadcast(")?; } + crate::GatherMode::QuadSwap(direction) => match direction { + crate::Direction::X => { + write!(self.out, "subgroupQuadSwapHorizontal(")?; + } + crate::Direction::Y => { + write!(self.out, "subgroupQuadSwapVertical(")?; + } + crate::Direction::Diagonal => { + write!(self.out, "subgroupQuadSwapDiagonal(")?; + } + }, } self.write_expr(argument, ctx)?; match mode { @@ -2732,35 +2743,10 @@ impl<'a, W: Write> Writer<'a, W> { write!(self.out, ", ")?; self.write_expr(index, ctx)?; } + crate::GatherMode::QuadSwap(_) => {} } writeln!(self.out, ");")?; } - Statement::SubgroupQuadSwap { - direction, - argument, - result, - } => { - write!(self.out, "{level}")?; - let res_name = Baked(result).to_string(); - let res_ty = ctx.info[result].ty.inner_with(&self.module.types); - self.write_value_type(res_ty)?; - write!(self.out, " {res_name} = ")?; - self.named_expressions.insert(result, res_name); - - match direction { - crate::Direction::X => { - write!(self.out, "subgroupQuadSwapHorizontal(")?; - } - crate::Direction::Y => { - write!(self.out, "subgroupQuadSwapVertical(")?; - } - crate::Direction::Diagonal => { - write!(self.out, "subgroupQuadSwapDiagonal(")?; - } - } - self.write_expr(argument, ctx)?; - writeln!(self.out, ");")?; - } } Ok(()) diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 6e1ddaf71c5..cb0dccd5441 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2621,6 +2621,20 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, ", ")?; self.write_expr(module, index, func_ctx)?; } + crate::GatherMode::QuadSwap(direction) => { + match direction { + crate::Direction::X => { + write!(self.out, "QuadReadAcrossX(")?; + } + crate::Direction::Y => { + write!(self.out, "QuadReadAcrossY(")?; + } + crate::Direction::Diagonal => { + write!(self.out, "QuadReadAcrossDiagonal(")?; + } + } + self.write_expr(module, argument, func_ctx)?; + } _ => { write!(self.out, "WaveReadLaneAt(")?; self.write_expr(module, argument, func_ctx)?; @@ -2644,42 +2658,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_expr(module, index, func_ctx)?; } crate::GatherMode::QuadBroadcast(_) => unreachable!(), + crate::GatherMode::QuadSwap(_) => unreachable!(), } } } writeln!(self.out, ");")?; } - Statement::SubgroupQuadSwap { - direction, - argument, - result, - } => { - write!(self.out, "{level}")?; - write!(self.out, "const ")?; - let name = Baked(result).to_string(); - match func_ctx.info[result].ty { - proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, - proc::TypeResolution::Value(ref value) => { - self.write_value_type(module, value)? - } - }; - write!(self.out, " {name} = ")?; - self.named_expressions.insert(result, name); - - match direction { - crate::Direction::X => { - write!(self.out, "QuadReadAcrossX(")?; - } - crate::Direction::Y => { - write!(self.out, "QuadReadAcrossY(")?; - } - crate::Direction::Diagonal => { - write!(self.out, "QuadReadAcrossDiagonal(")?; - } - } - self.write_expr(module, argument, func_ctx)?; - writeln!(self.out, ");")?; - } } Ok(()) diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 15ca3dee1ee..36ab38073cf 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -3934,6 +3934,9 @@ impl Writer { crate::GatherMode::QuadBroadcast(_) => { write!(self.out, "{NAMESPACE}::quad_broadcast(")?; } + crate::GatherMode::QuadSwap(_) => { + write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?; + } } self.put_expression(argument, &context.expression, true)?; match mode { @@ -3947,30 +3950,19 @@ impl Writer { write!(self.out, ", ")?; self.put_expression(index, &context.expression, true)?; } - } - writeln!(self.out, ");")?; - } - crate::Statement::SubgroupQuadSwap { - direction, - argument, - result, - } => { - write!(self.out, "{level}")?; - let name = self.namer.call(""); - self.start_baking_expression(result, &context.expression, &name)?; - self.named_expressions.insert(result, name); - write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?; - self.put_expression(argument, &context.expression, true)?; - write!(self.out, ", ")?; - match direction { - crate::Direction::X => { - write!(self.out, "1u")?; - } - crate::Direction::Y => { - write!(self.out, "2u")?; - } - crate::Direction::Diagonal => { - write!(self.out, "3u")?; + crate::GatherMode::QuadSwap(direction) => { + write!(self.out, ", ")?; + match direction { + crate::Direction::X => { + write!(self.out, "1u")?; + } + crate::Direction::Y => { + write!(self.out, "2u")?; + } + crate::Direction::Diagonal => { + write!(self.out, "3u")?; + } + } } } writeln!(self.out, ");")?; diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 86390a4fed9..0088c6eac3f 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -763,18 +763,11 @@ fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut S | crate::GatherMode::QuadBroadcast(ref mut index) => { adjust(index); } + crate::GatherMode::QuadSwap(_) => {} } adjust(argument); adjust(result) } - Statement::SubgroupQuadSwap { - ref mut argument, - ref mut result, - .. - } => { - adjust(argument); - adjust(result); - } Statement::Call { ref mut arguments, ref mut result, diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 81fb85e89e4..92edbcb05c4 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -3467,38 +3467,6 @@ impl BlockContext<'_> { } => { self.write_subgroup_gather(mode, argument, result, &mut block)?; } - Statement::SubgroupQuadSwap { - ref direction, - argument, - result, - } => { - self.writer.require_any( - "GroupNonUniformQuad", - &[spirv::Capability::GroupNonUniformQuad], - )?; - - let id = self.gen_id(); - let result_ty = &self.fun_info[result].ty; - let result_type_id = self.get_expression_type_id(result_ty); - - let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); - - let arg_id = self.cached[argument]; - - let direction = self.get_index_constant(match *direction { - crate::Direction::X => 0, - crate::Direction::Y => 1, - crate::Direction::Diagonal => 2, - }); - - block.body.push(Instruction::group_non_uniform_quad_swap( - result_type_id, - id, - exec_scope_id, - arg_id, - direction, - )); - } } } diff --git a/naga/src/back/spv/instructions.rs b/naga/src/back/spv/instructions.rs index f6f25693cee..97cf54587c6 100644 --- a/naga/src/back/spv/instructions.rs +++ b/naga/src/back/spv/instructions.rs @@ -1205,22 +1205,6 @@ impl super::Instruction { instruction } - pub(super) fn group_non_uniform_quad_broadcast( - result_type_id: Word, - id: Word, - exec_scope_id: Word, - value: Word, - index: Word, - ) -> Self { - let mut instruction = Self::new(Op::GroupNonUniformQuadBroadcast); - instruction.set_type(result_type_id); - instruction.set_result(id); - instruction.add_operand(exec_scope_id); - instruction.add_operand(value); - instruction.add_operand(index); - - instruction - } pub(super) fn group_non_uniform_quad_swap( result_type_id: Word, id: Word, diff --git a/naga/src/back/spv/subgroup.rs b/naga/src/back/spv/subgroup.rs index dba7072532e..a0c35f1d648 100644 --- a/naga/src/back/spv/subgroup.rs +++ b/naga/src/back/spv/subgroup.rs @@ -146,7 +146,7 @@ impl BlockContext<'_> { &[spirv::Capability::GroupNonUniformShuffleRelative], )?; } - crate::GatherMode::QuadBroadcast(_) => { + crate::GatherMode::QuadBroadcast(_) | crate::GatherMode::QuadSwap(_) => { self.writer.require_any( "GroupNonUniformQuad", &[spirv::Capability::GroupNonUniformQuad], @@ -191,6 +191,7 @@ impl BlockContext<'_> { crate::GatherMode::ShuffleUp(_) => spirv::Op::GroupNonUniformShuffleUp, crate::GatherMode::ShuffleXor(_) => spirv::Op::GroupNonUniformShuffleXor, crate::GatherMode::QuadBroadcast(_) => spirv::Op::GroupNonUniformQuadBroadcast, + crate::GatherMode::QuadSwap(_) => unreachable!(), }; block.body.push(Instruction::group_non_uniform_gather( op, @@ -201,6 +202,20 @@ impl BlockContext<'_> { index_id, )); } + crate::GatherMode::QuadSwap(direction) => { + let direction = self.get_index_constant(match direction { + crate::Direction::X => 0, + crate::Direction::Y => 1, + crate::Direction::Diagonal => 2, + }); + block.body.push(Instruction::group_non_uniform_quad_swap( + result_type_id, + id, + exec_scope_id, + arg_id, + direction, + )); + } } self.cached[result] = id; Ok(()) diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 7b54e6b092c..3ae12b3ecf1 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -948,6 +948,17 @@ impl Writer { crate::GatherMode::QuadBroadcast(_) => { write!(self.out, "quadBroadcast(")?; } + crate::GatherMode::QuadSwap(direction) => match direction { + crate::Direction::X => { + write!(self.out, "quadSwapX(")?; + } + crate::Direction::Y => { + write!(self.out, "quadSwapY(")?; + } + crate::Direction::Diagonal => { + write!(self.out, "quadSwapDiagonal(")?; + } + }, } self.write_expr(module, argument, func_ctx)?; match mode { @@ -961,33 +972,10 @@ impl Writer { write!(self.out, ", ")?; self.write_expr(module, index, func_ctx)?; } + crate::GatherMode::QuadSwap(_) => {} } writeln!(self.out, ");")?; } - Statement::SubgroupQuadSwap { - direction, - argument, - result, - } => { - write!(self.out, "{level}")?; - let res_name = Baked(result).to_string(); - self.start_named_expr(module, result, func_ctx, &res_name)?; - self.named_expressions.insert(result, res_name); - - match direction { - crate::Direction::X => { - write!(self.out, "quadSwapX(")?; - } - crate::Direction::Y => { - write!(self.out, "quadSwapY(")?; - } - crate::Direction::Diagonal => { - write!(self.out, "quadSwapDiagonal(")?; - } - } - self.write_expr(module, argument, func_ctx)?; - writeln!(self.out, ");")?; - } } Ok(()) diff --git a/naga/src/compact/statements.rs b/naga/src/compact/statements.rs index cf39621927b..bf8d6ec7c8f 100644 --- a/naga/src/compact/statements.rs +++ b/naga/src/compact/statements.rs @@ -145,18 +145,11 @@ impl FunctionTracer<'_> { | crate::GatherMode::QuadBroadcast(index) => { self.expressions_used.insert(index); } + crate::GatherMode::QuadSwap(_) => {} } self.expressions_used.insert(argument); self.expressions_used.insert(result); } - St::SubgroupQuadSwap { - direction: _, - argument, - result, - } => { - self.expressions_used.insert(argument); - self.expressions_used.insert(result); - } // Trivial statements. St::Break @@ -361,18 +354,11 @@ impl FunctionMap { | crate::GatherMode::ShuffleUp(ref mut index) | crate::GatherMode::ShuffleXor(ref mut index) | crate::GatherMode::QuadBroadcast(ref mut index) => adjust(index), + crate::GatherMode::QuadSwap(_) => {} } adjust(argument); adjust(result); } - St::SubgroupQuadSwap { - direction: _, - ref mut argument, - ref mut result, - } => { - adjust(argument); - adjust(result); - } // Trivial statements. St::Break diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index a25947b3796..ddfaa0243f5 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -4184,8 +4184,8 @@ impl> Frontend { ); block.push( - crate::Statement::SubgroupQuadSwap { - direction, + crate::Statement::SubgroupGather { + mode: crate::GatherMode::QuadSwap(direction), result: result_handle, argument: argument_handle, }, @@ -4574,8 +4574,7 @@ impl> Frontend { | S::RayQuery { .. } | S::SubgroupBallot { .. } | S::SubgroupCollectiveOperation { .. } - | S::SubgroupGather { .. } - | S::SubgroupQuadSwap { .. } => {} + | S::SubgroupGather { .. } => {} S::Call { function: ref mut callee, ref arguments, diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 9027610b679..c08f2200a95 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -2956,8 +2956,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { )?; let rctx = ctx.runtime_expression_ctx(span)?; rctx.block.push( - crate::Statement::SubgroupQuadSwap { - direction: crate::Direction::X, + crate::Statement::SubgroupGather { + mode: crate::GatherMode::QuadSwap(crate::Direction::X), argument, result, }, @@ -2980,8 +2980,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { )?; let rctx = ctx.runtime_expression_ctx(span)?; rctx.block.push( - crate::Statement::SubgroupQuadSwap { - direction: crate::Direction::Y, + crate::Statement::SubgroupGather { + mode: crate::GatherMode::QuadSwap(crate::Direction::Y), argument, result, }, @@ -3004,8 +3004,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { )?; let rctx = ctx.runtime_expression_ctx(span)?; rctx.block.push( - crate::Statement::SubgroupQuadSwap { - direction: crate::Direction::Diagonal, + crate::Statement::SubgroupGather { + mode: crate::GatherMode::QuadSwap(crate::Direction::Diagonal), argument, result, }, diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index 24185c0b507..34258bccd21 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -1279,8 +1279,20 @@ pub enum GatherMode { ShuffleUp(Handle), /// Each gathers from their lane xored with the given by the expression ShuffleXor(Handle), - /// All gather from the same lane at the index given by the expression + /// All gather from the same quad lane at the index given by the expression QuadBroadcast(Handle), + /// Each gathers from the opposite quad lane along the given direction + QuadSwap(Direction), +} + +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum Direction { + X = 0, + Y = 1, + Diagonal = 2, } #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] @@ -2101,26 +2113,6 @@ pub enum Statement { /// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult result: Handle, }, - SubgroupQuadSwap { - /// In which direction to swap - direction: Direction, - /// The value to swap over - argument: Handle, - /// The [`SubgroupOperationResult`] expression representing this load's result. - /// - /// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult - result: Handle, - } -} - -#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] -#[cfg_attr(feature = "serialize", derive(Serialize))] -#[cfg_attr(feature = "deserialize", derive(Deserialize))] -#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] -pub enum Direction { - X = 0, - Y = 1, - Diagonal = 2, } /// A function argument. diff --git a/naga/src/proc/terminator.rs b/naga/src/proc/terminator.rs index 06120df835c..f22e61e6a6d 100644 --- a/naga/src/proc/terminator.rs +++ b/naga/src/proc/terminator.rs @@ -42,7 +42,6 @@ pub fn ensure_block_returns(block: &mut crate::Block) { | S::SubgroupBallot { .. } | S::SubgroupCollectiveOperation { .. } | S::SubgroupGather { .. } - | S::SubgroupQuadSwap { .. } | S::Barrier(_)), ) | None => block.push(S::Return { value: None }, Default::default()), diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index ba8b7b80507..435e6b9fd57 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -1146,17 +1146,10 @@ impl FunctionInfo { | crate::GatherMode::QuadBroadcast(index) => { let _ = self.add_ref(index); } + crate::GatherMode::QuadSwap(_) => {} } FunctionUniformity::new() } - S::SubgroupQuadSwap { - direction: _, - argument, - result: _, - } => { - let _ = self.add_ref(argument); - FunctionUniformity::new() - } }; disruptor = disruptor.or(uniformity.exit_disruptor()); diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index e03bb940d6d..d38d821810a 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -721,6 +721,7 @@ impl super::Validator { } } } + crate::GatherMode::QuadSwap(_) => {} } let argument_inner = context.resolve_type_inner(argument, &self.valid_expression_set)?; if !matches!(*argument_inner, @@ -745,35 +746,6 @@ impl super::Validator { } Ok(()) } - fn validate_subgroup_quad_swap( - &mut self, - argument: Handle, - result: Handle, - context: &BlockContext, - ) -> Result<(), WithSpan> { - let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; - if !matches!(*argument_inner, - crate::TypeInner::Scalar ( scalar, .. ) | crate::TypeInner::Vector { scalar, .. } - if matches!(scalar.kind, crate::ScalarKind::Uint | crate::ScalarKind::Sint | crate::ScalarKind::Float) - ) { - log::error!("Subgroup quad swap operand type {:?}", argument_inner); - return Err(SubgroupError::InvalidOperand(argument) - .with_span_handle(argument, context.expressions) - .into_other()); - } - - self.emit_expression(result, context)?; - match context.expressions[result] { - crate::Expression::SubgroupOperationResult { ty } - if { &context.types[ty].inner == argument_inner } => {} - _ => { - return Err(SubgroupError::ResultTypeMismatch(result) - .with_span_handle(result, context.expressions) - .into_other()) - } - } - Ok(()) - } fn validate_block_impl( &mut self, @@ -1646,31 +1618,6 @@ impl super::Validator { } self.validate_subgroup_gather(mode, argument, result, context)?; } - S::SubgroupQuadSwap { - direction: _, - argument, - result, - } => { - stages &= self.subgroup_stages; - if !self.capabilities.contains(super::Capabilities::SUBGROUP) { - return Err(FunctionError::MissingCapability( - super::Capabilities::SUBGROUP, - ) - .with_span_static(span, "missing capability for this operation")); - } - if !self - .subgroup_operations - .contains(super::SubgroupOperationSet::QUAD_FRAGMENT_COMPUTE) - { - return Err(FunctionError::InvalidSubgroup( - SubgroupError::UnsupportedOperation( - super::SubgroupOperationSet::QUAD_FRAGMENT_COMPUTE, - ), - ) - .with_span_static(span, "support for this operation is not present")); - } - self.validate_subgroup_quad_swap(argument, result, context)?; - } } } Ok(BlockInfo { stages, finished }) diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index a84b6a3d86b..79d183ddd7d 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -742,19 +742,11 @@ impl super::Validator { | crate::GatherMode::ShuffleUp(index) | crate::GatherMode::ShuffleXor(index) | crate::GatherMode::QuadBroadcast(index) => validate_expr(index)?, + crate::GatherMode::QuadSwap(_) => {} } validate_expr(result)?; Ok(()) } - crate::Statement::SubgroupQuadSwap { - direction: _, - argument, - result, - } => { - validate_expr(argument)?; - validate_expr(result)?; - Ok(()) - } crate::Statement::Break | crate::Statement::Continue | crate::Statement::Kill diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index 460980e4fe0..680ad3a18d1 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -221,7 +221,7 @@ impl super::GatherMode { Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT, Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE, Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE, - Self::QuadBroadcast(_) => S::QUAD_FRAGMENT_COMPUTE, + Self::QuadBroadcast(_) | Self::QuadSwap(_) => S::QUAD_FRAGMENT_COMPUTE, } } } From 409aba6c83a857bc40c41d8c082b94918efbc60d Mon Sep 17 00:00:00 2001 From: Dmitry Zamkov Date: Sun, 11 May 2025 12:43:05 -0500 Subject: [PATCH 07/10] Add `GroupNonUniformQuad` spv capability to Vulkan --- wgpu-hal/src/vulkan/adapter.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/wgpu-hal/src/vulkan/adapter.rs b/wgpu-hal/src/vulkan/adapter.rs index 750985b4ca3..2541150259c 100644 --- a/wgpu-hal/src/vulkan/adapter.rs +++ b/wgpu-hal/src/vulkan/adapter.rs @@ -1951,6 +1951,7 @@ impl super::Adapter { capabilities.push(spv::Capability::GroupNonUniformBallot); capabilities.push(spv::Capability::GroupNonUniformShuffle); capabilities.push(spv::Capability::GroupNonUniformShuffleRelative); + capabilities.push(spv::Capability::GroupNonUniformQuad); } if features.intersects( From 89ba81151267901f1c656fd976a5958f2d87769b Mon Sep 17 00:00:00 2001 From: Dmitry Zamkov Date: Mon, 19 May 2025 17:56:26 -0500 Subject: [PATCH 08/10] Adding GPU tests for quad operations --- tests/tests/wgpu-gpu/subgroup_operations/mod.rs | 12 ++++++------ .../wgpu-gpu/subgroup_operations/shader.wgsl | 16 +++++++++++----- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/tests/tests/wgpu-gpu/subgroup_operations/mod.rs b/tests/tests/wgpu-gpu/subgroup_operations/mod.rs index 25fddf120db..7e50ea5051b 100644 --- a/tests/tests/wgpu-gpu/subgroup_operations/mod.rs +++ b/tests/tests/wgpu-gpu/subgroup_operations/mod.rs @@ -3,7 +3,7 @@ use std::num::NonZeroU64; use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters}; const THREAD_COUNT: u64 = 128; -const TEST_COUNT: u32 = 32; +const TEST_COUNT: u32 = 37; #[gpu_test] static SUBGROUP_OPERATIONS: GpuTestConfiguration = GpuTestConfiguration::new() @@ -35,7 +35,7 @@ static SUBGROUP_OPERATIONS: GpuTestConfiguration = GpuTestConfiguration::new() let storage_buffer = device.create_buffer(&wgpu::BufferDescriptor { label: None, - size: THREAD_COUNT * size_of::() as u64, + size: THREAD_COUNT * size_of::() as u64, usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::COPY_SRC, @@ -50,7 +50,7 @@ static SUBGROUP_OPERATIONS: GpuTestConfiguration = GpuTestConfiguration::new() ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: false }, has_dynamic_offset: false, - min_binding_size: NonZeroU64::new(THREAD_COUNT * size_of::() as u64), + min_binding_size: NonZeroU64::new(THREAD_COUNT * size_of::() as u64), }, count: None, }], @@ -101,10 +101,10 @@ static SUBGROUP_OPERATIONS: GpuTestConfiguration = GpuTestConfiguration::new() &storage_buffer.slice(..), |mapping_buffer_view| { let mapping_buffer_view = mapping_buffer_view.unwrap(); - let result: &[u32; THREAD_COUNT as usize] = + let result: &[u64; THREAD_COUNT as usize] = bytemuck::from_bytes(&mapping_buffer_view); let expected_mask = (1u64 << (TEST_COUNT)) - 1; // generate full mask - let expected_array = [expected_mask as u32; THREAD_COUNT as usize]; + let expected_array = [expected_mask; THREAD_COUNT as usize]; if result != &expected_array { use std::fmt::Write; let mut msg = String::new(); @@ -122,7 +122,7 @@ static SUBGROUP_OPERATIONS: GpuTestConfiguration = GpuTestConfiguration::new() { write!(&mut msg, "thread {thread} failed tests:").unwrap(); let difference = result ^ expected; - for i in (0..u32::BITS).filter(|i| (difference & (1 << i)) != 0) { + for i in (0..u64::BITS).filter(|i| (difference & (1 << i)) != 0) { write!(&mut msg, " {i},").unwrap(); } writeln!(&mut msg).unwrap(); diff --git a/tests/tests/wgpu-gpu/subgroup_operations/shader.wgsl b/tests/tests/wgpu-gpu/subgroup_operations/shader.wgsl index 77cb81ce750..454f35ea988 100644 --- a/tests/tests/wgpu-gpu/subgroup_operations/shader.wgsl +++ b/tests/tests/wgpu-gpu/subgroup_operations/shader.wgsl @@ -1,11 +1,11 @@ @group(0) @binding(0) -var storage_buffer: array; +var storage_buffer: array>; var workgroup_buffer: u32; -fn add_result_to_mask(mask: ptr, index: u32, value: bool) { - (*mask) |= u32(value) << index; +fn add_result_to_mask(mask: ptr>, index: u32, value: bool) { + (*mask)[index / 32u] |= u32(value) << (index % 32u); } @compute @@ -17,7 +17,7 @@ fn main( @builtin(subgroup_size) subgroup_size: u32, @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, ) { - var passed = 0u; + var passed = vec2(0u); var expected: u32; add_result_to_mask(&passed, 0u, num_subgroups == 128u / subgroup_size); @@ -152,8 +152,14 @@ fn main( workgroupBarrier(); add_result_to_mask(&passed, 30u, workgroup_buffer == subgroup_size); + add_result_to_mask(&passed, 31u, quadBroadcast(quadBroadcast(subgroup_invocation_id, 0u) ^ subgroup_invocation_id, 0u) == 0u); + add_result_to_mask(&passed, 32u, quadBroadcast(quadBroadcast(subgroup_invocation_id, 1u) ^ quadSwapX(subgroup_invocation_id), 0u) == 0u); + add_result_to_mask(&passed, 33u, quadBroadcast(quadBroadcast(subgroup_invocation_id, 2u) ^ quadSwapY(subgroup_invocation_id), 0u) == 0u); + add_result_to_mask(&passed, 34u, quadBroadcast(quadBroadcast(subgroup_invocation_id, 3u) ^ quadSwapDiagonal(subgroup_invocation_id), 0u) == 0u); + add_result_to_mask(&passed, 35u, quadSwapX(quadSwapY(subgroup_invocation_id)) == quadSwapDiagonal(subgroup_invocation_id)); + // Keep this test last, verify we are still convergent after running other tests - add_result_to_mask(&passed, 31u, subgroupAdd(1u) == subgroup_size); + add_result_to_mask(&passed, 36u, subgroupAdd(1u) == subgroup_size); // Increment TEST_COUNT in subgroup_operations/mod.rs if adding more tests From 3ad6074140304178ed987e9efbecbddb87457862 Mon Sep 17 00:00:00 2001 From: Dmitry Zamkov Date: Fri, 23 May 2025 19:47:11 -0500 Subject: [PATCH 09/10] Validate that broadcast operations use const invocation ids --- naga/src/valid/function.rs | 17 ++++++++++++++++- naga/src/valid/mod.rs | 2 +- naga/tests/naga/wgsl_errors.rs | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 2 deletions(-) diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index d38d821810a..9bb7f825c9a 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -73,6 +73,8 @@ pub enum SubgroupError { UnsupportedOperation(super::SubgroupOperationSet), #[error("Unknown operation")] UnknownOperation, + #[error("Invocation ID must be a const-expression")] + InvalidInvocationIdExprType(Handle), } #[derive(Clone, Debug, thiserror::Error)] @@ -248,6 +250,7 @@ struct BlockContext<'a> { special_types: &'a crate::SpecialTypes, prev_infos: &'a [FunctionInfo], return_type: Option>, + local_expr_kind: &'a crate::proc::ExpressionKindTracker, } impl<'a> BlockContext<'a> { @@ -256,6 +259,7 @@ impl<'a> BlockContext<'a> { module: &'a crate::Module, info: &'a FunctionInfo, prev_infos: &'a [FunctionInfo], + local_expr_kind: &'a crate::proc::ExpressionKindTracker, ) -> Self { Self { abilities: ControlFlowAbility::RETURN, @@ -268,6 +272,7 @@ impl<'a> BlockContext<'a> { special_types: &module.special_types, prev_infos, return_type: fun.result.as_ref().map(|fr| fr.ty), + local_expr_kind, } } @@ -723,6 +728,16 @@ impl super::Validator { } crate::GatherMode::QuadSwap(_) => {} } + match *mode { + crate::GatherMode::Broadcast(index) | crate::GatherMode::QuadBroadcast(index) => { + if !context.local_expr_kind.is_const(index) { + return Err(SubgroupError::InvalidInvocationIdExprType(index) + .with_span_handle(index, context.expressions) + .into_other()); + } + } + _ => {} + } let argument_inner = context.resolve_type_inner(argument, &self.valid_expression_set)?; if !matches!(*argument_inner, crate::TypeInner::Scalar ( scalar, .. ) | crate::TypeInner::Vector { scalar, .. } @@ -1774,7 +1789,7 @@ impl super::Validator { let stages = self .validate_block( &fun.body, - &BlockContext::new(fun, module, &info, &mod_info.functions), + &BlockContext::new(fun, module, &info, &mod_info.functions, &local_expr_kind), )? .stages; info.available_stages &= stages; diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index 680ad3a18d1..aef6a241646 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -29,7 +29,7 @@ pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, Uniformi pub use compose::ComposeError; pub use expression::{check_literal_value, LiteralError}; pub use expression::{ConstExpressionError, ExpressionError}; -pub use function::{CallError, FunctionError, LocalVariableError}; +pub use function::{CallError, FunctionError, LocalVariableError, SubgroupError}; pub use interface::{EntryPointError, GlobalVariableError, VaryingError}; pub use r#type::{Disalignment, PushConstantError, TypeError, TypeFlags, WidthError}; diff --git a/naga/tests/naga/wgsl_errors.rs b/naga/tests/naga/wgsl_errors.rs index 71e5b871715..c4db425a9e6 100644 --- a/naga/tests/naga/wgsl_errors.rs +++ b/naga/tests/naga/wgsl_errors.rs @@ -3596,3 +3596,35 @@ fn const_eval_value_errors() { assert!(variant("f32(abs(-9223372036854775807))").is_ok()); assert!(variant("f32(abs(-9223372036854775807 - 1))").is_ok()); } + +#[test] +fn subgroup_invalid_broadcast() { + check_validation! { + r#" + fn main(id: u32) { + subgroupBroadcast(123, id); + } + "#: + Err(naga::valid::ValidationError::Function { + source: naga::valid::FunctionError::InvalidSubgroup( + naga::valid::SubgroupError::InvalidInvocationIdExprType(_), + ), + .. + }), + naga::valid::Capabilities::SUBGROUP + } + check_validation! { + r#" + fn main(id: u32) { + quadBroadcast(123, id); + } + "#: + Err(naga::valid::ValidationError::Function { + source: naga::valid::FunctionError::InvalidSubgroup( + naga::valid::SubgroupError::InvalidInvocationIdExprType(_), + ), + .. + }), + naga::valid::Capabilities::SUBGROUP + } +} From 71c4d1a5abe0c18019ae97221e458d7e4687687c Mon Sep 17 00:00:00 2001 From: Dmitry Zamkov Date: Fri, 23 May 2025 19:48:20 -0500 Subject: [PATCH 10/10] Added changelog entry --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 951238009ae..c6bed86d80e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -50,6 +50,7 @@ Bottom level categories: #### Naga - When emitting GLSL, Uniform and Storage Buffer memory layouts are now emitted even if no explicit binding is given. By @cloone8 in [#7579](https://github.com/gfx-rs/wgpu/pull/7579). +- Add support for [quad operations](https://www.w3.org/TR/WGSL/#quad-builtin-functions) (requires `SUBGROUP` feature to be enabled). By @dzamkov and @valaphee in [#7683](https://github.com/gfx-rs/wgpu/pull/7683). ### Bug Fixes