diff --git a/examples_tests b/examples_tests index 8c76367c1c..f717206024 160000 --- a/examples_tests +++ b/examples_tests @@ -1 +1 @@ -Subproject commit 8c76367c1c226cce3d66f1c60f540e29a501a1cb +Subproject commit f717206024887dbabd028653edf5f823dd64b5bc diff --git a/include/nbl/builtin/hlsl/bxdf/fresnel.hlsl b/include/nbl/builtin/hlsl/bxdf/fresnel.hlsl index d3b3543a28..e09910aaff 100644 --- a/include/nbl/builtin/hlsl/bxdf/fresnel.hlsl +++ b/include/nbl/builtin/hlsl/bxdf/fresnel.hlsl @@ -33,6 +33,20 @@ struct orientedEtas rcpOrientedEta = backside ? eta : rcpEta; return backside; } + + static T diffuseFresnelCorrectionFactor(T n, T n2) + { + // assert(n*n==n2); + vector::Dimension> TIR = n < (T)1.0; + T invdenum = nbl::hlsl::mix(hlsl::promote(1.0), hlsl::promote(1.0) / (n2 * n2 * (hlsl::promote(554.33) - 380.7 * n)), TIR); + T num = n * nbl::hlsl::mix(hlsl::promote(0.1921156102251088), n * 298.25 - 261.38 * n2 + 138.43, TIR); + num += nbl::hlsl::mix(hlsl::promote(0.8078843897748912), hlsl::promote(-1.67), TIR); + return num * invdenum; + } + + T value; + T rcp; + bool backside; }; template<> @@ -140,6 +154,184 @@ struct refract scalar_type rcpOrientedEta2; }; +template && vector_traits::Dimension == 3) +struct ReflectRefract +{ + using this_t = ReflectRefract; + using vector_type = T; + using scalar_type = typename vector_traits::scalar_type; + + static this_t create(bool refract, NBL_CONST_REF_ARG(vector_type) I, NBL_CONST_REF_ARG(vector_type) N, scalar_type NdotI, scalar_type NdotTorR, scalar_type rcpOrientedEta) + { + this_t retval; + retval.refract = refract; + retval.I = I; + retval.N = N; + retval.NdotI = NdotI; + retval.NdotTorR = NdotTorR; + retval.rcpOrientedEta = rcpOrientedEta; + return retval; + } + + static this_t create(bool r, NBL_CONST_REF_ARG(Refract) refract) + { + this_t retval; + retval.refract = r; + retval.I = refract.I; + retval.N = refract.N; + retval.NdotI = refract.NdotI; + retval.NdotTorR = r ? Refract::computeNdotT(refract.backside, refract.NdotI2, refract.rcpOrientedEta2) : refract.NdotI; + retval.rcpOrientedEta = refract.rcpOrientedEta; + return retval; + } + + vector_type operator()() + { + return N * (NdotI * (hlsl::mix(1.0f, rcpOrientedEta, refract)) + NdotTorR) - I * (hlsl::mix(1.0f, rcpOrientedEta, refract)); + } + + bool refract; + vector_type I; + vector_type N; + scalar_type NdotI; + scalar_type NdotTorR; + scalar_type rcpOrientedEta; +}; + + +namespace fresnel +{ + +template || is_vector_v) +struct Schlick +{ + using scalar_type = typename vector_traits::scalar_type; + + static Schlick create(NBL_CONST_REF_ARG(T) F0, scalar_type VdotH) + { + Schlick retval; + retval.F0 = F0; + retval.VdotH = VdotH; + return retval; + } + + T operator()() + { + T x = 1.0 - VdotH; + return F0 + (1.0 - F0) * x*x*x*x*x; + } + + T F0; + scalar_type VdotH; +}; + +template || is_vector_v) +struct Conductor +{ + using scalar_type = typename vector_traits::scalar_type; + + static Conductor create(NBL_CONST_REF_ARG(T) eta, NBL_CONST_REF_ARG(T) etak, scalar_type cosTheta) + { + Conductor retval; + retval.eta = eta; + retval.etak = etak; + retval.cosTheta = cosTheta; + return retval; + } + + T operator()() + { + const scalar_type cosTheta2 = cosTheta * cosTheta; + //const float sinTheta2 = 1.0 - cosTheta2; + + const T etaLen2 = eta * eta + etak * etak; + const T etaCosTwice = eta * cosTheta * 2.0f; + + const T rs_common = etaLen2 + (T)(cosTheta2); + const T rs2 = (rs_common - etaCosTwice) / (rs_common + etaCosTwice); + + const T rp_common = etaLen2 * cosTheta2 + (T)(1.0); + const T rp2 = (rp_common - etaCosTwice) / (rp_common + etaCosTwice); + + return (rs2 + rp2) * 0.5f; + } + + T eta; + T etak; + scalar_type cosTheta; +}; + +template || is_vector_v) +struct Dielectric +{ + using scalar_type = typename vector_traits::scalar_type; + + static Dielectric create(NBL_CONST_REF_ARG(T) eta, scalar_type cosTheta) + { + Dielectric retval; + OrientedEtas orientedEta = OrientedEtas::create(cosTheta, eta); + retval.eta2 = orientedEta.value * orientedEta.value; + retval.cosTheta = cosTheta; + return retval; + } + + static T __call(NBL_CONST_REF_ARG(T) orientedEta2, scalar_type absCosTheta) + { + const scalar_type sinTheta2 = 1.0 - absCosTheta * absCosTheta; + + // the max() clamping can handle TIR when orientedEta2<1.0 + const T t0 = hlsl::sqrt(hlsl::max(orientedEta2 - sinTheta2, hlsl::promote(0.0))); + const T rs = (hlsl::promote(absCosTheta) - t0) / (hlsl::promote(absCosTheta) + t0); + + const T t2 = orientedEta2 * absCosTheta; + const T rp = (t0 - t2) / (t0 + t2); + + return (rs * rs + rp * rp) * 0.5f; + } + + T operator()() + { + return __call(eta2, cosTheta); + } + + T eta2; + scalar_type cosTheta; +}; + +template || is_vector_v) +struct DielectricFrontFaceOnly +{ + using scalar_type = typename vector_traits::scalar_type; + + static DielectricFrontFaceOnly create(NBL_CONST_REF_ARG(T) orientedEta2, scalar_type absCosTheta) + { + Dielectric retval; + retval.orientedEta2 = orientedEta2; + retval.absCosTheta = hlsl::abs(absCosTheta); + return retval; + } + + T operator()() + { + return Dielectric::__call(orientedEta2, absCosTheta); + } + + T orientedEta2; + scalar_type absCosTheta; +}; + + +// gets the sum of all R, T R T, T R^3 T, T R^5 T, ... paths +template +struct ThinDielectricInfiniteScatter +{ + T operator()(T singleInterfaceReflectance) + { + const T doubleInterfaceReflectance = singleInterfaceReflectance * singleInterfaceReflectance; + return hlsl::mix(hlsl::promote(1.0), (singleInterfaceReflectance - doubleInterfaceReflectance) / (hlsl::promote(1.0) - doubleInterfaceReflectance) * 2.0f, doubleInterfaceReflectance > hlsl::promote(0.9999)); + } +}; + } } diff --git a/include/nbl/builtin/hlsl/bxdf/geom_smith.hlsl b/include/nbl/builtin/hlsl/bxdf/geom_smith.hlsl new file mode 100644 index 0000000000..a39a90b72f --- /dev/null +++ b/include/nbl/builtin/hlsl/bxdf/geom_smith.hlsl @@ -0,0 +1,291 @@ +// Copyright (C) 2018-2023 - DevSH Graphics Programming Sp. z O.O. +// This file is part of the "Nabla Engine". +// For conditions of distribution and use, see copyright notice in nabla.h +#ifndef _NBL_BUILTIN_HLSL_BXDF_GEOM_INCLUDED_ +#define _NBL_BUILTIN_HLSL_BXDF_GEOM_INCLUDED_ + +#include "nbl/builtin/hlsl/bxdf/ndf.hlsl" + +namespace nbl +{ +namespace hlsl +{ +namespace bxdf +{ +namespace smith +{ + +template +typename NDF::scalar_type VNDF_pdf_wo_clamps(typename NDF::scalar_type ndf, typename NDF::scalar_type lambda_V, typename NDF::scalar_type maxNdotV, NBL_REF_ARG(typename NDF::scalar_type) onePlusLambda_V) +{ + onePlusLambda_V = 1.0 + lambda_V; + ndf::microfacet_to_light_measure_transform transform = ndf::microfacet_to_light_measure_transform::create(ndf / onePlusLambda_V, maxNdotV); + return transform(); +} + +template +typename NDF::scalar_type VNDF_pdf_wo_clamps(typename NDF::scalar_type ndf, typename NDF::scalar_type lambda_V, typename NDF::scalar_type absNdotV, bool transmitted, typename NDF::scalar_type VdotH, typename NDF::scalar_type LdotH, typename NDF::scalar_type VdotHLdotH, typename NDF::scalar_type orientedEta, typename NDF::scalar_type reflectance, NBL_REF_ARG(typename NDF::scalar_type) onePlusLambda_V) +{ + onePlusLambda_V = 1.0 + lambda_V; + ndf::microfacet_to_light_measure_transform transform + = ndf::microfacet_to_light_measure_transform::create((transmitted ? (1.0 - reflectance) : reflectance) * ndf / onePlusLambda_V, absNdotV, transmitted, VdotH, LdotH, VdotHLdotH, orientedEta); + return transform(); +} + +template) +T VNDF_pdf_wo_clamps(T ndf, T G1_over_2NdotV) +{ + return ndf * 0.5 * G1_over_2NdotV; +} + +template) +T FVNDF_pdf_wo_clamps(T fresnel_ndf, T G1_over_2NdotV, T absNdotV, bool transmitted, T VdotH, T LdotH, T VdotHLdotH, T orientedEta) +{ + T FNG = fresnel_ndf * G1_over_2NdotV; + T factor = 0.5; + if (transmitted) + { + const T VdotH_etaLdotH = (VdotH + orientedEta * LdotH); + // VdotHLdotH is negative under transmission, so this factor is negative + factor *= -2.0 * VdotHLdotH / (VdotH_etaLdotH * VdotH_etaLdotH); + } + return FNG * factor; +} + +template) +T VNDF_pdf_wo_clamps(T ndf, T G1_over_2NdotV, T absNdotV, bool transmitted, T VdotH, T LdotH, T VdotHLdotH, T orientedEta, T reflectance) +{ + T FN = (transmitted ? (1.0 - reflectance) : reflectance) * ndf; + return FVNDF_pdf_wo_clamps(FN, G1_over_2NdotV, absNdotV, transmitted, VdotH, LdotH, VdotHLdotH, orientedEta); +} + + +template) +struct SIsotropicParams +{ + using this_t = SIsotropicParams; + + static this_t create(T a2, T NdotV2, T NdotL2, T lambdaV_plus_one) // beckmann + { + this_t retval; + retval.a2 = a2; + retval.NdotV2 = NdotV2; + retval.NdotL2 = NdotL2; + retval.lambdaV_plus_one = lambdaV_plus_one; + return retval; + } + + static this_t create(T a2, T NdotV, T NdotV2, T NdotL, T NdotL2) // ggx + { + this_t retval; + retval.a2 = a2; + retval.NdotV = NdotV; + retval.NdotV2 = NdotV2; + retval.NdotL = NdotL; + retval.NdotL2 = NdotL2; + retval.one_minus_a2 = 1.0 - a2; + return retval; + } + + T a2; + T NdotV; + T NdotL; + T NdotV2; + T NdotL2; + T lambdaV_plus_one; + T one_minus_a2; +}; + +template) +struct SAnisotropicParams +{ + using this_t = SAnisotropicParams; + + static this_t create(T ax2, T ay2, T TdotV2, T BdotV2, T NdotV2, T TdotL2, T BdotL2, T NdotL2, T lambdaV_plus_one) // beckmann + { + this_t retval; + retval.ax2 = ax2; + retval.ay2 = ay2; + retval.TdotV2 = TdotV2; + retval.BdotV2 = BdotV2; + retval.NdotV2 = NdotV2; + retval.TdotL2 = TdotL2; + retval.BdotL2 = BdotL2; + retval.NdotL2 = NdotL2; + retval.lambdaV_plus_one = lambdaV_plus_one; + return retval; + } + + static this_t create(T ax2, T ay2, T NdotV, T TdotV2, T BdotV2, T NdotV2, T NdotL, T TdotL2, T BdotL2, T NdotL2) // ggx + { + this_t retval; + retval.ax2 = ax2; + retval.ay2 = ay2; + retval.NdotL = NdotL; + retval.NdotV = NdotV; + retval.TdotV2 = TdotV2; + retval.BdotV2 = BdotV2; + retval.NdotV2 = NdotV2; + retval.TdotL2 = TdotL2; + retval.BdotL2 = BdotL2; + retval.NdotL2 = NdotL2; + return retval; + } + + T ax2; + T ay2; + T NdotV; + T NdotL; + T TdotV2; + T BdotV2; + T NdotV2; + T TdotL2; + T BdotL2; + T NdotL2; + T lambdaV_plus_one; +}; + + +// beckmann +template) +struct Beckmann +{ + using scalar_type = T; + + scalar_type G1(scalar_type lambda) + { + return 1.0 / (1.0 + lambda); + } + + scalar_type C2(scalar_type NdotX2, scalar_type a2) + { + return NdotX2 / (a2 * (1.0 - NdotX2)); + } + + scalar_type C2(scalar_type TdotX2, scalar_type BdotX2, scalar_type NdotX2, scalar_type ax2, scalar_type ay2) + { + return NdotX2 / (TdotX2 * ax2 + BdotX2 * ay2); + } + + scalar_type Lambda(scalar_type c2) + { + scalar_type c = sqrt(c2); + scalar_type nom = 1.0 - 1.259 * c + 0.396 * c2; + scalar_type denom = 2.181 * c2 + 3.535 * c; + return hlsl::mix(0.0, nom / denom, c < 1.6); + } + + scalar_type Lambda(scalar_type NdotX2, scalar_type a2) + { + return Lambda(C2(NdotX2, a2)); + } + + scalar_type Lambda(scalar_type TdotX2, scalar_type BdotX2, scalar_type NdotX2, scalar_type ax2, scalar_type ay2) + { + return Lambda(C2(TdotX2, BdotX2, NdotX2, ax2, ay2)); + } + + scalar_type correlated(SIsotropicParams params) + { + scalar_type c2 = C2(params.NdotV2, params.a2); + scalar_type L_v = Lambda(c2); + c2 = C2(params.NdotL2, params.a2); + scalar_type L_l = Lambda(c2); + return G1(L_v + L_l); + } + + scalar_type correlated(SAnisotropicParams params) + { + scalar_type c2 = C2(params.TdotV2, params.BdotV2, params.NdotV2, params.ax2, params.ay2); + scalar_type L_v = Lambda(c2); + c2 = C2(params.TdotL2, params.BdotL2, params.NdotL2, params.ax2, params.ay2); + scalar_type L_l = Lambda(c2); + return G1(L_v + L_l); + } + + scalar_type G2_over_G1(SIsotropicParams params) + { + scalar_type lambdaL = Lambda(params.NdotL2, params.a2); + return params.lambdaV_plus_one / (params.lambdaV_plus_one + lambdaL); + } + + scalar_type G2_over_G1(SAnisotropicParams params) + { + scalar_type c2 = C2(params.TdotL2, params.BdotL2, params.NdotL2, params.ax2, params.ay2); + scalar_type lambdaL = Lambda(c2); + return params.lambdaV_plus_one / (params.lambdaV_plus_one + lambdaL); + } +}; + + +// ggx +template) +struct GGX +{ + using scalar_type = T; + + scalar_type devsh_part(scalar_type NdotX2, scalar_type a2, scalar_type one_minus_a2) + { + return sqrt(a2 + one_minus_a2 * NdotX2); + } + + scalar_type devsh_part(scalar_type TdotX2, scalar_type BdotX2, scalar_type NdotX2, scalar_type ax2, scalar_type ay2) + { + return sqrt(TdotX2 * ax2 + BdotX2 * ay2 + NdotX2); + } + + scalar_type G1_wo_numerator(scalar_type NdotX, scalar_type NdotX2, scalar_type a2, scalar_type one_minus_a2) + { + return 1.0 / (NdotX + devsh_part(NdotX2,a2,one_minus_a2)); + } + + scalar_type G1_wo_numerator(scalar_type NdotX, scalar_type TdotX2, scalar_type BdotX2, scalar_type NdotX2, scalar_type ax2, scalar_type ay2) + { + return 1.0 / (NdotX + devsh_part(TdotX2, BdotX2, NdotX2, ax2, ay2)); + } + + scalar_type G1_wo_numerator(scalar_type NdotX, scalar_type devsh_part) + { + return 1.0 / (NdotX + devsh_part); + } + + scalar_type correlated_wo_numerator(SIsotropicParams params) + { + scalar_type Vterm = params.NdotL * devsh_part(params.NdotV2, params.a2, params.one_minus_a2); + scalar_type Lterm = params.NdotV * devsh_part(params.NdotL2, params.a2, params.one_minus_a2); + return 0.5 / (Vterm + Lterm); + } + + scalar_type correlated_wo_numerator(SAnisotropicParams params) + { + scalar_type Vterm = params.NdotL * devsh_part(params.TdotV2, params.BdotV2, params.NdotV2, params.ax2, params.ay2); + scalar_type Lterm = params.NdotV * devsh_part(params.TdotL2, params.BdotL2, params.NdotL2, params.ax2, params.ay2); + return 0.5 / (Vterm + Lterm); + } + + scalar_type G2_over_G1(SIsotropicParams params) + { + scalar_type devsh_v = devsh_part(params.NdotV2, params.a2, params.one_minus_a2); + scalar_type G2_over_G1 = params.NdotL * (devsh_v + params.NdotV); // alternative `Vterm+NdotL*NdotV /// NdotL*NdotV could come as a parameter + G2_over_G1 /= params.NdotV * devsh_part(params.NdotL2, params.a2, params.one_minus_a2) + params.NdotL * devsh_v; + + return G2_over_G1; + } + + scalar_type G2_over_G1(SAnisotropicParams params) + { + scalar_type devsh_v = devsh_part(params.TdotV2, params.BdotV2, params.NdotV2, params.ax2, params.ay2); + scalar_type G2_over_G1 = params.NdotL * (devsh_v + params.NdotV); + G2_over_G1 /= params.NdotV * devsh_part(params.TdotL2, params.BdotL2, params.NdotL2, params.ax2, params.ay2) + params.NdotL * devsh_v; + + return G2_over_G1; + } + +}; + +} +} +} +} + +#endif \ No newline at end of file diff --git a/include/nbl/builtin/hlsl/cpp_compat/impl/intrinsics_impl.hlsl b/include/nbl/builtin/hlsl/cpp_compat/impl/intrinsics_impl.hlsl index 0309b78e0d..679fecb697 100644 --- a/include/nbl/builtin/hlsl/cpp_compat/impl/intrinsics_impl.hlsl +++ b/include/nbl/builtin/hlsl/cpp_compat/impl/intrinsics_impl.hlsl @@ -240,13 +240,17 @@ struct mix_helper(e } }; -template NBL_PARTIAL_REQ_TOP(concepts::FloatingPointScalar) -struct mix_helper) > +template +NBL_PARTIAL_REQ_TOP((concepts::Scalar || concepts::Vectorial) && !concepts::Boolean && concepts::Boolean) +struct mix_helper || concepts::Vectorial) && !concepts::Boolean && concepts::Boolean) > { using return_t = conditional_t, vector::scalar_type, vector_traits::Dimension>, T>; - static inline return_t __call(const T x, const T y, const bool a) + // for a component of a that is false, the corresponding component of x is returned + // for a component of a that is true, the corresponding component of y is returned + // so we make sure this is correct when calling the operation + static inline return_t __call(const T x, const T y, const U a) { - return a ? x : y; + return spirv::select(a, y, x); } }; @@ -862,8 +866,8 @@ struct mix_helper }; template -NBL_PARTIAL_REQ_TOP(concepts::Vectorial && concepts::Boolean && vector_traits::Dimension == vector_traits::Dimension) -struct mix_helper && concepts::Boolean && vector_traits::Dimension == vector_traits::Dimension) > +NBL_PARTIAL_REQ_TOP(VECTOR_SPECIALIZATION_CONCEPT && concepts::Boolean && vector_traits::Dimension == vector_traits::Dimension) +struct mix_helper && vector_traits::Dimension == vector_traits::Dimension) > { using return_t = T; static return_t __call(NBL_CONST_REF_ARG(T) x, NBL_CONST_REF_ARG(T) y, NBL_CONST_REF_ARG(U) a) diff --git a/include/nbl/builtin/hlsl/spirv_intrinsics/core.hlsl b/include/nbl/builtin/hlsl/spirv_intrinsics/core.hlsl index 7da69c4a55..4885fc11f8 100644 --- a/include/nbl/builtin/hlsl/spirv_intrinsics/core.hlsl +++ b/include/nbl/builtin/hlsl/spirv_intrinsics/core.hlsl @@ -346,6 +346,20 @@ template) [[vk::ext_instruction(spv::OpISubBorrow)]] SubBorrowOutput subBorrow(T operand1, T operand2); + +template && !is_matrix_v) +[[vk::ext_instruction(spv::OpIEqual)]] +conditional_t, vector::Dimension>, bool> IEqual(T lhs, T rhs); + +template && !is_matrix_v) +[[vk::ext_instruction(spv::OpFOrdEqual)]] +conditional_t, vector::Dimension>, bool> FOrdEqual(T lhs, T rhs); + + +template && !is_matrix_v && is_same_v::scalar_type, bool>) +[[vk::ext_instruction(spv::OpSelect)]] +T select(U a, T x, T y); + } #endif diff --git a/include/nbl/builtin/hlsl/spirv_intrinsics/subgroup_arithmetic.hlsl b/include/nbl/builtin/hlsl/spirv_intrinsics/subgroup_arithmetic.hlsl index c2a4b52ce5..306215415b 100644 --- a/include/nbl/builtin/hlsl/spirv_intrinsics/subgroup_arithmetic.hlsl +++ b/include/nbl/builtin/hlsl/spirv_intrinsics/subgroup_arithmetic.hlsl @@ -17,25 +17,23 @@ namespace hlsl namespace spirv { +template [[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]] [[vk::ext_instruction( spv::OpGroupNonUniformIAdd )]] -int32_t groupAdd(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, int32_t value); -[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]] -[[vk::ext_instruction( spv::OpGroupNonUniformIAdd )]] -uint32_t groupAdd(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, uint32_t value); +enable_if_t && is_integral_v::scalar_type>, T> groupAdd(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value); +template [[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]] [[vk::ext_instruction( spv::OpGroupNonUniformFAdd )]] -float32_t groupAdd(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, float32_t value); +enable_if_t && is_floating_point_v::scalar_type>, T> groupAdd(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value); +template [[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]] [[vk::ext_instruction( spv::OpGroupNonUniformIMul )]] -int32_t groupMul(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, int32_t value); -[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]] -[[vk::ext_instruction( spv::OpGroupNonUniformIMul )]] -uint32_t groupMul(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, uint32_t value); +enable_if_t && is_integral_v::scalar_type>, T> groupMul(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value); +template [[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]] [[vk::ext_instruction( spv::OpGroupNonUniformFMul )]] -float32_t groupMul(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, float32_t value); +enable_if_t && is_floating_point_v::scalar_type>, T> groupMul(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value); template [[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]] @@ -54,25 +52,31 @@ T groupBitwiseXor(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T // The MIN and MAX operations in SPIR-V have different Ops for each arithmetic type // so we implement them distinctly +template [[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]] [[vk::ext_instruction( spv::OpGroupNonUniformSMin )]] -int32_t groupBitwiseMin(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, int32_t value); +enable_if_t && is_signed_v && is_integral_v::scalar_type>, T> groupBitwiseMin(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value); +template [[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]] [[vk::ext_instruction( spv::OpGroupNonUniformUMin )]] -uint32_t groupBitwiseMin(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, uint32_t value); +enable_if_t && !is_signed_v && is_integral_v::scalar_type>, T> groupBitwiseMin(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value); +template [[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]] [[vk::ext_instruction( spv::OpGroupNonUniformFMin )]] -float32_t groupBitwiseMin(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, float32_t value); +enable_if_t && is_floating_point_v::scalar_type>, T> groupBitwiseMin(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value); +template [[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]] [[vk::ext_instruction( spv::OpGroupNonUniformSMax )]] -int32_t groupBitwiseMax(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, int32_t value); +enable_if_t && is_signed_v && is_integral_v::scalar_type>, T> groupBitwiseMax(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value); +template [[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]] [[vk::ext_instruction( spv::OpGroupNonUniformUMax )]] -uint32_t groupBitwiseMax(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, uint32_t value); +enable_if_t && !is_signed_v && is_integral_v::scalar_type>, T> groupBitwiseMax(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value); +template [[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]] [[vk::ext_instruction( spv::OpGroupNonUniformFMax )]] -float32_t groupBitwiseMax(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, float32_t value); +enable_if_t && is_floating_point_v::scalar_type>, T> groupBitwiseMax(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value); } } diff --git a/include/nbl/builtin/hlsl/subgroup/arithmetic_portability.hlsl b/include/nbl/builtin/hlsl/subgroup/arithmetic_portability.hlsl index 5c87dcf828..cf48ab648f 100644 --- a/include/nbl/builtin/hlsl/subgroup/arithmetic_portability.hlsl +++ b/include/nbl/builtin/hlsl/subgroup/arithmetic_portability.hlsl @@ -9,6 +9,7 @@ #include "nbl/builtin/hlsl/subgroup/basic.hlsl" #include "nbl/builtin/hlsl/subgroup/arithmetic_portability_impl.hlsl" +#include "nbl/builtin/hlsl/concepts.hlsl" namespace nbl diff --git a/include/nbl/builtin/hlsl/subgroup2/arithmetic_portability.hlsl b/include/nbl/builtin/hlsl/subgroup2/arithmetic_portability.hlsl new file mode 100644 index 0000000000..444b4c075c --- /dev/null +++ b/include/nbl/builtin/hlsl/subgroup2/arithmetic_portability.hlsl @@ -0,0 +1,45 @@ +// Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O. +// This file is part of the "Nabla Engine". +// For conditions of distribution and use, see copyright notice in nabla.h +#ifndef _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_INCLUDED_ +#define _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_INCLUDED_ + + +#include "nbl/builtin/hlsl/device_capabilities_traits.hlsl" + +#include "nbl/builtin/hlsl/subgroup2/ballot.hlsl" +#include "nbl/builtin/hlsl/subgroup2/arithmetic_portability_impl.hlsl" +#include "nbl/builtin/hlsl/concepts.hlsl" + + +namespace nbl +{ +namespace hlsl +{ +namespace subgroup2 +{ + +template) +struct ArithmeticParams +{ + using config_t = Config; + using binop_t = BinOp; + using scalar_t = typename BinOp::type_t; // BinOp should be with scalar type + using type_t = vector; + + NBL_CONSTEXPR_STATIC_INLINE int32_t ItemsPerInvocation = _ItemsPerInvocation; + NBL_CONSTEXPR_STATIC_INLINE bool UseNativeIntrinsics = device_capabilities_traits::shaderSubgroupArithmetic /*&& /*some heuristic for when its faster*/; +}; + +template +struct reduction : impl::reduction {}; +template +struct inclusive_scan : impl::inclusive_scan {}; +template +struct exclusive_scan : impl::exclusive_scan {}; + +} +} +} + +#endif diff --git a/include/nbl/builtin/hlsl/subgroup2/arithmetic_portability_impl.hlsl b/include/nbl/builtin/hlsl/subgroup2/arithmetic_portability_impl.hlsl new file mode 100644 index 0000000000..15531da37f --- /dev/null +++ b/include/nbl/builtin/hlsl/subgroup2/arithmetic_portability_impl.hlsl @@ -0,0 +1,228 @@ +// Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O. +// This file is part of the "Nabla Engine". +// For conditions of distribution and use, see copyright notice in nabla.h +#ifndef _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_ +#define _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_ + +#include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl" +#include "nbl/builtin/hlsl/glsl_compat/subgroup_arithmetic.hlsl" + +#include "nbl/builtin/hlsl/subgroup/ballot.hlsl" +#include "nbl/builtin/hlsl/subgroup2/ballot.hlsl" + +#include "nbl/builtin/hlsl/functional.hlsl" +#include "nbl/builtin/hlsl/cpp_compat/intrinsics.hlsl" + +namespace nbl +{ +namespace hlsl +{ +namespace subgroup2 +{ + +namespace impl +{ + +// forward declarations +template +struct inclusive_scan; + +template +struct exclusive_scan; + +template +struct reduction; + + +// BinOp needed to specialize native +template +struct inclusive_scan +{ + using type_t = typename Params::type_t; + using scalar_t = typename Params::scalar_t; + using binop_t = typename Params::binop_t; + // assert binop_t == BinOp + using exclusive_scan_op_t = exclusive_scan; + + // NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits::Dimension; + + type_t operator()(NBL_CONST_REF_ARG(type_t) value) + { + binop_t binop; + type_t retval; + retval[0] = value[0]; + [unroll] + for (uint32_t i = 1; i < ItemsPerInvocation; i++) + retval[i] = binop(retval[i-1], value[i]); + + exclusive_scan_op_t op; + scalar_t exclusive = op(retval[ItemsPerInvocation-1]); + + [unroll] + for (uint32_t i = 0; i < ItemsPerInvocation; i++) + retval[i] = binop(retval[i], exclusive); + return retval; + } +}; + +template +struct exclusive_scan +{ + using type_t = typename Params::type_t; + using scalar_t = typename Params::scalar_t; + using binop_t = typename Params::binop_t; + using inclusive_scan_op_t = inclusive_scan; + + // NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits::Dimension; + + type_t operator()(type_t value) + { + inclusive_scan_op_t op; + value = op(value); + + type_t left = glsl::subgroupShuffleUp(value,1); + + type_t retval; + retval[0] = hlsl::mix(binop_t::identity, left[ItemsPerInvocation-1], bool(glsl::gl_SubgroupInvocationID())); + [unroll] + for (uint32_t i = 1; i < ItemsPerInvocation; i++) + retval[i] = value[i-1]; + return retval; + } +}; + +template +struct reduction +{ + using type_t = typename Params::type_t; + using scalar_t = typename Params::scalar_t; + using binop_t = typename Params::binop_t; + using op_t = reduction; + + // NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits::Dimension; + + scalar_t operator()(NBL_CONST_REF_ARG(type_t) value) + { + binop_t binop; + op_t op; + scalar_t retval = value[0]; + [unroll] + for (uint32_t i = 1; i < ItemsPerInvocation; i++) + retval = binop(retval, value[i]); + return op(retval); + } +}; + + +// specs for N=1 uses subgroup funcs +// specialize native +#define SPECIALIZE(NAME,BINOP,SUBGROUP_OP) template struct NAME,1,true> \ +{ \ + using type_t = T; \ + \ + type_t operator()(NBL_CONST_REF_ARG(type_t) v) {return glsl::subgroup##SUBGROUP_OP(v);} \ +} + +#define SPECIALIZE_ALL(BINOP,SUBGROUP_OP) SPECIALIZE(reduction,BINOP,SUBGROUP_OP); \ + SPECIALIZE(inclusive_scan,BINOP,Inclusive##SUBGROUP_OP); \ + SPECIALIZE(exclusive_scan,BINOP,Exclusive##SUBGROUP_OP); + +SPECIALIZE_ALL(bit_and,And); +SPECIALIZE_ALL(bit_or,Or); +SPECIALIZE_ALL(bit_xor,Xor); + +SPECIALIZE_ALL(plus,Add); +SPECIALIZE_ALL(multiplies,Mul); + +SPECIALIZE_ALL(minimum,Min); +SPECIALIZE_ALL(maximum,Max); + +#undef SPECIALIZE_ALL +#undef SPECIALIZE + +// specialize portability +template +struct inclusive_scan +{ + using type_t = typename Params::type_t; + using scalar_t = typename Params::scalar_t; + using binop_t = typename Params::binop_t; + // assert T == scalar type, binop::type == T + using config_t = typename Params::config_t; + + // affected by https://github.com/microsoft/DirectXShaderCompiler/issues/7006 + // NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupSizeLog2 = config_t::SizeLog2; + + scalar_t operator()(scalar_t value) + { + return __call(value); + } + + static scalar_t __call(scalar_t value) + { + binop_t op; + const uint32_t subgroupInvocation = glsl::gl_SubgroupInvocationID(); + + scalar_t rhs = glsl::subgroupShuffleUp(value, 1u); // all invocations must execute the shuffle, even if we don't apply the op() to all of them + value = op(value, hlsl::mix(rhs, binop_t::identity, subgroupInvocation < 1u)); + + const uint32_t SubgroupSizeLog2 = config_t::SizeLog2; + [unroll] + for (uint32_t i = 1; i < integral_constant::value; i++) + { + const uint32_t step = 1u << i; + rhs = glsl::subgroupShuffleUp(value, step); + value = op(value, hlsl::mix(rhs, binop_t::identity, subgroupInvocation < step)); + } + return value; + } +}; + +template +struct exclusive_scan +{ + using type_t = typename Params::type_t; + using scalar_t = typename Params::scalar_t; + using binop_t = typename Params::binop_t; + + scalar_t operator()(scalar_t value) + { + value = inclusive_scan::__call(value); + // can't risk getting short-circuited, need to store to a var + scalar_t left = glsl::subgroupShuffleUp(value,1); + // the first invocation doesn't have anything in its left so we set to the binop's identity value for exlusive scan + return hlsl::mix(binop_t::identity, left, bool(glsl::gl_SubgroupInvocationID())); + } +}; + +template +struct reduction +{ + using type_t = typename Params::type_t; + using scalar_t = typename Params::scalar_t; + using binop_t = typename Params::binop_t; + using config_t = typename Params::config_t; + + // affected by https://github.com/microsoft/DirectXShaderCompiler/issues/7006 + // NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupSizeLog2 = config_t::SizeLog2; + + scalar_t operator()(scalar_t value) + { + binop_t op; + + const uint32_t SubgroupSizeLog2 = config_t::SizeLog2; + [unroll] + for (uint32_t i = 0; i < integral_constant::value; i++) + value = op(glsl::subgroupShuffleXor(value,0x1u< +struct Configuration +{ + using mask_t = conditional_t, uint32_t4>; + + NBL_CONSTEXPR_STATIC_INLINE uint16_t SizeLog2 = uint16_t(SubgroupSizeLog2); + NBL_CONSTEXPR_STATIC_INLINE uint16_t Size = uint16_t(0x1u) << SubgroupSizeLog2; +}; + +template +struct is_configuration : bool_constant {}; + +template +struct is_configuration > : bool_constant {}; + +template +NBL_CONSTEXPR bool is_configuration_v = is_configuration::value; + +} +} +} + +#endif