Skip to content

Commit 50fd2e2

Browse files
Merge pull request #873 from Devsh-Graphics-Programming/improve-subgroup-scan
Improve subgroup scan
2 parents a9395de + d93172f commit 50fd2e2

File tree

7 files changed

+354
-22
lines changed

7 files changed

+354
-22
lines changed

include/nbl/builtin/hlsl/cpp_compat/impl/intrinsics_impl.hlsl

+10-6
Original file line numberDiff line numberDiff line change
@@ -240,13 +240,17 @@ struct mix_helper<T, T NBL_PARTIAL_REQ_BOT(always_true<decltype(spirv::fMix<T>(e
240240
}
241241
};
242242

243-
template<typename T> NBL_PARTIAL_REQ_TOP(concepts::FloatingPointScalar<T>)
244-
struct mix_helper<T, bool NBL_PARTIAL_REQ_BOT(concepts::FloatingPointScalar<T>) >
243+
template<typename T, typename U>
244+
NBL_PARTIAL_REQ_TOP((concepts::Scalar<T> || concepts::Vectorial<T>) && !concepts::Boolean<T> && concepts::Boolean<U>)
245+
struct mix_helper<T, U NBL_PARTIAL_REQ_BOT((concepts::Scalar<T> || concepts::Vectorial<T>) && !concepts::Boolean<T> && concepts::Boolean<U>) >
245246
{
246247
using return_t = conditional_t<is_vector_v<T>, vector<typename vector_traits<T>::scalar_type, vector_traits<T>::Dimension>, T>;
247-
static inline return_t __call(const T x, const T y, const bool a)
248+
// for a component of a that is false, the corresponding component of x is returned
249+
// for a component of a that is true, the corresponding component of y is returned
250+
// so we make sure this is correct when calling the operation
251+
static inline return_t __call(const T x, const T y, const U a)
248252
{
249-
return a ? x : y;
253+
return spirv::select<T, U>(a, y, x);
250254
}
251255
};
252256

@@ -862,8 +866,8 @@ struct mix_helper<T, T NBL_PARTIAL_REQ_BOT(VECTOR_SPECIALIZATION_CONCEPT) >
862866
};
863867

864868
template<typename T, typename U>
865-
NBL_PARTIAL_REQ_TOP(concepts::Vectorial<T> && concepts::Boolean<U> && vector_traits<T>::Dimension == vector_traits<U>::Dimension)
866-
struct mix_helper<T, U NBL_PARTIAL_REQ_BOT(concepts::Vectorial<T> && concepts::Boolean<U> && vector_traits<T>::Dimension == vector_traits<U>::Dimension) >
869+
NBL_PARTIAL_REQ_TOP(VECTOR_SPECIALIZATION_CONCEPT && concepts::Boolean<U> && vector_traits<T>::Dimension == vector_traits<U>::Dimension)
870+
struct mix_helper<T, U NBL_PARTIAL_REQ_BOT(VECTOR_SPECIALIZATION_CONCEPT && concepts::Boolean<U> && vector_traits<T>::Dimension == vector_traits<U>::Dimension) >
867871
{
868872
using return_t = T;
869873
static return_t __call(NBL_CONST_REF_ARG(T) x, NBL_CONST_REF_ARG(T) y, NBL_CONST_REF_ARG(U) a)

include/nbl/builtin/hlsl/spirv_intrinsics/core.hlsl

+14
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,20 @@ template<typename T NBL_FUNC_REQUIRES(concepts::UnsignedIntegral<T>)
346346
[[vk::ext_instruction(spv::OpISubBorrow)]]
347347
SubBorrowOutput<T> subBorrow(T operand1, T operand2);
348348

349+
350+
template<typename T NBL_FUNC_REQUIRES(is_integral_v<T> && !is_matrix_v<T>)
351+
[[vk::ext_instruction(spv::OpIEqual)]]
352+
conditional_t<is_vector_v<T>, vector<bool, vector_traits<T>::Dimension>, bool> IEqual(T lhs, T rhs);
353+
354+
template<typename T NBL_FUNC_REQUIRES(is_floating_point_v<T> && !is_matrix_v<T>)
355+
[[vk::ext_instruction(spv::OpFOrdEqual)]]
356+
conditional_t<is_vector_v<T>, vector<bool, vector_traits<T>::Dimension>, bool> FOrdEqual(T lhs, T rhs);
357+
358+
359+
template<typename T, typename U NBL_FUNC_REQUIRES(!is_matrix_v<T> && !is_matrix_v<U> && is_same_v<typename vector_traits<U>::scalar_type, bool>)
360+
[[vk::ext_instruction(spv::OpSelect)]]
361+
T select(U a, T x, T y);
362+
349363
}
350364

