Skip to content

Commit d0d92a7

Browse files
committed
Change implementation of map_impl to support policies
1 parent 986ca55 commit d0d92a7

File tree

7 files changed

+126
-150
lines changed

7 files changed

+126
-150
lines changed

docs/guides.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ Guides
55

66
guides/introduction.rst
77
guides/promotion.rst
8+
guides/accuracy.rst
89
guides/constant.rst

docs/guides/accuracy.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
Accuracy level
2+
===
3+
4+
Many of the functions in Kernel Float take an additional `Accuracy` option as a template parameter.
5+
This option can be used to increase the performance of certain operations, at the cost of lower accuracy.
6+
7+
There are four possible values for this parameter:
8+
9+
* `accurate_policy`: Use the most accurate version available.
10+
* `fast_policy`: Use the "fast math" version (for example, `__sinf` for sin on CUDA devices). Falls back to `accurate_policy` if unavailable.
11+
* `approx_policy<N>`: Rough approximation using a polynomial of degree `N`. Falls back to `fast_policy` if no polynomial exists.
12+
* `default_policy`: Use a global default policy (see the next section).
13+
14+
15+
For example, consider this code:
16+
17+
```C++
18+
19+
#include "kernel_float.h"
20+
namespace kf = kernel_float;
21+
22+
23+
int main() {
24+
kf::vec<float, 2> input = {1.0f, 2.0f};
25+
26+
// Use the default policy
27+
kf::vec<float, 2> A = kf::cos(input);
28+
29+
// Use the most accuracy policy
30+
kf::vec<float, 2> B = kf::cos<kf::accurate_policy>(input);
31+
32+
// Use the fastest policy
33+
kf::vec<float, 2> C = kf::cos<kf::fast_policy>(input);
34+
35+
printf("A = %f, %f", A[0], A[1]);
36+
printf("B = %f, %f", B[0], B[1]);
37+
printf("C = %f, %f", C[0], C[1]);
38+
39+
return EXIT_SUCCESS;
40+
}
41+
42+
```
43+
44+
45+
Setting `default_policy`
46+
---
47+
By default, the value for `default_policy` is `accurate_policy`.
48+
49+
Set the preprocessor option `KERNEL_FLOAT_FAST_MATH=1` to change the default policy to `fast_policy`.

include/kernel_float/apply.h

Lines changed: 26 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -130,51 +130,50 @@ struct apply_impl {
130130

131131
template<typename F, size_t N, typename Output, typename... Args>
132132
struct apply_fastmath_impl: apply_impl<F, N, Output, Args...> {};
133+
} // namespace detail
133134

134-
template<typename F, size_t N, typename Output, typename... Args>
135-
struct map_impl {
136-
static constexpr size_t packet_size = preferred_vector_size<Output>::value;
137-
138-
KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) {
139-
if constexpr (N / packet_size > 0) {
140-
#pragma unroll
141-
for (size_t i = 0; i < N - N % packet_size; i += packet_size) {
142-
apply_impl<F, packet_size, Output, Args...>::call(fun, output + i, (args + i)...);
143-
}
144-
}
135+
struct accurate_policy {
136+
template<typename F, size_t N, typename Output, typename... Args>
137+
using type = detail::apply_impl<F, N, Output, Args...>;
138+
};
145139

146-
if constexpr (N % packet_size > 0) {
147-
#pragma unroll
148-
for (size_t i = N - N % packet_size; i < N; i++) {
149-
apply_impl<F, 1, Output, Args...>::call(fun, output + i, (args + i)...);
150-
}
151-
}
152-
}
140+
struct fast_policy {
141+
template<typename F, size_t N, typename Output, typename... Args>
142+
using type = detail::apply_fastmath_impl<F, N, Output, Args...>;
153143
};
154144

