diff --git a/resources/Materials/TestSuite/_options.mtlx b/resources/Materials/TestSuite/_options.mtlx index 49b0b7cc39..e9dcc934e5 100644 --- a/resources/Materials/TestSuite/_options.mtlx +++ b/resources/Materials/TestSuite/_options.mtlx @@ -84,5 +84,12 @@ Default is false to avoid overhead when not profiling. --> + + + diff --git a/source/MaterialXGenHw/HwShaderGenerator.cpp b/source/MaterialXGenHw/HwShaderGenerator.cpp index 3d4d19e2b7..8bb17811f8 100644 --- a/source/MaterialXGenHw/HwShaderGenerator.cpp +++ b/source/MaterialXGenHw/HwShaderGenerator.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -391,7 +392,9 @@ void HwShaderGenerator::addStageLightingUniforms(GenContext& context, ShaderStag numActiveLights->setValue(Value::createValue(0)); } } -ShaderNodeImplPtr HwShaderGenerator::createShaderNodeImplForNodeGraph(const NodeGraph& nodegraph) const +ShaderNodeImplPtr HwShaderGenerator::createShaderNodeImplForNodeGraph( + const NodeGraph& nodegraph, + std::unique_ptr permutation) const { vector outputs = nodegraph.getActiveOutputs(); if (outputs.empty()) @@ -404,6 +407,7 @@ ShaderNodeImplPtr HwShaderGenerator::createShaderNodeImplForNodeGraph(const Node // Use specialized implementations for nodes that output light shaders and materials. if (outputType == Type::LIGHTSHADER) { + // HwLightCompoundNode doesn't support permutations (light shaders don't have lobe weights) return HwLightCompoundNode::create(); } if (outputType == Type::MATERIAL) @@ -411,8 +415,7 @@ ShaderNodeImplPtr HwShaderGenerator::createShaderNodeImplForNodeGraph(const Node return HwMaterialCompoundNode::create(); } - // Use the base implementation for nodes that output other types. - return CompoundNode::create(); + return CompoundNode::create(std::move(permutation)); } void HwShaderGenerator::emitClosureDataArg(const ShaderNode& node, GenContext& /*context*/, ShaderStage& stage) const diff --git a/source/MaterialXGenHw/HwShaderGenerator.h b/source/MaterialXGenHw/HwShaderGenerator.h index d42f7dbb1f..6a72aa250b 100644 --- a/source/MaterialXGenHw/HwShaderGenerator.h +++ b/source/MaterialXGenHw/HwShaderGenerator.h @@ -64,7 +64,9 @@ class MX_GENHW_API HwShaderGenerator : public ShaderGenerator virtual string getVertexDataPrefix(const VariableBlock& vertexData) const = 0; /// Create the shader node implementation for a NodeGraph implementation. - ShaderNodeImplPtr createShaderNodeImplForNodeGraph(const NodeGraph& nodegraph) const override; + ShaderNodeImplPtr createShaderNodeImplForNodeGraph( + const NodeGraph& nodegraph, + std::unique_ptr permutation) const override; // Note : the order must match the order defined in libraries/pbrlib/genglsl/lib/mx_closure_type.glsl // TODO : investigate build time mechanism for ensuring these stay in sync. diff --git a/source/MaterialXGenHw/Nodes/HwLightCompoundNode.cpp b/source/MaterialXGenHw/Nodes/HwLightCompoundNode.cpp index 55194c32fe..d71f4bfe3c 100644 --- a/source/MaterialXGenHw/Nodes/HwLightCompoundNode.cpp +++ b/source/MaterialXGenHw/Nodes/HwLightCompoundNode.cpp @@ -11,6 +11,7 @@ MATERIALX_NAMESPACE_BEGIN HwLightCompoundNode::HwLightCompoundNode() : + CompoundNode(nullptr), _lightUniforms(HW::LIGHT_DATA, EMPTY_STRING) { } diff --git a/source/MaterialXGenMdl/MdlShaderGenerator.cpp b/source/MaterialXGenMdl/MdlShaderGenerator.cpp index 5e89ac38c1..6ba7c3dd86 100644 --- a/source/MaterialXGenMdl/MdlShaderGenerator.cpp +++ b/source/MaterialXGenMdl/MdlShaderGenerator.cpp @@ -339,7 +339,9 @@ ShaderPtr MdlShaderGenerator::generate(const string& name, ElementPtr element, G return shader; } -ShaderNodeImplPtr MdlShaderGenerator::createShaderNodeImplForNodeGraph(const NodeGraph& nodegraph) const +ShaderNodeImplPtr MdlShaderGenerator::createShaderNodeImplForNodeGraph( + const NodeGraph& nodegraph, + std::unique_ptr permutation) const { vector outputs = nodegraph.getActiveOutputs(); if (outputs.empty()) @@ -349,13 +351,12 @@ ShaderNodeImplPtr MdlShaderGenerator::createShaderNodeImplForNodeGraph(const Nod const TypeDesc outputType = _typeSystem->getType(outputs[0]->getType()); - ShaderNodeImplPtr impl; - // Use a compound implementation. + // Use a compound implementation with permutation support if (outputType.isClosure()) { - return ClosureCompoundNodeMdl::create(); + return ClosureCompoundNodeMdl::create(std::move(permutation)); } - return CompoundNodeMdl::create(); + return CompoundNodeMdl::create(std::move(permutation)); } ShaderNodeImplPtr MdlShaderGenerator::createShaderNodeImplForImplementation(const Implementation& implElement) const diff --git a/source/MaterialXGenMdl/MdlShaderGenerator.h b/source/MaterialXGenMdl/MdlShaderGenerator.h index 04415798c3..cd8ff94704 100644 --- a/source/MaterialXGenMdl/MdlShaderGenerator.h +++ b/source/MaterialXGenMdl/MdlShaderGenerator.h @@ -75,7 +75,9 @@ class MX_GENMDL_API MdlShaderGenerator : public ShaderGenerator ShaderPtr generate(const string& name, ElementPtr element, GenContext& context) const override; /// Create the shader node implementation for a NodeGraph implementation. - ShaderNodeImplPtr createShaderNodeImplForNodeGraph(const NodeGraph& nodegraph) const override; + ShaderNodeImplPtr createShaderNodeImplForNodeGraph( + const NodeGraph& nodegraph, + std::unique_ptr permutation) const override; /// Create the shader node implementation for an mplementation implementation. ShaderNodeImplPtr createShaderNodeImplForImplementation(const Implementation& implementation) const override; diff --git a/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp b/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp index b5c16456c0..78c1b1ec60 100644 --- a/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp +++ b/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp @@ -13,9 +13,14 @@ MATERIALX_NAMESPACE_BEGIN -ShaderNodeImplPtr ClosureCompoundNodeMdl::create() +ShaderNodeImplPtr ClosureCompoundNodeMdl::create(std::unique_ptr permutation) +{ + return std::make_shared(std::move(permutation)); +} + +ClosureCompoundNodeMdl::ClosureCompoundNodeMdl(std::unique_ptr permutation) : + CompoundNodeMdl(std::move(permutation)) { - return std::make_shared(); } void ClosureCompoundNodeMdl::addClassification(ShaderNode& node) const diff --git a/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.h b/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.h index 633ed4553e..03a04628d0 100644 --- a/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.h +++ b/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.h @@ -15,11 +15,14 @@ MATERIALX_NAMESPACE_BEGIN class MX_GENMDL_API ClosureCompoundNodeMdl : public CompoundNodeMdl { public: - static ShaderNodeImplPtr create(); + /// Create with permutation (may be nullptr). + static ShaderNodeImplPtr create(std::unique_ptr permutation); void addClassification(ShaderNode& node) const override; void emitFunctionDefinition(const ShaderNode& node, GenContext& context, ShaderStage& stage) const override; void emitFunctionCall(const ShaderNode& node, GenContext& context, ShaderStage& stage) const override; + + explicit ClosureCompoundNodeMdl(std::unique_ptr permutation); }; MATERIALX_NAMESPACE_END diff --git a/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp b/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp index ea15eca7c2..da2d617cbf 100644 --- a/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp +++ b/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp @@ -16,9 +16,14 @@ MATERIALX_NAMESPACE_BEGIN const string CompoundNodeMdl::GEN_USER_DATA_RETURN_STRUCT_FIELD_NAME = "returnStructFieldName"; -ShaderNodeImplPtr CompoundNodeMdl::create() +ShaderNodeImplPtr CompoundNodeMdl::create(std::unique_ptr permutation) +{ + return std::make_shared(std::move(permutation)); +} + +CompoundNodeMdl::CompoundNodeMdl(std::unique_ptr permutation) : + CompoundNode(std::move(permutation)) { - return std::make_shared(); } void CompoundNodeMdl::initialize(const InterfaceElement& element, GenContext& context) diff --git a/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.h b/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.h index 7bfc70fbe6..2923c6f2ee 100644 --- a/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.h +++ b/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.h @@ -30,7 +30,8 @@ using GenUserDataStringPtr = std::shared_ptr; class MX_GENMDL_API CompoundNodeMdl : public CompoundNode { public: - static ShaderNodeImplPtr create(); + /// Create with permutation (may be nullptr). + static ShaderNodeImplPtr create(std::unique_ptr permutation); void initialize(const InterfaceElement& element, GenContext& context) override; void emitFunctionDefinition(const ShaderNode& node, GenContext& context, ShaderStage& stage) const override; @@ -39,6 +40,8 @@ class MX_GENMDL_API CompoundNodeMdl : public CompoundNode bool isReturnStruct() const { return !_returnStruct.empty(); } bool unrollReturnStructMembers() const { return _unrollReturnStructMembers; } + explicit CompoundNodeMdl(std::unique_ptr permutation); + protected: void emitFunctionSignature(const ShaderNode& node, GenContext& context, ShaderStage& stage) const; diff --git a/source/MaterialXGenShader/GenContext.cpp b/source/MaterialXGenShader/GenContext.cpp index 3fea10c930..083bb52925 100644 --- a/source/MaterialXGenShader/GenContext.cpp +++ b/source/MaterialXGenShader/GenContext.cpp @@ -62,6 +62,7 @@ void GenContext::getNodeImplementationNames(StringSet& names) void GenContext::clearNodeImplementations() { _nodeImpls.clear(); + _nodeGraphTopologyCache.clear(); } void GenContext::clearUserData() diff --git a/source/MaterialXGenShader/GenContext.h b/source/MaterialXGenShader/GenContext.h index 0e5e87804b..801f7ccae7 100644 --- a/source/MaterialXGenShader/GenContext.h +++ b/source/MaterialXGenShader/GenContext.h @@ -13,6 +13,7 @@ #include #include +#include #include #include @@ -30,7 +31,7 @@ class MX_GENSHADER_API GenContext { public: /// Constructor. - GenContext(ShaderGeneratorPtr sg); + explicit GenContext(ShaderGeneratorPtr sg); /// Return shader generatior. ShaderGenerator& getShaderGenerator() @@ -204,6 +205,12 @@ class MX_GENSHADER_API GenContext return _applicationVariableHandler; } + /// Return the node graph topology cache for early pruning optimizations. + NodeGraphTopologyCache& getNodeGraphTopologyCache() + { + return _nodeGraphTopologyCache; + } + protected: GenContext() = delete; @@ -219,6 +226,8 @@ class MX_GENSHADER_API GenContext vector _parentNodes; + NodeGraphTopologyCache _nodeGraphTopologyCache; + ApplicationVariableHandler _applicationVariableHandler; }; diff --git a/source/MaterialXGenShader/GenOptions.h b/source/MaterialXGenShader/GenOptions.h index 034ab782c3..6a87da8776 100644 --- a/source/MaterialXGenShader/GenOptions.h +++ b/source/MaterialXGenShader/GenOptions.h @@ -96,6 +96,7 @@ class MX_GENSHADER_API GenOptions hwWriteAlbedoTable(false), hwWriteEnvPrefilter(false), hwImplicitBitangents(true), + enableLobePruning(false), oslImplicitSurfaceShaderConversion(true), oslConnectCiWrapper(false) { @@ -204,6 +205,16 @@ class MX_GENSHADER_API GenOptions /// inside the bitangent node. bool hwImplicitBitangents; + /// Enable lobe pruning during ShaderGraph construction. + /// When enabled, performs topology analysis on NodeGraphs to identify + /// "permutation-defining" inputs (e.g., mix weights). If these inputs + /// are constant 0 or 1 at a call site, the corresponding dead branches + /// are pruned during ShaderGraph creation, avoiding unnecessary node + /// instantiation. This produces more compact shaders at the cost of + /// more ShaderNodeImpl permutations. + /// Defaults to false. + bool enableLobePruning; + // Enables OSL conversion of surfaceshader struct to closure color. // Defaults to true. bool oslImplicitSurfaceShaderConversion; diff --git a/source/MaterialXGenShader/NodeGraphTopology.cpp b/source/MaterialXGenShader/NodeGraphTopology.cpp new file mode 100644 index 0000000000..1419ee3e6f --- /dev/null +++ b/source/MaterialXGenShader/NodeGraphTopology.cpp @@ -0,0 +1,423 @@ +// +// Copyright Contributors to the MaterialX Project +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include +#include +#include + +#include + +MATERIALX_NAMESPACE_BEGIN + +namespace +{ +// PBR nodes with a weight parameter -- when weight=0, the node can be pruned. +// TODO: Could this allowlist be replaced by checking output type == BSDF and the +// presence of a topological weight input (see isTopologicalInput)? +const std::unordered_set kWeightedPbrNodes = { + "oren_nayar_diffuse_bsdf", + "compensating_oren_nayar_diffuse_bsdf", + "burley_diffuse_bsdf", + "conductor_bsdf", + "subsurface_bsdf", + "translucent_bsdf", + "dielectric_bsdf", + "generalized_schlick_bsdf", + "sheen_bsdf", + "dielectric_tf_bsdf", + "generalized_schlick_tf_82_bsdf", + "sheen_zeltner_bsdf", +}; + +} // anonymous namespace + +NodeGraphTopology::NodeGraphTopology(const NodeGraph& nodeGraph) +{ + NodeDefPtr nodeDef = nodeGraph.getNodeDef(); + if (!nodeDef) + { + throw ExceptionShaderGenError("Can't find nodedef for nodegraph '" + nodeGraph.getName() + "'"); + } + + // Build per-node info (downstream ref counts and upstream dependencies) + buildNodeInfos(nodeGraph); + + // Scan all nodes in the NodeGraph for optimization opportunities + for (const NodePtr& node : nodeGraph.getNodes()) + { + const string& category = node->getCategory(); + + if (category == "mix") + { + // Mix nodes: mix input at 0 or 1 can eliminate a branch + InputPtr mixInput = node->getActiveInput("mix"); + if (mixInput && isTopologicalInput(mixInput, nodeDef)) + { + const string& interfaceName = mixInput->getInterfaceName(); + if (_topologicalInputs.find(interfaceName) == _topologicalInputs.end()) + { + _topologicalInputs.emplace(interfaceName, TopologicalInput(interfaceName, node, mixInput)); + } + } + } + else if (category == "multiply") + { + // Multiply nodes: any float input at 0 zeroes the output + for (const InputPtr& input : node->getActiveInputs()) + { + if (input && isTopologicalInput(input, nodeDef)) + { + const string& interfaceName = input->getInterfaceName(); + if (_topologicalInputs.find(interfaceName) == _topologicalInputs.end()) + { + _topologicalInputs.emplace(interfaceName, TopologicalInput(interfaceName, node, input)); + } + } + } + } + else if (kWeightedPbrNodes.count(category)) + { + // PBR nodes: weight at 0 makes the node output dark/transparent + InputPtr weightInput = node->getActiveInput("weight"); + if (weightInput && isTopologicalInput(weightInput, nodeDef)) + { + const string& interfaceName = weightInput->getInterfaceName(); + if (_topologicalInputs.find(interfaceName) == _topologicalInputs.end()) + { + _topologicalInputs.emplace(interfaceName, TopologicalInput(interfaceName, node, weightInput)); + } + } + } + } +} + +bool NodeGraphTopology::isTopologicalInput(const InputPtr& input, const NodeDefPtr& nodeDef) +{ + // Must be connected to the NodeGraph interface + if (!input->hasInterfaceName()) + { + return false; + } + + // Must be float type + if (input->getType() != "float") + { + return false; + } + + // Get the corresponding NodeDef input + const string& interfaceName = input->getInterfaceName(); + InputPtr ndInput = nodeDef->getActiveInput(interfaceName); + if (!ndInput) + { + return false; + } + // Check for uimin=0, uimax=1 (indicates a 0-1 weight parameter) + if (!ndInput->hasAttribute("uimin") || !ndInput->hasAttribute("uimax")) + { + return false; + } + + try + { + float minVal = std::stof(ndInput->getAttribute("uimin")); + float maxVal = std::stof(ndInput->getAttribute("uimax")); + return (minVal == 0.0f && maxVal == 1.0f); + } + catch (...) + { + return false; + } +} + +TopologicalInput::TopologicalInput( + const string& inputName_, + const NodePtr& node, + const InputPtr& input) : + inputName(inputName_) +{ + const string& category = node->getCategory(); + + if (category == "mix") + { + // For mix nodes: + // - When mix=0, the "fg" (foreground) branch loses a consumer + // - When mix=1, the "bg" (background) branch loses a consumer + // The mix node itself stays alive; we just decrement the unused input's upstream ref count + InputPtr bgInput = node->getActiveInput("bg"); + InputPtr fgInput = node->getActiveInput("fg"); + + if (fgInput && fgInput->hasNodeName()) + { + // mix=0 means fg branch loses this consumer + potentiallyPrunableAtValue[0].insert(fgInput->getNodeName()); + } + + if (bgInput && bgInput->hasNodeName()) + { + // mix=1 means bg branch loses this consumer + potentiallyPrunableAtValue[1].insert(bgInput->getNodeName()); + } + } + else if (category == "multiply") + { + // For multiply nodes with input=0: + // The multiply node stays alive (outputs 0), but the other inputs lose a consumer + for (const InputPtr& otherInput : node->getActiveInputs()) + { + if (otherInput != input && otherInput->hasNodeName()) + { + potentiallyPrunableAtValue[0].insert(otherInput->getNodeName()); + } + } + } + else if (kWeightedPbrNodes.count(category)) + { + // For PBR nodes with weight=0: + // The PBR node itself is unconditionally pruned (replaced with dark/transparent) + // Its upstream dependencies will be handled via ref count propagation + prunableAtValue[0].insert(node->getName()); + } +} + +void NodeGraphTopology::buildNodeInfos(const NodeGraph& nodeGraph) +{ + for (const NodePtr& node : nodeGraph.getNodes()) + { + const string& nodeName = node->getName(); + NodeInfo& nodeInfo = _nodeInfos[nodeName]; + + for (const InputPtr& input : node->getActiveInputs()) + { + if (input->hasNodeName()) + { + nodeInfo.upstreams.insert(input->getNodeName()); + } + } + + // Increment downstream ref counts once per unique upstream + for (const string& upstreamName : nodeInfo.upstreams) + { + _nodeInfos[upstreamName].downstreamRefCount++; + } + } + + // NodeGraph outputs are roots -- they also count as downstream consumers + for (const OutputPtr& output : nodeGraph.getOutputs()) + { + if (output->hasNodeName()) + { + _nodeInfos[output->getNodeName()].downstreamRefCount++; + } + } +} + +std::unique_ptr NodeGraphTopology::createPermutation(const Node& node) const +{ + if (_topologicalInputs.empty()) + { + return nullptr; + } + + string permutationKey; + std::unordered_set nodesToPrune; + bool hasOptimization = false; + + // Working copy of ref counts for this permutation + std::unordered_map downstreamRefCounts; + downstreamRefCounts.reserve(_nodeInfos.size()); + for (const auto& [name, nodeInfo] : _nodeInfos) + { + downstreamRefCounts[name] = nodeInfo.downstreamRefCount; + } + + std::vector worklist; + + // Mark a node as pruned and enqueue for upstream propagation. + auto pruneNode = [&nodesToPrune, &worklist](const string& nodeName) + { + if (nodesToPrune.insert(nodeName).second) + { + worklist.push_back(nodeName); + } + }; + + // A downstream consumer of `nodeName` has been eliminated. Decrement ref count + // and prune `nodeName` if no consumers remain. + auto removeDownstream = [&downstreamRefCounts, &pruneNode](const string& nodeName) + { + auto itDownstreamRefCount = downstreamRefCounts.find(nodeName); + if (itDownstreamRefCount != downstreamRefCounts.end() + && itDownstreamRefCount->second > 0) + { + itDownstreamRefCount->second--; + if (itDownstreamRefCount->second == 0) + { + pruneNode(nodeName); + } + } + }; + + // First pass: build the key and collect initial dead nodes + for (const auto& [inputName, topoInput] : _topologicalInputs) + { + char flag = 'x'; // 'x' = not optimized (connected or intermediate value) + + auto applyConstantValue = [&pruneNode, &removeDownstream]( + const std::unordered_set& prunable, + const std::unordered_set& potentiallyPrunable) + { + for (const string& nodeName : prunable) + { + pruneNode(nodeName); + } + for (const string& nodeName : potentiallyPrunable) + { + removeDownstream(nodeName); + } + }; + + // Apply the effects of a topological input being constant (0 or 1). + // Takes topoInput by parameter (not capture) because capturing + // structured bindings requires C++20. + auto applyConstantInput = [&applyConstantValue, &flag]( + const TopologicalInput& topo, Input& input) + { + if (!input.hasValue()) + { + return; + } + + const float value = input.getValue()->asA(); + if (value == 0.0f) + { + flag = '0'; + applyConstantValue(topo.prunableAtValue[0], topo.potentiallyPrunableAtValue[0]); + } + else if (value == 1.0f) + { + flag = '1'; + applyConstantValue(topo.prunableAtValue[1], topo.potentiallyPrunableAtValue[1]); + } + }; + + // Check if this input is connected on the node instance + if (InputPtr nodeInput = node.getInput(inputName)) + { + // If connected to another node, can't optimize + if (!( nodeInput->hasNodeName() + || nodeInput->hasOutputString() + || nodeInput->hasInterfaceName())) + { + applyConstantInput(topoInput, *nodeInput); + } + } + else if (NodeDefPtr nodeDef = node.getNodeDef()) // Input not set on node instance - check NodeDef default value + { + if (InputPtr defaultInput = nodeDef->getActiveInput(inputName)) + { + applyConstantInput(topoInput, *defaultInput); + } + } + + if (!permutationKey.empty()) + { + permutationKey += ","; + } + permutationKey += inputName + "=" + flag; + if (flag != 'x') + { + hasOptimization = true; + } + } + + // Worklist-driven DCE: when a node is pruned, decrement ref counts of + // its upstream dependencies. Prune any upstream whose count hits 0. + while (!worklist.empty()) + { + string nodeName = worklist.back(); + worklist.pop_back(); + + auto itNodeInfo = _nodeInfos.find(nodeName); + if (itNodeInfo != _nodeInfos.end()) + { + for (const string& upstream : itNodeInfo->second.upstreams) + { + removeDownstream(upstream); + } + } + } + + if (!hasOptimization) + { + return nullptr; + } + + return std::make_unique( + std::move(permutationKey), std::move(nodesToPrune)); +} + +// +// NodeGraphTopologyCache implementation +// + +NodeGraphTopologyCache::~NodeGraphTopologyCache() = default; + +NodeGraphTopologyCache::NodeGraphTopologyCache(const NodeGraphTopologyCache& other) +{ + std::lock_guard lock(other._cacheMutex); + _cache = other._cache; +} + +NodeGraphTopologyCache& NodeGraphTopologyCache::operator=(const NodeGraphTopologyCache& other) +{ + if (this != &other) + { + std::scoped_lock lock(_cacheMutex, other._cacheMutex); + _cache = other._cache; + } + return *this; +} + +void NodeGraphTopologyCache::clear() +{ + std::lock_guard lock(_cacheMutex); + _cache.clear(); +} + +std::unique_ptr NodeGraphTopologyCache::createPermutation( + const NodeGraph& nodeGraph, const Node& node) +{ + const NodeGraphTopology& topology = getTopology(nodeGraph); + return topology.createPermutation(node); +} + +const NodeGraphTopology& NodeGraphTopologyCache::getTopology(const NodeGraph& nodeGraph) +{ + const string& ngName = nodeGraph.getName(); + + // Check cache first (with lock) + { + std::lock_guard lock(_cacheMutex); + auto it = _cache.find(ngName); + if (it != _cache.end()) + { + return *it->second; + } + } + + // Cache miss - construct outside lock to allow parallel construction + // of different topologies. Safe because emplace() won't overwrite. + auto topology = std::make_shared(nodeGraph); + + std::lock_guard lock(_cacheMutex); + auto [it, inserted] = _cache.emplace(ngName, std::move(topology)); + return *it->second; +} + +MATERIALX_NAMESPACE_END diff --git a/source/MaterialXGenShader/NodeGraphTopology.h b/source/MaterialXGenShader/NodeGraphTopology.h new file mode 100644 index 0000000000..f28bd3d433 --- /dev/null +++ b/source/MaterialXGenShader/NodeGraphTopology.h @@ -0,0 +1,115 @@ +// +// Copyright Contributors to the MaterialX Project +// SPDX-License-Identifier: Apache-2.0 +// + +#ifndef MATERIALX_NODEGRAPHTOPOLOGY_H +#define MATERIALX_NODEGRAPHTOPOLOGY_H + +#include + +#include +#include + +#include +#include +#include +#include +#include + +MATERIALX_NAMESPACE_BEGIN + +/// @class NodeGraphPermutation +/// Represents a specific permutation of a NodeGraph based on call-site input values. +/// Lightweight object used for cache key computation before ShaderNodeImpl creation. +class MX_GENSHADER_API NodeGraphPermutation +{ + public: + NodeGraphPermutation(string key, std::unordered_set nodesToPrune) : + _key(std::move(key)), _nodesToPrune(std::move(nodesToPrune)) { } + + /// Return the permutation key (e.g., "coat=0,sheen=x"). + const string& getKey() const { return _key; } + + /// Check whether a node should be pruned for this permutation. + bool shouldPrune(const string& nodeName) const { return _nodesToPrune.count(nodeName) != 0; } + + private: + const string _key; + const std::unordered_set _nodesToPrune; +}; + +/// Describes a single topological input and the nodes it can eliminate. +/// Arrays are indexed by the constant value (0 or 1). +struct MX_GENSHADER_API TopologicalInput +{ + TopologicalInput(const string& inputName, const NodePtr& node, const InputPtr& input); + + string inputName; + using UnorderedStringSet = std::unordered_set; + + // If the value of this attribute is hardcoded to 0 or 1, + // these nodes are guaranteed to become prunable + UnorderedStringSet prunableAtValue[2]; + + // If the value of this attribute is hardcoded to 0 or 1, + // these nodes may become prunable, depending on their downstream connections + UnorderedStringSet potentiallyPrunableAtValue[2]; +}; + +/// @class NodeGraphTopology +/// Analyzes a NodeGraph to identify "topological" inputs that, when constant +/// (0 or 1), can eliminate entire branches of the graph. +class MX_GENSHADER_API NodeGraphTopology +{ + public: + explicit NodeGraphTopology(const NodeGraph& nodeGraph); + std::unique_ptr createPermutation(const Node& node) const; + + private: + static bool isTopologicalInput(const InputPtr& input, const NodeDefPtr& nodeDef); + void buildNodeInfos(const NodeGraph& nodeGraph); + + struct NodeInfo + { + size_t downstreamRefCount = 0; + StringSet upstreams; + }; + + std::map _topologicalInputs; + std::unordered_map _nodeInfos; +}; + +/// @class NodeGraphTopologyCache +/// Caches topology analyses per NodeGraph definition and creates permutations. +/// Thread-safe; lives on GenContext so clients control its lifetime and +/// invalidation. +class MX_GENSHADER_API NodeGraphTopologyCache +{ + public: + NodeGraphTopologyCache() = default; + ~NodeGraphTopologyCache(); + + /// Copy creates a fresh empty cache (it's a cache — rebuilds on demand). + NodeGraphTopologyCache(const NodeGraphTopologyCache&); + NodeGraphTopologyCache& operator=(const NodeGraphTopologyCache&); + + /// Analyze a NodeGraph's topology (cached) and create a permutation for a + /// specific call-site node instance. + /// @return The permutation, or nullptr if no optimization is possible. + std::unique_ptr createPermutation( + const NodeGraph& nodeGraph, const Node& node); + + /// Discard all cached topology analyses. + void clear(); + + private: + const NodeGraphTopology& getTopology(const NodeGraph& nodeGraph); + + mutable std::mutex _cacheMutex; + std::unordered_map> _cache; +}; + +MATERIALX_NAMESPACE_END + +#endif diff --git a/source/MaterialXGenShader/Nodes/CompoundNode.cpp b/source/MaterialXGenShader/Nodes/CompoundNode.cpp index 78713d821f..69388c65df 100644 --- a/source/MaterialXGenShader/Nodes/CompoundNode.cpp +++ b/source/MaterialXGenShader/Nodes/CompoundNode.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -14,11 +15,18 @@ MATERIALX_NAMESPACE_BEGIN -ShaderNodeImplPtr CompoundNode::create() +ShaderNodeImplPtr CompoundNode::create(std::unique_ptr permutation) { - return std::make_shared(); + return std::make_shared(std::move(permutation)); } +CompoundNode::CompoundNode(std::unique_ptr permutation) : + _permutation(std::move(permutation)) +{ +} + +CompoundNode::~CompoundNode() = default; + void CompoundNode::addClassification(ShaderNode& node) const { // Add classification from the graph implementation. @@ -43,12 +51,12 @@ void CompoundNode::initialize(const InterfaceElement& element, GenContext& conte // so always use the reduced interface for this graph. const ShaderInterfaceType oldShaderInterfaceType = context.getOptions().shaderInterfaceType; context.getOptions().shaderInterfaceType = SHADER_INTERFACE_REDUCED; - _rootGraph = ShaderGraph::create(nullptr, graph, context); + _rootGraph = ShaderGraph::create(nullptr, graph, context, _permutation.get()); context.getOptions().shaderInterfaceType = oldShaderInterfaceType; - // Set hash using the function name. - // TODO: Could be improved to include the full function signature. - _hash = std::hash{}(_functionName); + // Hash includes function name and permutation key (if any) + const string& permKey = _permutation ? _permutation->getKey() : EMPTY_STRING; + _hash = std::hash{}(_functionName + permKey); } void CompoundNode::createVariables(const ShaderNode&, GenContext& context, Shader& shader) const diff --git a/source/MaterialXGenShader/Nodes/CompoundNode.h b/source/MaterialXGenShader/Nodes/CompoundNode.h index ff64cf6b55..e151c5d377 100644 --- a/source/MaterialXGenShader/Nodes/CompoundNode.h +++ b/source/MaterialXGenShader/Nodes/CompoundNode.h @@ -11,13 +11,21 @@ #include #include +#include + MATERIALX_NAMESPACE_BEGIN +class NodeGraphPermutation; + /// Compound node implementation class MX_GENSHADER_API CompoundNode : public ShaderNodeImpl { public: - static ShaderNodeImplPtr create(); + /// Create a CompoundNode with a permutation (may be nullptr). + /// @param permutation The permutation for this instance (ownership transferred) + static ShaderNodeImplPtr create(std::unique_ptr permutation); + + ~CompoundNode() override; void initialize(const InterfaceElement& element, GenContext& context) override; @@ -31,9 +39,15 @@ class MX_GENSHADER_API CompoundNode : public ShaderNodeImpl ShaderGraph* getGraph() const override { return _rootGraph.get(); } + /// Return the permutation (if any) for this compound node. + const NodeGraphPermutation* getPermutation() const { return _permutation.get(); } + + explicit CompoundNode(std::unique_ptr permutation); + protected: ShaderGraphPtr _rootGraph; string _functionName; + std::unique_ptr _permutation; }; MATERIALX_NAMESPACE_END diff --git a/source/MaterialXGenShader/ShaderGenerator.cpp b/source/MaterialXGenShader/ShaderGenerator.cpp index 9dc2d5f212..6a08ff6f58 100644 --- a/source/MaterialXGenShader/ShaderGenerator.cpp +++ b/source/MaterialXGenShader/ShaderGenerator.cpp @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -285,9 +286,11 @@ bool ShaderGenerator::implementationRegistered(const string& name) const return _implFactory.classRegistered(name); } -ShaderNodeImplPtr ShaderGenerator::createShaderNodeImplForNodeGraph(const NodeGraph& /*nodegraph*/) const +ShaderNodeImplPtr ShaderGenerator::createShaderNodeImplForNodeGraph( + const NodeGraph& /*nodegraph*/, + std::unique_ptr permutation) const { - return CompoundNode::create(); + return CompoundNode::create(std::move(permutation)); } ShaderNodeImplPtr ShaderGenerator::createShaderNodeImplForImplementation(const Implementation& /*implementation*/) const @@ -303,7 +306,31 @@ ShaderNodeImplPtr ShaderGenerator::getImplementation(const NodeDef& nodedef, Gen return nullptr; } - const string& name = implElement->getName(); + string name = implElement->getName(); + + // For NodeGraphs, compute permutation and append it to the cache key + std::unique_ptr permutation; + + if (context.getOptions().enableLobePruning && implElement->isA()) + { + const NodeGraph& graph = *implElement->asA(); + + // The node instance is needed to read call-site input values. + // It's on the parent node stack, pushed by createConnectedNodes(). + const vector& parentNodes = context.getParentNodes(); + if (!parentNodes.empty()) + { + permutation = context.getNodeGraphTopologyCache().createPermutation( + graph, *parentNodes.back()); + + const string& key = permutation->getKey(); + if (permutation && !key.empty()) + { + name += "_"; + name += key; + } + } + } // Check if it's created and cached already. ShaderNodeImplPtr impl = context.findNodeImplementation(name); @@ -312,20 +339,19 @@ ShaderNodeImplPtr ShaderGenerator::getImplementation(const NodeDef& nodedef, Gen return impl; } + // Cache miss - create the implementation if (implElement->isA()) { - impl = createShaderNodeImplForNodeGraph(*implElement->asA()); + impl = createShaderNodeImplForNodeGraph(*implElement->asA(), std::move(permutation)); } else if (implElement->isA()) { - ImplementationPtr implementationElement = implElement->asA(); if (getColorManagementSystem() && getColorManagementSystem()->hasImplementation(name)) { impl = getColorManagementSystem()->createImplementation(name); } else { - // Try creating a new in the factory. impl = _implFactory.create(name); } if (!impl) diff --git a/source/MaterialXGenShader/ShaderGenerator.h b/source/MaterialXGenShader/ShaderGenerator.h index ce6c3cf996..ac887f7335 100644 --- a/source/MaterialXGenShader/ShaderGenerator.h +++ b/source/MaterialXGenShader/ShaderGenerator.h @@ -19,9 +19,12 @@ #include +#include MATERIALX_NAMESPACE_BEGIN +class NodeGraphPermutation; + /// @class ShaderGenerator /// Base class for shader generators /// All third-party shader generators should derive from this class. @@ -170,7 +173,11 @@ class MX_GENSHADER_API ShaderGenerator bool implementationRegistered(const string& name) const; /// Create the shader node implementation for a NodeGraph implementation. - virtual ShaderNodeImplPtr createShaderNodeImplForNodeGraph(const NodeGraph& nodegraph) const; + /// @param nodegraph The NodeGraph to create an implementation for + /// @param permutation Permutation for this instance, or nullptr (ownership transferred) + virtual ShaderNodeImplPtr createShaderNodeImplForNodeGraph( + const NodeGraph& nodegraph, + std::unique_ptr permutation) const; /// Create the shader node implementation for an Implementation implementation. virtual ShaderNodeImplPtr createShaderNodeImplForImplementation(const Implementation& implementation) const; diff --git a/source/MaterialXGenShader/ShaderGraph.cpp b/source/MaterialXGenShader/ShaderGraph.cpp index c1dee169eb..aaa4f431e3 100644 --- a/source/MaterialXGenShader/ShaderGraph.cpp +++ b/source/MaterialXGenShader/ShaderGraph.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include @@ -76,7 +77,8 @@ void ShaderGraph::addOutputSockets(const InterfaceElement& elem, GenContext& con void ShaderGraph::createConnectedNodes(const ElementPtr& downstreamElement, const ElementPtr& upstreamElement, ElementPtr connectingElement, - GenContext& context) + GenContext& context, + const NodeGraphPermutation* permutation) { // Create the node if it doesn't exist. NodePtr upstreamNode = upstreamElement->asA(); @@ -85,6 +87,14 @@ void ShaderGraph::createConnectedNodes(const ElementPtr& downstreamElement, throw ExceptionShaderGenError("Upstream element to connect is not a node '" + upstreamElement->getName() + "'"); } + // Check if this node should be pruned + if (permutation && permutation->shouldPrune(upstreamNode->getName())) + { + // Prune this node entirely. The downstream input will remain + // unconnected and use its default value (e.g., transparent BSDF). + return; + } + ShaderNode* newNode = getNode(upstreamNode->getNamePath()); if (!newNode) { @@ -167,7 +177,8 @@ void ShaderGraph::createConnectedNodes(const ElementPtr& downstreamElement, } } -void ShaderGraph::addUpstreamDependencies(const Element& root, GenContext& context) +void ShaderGraph::addUpstreamDependencies(const Element& root, GenContext& context, + const NodeGraphPermutation* permutation) { std::set processedOutputs; @@ -206,7 +217,8 @@ void ShaderGraph::addUpstreamDependencies(const Element& root, GenContext& conte createConnectedNodes(downstreamElement, upstreamElement, edge.getConnectingElement(), - context); + context, + permutation); } } @@ -428,7 +440,8 @@ void ShaderGraph::addUnitTransformNode(ShaderOutput* output, const UnitTransform } } -ShaderGraphPtr ShaderGraph::create(const ShaderGraph* parent, const NodeGraph& nodeGraph, GenContext& context) +ShaderGraphPtr ShaderGraph::create(const ShaderGraph* parent, const NodeGraph& nodeGraph, + GenContext& context, const NodeGraphPermutation* permutation) { NodeDefPtr nodeDef = nodeGraph.getNodeDef(); if (!nodeDef) @@ -452,7 +465,7 @@ ShaderGraphPtr ShaderGraph::create(const ShaderGraph* parent, const NodeGraph& n // Traverse all outputs and create all internal nodes for (OutputPtr graphOutput : nodeGraph.getActiveOutputs()) { - graph->addUpstreamDependencies(*graphOutput, context); + graph->addUpstreamDependencies(*graphOutput, context, permutation); } // Finalize the graph @@ -637,7 +650,7 @@ ShaderGraphPtr ShaderGraph::create(const ShaderGraph* parent, const string& name // Traverse and create all dependencies upstream if (root && context.getOptions().addUpstreamDependencies) { - graph->addUpstreamDependencies(*root, context); + graph->addUpstreamDependencies(*root, context, nullptr); } graph->finalize(context); diff --git a/source/MaterialXGenShader/ShaderGraph.h b/source/MaterialXGenShader/ShaderGraph.h index f02f4fc7ae..cebf9bbce8 100644 --- a/source/MaterialXGenShader/ShaderGraph.h +++ b/source/MaterialXGenShader/ShaderGraph.h @@ -26,6 +26,7 @@ class Syntax; class ShaderGraphEdge; class ShaderGraphEdgeIterator; class GenOptions; +class NodeGraphPermutation; /// An internal input socket in a shader graph, /// used for connecting internal nodes to the outside @@ -56,8 +57,13 @@ class MX_GENSHADER_API ShaderGraph : public ShaderNode GenContext& context); /// Create a new shader graph from a nodegraph. + /// @param parent Parent graph, or nullptr for root graph. + /// @param nodeGraph The NodeGraph to create from. + /// @param context Generation context. + /// @param permutation Optional permutation for early pruning (skip nodes). static ShaderGraphPtr create(const ShaderGraph* parent, const NodeGraph& nodeGraph, - GenContext& context); + GenContext& context, + const NodeGraphPermutation* permutation = nullptr); /// Return true if this node is a graph. bool isAGraph() const override { return true; } @@ -129,13 +135,15 @@ class MX_GENSHADER_API ShaderGraph : public ShaderNode protected: /// Create node connections corresponding to the connection between a pair of elements. /// @param downstreamElement Element representing the node to connect to. - /// @param upstreamElement Element representing the node to connect from + /// @param upstreamElement Element representing the node to connect from. /// @param connectingElement If non-null, specifies the element on on the downstream node to connect to. /// @param context Context for generation. + /// @param permutation Optional permutation for early pruning. void createConnectedNodes(const ElementPtr& downstreamElement, const ElementPtr& upstreamElement, ElementPtr connectingElement, - GenContext& context); + GenContext& context, + const NodeGraphPermutation* permutation); /// Create a new node in a graph from a node definition. /// The uniqueId argument is used as the node's key in the graph's node map. @@ -155,7 +163,11 @@ class MX_GENSHADER_API ShaderGraph : public ShaderNode /// Traverse from the given root element and add all dependencies upstream. /// The traversal is done in the context of a material, if given, to include /// bind input elements in the traversal. - void addUpstreamDependencies(const Element& root, GenContext& context); + /// @param root Root element to traverse from. + /// @param context Generation context. + /// @param permutation Optional permutation for early pruning. + void addUpstreamDependencies(const Element& root, GenContext& context, + const NodeGraphPermutation* permutation); /// Add a color transform node and connect to the given input. void addColorTransformNode(ShaderInput* input, const ColorSpaceTransform& transform, GenContext& context); diff --git a/source/MaterialXTest/MaterialXGenShader/GenShaderUtil.cpp b/source/MaterialXTest/MaterialXGenShader/GenShaderUtil.cpp index e2a197e1cc..96ea87ca3a 100644 --- a/source/MaterialXTest/MaterialXGenShader/GenShaderUtil.cpp +++ b/source/MaterialXTest/MaterialXGenShader/GenShaderUtil.cpp @@ -1006,6 +1006,7 @@ void TestSuiteOptions::print(std::ostream& output) const output << "\tEnable Reference Quality: " << enableReferenceQuality << std::endl; output << "\tOutput Directory: " << (outputDirectory.isEmpty() ? "(default)" : outputDirectory.asString()) << std::endl; output << "\tEnable Tracing: " << enableTracing << std::endl; + output << "\tenableLobePruning: " << enableLobePruning << std::endl; } bool TestSuiteOptions::readOptions(const std::string& optionFile) @@ -1033,6 +1034,7 @@ bool TestSuiteOptions::readOptions(const std::string& optionFile) const std::string ENABLE_REFERENCE_QUALITY("enableReferenceQuality"); const std::string OUTPUT_DIRECTORY_STRING("outputDirectory"); const std::string ENABLE_TRACING_STRING("enableTracing"); + const std::string ENABLE_LOBE_PRUNING_STRING("enableLobePruning"); overrideFiles.clear(); dumpGeneratedCode = false; @@ -1148,6 +1150,10 @@ bool TestSuiteOptions::readOptions(const std::string& optionFile) { enableTracing = val->asA(); } + else if (name == ENABLE_LOBE_PRUNING_STRING) + { + enableLobePruning = val->asA(); + } } } } diff --git a/source/MaterialXTest/MaterialXGenShader/GenShaderUtil.h b/source/MaterialXTest/MaterialXGenShader/GenShaderUtil.h index e6ae23f19f..741f8e51bf 100644 --- a/source/MaterialXTest/MaterialXGenShader/GenShaderUtil.h +++ b/source/MaterialXTest/MaterialXGenShader/GenShaderUtil.h @@ -126,6 +126,9 @@ class TestSuiteOptions // Default is false to avoid overhead when not profiling. bool enableTracing = false; + // Enable early pruning during ShaderGraph construction. + bool enableLobePruning = false; + // Helper to resolve output path for an artifact. // If outputDirectory is set, returns outputDirectory/filename. // Otherwise returns the original path unchanged. diff --git a/source/MaterialXTest/MaterialXRender/RenderUtil.cpp b/source/MaterialXTest/MaterialXRender/RenderUtil.cpp index e79387f23a..f05b74b31c 100644 --- a/source/MaterialXTest/MaterialXRender/RenderUtil.cpp +++ b/source/MaterialXTest/MaterialXRender/RenderUtil.cpp @@ -43,6 +43,7 @@ void ShaderRenderTester::getGenerationOptions(const GenShaderUtil::TestSuiteOpti { mx::GenOptions reducedOption = originalOptions; reducedOption.shaderInterfaceType = mx::SHADER_INTERFACE_REDUCED; + reducedOption.enableLobePruning = testOptions.enableLobePruning; optionsList.push_back(reducedOption); } // Always fallback to complete if no options specified. @@ -50,6 +51,7 @@ void ShaderRenderTester::getGenerationOptions(const GenShaderUtil::TestSuiteOpti { mx::GenOptions completeOption = originalOptions; completeOption.shaderInterfaceType = mx::SHADER_INTERFACE_COMPLETE; + completeOption.enableLobePruning = testOptions.enableLobePruning; optionsList.push_back(completeOption); } }