Skip to content

Commit ee2eef6

Browse files
author
kevyuu
committed
Add Validation to ray tracing pipeline creation
Signed-off-by: kevyuu <[email protected]>
1 parent aec8fec commit ee2eef6

File tree

2 files changed

+27
-6
lines changed

2 files changed

+27
-6
lines changed

include/nbl/asset/IRayTracingPipeline.h

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ namespace nbl::asset
5050
struct SCachedCreationParams final
5151
{
5252
SShaderGroupsParams shaderGroups;
53-
uint64_t maxRecursionDepth;
53+
uint32_t maxRecursionDepth;
5454
};
5555
};
5656

@@ -92,27 +92,38 @@ namespace nbl::asset
9292
if (getShaderStage(cached.shaderGroups.raygenGroup.shaderIndex) != ICPUShader::E_SHADER_STAGE::ESS_RAYGEN)
9393
return false;
9494

95+
auto isValidShaderIndex = [this, getShaderStage](size_t index, ICPUShader::E_SHADER_STAGE expectedStage) -> bool
96+
{
97+
if (index == SShaderGroupsParams::ShaderUnused)
98+
return true;
99+
if (index >= shaders.size())
100+
return false;
101+
if (getShaderStage(index) != expectedStage)
102+
return false;
103+
return true;
104+
};
105+
95106
for (const auto& shaderGroup : cached.shaderGroups.hitGroups)
96107
{
97-
if (shaderGroup.anyHitShaderIndex != SShaderGroupsParams::ShaderUnused && getShaderStage(shaderGroup.anyHitShaderIndex) != ICPUShader::E_SHADER_STAGE::ESS_ANY_HIT)
108+
if (!isValidShaderIndex(shaderGroup.anyHitShaderIndex, ICPUShader::E_SHADER_STAGE::ESS_ANY_HIT))
98109
return false;
99110

100-
if (shaderGroup.closestHitShaderIndex != SShaderGroupsParams::ShaderUnused && getShaderStage(shaderGroup.closestHitShaderIndex) != ICPUShader::E_SHADER_STAGE::ESS_CLOSEST_HIT)
111+
if (!isValidShaderIndex(shaderGroup.closestHitShaderIndex, ICPUShader::E_SHADER_STAGE::ESS_CLOSEST_HIT))
101112
return false;
102113

103-
if (shaderGroup.intersectionShaderIndex != SShaderGroupsParams::ShaderUnused && getShaderStage(shaderGroup.intersectionShaderIndex) != ICPUShader::E_SHADER_STAGE::ESS_INTERSECTION)
114+
if (!isValidShaderIndex(shaderGroup.intersectionShaderIndex, ICPUShader::E_SHADER_STAGE::ESS_INTERSECTION))
104115
return false;
105116
}
106117

107118
for (const auto& shaderGroup : cached.shaderGroups.missGroups)
108119
{
109-
if (getShaderStage(shaderGroup.shaderIndex) != ICPUShader::E_SHADER_STAGE::ESS_MISS)
120+
if (!isValidShaderIndex(shaderGroup.shaderIndex, ICPUShader::E_SHADER_STAGE::ESS_MISS))
110121
return false;
111122
}
112123

113124
for (const auto& shaderGroup : cached.shaderGroups.callableShaderGroups)
114125
{
115-
if (getShaderStage(shaderGroup.shaderIndex) != ICPUShader::E_SHADER_STAGE::ESS_CALLABLE)
126+
if (!isValidShaderIndex(shaderGroup.shaderIndex, ICPUShader::E_SHADER_STAGE::ESS_CALLABLE))
116127
return false;
117128
}
118129
return true;

src/nbl/video/ILogicalDevice.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -962,6 +962,16 @@ bool ILogicalDevice::createRayTracingPipelines(IGPUPipelineCache* const pipeline
962962
return false;
963963
}
964964

965+
const auto& limits = getPhysicalDeviceLimits();
966+
for (const auto& param : params)
967+
{
968+
if (param.cached.maxRecursionDepth > limits.maxRayRecursionDepth)
969+
{
970+
NBL_LOG_ERROR("Invalid maxRecursionDepth. maxRecursionDepth(%zu) exceed the limits(%zu)", param.cached.maxRecursionDepth, limits.maxRayRecursionDepth);
971+
return false;
972+
}
973+
}
974+
965975
createRayTracingPipelines_impl(pipelineCache,params,output,specConstantValidation);
966976

967977
bool retval = true;

0 commit comments

Comments
 (0)