Skip to content

[naga] Write only the current entrypoint #7626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ Naga now infers the correct binding layout when a resource appears only in an as
#### Naga

- Mark `readonly_and_readwrite_storage_textures` & `packed_4x8_integer_dot_product` language extensions as implemented. By @teoxoy in [#7543](https://github.com/gfx-rs/wgpu/pull/7543)
- `naga::back::hlsl::Writer::new` has a new `pipeline_options` argument. `hlsl::PipelineOptions::default()` can be passed as a default. The `shader_stage` and `entry_point` members of `pipeline_options` can be used to write only a single entry point when using the HLSL and MSL backends (GLSL and SPIR-V already had this functionality). The Metal and DX12 HALs now write only a single entry point when loading shaders. By @andyleiserson in [#7626](https://github.com/gfx-rs/wgpu/pull/7626).

#### D3D12

Expand Down
4 changes: 3 additions & 1 deletion benches/benches/wgpu-benchmark/shader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,9 @@ fn backends(c: &mut Criterion) {
let options = naga::back::hlsl::Options::default();
let mut string = String::new();
for input in &inputs.inner {
let mut writer = naga::back::hlsl::Writer::new(&mut string, &options);
let pipeline_options = Default::default();
let mut writer =
naga::back::hlsl::Writer::new(&mut string, &options, &pipeline_options);
let _ = writer.write(
input.module.as_ref().unwrap(),
input.module_info.as_ref().unwrap(),
Expand Down
3 changes: 2 additions & 1 deletion naga-cli/src/bin/naga.rs
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,8 @@ fn write_output(
.unwrap_pretty();

let mut buffer = String::new();
let mut writer = hlsl::Writer::new(&mut buffer, &params.hlsl);
let pipeline_options = Default::default();
let mut writer = hlsl::Writer::new(&mut buffer, &params.hlsl, &pipeline_options);
writer.write(&module, &info, None).unwrap_pretty();
fs::write(output_path, buffer)?;
}
Expand Down
3 changes: 2 additions & 1 deletion naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,8 @@ pub struct PipelineOptions {
pub shader_stage: ShaderStage,
/// The name of the entry point.
///
/// If no entry point that matches is found while creating a [`Writer`], a error will be thrown.
/// If no entry point that matches is found while creating a [`Writer`], an
/// error will be thrown.
pub entry_point: String,
/// How many views to render to, if doing multiview rendering.
pub multiview: Option<core::num::NonZeroU32>,
Expand Down
24 changes: 22 additions & 2 deletions naga/src/back/hlsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ use core::fmt::Error as FmtError;

use thiserror::Error;

use crate::{back, proc};
use crate::{back, ir, proc};

#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
Expand Down Expand Up @@ -434,6 +434,22 @@ pub struct ReflectionInfo {
pub entry_point_names: Vec<Result<String, EntryPointError>>,
}

/// A subset of options that are meant to be changed per pipeline.
#[derive(Debug, Default, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[cfg_attr(feature = "deserialize", serde(default))]
pub struct PipelineOptions {
/// The entry point to write.
///
/// Entry points are identified by a shader stage specification,
/// and a name.
///
/// If `None`, all entry points will be written. If `Some` and the entry
/// point is not found, an error will be thrown while writing.
pub entry_point: Option<(ir::ShaderStage, String)>,
}

#[derive(Error, Debug)]
pub enum Error {
#[error(transparent)]
Expand All @@ -448,6 +464,8 @@ pub enum Error {
Override,
#[error(transparent)]
ResolveArraySizeError(#[from] proc::ResolveArraySizeError),
#[error("entry point with stage {0:?} and name '{1}' not found")]
EntryPointNotFound(ir::ShaderStage, String),
}

#[derive(PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -519,8 +537,10 @@ pub struct Writer<'a, W> {
namer: proc::Namer,
/// HLSL backend options
options: &'a Options,
/// Per-stage backend options
pipeline_options: &'a PipelineOptions,
/// Information about entry point arguments and result types.
entry_point_io: Vec<writer::EntryPointInterface>,
entry_point_io: crate::FastHashMap<usize, writer::EntryPointInterface>,
/// Set of expressions that have associated temporary variables
named_expressions: crate::NamedExpressions,
wrapped: Wrapped,
Expand Down
62 changes: 44 additions & 18 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ use super::{
WrappedZeroValue,
},
storage::StoreValue,
BackendResult, Error, FragmentEntryPoint, Options, ShaderModel,
BackendResult, Error, FragmentEntryPoint, Options, PipelineOptions, ShaderModel,
};
use crate::{
back::{self, Baked},
back::{self, get_entry_points, Baked},
common,
proc::{self, index, NameKey},
valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner,
Expand Down Expand Up @@ -123,13 +123,14 @@ struct BindingArraySamplerInfo {
}

impl<'a, W: fmt::Write> super::Writer<'a, W> {
pub fn new(out: W, options: &'a Options) -> Self {
pub fn new(out: W, options: &'a Options, pipeline_options: &'a PipelineOptions) -> Self {
Self {
out,
names: crate::FastHashMap::default(),
namer: proc::Namer::default(),
options,
entry_point_io: Vec::new(),
pipeline_options,
entry_point_io: crate::FastHashMap::default(),
named_expressions: crate::NamedExpressions::default(),
wrapped: super::Wrapped::default(),
written_committed_intersection: false,
Expand Down Expand Up @@ -387,8 +388,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
writeln!(self.out)?;
}

let ep_range = get_entry_points(module, self.pipeline_options.entry_point.as_ref())
.map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;

// Write all entry points wrapped structs
for (index, ep) in module.entry_points.iter().enumerate() {
for index in ep_range.clone() {
let ep = &module.entry_points[index];
let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
let ep_io = self.write_ep_interface(
module,
Expand All @@ -397,7 +402,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
&ep_name,
fragment_entry_point,
)?;
self.entry_point_io.push(ep_io);
self.entry_point_io.insert(index, ep_io);
}

// Write all regular functions
Expand Down Expand Up @@ -442,10 +447,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
writeln!(self.out)?;
}

let mut entry_point_names = Vec::with_capacity(module.entry_points.len());
let mut translated_ep_names = Vec::with_capacity(ep_range.len());

// Write all entry points
for (index, ep) in module.entry_points.iter().enumerate() {
for index in ep_range {
let ep = &module.entry_points[index];
let info = module_info.get_entry_point(index);

if !self.options.fake_missing_bindings {
Expand All @@ -462,7 +468,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}
if let Some(err) = ep_error {
entry_point_names.push(Err(err));
translated_ep_names.push(Err(err));
continue;
}
}
Expand Down Expand Up @@ -493,10 +499,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
writeln!(self.out)?;
}

entry_point_names.push(Ok(name));
translated_ep_names.push(Ok(name));
}

Ok(super::ReflectionInfo { entry_point_names })
Ok(super::ReflectionInfo {
entry_point_names: translated_ep_names,
})
}

fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
Expand Down Expand Up @@ -816,7 +824,13 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
ep_index: u16,
) -> BackendResult {
let ep = &module.entry_points[ep_index as usize];
let ep_input = match self.entry_point_io[ep_index as usize].input.take() {
let ep_input = match self
.entry_point_io
.get_mut(&(ep_index as usize))
.unwrap()
.input
.take()
{
Some(ep_input) => ep_input,
None => return Ok(()),
};
Expand Down Expand Up @@ -1432,7 +1446,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}
back::FunctionType::EntryPoint(index) => {
if let Some(ref ep_output) = self.entry_point_io[index as usize].output {
if let Some(ref ep_output) =
self.entry_point_io.get(&(index as usize)).unwrap().output
{
write!(self.out, "{}", ep_output.ty_name)?;
} else {
self.write_type(module, result.ty)?;
Expand Down Expand Up @@ -1479,7 +1495,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}
back::FunctionType::EntryPoint(ep_index) => {
if let Some(ref ep_input) = self.entry_point_io[ep_index as usize].input {
if let Some(ref ep_input) =
self.entry_point_io.get(&(ep_index as usize)).unwrap().input
{
write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
} else {
let stage = module.entry_points[ep_index as usize].stage;
Expand All @@ -1501,7 +1519,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}
if need_workgroup_variables_initialization {
if self.entry_point_io[ep_index as usize].input.is_some()
if self
.entry_point_io
.get(&(ep_index as usize))
.unwrap()
.input
.is_some()
|| !func.arguments.is_empty()
{
write!(self.out, ", ")?;
Expand Down Expand Up @@ -1870,9 +1893,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// for entry point returns, we may need to reshuffle the outputs into a different struct
let ep_output = match func_ctx.ty {
back::FunctionType::Function(_) => None,
back::FunctionType::EntryPoint(index) => {
self.entry_point_io[index as usize].output.as_ref()
}
back::FunctionType::EntryPoint(index) => self
.entry_point_io
.get(&(index as usize))
.unwrap()
.output
.as_ref(),
};
let final_name = match ep_output {
Some(ep_output) => {
Expand Down
27 changes: 27 additions & 0 deletions naga/src/back/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,33 @@ impl core::fmt::Display for Level {
}
}

/// Locate the entry point(s) to write.
///
/// If `entry_point` is given, and the specified entry point exists, returns a
/// length-1 `Range` containing the index of that entry point. If no
/// `entry_point` is given, returns the complete range of entry point indices.
/// If `entry_point` is given but does not exist, returns an error.
#[cfg(any(hlsl_out, msl_out))]
fn get_entry_points(
module: &crate::ir::Module,
entry_point: Option<&(crate::ir::ShaderStage, String)>,
) -> Result<core::ops::Range<usize>, (crate::ir::ShaderStage, String)> {
use alloc::borrow::ToOwned;

if let Some(&(stage, ref name)) = entry_point {
let Some(ep_index) = module
.entry_points
.iter()
.position(|ep| ep.stage == stage && ep.name == *name)
else {
return Err((stage, name.to_owned()));
};
Ok(ep_index..ep_index + 1)
} else {
Ok(0..module.entry_points.len())
}
}

/// Whether we're generating an entry point or a regular function.
///
/// Backend languages often require different code for a [`Function`]
Expand Down
17 changes: 14 additions & 3 deletions naga/src/back/msl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ use alloc::{
};
use core::fmt::{Error as FmtError, Write};

use crate::{arena::Handle, proc::index, valid::ModuleInfo};
use crate::{arena::Handle, ir, proc::index, valid::ModuleInfo};

mod keywords;
pub mod sampler;
Expand Down Expand Up @@ -184,7 +184,7 @@ pub enum Error {
#[error("can not use writeable storage buffers in fragment stage prior to MSL 1.2")]
UnsupportedWriteableStorageBuffer,
#[error("can not use writeable storage textures in {0:?} stage prior to MSL 1.2")]
UnsupportedWriteableStorageTexture(crate::ShaderStage),
UnsupportedWriteableStorageTexture(ir::ShaderStage),
#[error("can not use read-write storage textures prior to MSL 1.2")]
UnsupportedRWStorageTexture,
#[error("array of '{0}' is not supported for target MSL version")]
Expand All @@ -199,6 +199,8 @@ pub enum Error {
UnsupportedBitCast(crate::TypeInner),
#[error(transparent)]
ResolveArraySizeError(#[from] crate::proc::ResolveArraySizeError),
#[error("entry point with stage {0:?} and name '{1}' not found")]
EntryPointNotFound(ir::ShaderStage, String),
}

#[derive(Clone, Debug, PartialEq, thiserror::Error)]
Expand Down Expand Up @@ -420,6 +422,15 @@ pub struct VertexBufferMapping {
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[cfg_attr(feature = "deserialize", serde(default))]
pub struct PipelineOptions {
/// The entry point to write.
///
/// Entry points are identified by a shader stage specification,
/// and a name.
///
/// If `None`, all entry points will be written. If `Some` and the entry
/// point is not found, an error will be thrown while writing.
pub entry_point: Option<(ir::ShaderStage, String)>,

/// Allow `BuiltIn::PointSize` and inject it if doesn't exist.
///
/// Metal doesn't like this for non-point primitive topologies and requires it for
Expand Down Expand Up @@ -737,5 +748,5 @@ pub fn write_string(

#[test]
fn test_error_size() {
assert_eq!(size_of::<Error>(), 32);
assert_eq!(size_of::<Error>(), 40);
}
15 changes: 10 additions & 5 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use half::f16;
use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo};
use crate::{
arena::{Handle, HandleSet},
back::{self, Baked},
back::{self, get_entry_points, Baked},
common,
proc::{
self,
Expand Down Expand Up @@ -5872,10 +5872,15 @@ template <typename A>
self.named_expressions.clear();
}

let ep_range = get_entry_points(module, pipeline_options.entry_point.as_ref())
.map_err(|(stage, name)| Error::EntryPointNotFound(stage, name))?;

let mut info = TranslationInfo {
entry_point_names: Vec::with_capacity(module.entry_points.len()),
entry_point_names: Vec::with_capacity(ep_range.len()),
};
for (ep_index, ep) in module.entry_points.iter().enumerate() {

for ep_index in ep_range {
let ep = &module.entry_points[ep_index];
let fun = &ep.function;
let fun_info = mod_info.get_entry_point(ep_index);
let mut ep_error = None;
Expand Down Expand Up @@ -7076,8 +7081,8 @@ fn test_stack_size() {
}
let stack_size = addresses_end - addresses_start;
// check the size (in debug only)
// last observed macOS value: 20528 (CI)
if !(11000..=25000).contains(&stack_size) {
// last observed macOS value: 25904 (CI), 2025-04-29
if !(11000..=27000).contains(&stack_size) {
panic!("`put_expression` stack size {stack_size} has changed!");
}
}
Expand Down
3 changes: 2 additions & 1 deletion naga/tests/naga/snapshots.rs
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,8 @@ fn write_output_hlsl(
.expect("override evaluation failed");

let mut buffer = String::new();
let mut writer = hlsl::Writer::new(&mut buffer, options);
let pipeline_options = Default::default();
let mut writer = hlsl::Writer::new(&mut buffer, options, &pipeline_options);
let reflection_info = writer
.write(&module, &info, frag_ep.as_ref())
.expect("HLSL write failed");
Expand Down
Loading