@@ -130,51 +130,50 @@ struct apply_impl {
130
130
131
131
template <typename F, size_t N, typename Output, typename ... Args>
132
132
struct apply_fastmath_impl : apply_impl<F, N, Output, Args...> {};
133
+ } // namespace detail
133
134
134
- template <typename F, size_t N, typename Output, typename ... Args>
135
- struct map_impl {
136
- static constexpr size_t packet_size = preferred_vector_size<Output>::value;
137
-
138
- KERNEL_FLOAT_INLINE static void call (F fun, Output* output, const Args*... args) {
139
- if constexpr (N / packet_size > 0 ) {
140
- #pragma unroll
141
- for (size_t i = 0 ; i < N - N % packet_size; i += packet_size) {
142
- apply_impl<F, packet_size, Output, Args...>::call (fun, output + i, (args + i)...);
143
- }
144
- }
135
+ struct accurate_policy {
136
+ template <typename F, size_t N, typename Output, typename ... Args>
137
+ using type = detail::apply_impl<F, N, Output, Args...>;
138
+ };
145
139
146
- if constexpr (N % packet_size > 0 ) {
147
- #pragma unroll
148
- for (size_t i = N - N % packet_size; i < N; i++) {
149
- apply_impl<F, 1 , Output, Args...>::call (fun, output + i, (args + i)...);
150
- }
151
- }
152
- }
140
+ struct fast_policy {
141
+ template <typename F, size_t N, typename Output, typename ... Args>
142
+ using type = detail::apply_fastmath_impl<F, N, Output, Args...>;
153
143
};
154
144
155
- template <typename F, size_t N, typename Output, typename ... Args>
156
- struct fast_map_impl {
145
+ #ifdef KERNEL_FLOAT_POLICY
146
+ using default_policy = KERNEL_FLOAT_POLICY;
147
+ #else
148
+ using default_policy = accurate_policy;
149
+ #endif
150
+
151
+ namespace detail {
152
+
153
+ template <typename Policy, typename F, size_t N, typename Output, typename ... Args>
154
+ struct map_policy_impl {
157
155
static constexpr size_t packet_size = preferred_vector_size<Output>::value;
158
156
159
157
KERNEL_FLOAT_INLINE static void call (F fun, Output* output, const Args*... args) {
160
158
if constexpr (N / packet_size > 0 ) {
161
159
#pragma unroll
162
160
for (size_t i = 0 ; i < N - N % packet_size; i += packet_size) {
163
- apply_fastmath_impl<F, packet_size, Output, Args...>::call (
164
- fun,
165
- output + i,
166
- (args + i)...);
161
+ Policy::template type<F, N, Output, Args...>::call (fun, output + i, (args + i)...);
167
162
}
168
163
}
169
164
170
165
if constexpr (N % packet_size > 0 ) {
171
166
#pragma unroll
172
167
for (size_t i = N - N % packet_size; i < N; i++) {
173
- apply_fastmath_impl <F, 1 , Output, Args...>::call (fun, output + i, (args + i)...);
168
+ Policy:: template type <F, N , Output, Args...>::call (fun, output + i, (args + i)...);
174
169
}
175
170
}
176
171
}
177
172
};
173
+
174
+ template <typename F, size_t N, typename Output, typename ... Args>
175
+ using map_impl = map_policy_impl<default_policy, F, N, Output, Args...>;
176
+
178
177
} // namespace detail
179
178
180
179
template <typename F, typename ... Args>
@@ -191,41 +190,13 @@ using map_type =
191
190
* vec<float, 4> squared = map([](auto x) { return x * x; }, input); // [1.0f, 4.0f, 9.0f, 16.0f]
192
191
* ```
193
192
*/
194
- template <typename F, typename ... Args>
193
+ template <typename Accuracy = default_policy, typename F, typename ... Args>
195
194
KERNEL_FLOAT_INLINE map_type<F, Args...> map (F fun, const Args&... args) {
196
195
using Output = result_t <F, vector_value_type<Args>...>;
197
196
using E = broadcast_vector_extent_type<Args...>;
198
197
vector_storage<Output, extent_size<E>> result;
199
198
200
- // Use the `apply_fastmath_impl` if KERNEL_FLOAT_FAST_MATH is enabled
201
- #if KERNEL_FLOAT_FAST_MATH
202
- using apply_impl =
203
- detail::fast_math_impl<F, extent_size<E>, Output, vector_value_type<Args>...>;
204
- #else
205
- using map_impl = detail::map_impl<F, extent_size<E>, Output, vector_value_type<Args>...>;
206
- #endif
207
-
208
- map_impl::call (
209
- fun,
210
- result.data (),
211
- (detail::broadcast_impl<vector_value_type<Args>, vector_extent_type<Args>, E>::call (
212
- into_vector_storage (args))
213
- .data ())...);
214
-
215
- return result;
216
- }
217
-
218
- /* *
219
- * Apply the function `F` to each element from the vector `input` and return the results as a new vector. This
220
- * uses fast-math if available for the given function `F`, otherwise this function behaves like `map`.
221
- */
222
- template <typename F, typename ... Args>
223
- KERNEL_FLOAT_INLINE map_type<F, Args...> fast_map (F fun, const Args&... args) {
224
- using Output = result_t <F, vector_value_type<Args>...>;
225
- using E = broadcast_vector_extent_type<Args...>;
226
- vector_storage<Output, extent_size<E>> result;
227
-
228
- detail::fast_map_impl<F, extent_size<E>, Output, vector_value_type<Args>...>::call (
199
+ detail::map_policy_impl<Accuracy, F, extent_size<E>, Output, vector_value_type<Args>...>::call (
229
200
fun,
230
201
result.data (),
231
202
(detail::broadcast_impl<vector_value_type<Args>, vector_extent_type<Args>, E>::call (
0 commit comments