1515 */
1616
1717#include < cudf/column/column_device_view_base.cuh>
18+ #include < cudf/jit/span.cuh>
1819#include < cudf/strings/string_view.cuh>
1920#include < cudf/types.hpp>
2021#include < cudf/utilities/traits.hpp>
@@ -37,27 +38,45 @@ namespace transformation {
3738namespace jit {
3839
3940template <typename T, int32_t Index>
40- struct accessor {
41+ struct column_accessor {
4142 using type = T;
4243 static constexpr int32_t index = Index;
4344
44- static __device__ decltype (auto ) element(cudf::mutable_column_device_view_core const * views ,
45+ static __device__ decltype (auto ) element(cudf::mutable_column_device_view_core const * outputs ,
4546 cudf::size_type row)
4647 {
47- return views [index].element <T>(row);
48+ return outputs [index].element <T>(row);
4849 }
4950
50- static __device__ decltype (auto ) element(cudf::column_device_view_core const * views ,
51+ static __device__ decltype (auto ) element(cudf::column_device_view_core const * inputs ,
5152 cudf::size_type row)
5253 {
53- return views [index].element <T>(row);
54+ return inputs [index].element <T>(row);
5455 }
5556
56- static __device__ void assign (cudf::mutable_column_device_view_core const * views ,
57+ static __device__ void assign (cudf::mutable_column_device_view_core const * outputs ,
5758 cudf::size_type row,
5859 T value)
5960 {
60- views[index].assign <T>(row, value);
61+ outputs[index].assign <T>(row, value);
62+ }
63+ };
64+
65+ template <typename T, int32_t Index>
66+ struct span_accessor {
67+ using type = T;
68+ static constexpr int32_t index = Index;
69+
70+ static __device__ type& element (cudf::jit::device_span<T> const * spans, cudf::size_type row)
71+ {
72+ return spans[index][row];
73+ }
74+
75+ static __device__ void assign (cudf::jit::device_span<T> const * outputs,
76+ cudf::size_type row,
77+ T value)
78+ {
79+ outputs[index][row] = value;
6180 }
6281};
6382
@@ -66,59 +85,94 @@ struct scalar {
6685 using type = typename Accessor::type;
6786 static constexpr int32_t index = Accessor::index;
6887
69- static __device__ decltype (auto ) element(cudf::mutable_column_device_view_core const * views ,
88+ static __device__ decltype (auto ) element(cudf::mutable_column_device_view_core const * outputs ,
7089 cudf::size_type row)
7190 {
72- return Accessor::element (views , 0 );
91+ return Accessor::element (outputs , 0 );
7392 }
7493
75- static __device__ decltype (auto ) element(cudf::column_device_view_core const * views ,
94+ static __device__ decltype (auto ) element(cudf::column_device_view_core const * inputs ,
7695 cudf::size_type row)
7796 {
78- return Accessor::element (views , 0 );
97+ return Accessor::element (inputs , 0 );
7998 }
8099
81- static __device__ void assign (cudf::mutable_column_device_view_core const * views ,
100+ static __device__ void assign (cudf::mutable_column_device_view_core const * outputs ,
82101 cudf::size_type row,
83102 type value)
84103 {
85- return Accessor::assign (views , 0 , value);
104+ return Accessor::assign (outputs , 0 , value);
86105 }
87106};
88107
89- template <typename Out, typename ... In>
90- CUDF_KERNEL void kernel (cudf::mutable_column_device_view_core const * output,
91- cudf::column_device_view_core const * inputs)
108+ template <bool has_user_data, typename Out, typename ... In>
109+ CUDF_KERNEL void kernel (cudf::mutable_column_device_view_core const * outputs,
110+ cudf::column_device_view_core const * inputs,
111+ void * user_data)
92112{
113+ // inputs to JITIFY kernels have to be either sized-integral types or pointers. Structs or
114+ // references can't be passed directly/correctly as they will be crossing an ABI boundary
115+
93116 // cannot use global_thread_id utility due to a JIT build issue by including
94117 // the `cudf/detail/utilities/cuda.cuh` header
95118 auto const block_size = static_cast <thread_index_type>(blockDim .x );
96119 thread_index_type const start = threadIdx .x + blockIdx .x * block_size;
97120 thread_index_type const stride = block_size * gridDim .x ;
98- thread_index_type const size = output-> size ();
121+ thread_index_type const size = outputs[ 0 ]. size ();
99122
100123 for (auto i = start; i < size; i += stride) {
101- GENERIC_TRANSFORM_OP (&Out::element (output, i), In::element (inputs, i)...);
124+ if constexpr (has_user_data) {
125+ GENERIC_TRANSFORM_OP (user_data, i, &Out::element (outputs, i), In::element (inputs, i)...);
126+ } else {
127+ GENERIC_TRANSFORM_OP (&Out::element (outputs, i), In::element (inputs, i)...);
128+ }
102129 }
103130}
104131
105- template <typename Out, typename ... In>
106- CUDF_KERNEL void fixed_point_kernel (cudf::mutable_column_device_view_core const * output,
107- cudf::column_device_view_core const * inputs)
132+ template <bool has_user_data, typename Out, typename ... In>
133+ CUDF_KERNEL void fixed_point_kernel (cudf::mutable_column_device_view_core const * outputs,
134+ cudf::column_device_view_core const * inputs,
135+ void * user_data)
108136{
109137 // cannot use global_thread_id utility due to a JIT build issue by including
110138 // the `cudf/detail/utilities/cuda.cuh` header
111139 auto const block_size = static_cast <thread_index_type>(blockDim .x );
112140 thread_index_type const start = threadIdx .x + blockIdx .x * block_size;
113141 thread_index_type const stride = block_size * gridDim .x ;
114- thread_index_type const size = output->size ();
115-
116- numeric::scale_type const output_scale = static_cast <numeric::scale_type>(output->type ().scale ());
142+ thread_index_type const size = outputs[0 ].size ();
143+ auto const output_scale = static_cast <numeric::scale_type>(outputs[0 ].type ().scale ());
117144
118145 for (auto i = start; i < size; i += stride) {
119146 typename Out::type result{numeric::scaled_integer<typename Out::type::rep>{0 , output_scale}};
120- GENERIC_TRANSFORM_OP (&result, In::element (inputs, i)...);
121- Out::assign (output, i, result);
147+
148+ if constexpr (has_user_data) {
149+ GENERIC_TRANSFORM_OP (user_data, i, &result, In::element (inputs, i)...);
150+ } else {
151+ GENERIC_TRANSFORM_OP (&result, In::element (inputs, i)...);
152+ }
153+
154+ Out::assign (outputs, i, result);
155+ }
156+ }
157+
158+ template <bool has_user_data, typename Out, typename ... In>
159+ CUDF_KERNEL void span_kernel (cudf::jit::device_span<typename Out::type> const * outputs,
160+ cudf::column_device_view_core const * inputs,
161+ void * user_data)
162+ {
163+ // cannot use global_thread_id utility due to a JIT build issue by including
164+ // the `cudf/detail/utilities/cuda.cuh` header
165+ auto const block_size = static_cast <thread_index_type>(blockDim .x );
166+ thread_index_type const start = threadIdx .x + blockIdx .x * block_size;
167+ thread_index_type const stride = block_size * gridDim .x ;
168+ thread_index_type const size = outputs[0 ].size ();
169+
170+ for (auto i = start; i < size; i += stride) {
171+ if constexpr (has_user_data) {
172+ GENERIC_TRANSFORM_OP (user_data, i, &Out::element (outputs, i), In::element (inputs, i)...);
173+ } else {
174+ GENERIC_TRANSFORM_OP (&Out::element (outputs, i), In::element (inputs, i)...);
175+ }
122176 }
123177}
124178
0 commit comments