155-
template<typename F, size_t N, typename Output, typename... Args>
156-
struct fast_map_impl {
145+
#ifdef KERNEL_FLOAT_POLICY
146+
using default_policy = KERNEL_FLOAT_POLICY;
147+
#else
148+
using default_policy = accurate_policy;
149+
#endif
150+
151+
namespace detail {
152+
153+
template<typename Policy, typename F, size_t N, typename Output, typename... Args>
154+
struct map_policy_impl {
157155
static constexpr size_t packet_size = preferred_vector_size<Output>::value;
158156

159157
KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) {
160158
if constexpr (N / packet_size > 0) {
161159
#pragma unroll
162160
for (size_t i = 0; i < N - N % packet_size; i += packet_size) {
163-
apply_fastmath_impl<F, packet_size, Output, Args...>::call(
164-
fun,
165-
output + i,
166-
(args + i)...);
161+
Policy::template type<F, N, Output, Args...>::call(fun, output + i, (args + i)...);
167162
}
168163
}
169164

170165
if constexpr (N % packet_size > 0) {
171166
#pragma unroll
172167
for (size_t i = N - N % packet_size; i < N; i++) {
173-
apply_fastmath_impl<F, 1, Output, Args...>::call(fun, output + i, (args + i)...);
168+
Policy::template type<F, N, Output, Args...>::call(fun, output + i, (args + i)...);
174169
}
175170
}
176171
}
177172
};
173+
174+
template<typename F, size_t N, typename Output, typename... Args>
175+
using map_impl = map_policy_impl<default_policy, F, N, Output, Args...>;
176+
178177
} // namespace detail
179178

