diff --git a/resources/Materials/TestSuite/_options.mtlx b/resources/Materials/TestSuite/_options.mtlx
index 49b0b7cc39..e9dcc934e5 100644
--- a/resources/Materials/TestSuite/_options.mtlx
+++ b/resources/Materials/TestSuite/_options.mtlx
@@ -84,5 +84,12 @@
Default is false to avoid overhead when not profiling.
-->
+
+
+
diff --git a/source/MaterialXGenHw/HwShaderGenerator.cpp b/source/MaterialXGenHw/HwShaderGenerator.cpp
index 3d4d19e2b7..8bb17811f8 100644
--- a/source/MaterialXGenHw/HwShaderGenerator.cpp
+++ b/source/MaterialXGenHw/HwShaderGenerator.cpp
@@ -10,6 +10,7 @@
#include
#include
#include
+#include
#include
#include
#include
@@ -391,7 +392,9 @@ void HwShaderGenerator::addStageLightingUniforms(GenContext& context, ShaderStag
numActiveLights->setValue(Value::createValue(0));
}
}
-ShaderNodeImplPtr HwShaderGenerator::createShaderNodeImplForNodeGraph(const NodeGraph& nodegraph) const
+ShaderNodeImplPtr HwShaderGenerator::createShaderNodeImplForNodeGraph(
+ const NodeGraph& nodegraph,
+ std::unique_ptr permutation) const
{
vector outputs = nodegraph.getActiveOutputs();
if (outputs.empty())
@@ -404,6 +407,7 @@ ShaderNodeImplPtr HwShaderGenerator::createShaderNodeImplForNodeGraph(const Node
// Use specialized implementations for nodes that output light shaders and materials.
if (outputType == Type::LIGHTSHADER)
{
+ // HwLightCompoundNode doesn't support permutations (light shaders don't have lobe weights)
return HwLightCompoundNode::create();
}
if (outputType == Type::MATERIAL)
@@ -411,8 +415,7 @@ ShaderNodeImplPtr HwShaderGenerator::createShaderNodeImplForNodeGraph(const Node
return HwMaterialCompoundNode::create();
}
- // Use the base implementation for nodes that output other types.
- return CompoundNode::create();
+ return CompoundNode::create(std::move(permutation));
}
void HwShaderGenerator::emitClosureDataArg(const ShaderNode& node, GenContext& /*context*/, ShaderStage& stage) const
diff --git a/source/MaterialXGenHw/HwShaderGenerator.h b/source/MaterialXGenHw/HwShaderGenerator.h
index d42f7dbb1f..6a72aa250b 100644
--- a/source/MaterialXGenHw/HwShaderGenerator.h
+++ b/source/MaterialXGenHw/HwShaderGenerator.h
@@ -64,7 +64,9 @@ class MX_GENHW_API HwShaderGenerator : public ShaderGenerator
virtual string getVertexDataPrefix(const VariableBlock& vertexData) const = 0;
/// Create the shader node implementation for a NodeGraph implementation.
- ShaderNodeImplPtr createShaderNodeImplForNodeGraph(const NodeGraph& nodegraph) const override;
+ ShaderNodeImplPtr createShaderNodeImplForNodeGraph(
+ const NodeGraph& nodegraph,
+ std::unique_ptr permutation) const override;
// Note : the order must match the order defined in libraries/pbrlib/genglsl/lib/mx_closure_type.glsl
// TODO : investigate build time mechanism for ensuring these stay in sync.
diff --git a/source/MaterialXGenHw/Nodes/HwLightCompoundNode.cpp b/source/MaterialXGenHw/Nodes/HwLightCompoundNode.cpp
index 55194c32fe..d71f4bfe3c 100644
--- a/source/MaterialXGenHw/Nodes/HwLightCompoundNode.cpp
+++ b/source/MaterialXGenHw/Nodes/HwLightCompoundNode.cpp
@@ -11,6 +11,7 @@
MATERIALX_NAMESPACE_BEGIN
HwLightCompoundNode::HwLightCompoundNode() :
+ CompoundNode(nullptr),
_lightUniforms(HW::LIGHT_DATA, EMPTY_STRING)
{
}
diff --git a/source/MaterialXGenMdl/MdlShaderGenerator.cpp b/source/MaterialXGenMdl/MdlShaderGenerator.cpp
index 5e89ac38c1..6ba7c3dd86 100644
--- a/source/MaterialXGenMdl/MdlShaderGenerator.cpp
+++ b/source/MaterialXGenMdl/MdlShaderGenerator.cpp
@@ -339,7 +339,9 @@ ShaderPtr MdlShaderGenerator::generate(const string& name, ElementPtr element, G
return shader;
}
-ShaderNodeImplPtr MdlShaderGenerator::createShaderNodeImplForNodeGraph(const NodeGraph& nodegraph) const
+ShaderNodeImplPtr MdlShaderGenerator::createShaderNodeImplForNodeGraph(
+ const NodeGraph& nodegraph,
+ std::unique_ptr permutation) const
{
vector outputs = nodegraph.getActiveOutputs();
if (outputs.empty())
@@ -349,13 +351,12 @@ ShaderNodeImplPtr MdlShaderGenerator::createShaderNodeImplForNodeGraph(const Nod
const TypeDesc outputType = _typeSystem->getType(outputs[0]->getType());
- ShaderNodeImplPtr impl;
- // Use a compound implementation.
+ // Use a compound implementation with permutation support
if (outputType.isClosure())
{
- return ClosureCompoundNodeMdl::create();
+ return ClosureCompoundNodeMdl::create(std::move(permutation));
}
- return CompoundNodeMdl::create();
+ return CompoundNodeMdl::create(std::move(permutation));
}
ShaderNodeImplPtr MdlShaderGenerator::createShaderNodeImplForImplementation(const Implementation& implElement) const
diff --git a/source/MaterialXGenMdl/MdlShaderGenerator.h b/source/MaterialXGenMdl/MdlShaderGenerator.h
index 04415798c3..cd8ff94704 100644
--- a/source/MaterialXGenMdl/MdlShaderGenerator.h
+++ b/source/MaterialXGenMdl/MdlShaderGenerator.h
@@ -75,7 +75,9 @@ class MX_GENMDL_API MdlShaderGenerator : public ShaderGenerator
ShaderPtr generate(const string& name, ElementPtr element, GenContext& context) const override;
/// Create the shader node implementation for a NodeGraph implementation.
- ShaderNodeImplPtr createShaderNodeImplForNodeGraph(const NodeGraph& nodegraph) const override;
+ ShaderNodeImplPtr createShaderNodeImplForNodeGraph(
+ const NodeGraph& nodegraph,
+ std::unique_ptr permutation) const override;
/// Create the shader node implementation for an mplementation implementation.
ShaderNodeImplPtr createShaderNodeImplForImplementation(const Implementation& implementation) const override;
diff --git a/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp b/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp
index b5c16456c0..78c1b1ec60 100644
--- a/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp
+++ b/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp
@@ -13,9 +13,14 @@
MATERIALX_NAMESPACE_BEGIN
-ShaderNodeImplPtr ClosureCompoundNodeMdl::create()
+ShaderNodeImplPtr ClosureCompoundNodeMdl::create(std::unique_ptr permutation)
+{
+ return std::make_shared(std::move(permutation));
+}
+
+ClosureCompoundNodeMdl::ClosureCompoundNodeMdl(std::unique_ptr permutation) :
+ CompoundNodeMdl(std::move(permutation))
{
- return std::make_shared();
}
void ClosureCompoundNodeMdl::addClassification(ShaderNode& node) const
diff --git a/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.h b/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.h
index 633ed4553e..03a04628d0 100644
--- a/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.h
+++ b/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.h
@@ -15,11 +15,14 @@ MATERIALX_NAMESPACE_BEGIN
class MX_GENMDL_API ClosureCompoundNodeMdl : public CompoundNodeMdl
{
public:
- static ShaderNodeImplPtr create();
+ /// Create with permutation (may be nullptr).
+ static ShaderNodeImplPtr create(std::unique_ptr permutation);
void addClassification(ShaderNode& node) const override;
void emitFunctionDefinition(const ShaderNode& node, GenContext& context, ShaderStage& stage) const override;
void emitFunctionCall(const ShaderNode& node, GenContext& context, ShaderStage& stage) const override;
+
+ explicit ClosureCompoundNodeMdl(std::unique_ptr permutation);
};
MATERIALX_NAMESPACE_END
diff --git a/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp b/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp
index ea15eca7c2..da2d617cbf 100644
--- a/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp
+++ b/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp
@@ -16,9 +16,14 @@ MATERIALX_NAMESPACE_BEGIN
const string CompoundNodeMdl::GEN_USER_DATA_RETURN_STRUCT_FIELD_NAME = "returnStructFieldName";
-ShaderNodeImplPtr CompoundNodeMdl::create()
+ShaderNodeImplPtr CompoundNodeMdl::create(std::unique_ptr permutation)
+{
+ return std::make_shared(std::move(permutation));
+}
+
+CompoundNodeMdl::CompoundNodeMdl(std::unique_ptr permutation) :
+ CompoundNode(std::move(permutation))
{
- return std::make_shared();
}
void CompoundNodeMdl::initialize(const InterfaceElement& element, GenContext& context)
diff --git a/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.h b/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.h
index 7bfc70fbe6..2923c6f2ee 100644
--- a/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.h
+++ b/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.h
@@ -30,7 +30,8 @@ using GenUserDataStringPtr = std::shared_ptr;
class MX_GENMDL_API CompoundNodeMdl : public CompoundNode
{
public:
- static ShaderNodeImplPtr create();
+ /// Create with permutation (may be nullptr).
+ static ShaderNodeImplPtr create(std::unique_ptr permutation);
void initialize(const InterfaceElement& element, GenContext& context) override;
void emitFunctionDefinition(const ShaderNode& node, GenContext& context, ShaderStage& stage) const override;
@@ -39,6 +40,8 @@ class MX_GENMDL_API CompoundNodeMdl : public CompoundNode
bool isReturnStruct() const { return !_returnStruct.empty(); }
bool unrollReturnStructMembers() const { return _unrollReturnStructMembers; }
+ explicit CompoundNodeMdl(std::unique_ptr permutation);
+
protected:
void emitFunctionSignature(const ShaderNode& node, GenContext& context, ShaderStage& stage) const;
diff --git a/source/MaterialXGenShader/GenContext.cpp b/source/MaterialXGenShader/GenContext.cpp
index 3fea10c930..083bb52925 100644
--- a/source/MaterialXGenShader/GenContext.cpp
+++ b/source/MaterialXGenShader/GenContext.cpp
@@ -62,6 +62,7 @@ void GenContext::getNodeImplementationNames(StringSet& names)
void GenContext::clearNodeImplementations()
{
_nodeImpls.clear();
+ _nodeGraphTopologyCache.clear();
}
void GenContext::clearUserData()
diff --git a/source/MaterialXGenShader/GenContext.h b/source/MaterialXGenShader/GenContext.h
index 0e5e87804b..801f7ccae7 100644
--- a/source/MaterialXGenShader/GenContext.h
+++ b/source/MaterialXGenShader/GenContext.h
@@ -13,6 +13,7 @@
#include
#include
+#include
#include
#include
@@ -30,7 +31,7 @@ class MX_GENSHADER_API GenContext
{
public:
/// Constructor.
- GenContext(ShaderGeneratorPtr sg);
+ explicit GenContext(ShaderGeneratorPtr sg);
/// Return shader generatior.
ShaderGenerator& getShaderGenerator()
@@ -204,6 +205,12 @@ class MX_GENSHADER_API GenContext
return _applicationVariableHandler;
}
+ /// Return the node graph topology cache for early pruning optimizations.
+ NodeGraphTopologyCache& getNodeGraphTopologyCache()
+ {
+ return _nodeGraphTopologyCache;
+ }
+
protected:
GenContext() = delete;
@@ -219,6 +226,8 @@ class MX_GENSHADER_API GenContext
vector _parentNodes;
+ NodeGraphTopologyCache _nodeGraphTopologyCache;
+
ApplicationVariableHandler _applicationVariableHandler;
};
diff --git a/source/MaterialXGenShader/GenOptions.h b/source/MaterialXGenShader/GenOptions.h
index 034ab782c3..6a87da8776 100644
--- a/source/MaterialXGenShader/GenOptions.h
+++ b/source/MaterialXGenShader/GenOptions.h
@@ -96,6 +96,7 @@ class MX_GENSHADER_API GenOptions
hwWriteAlbedoTable(false),
hwWriteEnvPrefilter(false),
hwImplicitBitangents(true),
+ enableLobePruning(false),
oslImplicitSurfaceShaderConversion(true),
oslConnectCiWrapper(false)
{
@@ -204,6 +205,16 @@ class MX_GENSHADER_API GenOptions
/// inside the bitangent node.
bool hwImplicitBitangents;
+ /// Enable lobe pruning during ShaderGraph construction.
+ /// When enabled, performs topology analysis on NodeGraphs to identify
+ /// "permutation-defining" inputs (e.g., mix weights). If these inputs
+ /// are constant 0 or 1 at a call site, the corresponding dead branches
+ /// are pruned during ShaderGraph creation, avoiding unnecessary node
+ /// instantiation. This produces more compact shaders at the cost of
+ /// more ShaderNodeImpl permutations.
+ /// Defaults to false.
+ bool enableLobePruning;
+
// Enables OSL conversion of surfaceshader struct to closure color.
// Defaults to true.
bool oslImplicitSurfaceShaderConversion;
diff --git a/source/MaterialXGenShader/NodeGraphTopology.cpp b/source/MaterialXGenShader/NodeGraphTopology.cpp
new file mode 100644
index 0000000000..1419ee3e6f
--- /dev/null
+++ b/source/MaterialXGenShader/NodeGraphTopology.cpp
@@ -0,0 +1,423 @@
+//
+// Copyright Contributors to the MaterialX Project
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include
+#include
+
+#include
+#include
+#include
+
+#include
+
+MATERIALX_NAMESPACE_BEGIN
+
+namespace
+{
+// PBR nodes with a weight parameter -- when weight=0, the node can be pruned.
+// TODO: Could this allowlist be replaced by checking output type == BSDF and the
+// presence of a topological weight input (see isTopologicalInput)?
+const std::unordered_set kWeightedPbrNodes = {
+ "oren_nayar_diffuse_bsdf",
+ "compensating_oren_nayar_diffuse_bsdf",
+ "burley_diffuse_bsdf",
+ "conductor_bsdf",
+ "subsurface_bsdf",
+ "translucent_bsdf",
+ "dielectric_bsdf",
+ "generalized_schlick_bsdf",
+ "sheen_bsdf",
+ "dielectric_tf_bsdf",
+ "generalized_schlick_tf_82_bsdf",
+ "sheen_zeltner_bsdf",
+};
+
+} // anonymous namespace
+
+NodeGraphTopology::NodeGraphTopology(const NodeGraph& nodeGraph)
+{
+ NodeDefPtr nodeDef = nodeGraph.getNodeDef();
+ if (!nodeDef)
+ {
+ throw ExceptionShaderGenError("Can't find nodedef for nodegraph '" + nodeGraph.getName() + "'");
+ }
+
+ // Build per-node info (downstream ref counts and upstream dependencies)
+ buildNodeInfos(nodeGraph);
+
+ // Scan all nodes in the NodeGraph for optimization opportunities
+ for (const NodePtr& node : nodeGraph.getNodes())
+ {
+ const string& category = node->getCategory();
+
+ if (category == "mix")
+ {
+ // Mix nodes: mix input at 0 or 1 can eliminate a branch
+ InputPtr mixInput = node->getActiveInput("mix");
+ if (mixInput && isTopologicalInput(mixInput, nodeDef))
+ {
+ const string& interfaceName = mixInput->getInterfaceName();
+ if (_topologicalInputs.find(interfaceName) == _topologicalInputs.end())
+ {
+ _topologicalInputs.emplace(interfaceName, TopologicalInput(interfaceName, node, mixInput));
+ }
+ }
+ }
+ else if (category == "multiply")
+ {
+ // Multiply nodes: any float input at 0 zeroes the output
+ for (const InputPtr& input : node->getActiveInputs())
+ {
+ if (input && isTopologicalInput(input, nodeDef))
+ {
+ const string& interfaceName = input->getInterfaceName();
+ if (_topologicalInputs.find(interfaceName) == _topologicalInputs.end())
+ {
+ _topologicalInputs.emplace(interfaceName, TopologicalInput(interfaceName, node, input));
+ }
+ }
+ }
+ }
+ else if (kWeightedPbrNodes.count(category))
+ {
+ // PBR nodes: weight at 0 makes the node output dark/transparent
+ InputPtr weightInput = node->getActiveInput("weight");
+ if (weightInput && isTopologicalInput(weightInput, nodeDef))
+ {
+ const string& interfaceName = weightInput->getInterfaceName();
+ if (_topologicalInputs.find(interfaceName) == _topologicalInputs.end())
+ {
+ _topologicalInputs.emplace(interfaceName, TopologicalInput(interfaceName, node, weightInput));
+ }
+ }
+ }
+ }
+}
+
+bool NodeGraphTopology::isTopologicalInput(const InputPtr& input, const NodeDefPtr& nodeDef)
+{
+ // Must be connected to the NodeGraph interface
+ if (!input->hasInterfaceName())
+ {
+ return false;
+ }
+
+ // Must be float type
+ if (input->getType() != "float")
+ {
+ return false;
+ }
+
+ // Get the corresponding NodeDef input
+ const string& interfaceName = input->getInterfaceName();
+ InputPtr ndInput = nodeDef->getActiveInput(interfaceName);
+ if (!ndInput)
+ {
+ return false;
+ }
+ // Check for uimin=0, uimax=1 (indicates a 0-1 weight parameter)
+ if (!ndInput->hasAttribute("uimin") || !ndInput->hasAttribute("uimax"))
+ {
+ return false;
+ }
+
+ try
+ {
+ float minVal = std::stof(ndInput->getAttribute("uimin"));
+ float maxVal = std::stof(ndInput->getAttribute("uimax"));
+ return (minVal == 0.0f && maxVal == 1.0f);
+ }
+ catch (...)
+ {
+ return false;
+ }
+}
+
+TopologicalInput::TopologicalInput(
+ const string& inputName_,
+ const NodePtr& node,
+ const InputPtr& input) :
+ inputName(inputName_)
+{
+ const string& category = node->getCategory();
+
+ if (category == "mix")
+ {
+ // For mix nodes:
+ // - When mix=0, the "fg" (foreground) branch loses a consumer
+ // - When mix=1, the "bg" (background) branch loses a consumer
+ // The mix node itself stays alive; we just decrement the unused input's upstream ref count
+ InputPtr bgInput = node->getActiveInput("bg");
+ InputPtr fgInput = node->getActiveInput("fg");
+
+ if (fgInput && fgInput->hasNodeName())
+ {
+ // mix=0 means fg branch loses this consumer
+ potentiallyPrunableAtValue[0].insert(fgInput->getNodeName());
+ }
+
+ if (bgInput && bgInput->hasNodeName())
+ {
+ // mix=1 means bg branch loses this consumer
+ potentiallyPrunableAtValue[1].insert(bgInput->getNodeName());
+ }
+ }
+ else if (category == "multiply")
+ {
+ // For multiply nodes with input=0:
+ // The multiply node stays alive (outputs 0), but the other inputs lose a consumer
+ for (const InputPtr& otherInput : node->getActiveInputs())
+ {
+ if (otherInput != input && otherInput->hasNodeName())
+ {
+ potentiallyPrunableAtValue[0].insert(otherInput->getNodeName());
+ }
+ }
+ }
+ else if (kWeightedPbrNodes.count(category))
+ {
+ // For PBR nodes with weight=0:
+ // The PBR node itself is unconditionally pruned (replaced with dark/transparent)
+ // Its upstream dependencies will be handled via ref count propagation
+ prunableAtValue[0].insert(node->getName());
+ }
+}
+
+void NodeGraphTopology::buildNodeInfos(const NodeGraph& nodeGraph)
+{
+ for (const NodePtr& node : nodeGraph.getNodes())
+ {
+ const string& nodeName = node->getName();
+ NodeInfo& nodeInfo = _nodeInfos[nodeName];
+
+ for (const InputPtr& input : node->getActiveInputs())
+ {
+ if (input->hasNodeName())
+ {
+ nodeInfo.upstreams.insert(input->getNodeName());
+ }
+ }
+
+ // Increment downstream ref counts once per unique upstream
+ for (const string& upstreamName : nodeInfo.upstreams)
+ {
+ _nodeInfos[upstreamName].downstreamRefCount++;
+ }
+ }
+
+ // NodeGraph outputs are roots -- they also count as downstream consumers
+ for (const OutputPtr& output : nodeGraph.getOutputs())
+ {
+ if (output->hasNodeName())
+ {
+ _nodeInfos[output->getNodeName()].downstreamRefCount++;
+ }
+ }
+}
+
+std::unique_ptr NodeGraphTopology::createPermutation(const Node& node) const
+{
+ if (_topologicalInputs.empty())
+ {
+ return nullptr;
+ }
+
+ string permutationKey;
+ std::unordered_set nodesToPrune;
+ bool hasOptimization = false;
+
+ // Working copy of ref counts for this permutation
+ std::unordered_map downstreamRefCounts;
+ downstreamRefCounts.reserve(_nodeInfos.size());
+ for (const auto& [name, nodeInfo] : _nodeInfos)
+ {
+ downstreamRefCounts[name] = nodeInfo.downstreamRefCount;
+ }
+
+ std::vector worklist;
+
+ // Mark a node as pruned and enqueue for upstream propagation.
+ auto pruneNode = [&nodesToPrune, &worklist](const string& nodeName)
+ {
+ if (nodesToPrune.insert(nodeName).second)
+ {
+ worklist.push_back(nodeName);
+ }
+ };
+
+ // A downstream consumer of `nodeName` has been eliminated. Decrement ref count
+ // and prune `nodeName` if no consumers remain.
+ auto removeDownstream = [&downstreamRefCounts, &pruneNode](const string& nodeName)
+ {
+ auto itDownstreamRefCount = downstreamRefCounts.find(nodeName);
+ if (itDownstreamRefCount != downstreamRefCounts.end()
+ && itDownstreamRefCount->second > 0)
+ {
+ itDownstreamRefCount->second--;
+ if (itDownstreamRefCount->second == 0)
+ {
+ pruneNode(nodeName);
+ }
+ }
+ };
+
+ // First pass: build the key and collect initial dead nodes
+ for (const auto& [inputName, topoInput] : _topologicalInputs)
+ {
+ char flag = 'x'; // 'x' = not optimized (connected or intermediate value)
+
+ auto applyConstantValue = [&pruneNode, &removeDownstream](
+ const std::unordered_set& prunable,
+ const std::unordered_set& potentiallyPrunable)
+ {
+ for (const string& nodeName : prunable)
+ {
+ pruneNode(nodeName);
+ }
+ for (const string& nodeName : potentiallyPrunable)
+ {
+ removeDownstream(nodeName);
+ }
+ };
+
+ // Apply the effects of a topological input being constant (0 or 1).
+ // Takes topoInput by parameter (not capture) because capturing
+ // structured bindings requires C++20.
+ auto applyConstantInput = [&applyConstantValue, &flag](
+ const TopologicalInput& topo, Input& input)
+ {
+ if (!input.hasValue())
+ {
+ return;
+ }
+
+ const float value = input.getValue()->asA();
+ if (value == 0.0f)
+ {
+ flag = '0';
+ applyConstantValue(topo.prunableAtValue[0], topo.potentiallyPrunableAtValue[0]);
+ }
+ else if (value == 1.0f)
+ {
+ flag = '1';
+ applyConstantValue(topo.prunableAtValue[1], topo.potentiallyPrunableAtValue[1]);
+ }
+ };
+
+ // Check if this input is connected on the node instance
+ if (InputPtr nodeInput = node.getInput(inputName))
+ {
+ // If connected to another node, can't optimize
+ if (!( nodeInput->hasNodeName()
+ || nodeInput->hasOutputString()
+ || nodeInput->hasInterfaceName()))
+ {
+ applyConstantInput(topoInput, *nodeInput);
+ }
+ }
+ else if (NodeDefPtr nodeDef = node.getNodeDef()) // Input not set on node instance - check NodeDef default value
+ {
+ if (InputPtr defaultInput = nodeDef->getActiveInput(inputName))
+ {
+ applyConstantInput(topoInput, *defaultInput);
+ }
+ }
+
+ if (!permutationKey.empty())
+ {
+ permutationKey += ",";
+ }
+ permutationKey += inputName + "=" + flag;
+ if (flag != 'x')
+ {
+ hasOptimization = true;
+ }
+ }
+
+ // Worklist-driven DCE: when a node is pruned, decrement ref counts of
+ // its upstream dependencies. Prune any upstream whose count hits 0.
+ while (!worklist.empty())
+ {
+ string nodeName = worklist.back();
+ worklist.pop_back();
+
+ auto itNodeInfo = _nodeInfos.find(nodeName);
+ if (itNodeInfo != _nodeInfos.end())
+ {
+ for (const string& upstream : itNodeInfo->second.upstreams)
+ {
+ removeDownstream(upstream);
+ }
+ }
+ }
+
+ if (!hasOptimization)
+ {
+ return nullptr;
+ }
+
+ return std::make_unique(
+ std::move(permutationKey), std::move(nodesToPrune));
+}
+
+//
+// NodeGraphTopologyCache implementation
+//
+
+NodeGraphTopologyCache::~NodeGraphTopologyCache() = default;
+
+NodeGraphTopologyCache::NodeGraphTopologyCache(const NodeGraphTopologyCache& other)
+{
+ std::lock_guard lock(other._cacheMutex);
+ _cache = other._cache;
+}
+
+NodeGraphTopologyCache& NodeGraphTopologyCache::operator=(const NodeGraphTopologyCache& other)
+{
+ if (this != &other)
+ {
+ std::scoped_lock lock(_cacheMutex, other._cacheMutex);
+ _cache = other._cache;
+ }
+ return *this;
+}
+
+void NodeGraphTopologyCache::clear()
+{
+ std::lock_guard lock(_cacheMutex);
+ _cache.clear();
+}
+
+std::unique_ptr NodeGraphTopologyCache::createPermutation(
+ const NodeGraph& nodeGraph, const Node& node)
+{
+ const NodeGraphTopology& topology = getTopology(nodeGraph);
+ return topology.createPermutation(node);
+}
+
+const NodeGraphTopology& NodeGraphTopologyCache::getTopology(const NodeGraph& nodeGraph)
+{
+ const string& ngName = nodeGraph.getName();
+
+ // Check cache first (with lock)
+ {
+ std::lock_guard lock(_cacheMutex);
+ auto it = _cache.find(ngName);
+ if (it != _cache.end())
+ {
+ return *it->second;
+ }
+ }
+
+ // Cache miss - construct outside lock to allow parallel construction
+ // of different topologies. Safe because emplace() won't overwrite.
+ auto topology = std::make_shared(nodeGraph);
+
+ std::lock_guard lock(_cacheMutex);
+ auto [it, inserted] = _cache.emplace(ngName, std::move(topology));
+ return *it->second;
+}
+
+MATERIALX_NAMESPACE_END
diff --git a/source/MaterialXGenShader/NodeGraphTopology.h b/source/MaterialXGenShader/NodeGraphTopology.h
new file mode 100644
index 0000000000..f28bd3d433
--- /dev/null
+++ b/source/MaterialXGenShader/NodeGraphTopology.h
@@ -0,0 +1,115 @@
+//
+// Copyright Contributors to the MaterialX Project
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#ifndef MATERIALX_NODEGRAPHTOPOLOGY_H
+#define MATERIALX_NODEGRAPHTOPOLOGY_H
+
+#include
+
+#include
+#include
+
+#include