351365
#endif

include/nbl/builtin/hlsl/spirv_intrinsics/subgroup_arithmetic.hlsl

+20-16
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,23 @@ namespace hlsl
1717
namespace spirv
1818
{
1919

20+
template<typename T>
2021
[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]]
2122
[[vk::ext_instruction( spv::OpGroupNonUniformIAdd )]]
22-
int32_t groupAdd(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, int32_t value);
23-
[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]]
24-
[[vk::ext_instruction( spv::OpGroupNonUniformIAdd )]]
25-
uint32_t groupAdd(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, uint32_t value);
23+
enable_if_t<!is_matrix_v<T> && is_integral_v<typename vector_traits<T>::scalar_type>, T> groupAdd(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value);
24+
template<typename T>
2625
[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]]
2726
[[vk::ext_instruction( spv::OpGroupNonUniformFAdd )]]
28-
float32_t groupAdd(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, float32_t value);
27+
enable_if_t<!is_matrix_v<T> && is_floating_point_v<typename vector_traits<T>::scalar_type>, T> groupAdd(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value);
2928

29+
template<typename T>
3030
[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]]
3131
[[vk::ext_instruction( spv::OpGroupNonUniformIMul )]]
32-
int32_t groupMul(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, int32_t value);
33-
[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]]
34-
[[vk::ext_instruction( spv::OpGroupNonUniformIMul )]]
35-
uint32_t groupMul(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, uint32_t value);
32+
enable_if_t<!is_matrix_v<T> && is_integral_v<typename vector_traits<T>::scalar_type>, T> groupMul(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value);
33+
template<typename T>
3634
[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]]
3735
[[vk::ext_instruction( spv::OpGroupNonUniformFMul )]]
38-
float32_t groupMul(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, float32_t value);
36+
enable_if_t<!is_matrix_v<T> && is_floating_point_v<typename vector_traits<T>::scalar_type>, T> groupMul(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value);
3937

4038
template<typename T>
4139
[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]]
@@ -54,25 +52,31 @@ T groupBitwiseXor(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T
5452

5553
// The MIN and MAX operations in SPIR-V have different Ops for each arithmetic type
5654
// so we implement them distinctly
55+
template<typename T>
5756
[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]]
5857
[[vk::ext_instruction( spv::OpGroupNonUniformSMin )]]
59-
int32_t groupBitwiseMin(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, int32_t value);
58+
enable_if_t<!is_matrix_v<T> && is_signed_v<T> && is_integral_v<typename vector_traits<T>::scalar_type>, T> groupSMin(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value);
59+
template<typename T>
6060
[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]]
6161
[[vk::ext_instruction( spv::OpGroupNonUniformUMin )]]
62-
uint32_t groupBitwiseMin(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, uint32_t value);
62+
enable_if_t<!is_matrix_v<T> && !is_signed_v<T> && is_integral_v<typename vector_traits<T>::scalar_type>, T> groupUMin(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value);
63+
template<typename T>
6364
[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]]
6465
[[vk::ext_instruction( spv::OpGroupNonUniformFMin )]]
65-
float32_t groupBitwiseMin(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, float32_t value);
66+
enable_if_t<!is_matrix_v<T> && is_floating_point_v<typename vector_traits<T>::scalar_type>, T> groupFMin(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value);
6667

