Skip to content

Commit 9abcd89

Browse files
committed
Fix bug in map_policy_impl where policy is incorrectly called
1 parent 3a88b56 commit 9abcd89

File tree

3 files changed

+21
-15
lines changed

3 files changed

+21
-15
lines changed

include/kernel_float/apply.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,19 +153,23 @@ namespace detail {
153153
template<typename Policy, typename F, size_t N, typename Output, typename... Args>
154154
struct map_policy_impl {
155155
static constexpr size_t packet_size = preferred_vector_size<Output>::value;
156+
static constexpr size_t remainder = N % packet_size;
156157

157158
KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) {
158159
if constexpr (N / packet_size > 0) {
159160
#pragma unroll
160-
for (size_t i = 0; i < N - N % packet_size; i += packet_size) {
161-
Policy::template type<F, N, Output, Args...>::call(fun, output + i, (args + i)...);
161+
for (size_t i = 0; i < N - remainder; i += packet_size) {
162+
Policy::template type<F, packet_size, Output, Args...>::call(
163+
fun,
164+
output + i,
165+
(args + i)...);
162166
}
163167
}
164168

165-
if constexpr (N % packet_size > 0) {
169+
if constexpr (remainder > 0) {
166170
#pragma unroll
167-
for (size_t i = N - N % packet_size; i < N; i++) {
168-
Policy::template type<F, N, Output, Args...>::call(fun, output + i, (args + i)...);
171+
for (size_t i = N - remainder; i < N; i++) {
172+
Policy::template type<F, 1, Output, Args...>::call(fun, output + i, (args + i)...);
169173
}
170174
}
171175
}

include/kernel_float/base.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ struct alignas(Alignment) aligned_array<T, 1, Alignment> {
4545
};
4646

4747
template<typename T, size_t Alignment>
48-
4948
struct aligned_array<T, 0, Alignment> {
5049
KERNEL_FLOAT_INLINE
5150
T* data() {

single_include/kernel_float.h

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
//================================================================================
1818
// this file has been auto-generated, do not modify its contents!
19-
// date: 2024-07-24 15:35:29.178410
20-
// git hash: 986ca557aa59f869d68fe1e7184c2228517ea52d
19+
// date: 2024-09-23 14:12:25.024358
20+
// git hash: 3a88b56a57cce5e1f3365aa6e8efb76a14f7f865
2121
//================================================================================
2222

2323
#ifndef KERNEL_FLOAT_MACROS_H
@@ -85,7 +85,7 @@
8585

8686
#define KERNEL_FLOAT_MAX_ALIGNMENT (32)
8787

88-
#ifndef KERNEL_FLOAT_FAST_MATH
88+
#if KERNEL_FLOAT_FAST_MATH
8989
#define KERNEL_FLOAT_POLICY ::kernel_float::fast_policy;
9090
#endif
9191

@@ -424,7 +424,6 @@ struct alignas(Alignment) aligned_array<T, 1, Alignment> {
424424
};
425425

426426
template<typename T, size_t Alignment>
427-
428427
struct aligned_array<T, 0, Alignment> {
429428
KERNEL_FLOAT_INLINE
430429
T* data() {
@@ -807,19 +806,23 @@ namespace detail {
807806
template<typename Policy, typename F, size_t N, typename Output, typename... Args>
808807
struct map_policy_impl {
809808
static constexpr size_t packet_size = preferred_vector_size<Output>::value;
809+
static constexpr size_t remainder = N % packet_size;
810810

811811
KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) {
812812
if constexpr (N / packet_size > 0) {
813813
#pragma unroll
814-
for (size_t i = 0; i < N - N % packet_size; i += packet_size) {
815-
Policy::template type<F, N, Output, Args...>::call(fun, output + i, (args + i)...);
814+
for (size_t i = 0; i < N - remainder; i += packet_size) {
815+
Policy::template type<F, packet_size, Output, Args...>::call(
816+
fun,
817+
output + i,
818+
(args + i)...);
816819
}
817820
}
818821

819-
if constexpr (N % packet_size > 0) {
822+
if constexpr (remainder > 0) {
820823
#pragma unroll
821-
for (size_t i = N - N % packet_size; i < N; i++) {
822-
Policy::template type<F, N, Output, Args...>::call(fun, output + i, (args + i)...);
824+
for (size_t i = N - remainder; i < N; i++) {
825+
Policy::template type<F, 1, Output, Args...>::call(fun, output + i, (args + i)...);
823826
}
824827
}
825828
}

0 commit comments

Comments
 (0)