Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions resources/Materials/TestSuite/_options.mtlx
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,12 @@
Default is false to avoid overhead when not profiling.
-->
<input name="enableTracing" type="boolean" value="true" />

<!-- Enable lobe pruning during ShaderGraph construction.
When a NodeGraph has topological inputs (e.g. mix weights) that are
compile-time constant 0 or 1, skip creating the dead branch nodes.
Default is false.
-->
<input name="enableLobePruning" type="boolean" value="false" />
</nodedef>
</materialx>
9 changes: 6 additions & 3 deletions source/MaterialXGenHw/HwShaderGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <MaterialXGenHw/Nodes/HwLightCompoundNode.h>
#include <MaterialXGenHw/Nodes/HwMaterialCompoundNode.h>
#include <MaterialXGenShader/Exception.h>
#include <MaterialXGenShader/NodeGraphTopology.h>
#include <MaterialXGenShader/Nodes/CompoundNode.h>
#include <MaterialXGenShader/GenContext.h>
#include <MaterialXGenShader/Shader.h>
Expand Down Expand Up @@ -391,7 +392,9 @@ void HwShaderGenerator::addStageLightingUniforms(GenContext& context, ShaderStag
numActiveLights->setValue(Value::createValue<int>(0));
}
}
ShaderNodeImplPtr HwShaderGenerator::createShaderNodeImplForNodeGraph(const NodeGraph& nodegraph) const
ShaderNodeImplPtr HwShaderGenerator::createShaderNodeImplForNodeGraph(
const NodeGraph& nodegraph,
std::unique_ptr<NodeGraphPermutation> permutation) const
{
vector<OutputPtr> outputs = nodegraph.getActiveOutputs();
if (outputs.empty())
Expand All @@ -404,15 +407,15 @@ ShaderNodeImplPtr HwShaderGenerator::createShaderNodeImplForNodeGraph(const Node
// Use specialized implementations for nodes that output light shaders and materials.
if (outputType == Type::LIGHTSHADER)
{
// HwLightCompoundNode doesn't support permutations (light shaders don't have lobe weights)
return HwLightCompoundNode::create();
}
if (outputType == Type::MATERIAL)
{
return HwMaterialCompoundNode::create();
}

// Use the base implementation for nodes that output other types.
return CompoundNode::create();
return CompoundNode::create(std::move(permutation));
}

void HwShaderGenerator::emitClosureDataArg(const ShaderNode& node, GenContext& /*context*/, ShaderStage& stage) const
Expand Down
4 changes: 3 additions & 1 deletion source/MaterialXGenHw/HwShaderGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ class MX_GENHW_API HwShaderGenerator : public ShaderGenerator
virtual string getVertexDataPrefix(const VariableBlock& vertexData) const = 0;

/// Create the shader node implementation for a NodeGraph implementation.
ShaderNodeImplPtr createShaderNodeImplForNodeGraph(const NodeGraph& nodegraph) const override;
ShaderNodeImplPtr createShaderNodeImplForNodeGraph(
const NodeGraph& nodegraph,
std::unique_ptr<NodeGraphPermutation> permutation) const override;

// Note : the order must match the order defined in libraries/pbrlib/genglsl/lib/mx_closure_type.glsl
// TODO : investigate build time mechanism for ensuring these stay in sync.
Expand Down
1 change: 1 addition & 0 deletions source/MaterialXGenHw/Nodes/HwLightCompoundNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
MATERIALX_NAMESPACE_BEGIN

HwLightCompoundNode::HwLightCompoundNode() :
CompoundNode(nullptr),
_lightUniforms(HW::LIGHT_DATA, EMPTY_STRING)
{
}
Expand Down
11 changes: 6 additions & 5 deletions source/MaterialXGenMdl/MdlShaderGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,9 @@ ShaderPtr MdlShaderGenerator::generate(const string& name, ElementPtr element, G
return shader;
}

ShaderNodeImplPtr MdlShaderGenerator::createShaderNodeImplForNodeGraph(const NodeGraph& nodegraph) const
ShaderNodeImplPtr MdlShaderGenerator::createShaderNodeImplForNodeGraph(
const NodeGraph& nodegraph,
std::unique_ptr<NodeGraphPermutation> permutation) const
{
vector<OutputPtr> outputs = nodegraph.getActiveOutputs();
if (outputs.empty())
Expand All @@ -349,13 +351,12 @@ ShaderNodeImplPtr MdlShaderGenerator::createShaderNodeImplForNodeGraph(const Nod

const TypeDesc outputType = _typeSystem->getType(outputs[0]->getType());

ShaderNodeImplPtr impl;
// Use a compound implementation.
// Use a compound implementation with permutation support
if (outputType.isClosure())
{
return ClosureCompoundNodeMdl::create();
return ClosureCompoundNodeMdl::create(std::move(permutation));
}
return CompoundNodeMdl::create();
return CompoundNodeMdl::create(std::move(permutation));
}

ShaderNodeImplPtr MdlShaderGenerator::createShaderNodeImplForImplementation(const Implementation& implElement) const
Expand Down
4 changes: 3 additions & 1 deletion source/MaterialXGenMdl/MdlShaderGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ class MX_GENMDL_API MdlShaderGenerator : public ShaderGenerator
ShaderPtr generate(const string& name, ElementPtr element, GenContext& context) const override;

/// Create the shader node implementation for a NodeGraph implementation.
ShaderNodeImplPtr createShaderNodeImplForNodeGraph(const NodeGraph& nodegraph) const override;
ShaderNodeImplPtr createShaderNodeImplForNodeGraph(
const NodeGraph& nodegraph,
std::unique_ptr<NodeGraphPermutation> permutation) const override;

/// Create the shader node implementation for an mplementation implementation.
ShaderNodeImplPtr createShaderNodeImplForImplementation(const Implementation& implementation) const override;
Expand Down
9 changes: 7 additions & 2 deletions source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,14 @@

MATERIALX_NAMESPACE_BEGIN

ShaderNodeImplPtr ClosureCompoundNodeMdl::create()
ShaderNodeImplPtr ClosureCompoundNodeMdl::create(std::unique_ptr<NodeGraphPermutation> permutation)
{
return std::make_shared<ClosureCompoundNodeMdl>(std::move(permutation));
}

ClosureCompoundNodeMdl::ClosureCompoundNodeMdl(std::unique_ptr<NodeGraphPermutation> permutation) :
CompoundNodeMdl(std::move(permutation))
{
return std::make_shared<ClosureCompoundNodeMdl>();
}

void ClosureCompoundNodeMdl::addClassification(ShaderNode& node) const
Expand Down
5 changes: 4 additions & 1 deletion source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@ MATERIALX_NAMESPACE_BEGIN
class MX_GENMDL_API ClosureCompoundNodeMdl : public CompoundNodeMdl
{
public:
static ShaderNodeImplPtr create();
/// Create with permutation (may be nullptr).
static ShaderNodeImplPtr create(std::unique_ptr<NodeGraphPermutation> permutation);

void addClassification(ShaderNode& node) const override;
void emitFunctionDefinition(const ShaderNode& node, GenContext& context, ShaderStage& stage) const override;
void emitFunctionCall(const ShaderNode& node, GenContext& context, ShaderStage& stage) const override;

explicit ClosureCompoundNodeMdl(std::unique_ptr<NodeGraphPermutation> permutation);
};

MATERIALX_NAMESPACE_END
Expand Down
9 changes: 7 additions & 2 deletions source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,14 @@ MATERIALX_NAMESPACE_BEGIN

const string CompoundNodeMdl::GEN_USER_DATA_RETURN_STRUCT_FIELD_NAME = "returnStructFieldName";

ShaderNodeImplPtr CompoundNodeMdl::create()
ShaderNodeImplPtr CompoundNodeMdl::create(std::unique_ptr<NodeGraphPermutation> permutation)
{
return std::make_shared<CompoundNodeMdl>(std::move(permutation));
}

CompoundNodeMdl::CompoundNodeMdl(std::unique_ptr<NodeGraphPermutation> permutation) :
CompoundNode(std::move(permutation))
{
return std::make_shared<CompoundNodeMdl>();
}

void CompoundNodeMdl::initialize(const InterfaceElement& element, GenContext& context)
Expand Down
5 changes: 4 additions & 1 deletion source/MaterialXGenMdl/Nodes/CompoundNodeMdl.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ using GenUserDataStringPtr = std::shared_ptr<GenUserDataString>;
class MX_GENMDL_API CompoundNodeMdl : public CompoundNode
{
public:
static ShaderNodeImplPtr create();
/// Create with permutation (may be nullptr).
static ShaderNodeImplPtr create(std::unique_ptr<NodeGraphPermutation> permutation);

void initialize(const InterfaceElement& element, GenContext& context) override;
void emitFunctionDefinition(const ShaderNode& node, GenContext& context, ShaderStage& stage) const override;
Expand All @@ -39,6 +40,8 @@ class MX_GENMDL_API CompoundNodeMdl : public CompoundNode
bool isReturnStruct() const { return !_returnStruct.empty(); }
bool unrollReturnStructMembers() const { return _unrollReturnStructMembers; }

explicit CompoundNodeMdl(std::unique_ptr<NodeGraphPermutation> permutation);

protected:
void emitFunctionSignature(const ShaderNode& node, GenContext& context, ShaderStage& stage) const;

Expand Down
1 change: 1 addition & 0 deletions source/MaterialXGenShader/GenContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ void GenContext::getNodeImplementationNames(StringSet& names)
void GenContext::clearNodeImplementations()
{
_nodeImpls.clear();
_nodeGraphTopologyCache.clear();
}

void GenContext::clearUserData()
Expand Down
11 changes: 10 additions & 1 deletion source/MaterialXGenShader/GenContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include <MaterialXGenShader/GenOptions.h>
#include <MaterialXGenShader/GenUserData.h>
#include <MaterialXGenShader/NodeGraphTopology.h>
#include <MaterialXGenShader/ShaderNode.h>
#include <MaterialXGenShader/ShaderGenerator.h>

Expand All @@ -30,7 +31,7 @@ class MX_GENSHADER_API GenContext
{
public:
/// Constructor.
GenContext(ShaderGeneratorPtr sg);
explicit GenContext(ShaderGeneratorPtr sg);

/// Return shader generatior.
ShaderGenerator& getShaderGenerator()
Expand Down Expand Up @@ -204,6 +205,12 @@ class MX_GENSHADER_API GenContext
return _applicationVariableHandler;
}

/// Return the node graph topology cache for early pruning optimizations.
NodeGraphTopologyCache& getNodeGraphTopologyCache()
{
return _nodeGraphTopologyCache;
}

protected:
GenContext() = delete;

Expand All @@ -219,6 +226,8 @@ class MX_GENSHADER_API GenContext

vector<ConstNodePtr> _parentNodes;

NodeGraphTopologyCache _nodeGraphTopologyCache;

ApplicationVariableHandler _applicationVariableHandler;
};

Expand Down
11 changes: 11 additions & 0 deletions source/MaterialXGenShader/GenOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class MX_GENSHADER_API GenOptions
hwWriteAlbedoTable(false),
hwWriteEnvPrefilter(false),
hwImplicitBitangents(true),
enableLobePruning(false),
oslImplicitSurfaceShaderConversion(true),
oslConnectCiWrapper(false)
{
Expand Down Expand Up @@ -204,6 +205,16 @@ class MX_GENSHADER_API GenOptions
/// inside the bitangent node.
bool hwImplicitBitangents;

/// Enable lobe pruning during ShaderGraph construction.
/// When enabled, performs topology analysis on NodeGraphs to identify
/// "permutation-defining" inputs (e.g., mix weights). If these inputs
/// are constant 0 or 1 at a call site, the corresponding dead branches
/// are pruned during ShaderGraph creation, avoiding unnecessary node
/// instantiation. This produces more compact shaders at the cost of
/// more ShaderNodeImpl permutations.
/// Defaults to false.
bool enableLobePruning;

// Enables OSL conversion of surfaceshader struct to closure color.
// Defaults to true.
bool oslImplicitSurfaceShaderConversion;
Expand Down
Loading
Loading