68+
template<typename T>
6769
[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]]
6870
[[vk::ext_instruction( spv::OpGroupNonUniformSMax )]]
69-
int32_t groupBitwiseMax(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, int32_t value);
71+
enable_if_t<!is_matrix_v<T> && is_signed_v<T> && is_integral_v<typename vector_traits<T>::scalar_type>, T> groupSMax(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value);
72+
template<typename T>
7073
[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]]
7174
[[vk::ext_instruction( spv::OpGroupNonUniformUMax )]]
72-
uint32_t groupBitwiseMax(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, uint32_t value);
75+
enable_if_t<!is_matrix_v<T> && !is_signed_v<T> && is_integral_v<typename vector_traits<T>::scalar_type>, T> groupUMax(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value);
76+
template<typename T>
7377
[[vk::ext_capability( spv::CapabilityGroupNonUniformArithmetic )]]
7478
[[vk::ext_instruction( spv::OpGroupNonUniformFMax )]]
75-
float32_t groupBitwiseMax(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, float32_t value);
79+
enable_if_t<!is_matrix_v<T> && is_floating_point_v<typename vector_traits<T>::scalar_type>, T> groupFMax(uint32_t groupScope, [[vk::ext_literal]] uint32_t operation, T value);
7680

7781
}
7882
}

include/nbl/builtin/hlsl/subgroup/arithmetic_portability.hlsl

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include "nbl/builtin/hlsl/subgroup/basic.hlsl"
1111
#include "nbl/builtin/hlsl/subgroup/arithmetic_portability_impl.hlsl"
12+
#include "nbl/builtin/hlsl/concepts.hlsl"
1213

1314

1415
namespace nbl
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O.
2+
// This file is part of the "Nabla Engine".
3+
// For conditions of distribution and use, see copyright notice in nabla.h
4+
#ifndef _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_INCLUDED_
5+
#define _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_INCLUDED_
6+
7+
8+
#include "nbl/builtin/hlsl/device_capabilities_traits.hlsl"
9+
10+
#include "nbl/builtin/hlsl/subgroup2/ballot.hlsl"
11+
#include "nbl/builtin/hlsl/subgroup2/arithmetic_portability_impl.hlsl"
12+
#include "nbl/builtin/hlsl/concepts.hlsl"
13+
14+
15+
namespace nbl
16+
{
17+
namespace hlsl
18+
{
19+
namespace subgroup2
20+
{
21+
22+
template<typename Config, class BinOp, int32_t _ItemsPerInvocation=1, class device_capabilities=void NBL_PRIMARY_REQUIRES(is_configuration_v<Config> && is_scalar_v<typename BinOp::type_t>)
23+
struct ArithmeticParams
24+
{
25+
using config_t = Config;
26+
using binop_t = BinOp;
27+
using scalar_t = typename BinOp::type_t;
28+
using type_t = vector<scalar_t, _ItemsPerInvocation>;
29+
using device_traits = device_capabilities_traits<device_capabilities>;
30+
31+
NBL_CONSTEXPR_STATIC_INLINE int32_t ItemsPerInvocation = _ItemsPerInvocation;
32+
NBL_CONSTEXPR_STATIC_INLINE bool UseNativeIntrinsics = device_capabilities_traits<device_capabilities>::shaderSubgroupArithmetic /*&& /*some heuristic for when its faster*/;
33+
// TODO add a IHV enum to device_capabilities_traits to check !is_nvidia
34+
};
35+
36+
template<typename Params>
37+
struct reduction : impl::reduction<Params,typename Params::binop_t,Params::ItemsPerInvocation,Params::UseNativeIntrinsics> {};
38+
template<typename Params>
39+
struct inclusive_scan : impl::inclusive_scan<Params,typename Params::binop_t,Params::ItemsPerInvocation,Params::UseNativeIntrinsics> {};
40+
template<typename Params>
41+
struct exclusive_scan : impl::exclusive_scan<Params,typename Params::binop_t,Params::ItemsPerInvocation,Params::UseNativeIntrinsics> {};
42+
43+
}
44+
}
45+
}
46+
47+
#endif

0 commit comments

Comments
 (0)