diff --git a/CHANGELOG.md b/CHANGELOG.md index 9bab4faf421..95c000c9b45 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -79,6 +79,7 @@ Naga now infers the correct binding layout when a resource appears only in an as #### Naga - Mark `readonly_and_readwrite_storage_textures` & `packed_4x8_integer_dot_product` language extensions as implemented. By @teoxoy in [#7543](https://github.com/gfx-rs/wgpu/pull/7543) +- `naga::back::hlsl::Writer::new` has a new `pipeline_options` argument. `hlsl::PipelineOptions::default()` can be passed as a default. The `shader_stage` and `entry_point` members of `pipeline_options` can be used to write only a single entry point when using the HLSL and MSL backends (GLSL and SPIR-V already had this functionality). The Metal and DX12 HALs now write only a single entry point when loading shaders. By @andyleiserson in [#7626](https://github.com/gfx-rs/wgpu/pull/7626). #### D3D12 diff --git a/benches/benches/wgpu-benchmark/shader.rs b/benches/benches/wgpu-benchmark/shader.rs index 71552d8be6e..0abcd1cd322 100644 --- a/benches/benches/wgpu-benchmark/shader.rs +++ b/benches/benches/wgpu-benchmark/shader.rs @@ -349,7 +349,9 @@ fn backends(c: &mut Criterion) { let options = naga::back::hlsl::Options::default(); let mut string = String::new(); for input in &inputs.inner { - let mut writer = naga::back::hlsl::Writer::new(&mut string, &options); + let pipeline_options = Default::default(); + let mut writer = + naga::back::hlsl::Writer::new(&mut string, &options, &pipeline_options); let _ = writer.write( input.module.as_ref().unwrap(), input.module_info.as_ref().unwrap(), diff --git a/naga-cli/src/bin/naga.rs b/naga-cli/src/bin/naga.rs index a39640dc3e0..6f95e429f68 100644 --- a/naga-cli/src/bin/naga.rs +++ b/naga-cli/src/bin/naga.rs @@ -824,7 +824,8 @@ fn write_output( .unwrap_pretty(); let mut buffer = String::new(); - let mut writer = hlsl::Writer::new(&mut buffer, ¶ms.hlsl); + let pipeline_options = Default::default(); + let mut writer = hlsl::Writer::new(&mut buffer, ¶ms.hlsl, &pipeline_options); writer.write(&module, &info, None).unwrap_pretty(); fs::write(output_path, buffer)?; } diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index cfb8260b984..52a47487ea2 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -349,7 +349,8 @@ pub struct PipelineOptions { pub shader_stage: ShaderStage, /// The name of the entry point. /// - /// If no entry point that matches is found while creating a [`Writer`], a error will be thrown. + /// If no entry point that matches is found while creating a [`Writer`], an + /// error will be thrown. pub entry_point: String, /// How many views to render to, if doing multiview rendering. pub multiview: Option, diff --git a/naga/src/back/hlsl/mod.rs b/naga/src/back/hlsl/mod.rs index 9e041ff73f8..ec6b3a25c07 100644 --- a/naga/src/back/hlsl/mod.rs +++ b/naga/src/back/hlsl/mod.rs @@ -119,7 +119,7 @@ use core::fmt::Error as FmtError; use thiserror::Error; -use crate::{back, proc}; +use crate::{back, ir, proc}; #[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] @@ -434,6 +434,22 @@ pub struct ReflectionInfo { pub entry_point_names: Vec>, } +/// A subset of options that are meant to be changed per pipeline. +#[derive(Debug, Default, Clone)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +#[cfg_attr(feature = "deserialize", serde(default))] +pub struct PipelineOptions { + /// The entry point to write. + /// + /// Entry points are identified by a shader stage specification, + /// and a name. + /// + /// If `None`, all entry points will be written. If `Some` and the entry + /// point is not found, an error will be thrown while writing. + pub entry_point: Option<(ir::ShaderStage, String)>, +} + #[derive(Error, Debug)] pub enum Error { #[error(transparent)] @@ -448,6 +464,8 @@ pub enum Error { Override, #[error(transparent)] ResolveArraySizeError(#[from] proc::ResolveArraySizeError), + #[error("entry point with stage {0:?} and name '{1}' not found")] + EntryPointNotFound(ir::ShaderStage, String), } #[derive(PartialEq, Eq, Hash)] @@ -519,8 +537,10 @@ pub struct Writer<'a, W> { namer: proc::Namer, /// HLSL backend options options: &'a Options, + /// Per-stage backend options + pipeline_options: &'a PipelineOptions, /// Information about entry point arguments and result types. - entry_point_io: Vec, + entry_point_io: crate::FastHashMap, /// Set of expressions that have associated temporary variables named_expressions: crate::NamedExpressions, wrapped: Wrapped, diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 59725df3db3..bb90db78593 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -12,10 +12,10 @@ use super::{ WrappedZeroValue, }, storage::StoreValue, - BackendResult, Error, FragmentEntryPoint, Options, ShaderModel, + BackendResult, Error, FragmentEntryPoint, Options, PipelineOptions, ShaderModel, }; use crate::{ - back::{self, Baked}, + back::{self, get_entry_points, Baked}, common, proc::{self, index, NameKey}, valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner, @@ -123,13 +123,14 @@ struct BindingArraySamplerInfo { } impl<'a, W: fmt::Write> super::Writer<'a, W> { - pub fn new(out: W, options: &'a Options) -> Self { + pub fn new(out: W, options: &'a Options, pipeline_options: &'a PipelineOptions) -> Self { Self { out, names: crate::FastHashMap::default(), namer: proc::Namer::default(), options, - entry_point_io: Vec::new(), + pipeline_options, + entry_point_io: crate::FastHashMap::default(), named_expressions: crate::NamedExpressions::default(), wrapped: super::Wrapped::default(), written_committed_intersection: false, @@ -387,8 +388,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out)?; } + let ep_range = get_entry_points(module, self.pipeline_options.entry_point.as_ref()) + .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?; + // Write all entry points wrapped structs - for (index, ep) in module.entry_points.iter().enumerate() { + for index in ep_range.clone() { + let ep = &module.entry_points[index]; let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone(); let ep_io = self.write_ep_interface( module, @@ -397,7 +402,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { &ep_name, fragment_entry_point, )?; - self.entry_point_io.push(ep_io); + self.entry_point_io.insert(index, ep_io); } // Write all regular functions @@ -442,10 +447,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out)?; } - let mut entry_point_names = Vec::with_capacity(module.entry_points.len()); + let mut translated_ep_names = Vec::with_capacity(ep_range.len()); // Write all entry points - for (index, ep) in module.entry_points.iter().enumerate() { + for index in ep_range { + let ep = &module.entry_points[index]; let info = module_info.get_entry_point(index); if !self.options.fake_missing_bindings { @@ -462,7 +468,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } } if let Some(err) = ep_error { - entry_point_names.push(Err(err)); + translated_ep_names.push(Err(err)); continue; } } @@ -493,10 +499,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out)?; } - entry_point_names.push(Ok(name)); + translated_ep_names.push(Ok(name)); } - Ok(super::ReflectionInfo { entry_point_names }) + Ok(super::ReflectionInfo { + entry_point_names: translated_ep_names, + }) } fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult { @@ -816,7 +824,13 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { ep_index: u16, ) -> BackendResult { let ep = &module.entry_points[ep_index as usize]; - let ep_input = match self.entry_point_io[ep_index as usize].input.take() { + let ep_input = match self + .entry_point_io + .get_mut(&(ep_index as usize)) + .unwrap() + .input + .take() + { Some(ep_input) => ep_input, None => return Ok(()), }; @@ -1432,7 +1446,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } } back::FunctionType::EntryPoint(index) => { - if let Some(ref ep_output) = self.entry_point_io[index as usize].output { + if let Some(ref ep_output) = + self.entry_point_io.get(&(index as usize)).unwrap().output + { write!(self.out, "{}", ep_output.ty_name)?; } else { self.write_type(module, result.ty)?; @@ -1479,7 +1495,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } } back::FunctionType::EntryPoint(ep_index) => { - if let Some(ref ep_input) = self.entry_point_io[ep_index as usize].input { + if let Some(ref ep_input) = + self.entry_point_io.get(&(ep_index as usize)).unwrap().input + { write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?; } else { let stage = module.entry_points[ep_index as usize].stage; @@ -1501,7 +1519,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } } if need_workgroup_variables_initialization { - if self.entry_point_io[ep_index as usize].input.is_some() + if self + .entry_point_io + .get(&(ep_index as usize)) + .unwrap() + .input + .is_some() || !func.arguments.is_empty() { write!(self.out, ", ")?; @@ -1870,9 +1893,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { // for entry point returns, we may need to reshuffle the outputs into a different struct let ep_output = match func_ctx.ty { back::FunctionType::Function(_) => None, - back::FunctionType::EntryPoint(index) => { - self.entry_point_io[index as usize].output.as_ref() - } + back::FunctionType::EntryPoint(index) => self + .entry_point_io + .get(&(index as usize)) + .unwrap() + .output + .as_ref(), }; let final_name = match ep_output { Some(ep_output) => { diff --git a/naga/src/back/mod.rs b/naga/src/back/mod.rs index 8eee9b6ff69..175c5481b33 100644 --- a/naga/src/back/mod.rs +++ b/naga/src/back/mod.rs @@ -79,6 +79,33 @@ impl core::fmt::Display for Level { } } +/// Locate the entry point(s) to write. +/// +/// If `entry_point` is given, and the specified entry point exists, returns a +/// length-1 `Range` containing the index of that entry point. If no +/// `entry_point` is given, returns the complete range of entry point indices. +/// If `entry_point` is given but does not exist, returns an error. +#[cfg(any(hlsl_out, msl_out))] +fn get_entry_points( + module: &crate::ir::Module, + entry_point: Option<&(crate::ir::ShaderStage, String)>, +) -> Result, (crate::ir::ShaderStage, String)> { + use alloc::borrow::ToOwned; + + if let Some(&(stage, ref name)) = entry_point { + let Some(ep_index) = module + .entry_points + .iter() + .position(|ep| ep.stage == stage && ep.name == *name) + else { + return Err((stage, name.to_owned())); + }; + Ok(ep_index..ep_index + 1) + } else { + Ok(0..module.entry_points.len()) + } +} + /// Whether we're generating an entry point or a regular function. /// /// Backend languages often require different code for a [`Function`] diff --git a/naga/src/back/msl/mod.rs b/naga/src/back/msl/mod.rs index 01ac1ac419b..7bc8289b9b8 100644 --- a/naga/src/back/msl/mod.rs +++ b/naga/src/back/msl/mod.rs @@ -52,7 +52,7 @@ use alloc::{ }; use core::fmt::{Error as FmtError, Write}; -use crate::{arena::Handle, proc::index, valid::ModuleInfo}; +use crate::{arena::Handle, ir, proc::index, valid::ModuleInfo}; mod keywords; pub mod sampler; @@ -184,7 +184,7 @@ pub enum Error { #[error("can not use writeable storage buffers in fragment stage prior to MSL 1.2")] UnsupportedWriteableStorageBuffer, #[error("can not use writeable storage textures in {0:?} stage prior to MSL 1.2")] - UnsupportedWriteableStorageTexture(crate::ShaderStage), + UnsupportedWriteableStorageTexture(ir::ShaderStage), #[error("can not use read-write storage textures prior to MSL 1.2")] UnsupportedRWStorageTexture, #[error("array of '{0}' is not supported for target MSL version")] @@ -199,6 +199,8 @@ pub enum Error { UnsupportedBitCast(crate::TypeInner), #[error(transparent)] ResolveArraySizeError(#[from] crate::proc::ResolveArraySizeError), + #[error("entry point with stage {0:?} and name '{1}' not found")] + EntryPointNotFound(ir::ShaderStage, String), } #[derive(Clone, Debug, PartialEq, thiserror::Error)] @@ -420,6 +422,15 @@ pub struct VertexBufferMapping { #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] #[cfg_attr(feature = "deserialize", serde(default))] pub struct PipelineOptions { + /// The entry point to write. + /// + /// Entry points are identified by a shader stage specification, + /// and a name. + /// + /// If `None`, all entry points will be written. If `Some` and the entry + /// point is not found, an error will be thrown while writing. + pub entry_point: Option<(ir::ShaderStage, String)>, + /// Allow `BuiltIn::PointSize` and inject it if doesn't exist. /// /// Metal doesn't like this for non-point primitive topologies and requires it for @@ -737,5 +748,5 @@ pub fn write_string( #[test] fn test_error_size() { - assert_eq!(size_of::(), 32); + assert_eq!(size_of::(), 40); } diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index f1cef93db8f..f05e5c233aa 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -16,7 +16,7 @@ use half::f16; use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo}; use crate::{ arena::{Handle, HandleSet}, - back::{self, Baked}, + back::{self, get_entry_points, Baked}, common, proc::{ self, @@ -5872,10 +5872,15 @@ template self.named_expressions.clear(); } + let ep_range = get_entry_points(module, pipeline_options.entry_point.as_ref()) + .map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?; + let mut info = TranslationInfo { - entry_point_names: Vec::with_capacity(module.entry_points.len()), + entry_point_names: Vec::with_capacity(ep_range.len()), }; - for (ep_index, ep) in module.entry_points.iter().enumerate() { + + for ep_index in ep_range { + let ep = &module.entry_points[ep_index]; let fun = &ep.function; let fun_info = mod_info.get_entry_point(ep_index); let mut ep_error = None; @@ -7076,8 +7081,8 @@ fn test_stack_size() { } let stack_size = addresses_end - addresses_start; // check the size (in debug only) - // last observed macOS value: 20528 (CI) - if !(11000..=25000).contains(&stack_size) { + // last observed macOS value: 25904 (CI), 2025-04-29 + if !(11000..=27000).contains(&stack_size) { panic!("`put_expression` stack size {stack_size} has changed!"); } } diff --git a/naga/tests/naga/snapshots.rs b/naga/tests/naga/snapshots.rs index e5c066c9e94..931136ed8d2 100644 --- a/naga/tests/naga/snapshots.rs +++ b/naga/tests/naga/snapshots.rs @@ -741,7 +741,8 @@ fn write_output_hlsl( .expect("override evaluation failed"); let mut buffer = String::new(); - let mut writer = hlsl::Writer::new(&mut buffer, options); + let pipeline_options = Default::default(); + let mut writer = hlsl::Writer::new(&mut buffer, options, &pipeline_options); let reflection_info = writer .write(&module, &info, frag_ep.as_ref()) .expect("HLSL write failed"); diff --git a/wgpu-hal/src/dx12/device.rs b/wgpu-hal/src/dx12/device.rs index 213f64f1d5a..3d8100b9c0b 100644 --- a/wgpu-hal/src/dx12/device.rs +++ b/wgpu-hal/src/dx12/device.rs @@ -299,9 +299,13 @@ impl super::Device { &layout.naga_options }; + let pipeline_options = hlsl::PipelineOptions { + entry_point: Some((naga_stage, stage.entry_point.to_string())), + }; + //TODO: reuse the writer let mut source = String::new(); - let mut writer = hlsl::Writer::new(&mut source, naga_options); + let mut writer = hlsl::Writer::new(&mut source, naga_options, &pipeline_options); let reflection_info = { profiling::scope!("naga::back::hlsl::write"); writer @@ -315,13 +319,7 @@ impl super::Device { naga_options.shader_model.to_str() ); - let ep_index = module - .entry_points - .iter() - .position(|ep| ep.stage == naga_stage && ep.name == stage.entry_point) - .ok_or(crate::PipelineError::EntryPoint(naga_stage))?; - - let raw_ep = reflection_info.entry_point_names[ep_index] + let raw_ep = reflection_info.entry_point_names[0] .as_ref() .map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?; diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index 3041862c652..6ab22b0c3e3 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -181,6 +181,7 @@ impl super::Device { }; let pipeline_options = naga::back::msl::PipelineOptions { + entry_point: Some((naga_stage, stage.entry_point.to_owned())), allow_and_force_point_size: match primitive_class { MTLPrimitiveTopologyClass::Point => true, _ => false, @@ -223,7 +224,7 @@ impl super::Device { .position(|ep| ep.stage == naga_stage && ep.name == stage.entry_point) .ok_or(crate::PipelineError::EntryPoint(naga_stage))?; let ep = &module.entry_points[ep_index]; - let ep_name = info.entry_point_names[ep_index] + let translated_ep_name = info.entry_point_names[0] .as_ref() .map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{}", e)))?; @@ -233,10 +234,12 @@ impl super::Device { depth: ep.workgroup_size[2] as _, }; - let function = library.get_function(ep_name, None).map_err(|e| { - log::error!("get_function: {:?}", e); - crate::PipelineError::EntryPoint(naga_stage) - })?; + let function = library + .get_function(translated_ep_name, None) + .map_err(|e| { + log::error!("get_function: {:?}", e); + crate::PipelineError::EntryPoint(naga_stage) + })?; // collect sizes indices, immutable buffers, and work group memory sizes let ep_info = &module_info.get_entry_point(ep_index);