Skip to content

Commit 06c10e3

Browse files
committed
Simplify implementation of map_impl
1 parent b59ee5f commit 06c10e3

File tree

10 files changed

+172
-112
lines changed

10 files changed

+172
-112
lines changed

include/kernel_float/apply.h

Lines changed: 49 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -118,43 +118,63 @@ broadcast_like(const V& input, const R& other) {
118118

119119
namespace detail {
120120

121-
template<size_t N>
122-
struct apply_recur_impl;
123-
124121
template<typename F, size_t N, typename Output, typename... Args>
125122
struct apply_impl {
126-
KERNEL_FLOAT_INLINE static void call(F fun, Output* result, const Args*... inputs) {
127-
apply_recur_impl<N>::call(fun, result, inputs...);
123+
KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) {
124+
#pragma unroll
125+
for (size_t i = 0; i < N; i++) {
126+
output[i] = fun(args[i]...);
127+
}
128128
}
129129
};
130130

131-
template<size_t N>
132-
struct apply_recur_impl {
133-
static constexpr size_t K = round_up_to_power_of_two(N) / 2;
131+
template<typename F, size_t N, typename Output, typename... Args>
132+
struct apply_fastmath_impl: apply_impl<F, N, Output, Args...> {};
133+
134+
template<typename F, size_t N, typename Output, typename... Args>
135+
struct map_impl {
136+
static constexpr size_t packet_size = preferred_vector_size<Output>::value;
137+
138+
KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) {
139+
if constexpr (N / packet_size > 0) {
140+
#pragma unroll
141+
for (size_t i = 0; i < N - N % packet_size; i += packet_size) {
142+
apply_impl<F, packet_size, Output, Args...>::call(fun, output + i, (args + i)...);
143+
}
144+
}
134145

135-
template<typename F, typename Output, typename... Args>
136-
KERNEL_FLOAT_INLINE static void call(F fun, Output* result, const Args*... inputs) {
137-
apply_impl<F, K, Output, Args...>::call(fun, result, inputs...);
138-
apply_impl<F, N - K, Output, Args...>::call(fun, result + K, (inputs + K)...);
146+
if constexpr (N % packet_size > 0) {
147+
#pragma unroll
148+
for (size_t i = N - N % packet_size; i < N; i++) {
149+
apply_impl<F, 1, Output, Args...>::call(fun, output + i, (args + i)...);
150+
}
151+
}
139152
}
140153
};
141154

142-
template<>
143-
struct apply_recur_impl<0> {
144-
template<typename F, typename Output, typename... Args>
145-
KERNEL_FLOAT_INLINE static void call(F fun, Output* result, const Args*... inputs) {}
146-
};
155+
template<typename F, size_t N, typename Output, typename... Args>
156+
struct fast_map_impl {
157+
static constexpr size_t packet_size = preferred_vector_size<Output>::value;
158+
159+
KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) {
160+
if constexpr (N / packet_size > 0) {
161+
#pragma unroll
162+
for (size_t i = 0; i < N - N % packet_size; i += packet_size) {
163+
apply_fastmath_impl<F, packet_size, Output, Args...>::call(
164+
fun,
165+
output + i,
166+
(args + i)...);
167+
}
168+
}
147169

148-
template<>
149-
struct apply_recur_impl<1> {
150-
template<typename F, typename Output, typename... Args>
151-
KERNEL_FLOAT_INLINE static void call(F fun, Output* result, const Args*... inputs) {
152-
result[0] = fun(inputs[0]...);
170+
if constexpr (N % packet_size > 0) {
171+
#pragma unroll
172+
for (size_t i = N - N % packet_size; i < N; i++) {
173+
apply_fastmath_impl<F, 1, Output, Args...>::call(fun, output + i, (args + i)...);
174+
}
175+
}
153176
}
154177
};
155-
156-
template<typename F, size_t N, typename Output, typename... Args>
157-
struct apply_fastmath_impl: apply_impl<F, N, Output, Args...> {};
158178
} // namespace detail
159179

160180
template<typename F, typename... Args>
@@ -180,12 +200,12 @@ KERNEL_FLOAT_INLINE map_type<F, Args...> map(F fun, const Args&... args) {
180200
// Use the `apply_fastmath_impl` if KERNEL_FLOAT_FAST_MATH is enabled
181201
#if KERNEL_FLOAT_FAST_MATH
182202
using apply_impl =
183-
detail::apply_fastmath_impl<F, extent_size<E>, Output, vector_value_type<Args>...>;
203+
detail::fast_math_impl<F, extent_size<E>, Output, vector_value_type<Args>...>;
184204
#else
185-
using apply_impl = detail::apply_impl<F, extent_size<E>, Output, vector_value_type<Args>...>;
205+
using map_impl = detail::map_impl<F, extent_size<E>, Output, vector_value_type<Args>...>;
186206
#endif
187207

188-
apply_impl::call(
208+
map_impl::call(
189209
fun,
190210
result.data(),
191211
(detail::broadcast_impl<vector_value_type<Args>, vector_extent_type<Args>, E>::call(
@@ -205,7 +225,7 @@ KERNEL_FLOAT_INLINE map_type<F, Args...> fast_map(F fun, const Args&... args) {
205225
using E = broadcast_vector_extent_type<Args...>;
206226
vector_storage<Output, extent_size<E>> result;
207227

208-
detail::apply_fastmath_impl<F, extent_size<E>, Output, vector_value_type<Args>...>::call(
228+
detail::fast_map_impl<F, extent_size<E>, Output, vector_value_type<Args>...>::call(
209229
fun,
210230
result.data(),
211231
(detail::broadcast_impl<vector_value_type<Args>, vector_extent_type<Args>, E>::call(

include/kernel_float/base.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,11 @@ struct into_vector_impl<vector<T, E, S>> {
231231
}
232232
};
233233

234+
template<typename T>
235+
struct preferred_vector_size {
236+
static constexpr size_t value = 1;
237+
};
238+
234239
template<typename V>
235240
struct vector_traits;
236241

include/kernel_float/bf16.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@
1111
#include "vector.h"
1212

1313
namespace kernel_float {
14+
15+
template<>
16+
struct preferred_vector_size<__nv_bfloat16> {
17+
static constexpr size_t value = 2;
18+
};
19+
1420
KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(__nv_bfloat16)
1521
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, __nv_bfloat16)
1622
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, __nv_bfloat16)

include/kernel_float/binops.h

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,12 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co
5454

5555
// Use the `apply_fastmath_impl` if KERNEL_FLOAT_FAST_MATH is enabled
5656
#if KERNEL_FLOAT_FAST_MATH
57-
using apply_impl = detail::apply_fastmath_impl<F, extent_size<E>, O, T, T>;
57+
using map_impl = detail::fast_map_impl<F, extent_size<E>, O, T, T>;
5858
#else
59-
using apply_impl = detail::apply_impl<F, extent_size<E>, O, T, T>;
59+
using map_impl = detail::map_impl<F, extent_size<E>, O, T, T>;
6060
#endif
6161

62-
apply_impl::call(
62+
map_impl::call(
6363
fun,
6464
result.data(),
6565
detail::convert_impl<vector_value_type<L>, vector_extent_type<L>, T, E>::call(
@@ -310,14 +310,11 @@ struct apply_fastmath_impl<ops::divide<T>, N, T, T, T> {
310310
};
311311

312312
#if KERNEL_FLOAT_IS_DEVICE
313-
template<size_t N>
314-
struct apply_fastmath_impl<ops::divide<float>, N, float, float, float> {
313+
template<>
314+
struct apply_fastmath_impl<ops::divide<float>, 1, float, float, float> {
315315
KERNEL_FLOAT_INLINE static void
316316
call(ops::divide<float> fun, float* result, const float* lhs, const float* rhs) {
317-
#pragma unroll
318-
for (size_t i = 0; i < N; i++) {
319-
result[i] = __fdividef(lhs[i], rhs[i]);
320-
}
317+
*result = __fdividef(*lhs, *rhs);
321318
}
322319
};
323320
#endif
@@ -329,7 +326,7 @@ fast_divide(const L& left, const R& right) {
329326
using E = broadcast_vector_extent_type<L, R>;
330327
vector_storage<T, extent_size<E>> result;
331328

332-
detail::apply_fastmath_impl<ops::divide<T>, extent_size<E>, T, T, T>::call(
329+
detail::fast_map_impl<ops::divide<T>, extent_size<E>, T, T, T>::call(
333330
ops::divide<T> {},
334331
result.data(),
335332
detail::convert_impl<vector_value_type<L>, vector_extent_type<L>, T, E>::call(

include/kernel_float/conversion.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ struct convert_impl {
1717
static vector_storage<T2, extent_size<E2>> call(vector_storage<T, extent_size<E>> input) {
1818
using F = ops::cast<T, T2, M>;
1919
vector_storage<T2, extent_size<E>> intermediate;
20-
detail::apply_impl<F, extent_size<E>, T2, T>::call(F {}, intermediate.data(), input.data());
20+
detail::map_impl<F, extent_size<E>, T2, T>::call(F {}, intermediate.data(), input.data());
2121
return detail::broadcast_impl<T2, E, E2>::call(intermediate);
2222
}
2323
};
@@ -48,7 +48,7 @@ struct convert_impl<T, E, T2, E, M> {
4848
using F = ops::cast<T, T2, M>;
4949

5050
vector_storage<T2, extent_size<E>> result;
51-
detail::apply_impl<F, extent_size<E>, T2, T>::call(F {}, result.data(), input.data());
51+
detail::map_impl<F, extent_size<E>, T2, T>::call(F {}, result.data(), input.data());
5252
return result;
5353
}
5454
};

include/kernel_float/fp16.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
#include "vector.h"
1010

1111
namespace kernel_float {
12+
13+
template<>
14+
struct preferred_vector_size<__half> {
15+
static constexpr size_t value = 2;
16+
};
17+
1218
KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(__half)
1319
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, __half)
1420
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, __half)

include/kernel_float/reduce.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ struct reduce_recur_impl {
2323
template<typename F, typename T>
2424
KERNEL_FLOAT_INLINE static T call(F fun, const T* input) {
2525
vector_storage<T, K> temp;
26-
apply_impl<F, N - K, T, T, T>::call(fun, temp.data(), input, input + K);
26+
map_impl<F, N - K, T, T, T>::call(fun, temp.data(), input, input + K);
2727

2828
if constexpr (N < 2 * K) {
2929
#pragma unroll
@@ -178,7 +178,7 @@ struct dot_impl {
178178
KERNEL_FLOAT_INLINE
179179
static T call(const T* left, const T* right) {
180180
vector_storage<T, N> intermediate;
181-
detail::apply_impl<ops::multiply<T>, N, T, T, T>::call(
181+
detail::map_impl<ops::multiply<T>, N, T, T, T>::call(
182182
ops::multiply<T>(),
183183
intermediate.data(),
184184
left,

include/kernel_float/triops.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ KERNEL_FLOAT_INLINE vector<T, E> where(const C& cond, const L& true_values, cons
4141
using F = ops::conditional<T>;
4242
vector_storage<T, extent_size<E>> result;
4343

44-
detail::apply_impl<F, extent_size<E>, T, bool, T, T>::call(
44+
detail::map_impl<F, extent_size<E>, T, bool, T, T>::call(
4545
F {},
4646
result.data(),
4747
detail::convert_impl<vector_value_type<C>, vector_extent_type<C>, bool, E>::call(
@@ -126,7 +126,7 @@ KERNEL_FLOAT_INLINE vector<T, E> fma(const A& a, const B& b, const C& c) {
126126
using F = ops::fma<T>;
127127
vector_storage<T, extent_size<E>> result;
128128

129-
detail::apply_impl<F, extent_size<E>, T, T, T, T>::call(
129+
detail::map_impl<F, extent_size<E>, T, T, T, T>::call(
130130
F {},
131131
result.data(),
132132
detail::convert_impl<vector_value_type<A>, vector_extent_type<A>, T, E>::call(

include/kernel_float/unops.h

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -214,12 +214,10 @@ KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(tan)
214214

215215
#define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(T, F, FAST_FUN) \
216216
namespace detail { \
217-
template<size_t N> \
218-
struct apply_fastmath_impl<ops::F<T>, N, T, T> { \
217+
template<> \
218+
struct apply_fastmath_impl<ops::F<T>, 1, T, T> { \
219219
KERNEL_FLOAT_INLINE static void call(ops::F<T>, T* result, const T* inputs) { \
220-
for (size_t i = 0; i < N; i++) { \
221-
result[i] = FAST_FUN(inputs[i]); \
222-
} \
220+
*result = FAST_FUN(*inputs); \
223221
} \
224222
}; \
225223
}
@@ -229,12 +227,10 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log, __logf)
229227

230228
#define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(T, F, INSTR, REG) \
231229
namespace detail { \
232-
template<size_t N> \
233-
struct apply_fastmath_impl<ops::F<T>, N, T, T> { \
230+
template<> \
231+
struct apply_fastmath_impl<ops::F<T>, 1, T, T> { \
234232
KERNEL_FLOAT_INLINE static void call(ops::F<T> fun, T* result, const T* inputs) { \
235-
for (size_t i = 0; i < N; i++) { \
236-
asm(INSTR : "=" REG(result[i]) : REG(inputs[i])); \
237-
} \
233+
asm(INSTR : "=" REG(*result) : REG(*inputs)); \
238234
} \
239235
}; \
240236
}

0 commit comments

Comments
 (0)