180179
template<typename F, typename... Args>
@@ -191,41 +190,13 @@ using map_type =
191190
* vec<float, 4> squared = map([](auto x) { return x * x; }, input); // [1.0f, 4.0f, 9.0f, 16.0f]
192191
* ```
193192
*/
194-
template<typename F, typename... Args>
193+
template<typename Accuracy = default_policy, typename F, typename... Args>
195194
KERNEL_FLOAT_INLINE map_type<F, Args...> map(F fun, const Args&... args) {
196195
using Output = result_t<F, vector_value_type<Args>...>;
197196
using E = broadcast_vector_extent_type<Args...>;
198197
vector_storage<Output, extent_size<E>> result;
199198

200-
// Use the `apply_fastmath_impl` if KERNEL_FLOAT_FAST_MATH is enabled
201-
#if KERNEL_FLOAT_FAST_MATH
202-
using apply_impl =
203-
detail::fast_math_impl<F, extent_size<E>, Output, vector_value_type<Args>...>;
204-
#else
205-
using map_impl = detail::map_impl<F, extent_size<E>, Output, vector_value_type<Args>...>;
206-
#endif
207-
208-
map_impl::call(
209-
fun,
210-
result.data(),
211-
(detail::broadcast_impl<vector_value_type<Args>, vector_extent_type<Args>, E>::call(
212-
into_vector_storage(args))
213-
.data())...);
214-
215-
return result;
216-
}
217-
218-
/**
219-
* Apply the function `F` to each element from the vector `input` and return the results as a new vector. This
220-
* uses fast-math if available for the given function `F`, otherwise this function behaves like `map`.
221-
*/
222-
template<typename F, typename... Args>
223-
KERNEL_FLOAT_INLINE map_type<F, Args...> fast_map(F fun, const Args&... args) {
224-
using Output = result_t<F, vector_value_type<Args>...>;
225-
using E = broadcast_vector_extent_type<Args...>;
226-
vector_storage<Output, extent_size<E>> result;
227-
228-
detail::fast_map_impl<F, extent_size<E>, Output, vector_value_type<Args>...>::call(
199+
detail::map_policy_impl<Accuracy, F, extent_size<E>, Output, vector_value_type<Args>...>::call(
229200
fun,
230201
result.data(),
231202
(detail::broadcast_impl<vector_value_type<Args>, vector_extent_type<Args>, E>::call(

include/kernel_float/binops.h

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,7 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co
5252

5353
vector_storage<O, extent_size<E>> result;
5454

55-
// Use the `apply_fastmath_impl` if KERNEL_FLOAT_FAST_MATH is enabled
56-
#if KERNEL_FLOAT_FAST_MATH
57-
using map_impl = detail::fast_map_impl<F, extent_size<E>, O, T, T>;
58-
#else
59-
using map_impl = detail::map_impl<F, extent_size<E>, O, T, T>;
60-
#endif
61-
62-
map_impl::call(
55+
detail::map_impl<F, extent_size<E>, O, T, T>::call(
6356
fun,
6457
result.data(),
6558
detail::convert_impl<vector_value_type<L>, vector_extent_type<L>, T, E>::call(
@@ -304,7 +297,7 @@ struct apply_fastmath_impl<ops::divide<T>, N, T, T, T> {
304297
T rhs_rcp[N];
305298

306299
// Fast way to perform division is to multiply by the reciprocal
307-
apply_fastmath_impl<ops::rcp<T>, N, T, T, T>::call({}, rhs_rcp, rhs);
300+
apply_fastmath_impl<ops::rcp<T>, N, T, T>::call({}, rhs_rcp, rhs);
308301
apply_fastmath_impl<ops::multiply<T>, N, T, T, T>::call({}, result, lhs, rhs_rcp);
309302
}
310303
};
@@ -326,7 +319,7 @@ fast_divide(const L& left, const R& right) {
326319
using E = broadcast_vector_extent_type<L, R>;
327320
vector_storage<T, extent_size<E>> result;
328321

329-
detail::fast_map_impl<ops::divide<T>, extent_size<E>, T, T, T>::call(
322+
detail::map_policy_impl<fast_policy, ops::divide<T>, extent_size<E>, T, T, T>::call(
330323
ops::divide<T> {},
331324
result.data(),
332325
detail::convert_impl<vector_value_type<L>, vector_extent_type<L>, T, E>::call(

include/kernel_float/macros.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
#define KERNEL_FLOAT_MAX_ALIGNMENT (32)
6565

6666
#ifndef KERNEL_FLOAT_FAST_MATH
67-
#define KERNEL_FLOAT_FAST_MATH (0)
67+
#define KERNEL_FLOAT_POLICY ::kernel_float::fast_policy;
6868
#endif
6969

7070
#endif //KERNEL_FLOAT_MACROS_H

include/kernel_float/unops.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,10 @@ KERNEL_FLOAT_INLINE vector<R, vector_extent_type<V>> cast(const V& input) {
8585
}
8686

8787
#define KERNEL_FLOAT_DEFINE_UNARY_FUN(NAME) \
88-
template<typename V> \
88+
template<typename Accuracy = default_policy, typename V> \
8989
KERNEL_FLOAT_INLINE vector<vector_value_type<V>, vector_extent_type<V>> NAME(const V& input) { \
9090
using F = ops::NAME<vector_value_type<V>>; \
91-
return map(F {}, input); \
91+
return ::kernel_float::map<Accuracy>(F {}, input); \
9292
}
9393

9494
#define KERNEL_FLOAT_DEFINE_UNARY(NAME, EXPR) \
@@ -193,12 +193,11 @@ KERNEL_FLOAT_DEFINE_UNARY_STRUCT(rcp, 1.0 / input, 1.0f / input)
193193

194194
KERNEL_FLOAT_DEFINE_UNARY_FUN(rcp)
195195

196-
#define KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(NAME) \
197-
template<typename V> \
198-
KERNEL_FLOAT_INLINE vector<vector_value_type<V>, vector_extent_type<V>> fast_##NAME( \
199-
const V& input) { \
200-
using F = ops::NAME<vector_value_type<V>>; \
201-
return fast_map(F {}, input); \
196+
#define KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(NAME) \
197+
template<typename V> \
198+
KERNEL_FLOAT_INLINE vector<vector_value_type<V>, vector_extent_type<V>> fast_##NAME( \
199+
const V& input) { \
200+
return ::kernel_float::map<fast_policy>(ops::NAME<vector_value_type<V>> {}, input); \
202201
}
203202

204203
KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp)

0 commit comments

Comments
 (0)