@@ -22,7 +22,8 @@ struct has_mlas_transpose<uint32_t> : std::true_type {};
22
22
template <typename T>
23
23
typename std::enable_if<!has_mlas_transpose<T>::value, void >::type SimpleTransposeSingleAxisOutwards (
24
24
const T* input_data, T* output_data, int64_t num_loops, int64_t num_writers, int64_t writes_per_loop,
25
- int64_t writes_per_writer_per_loop) {
25
+ int64_t writes_per_writer_per_loop, concurrency::ThreadPool* tp = nullptr ) {
26
+ ORT_UNUSED_PARAMETER (tp);
26
27
const T* end;
27
28
for (int64_t l = 0 ; l < num_loops; ++l) {
28
29
T* output_for_first_writer = output_data;
@@ -48,10 +49,10 @@ typename std::enable_if<!has_mlas_transpose<T>::value, void>::type SimpleTranspo
48
49
template <typename T>
49
50
typename std::enable_if<has_mlas_transpose<T>::value, void >::type SimpleTransposeSingleAxisOutwards (
50
51
const T* input_data, T* output_data, int64_t num_loops, int64_t num_writers, int64_t writes_per_loop,
51
- int64_t writes_per_writer_per_loop) {
52
+ int64_t writes_per_writer_per_loop, concurrency::ThreadPool* tp = nullptr ) {
52
53
for (int64_t l = 0 ; l < num_loops; ++l) {
53
54
MlasTranspose (input_data, output_data, static_cast <size_t >(writes_per_writer_per_loop),
54
- static_cast <size_t >(num_writers));
55
+ static_cast <size_t >(num_writers), tp );
55
56
input_data += writes_per_loop;
56
57
output_data += writes_per_loop;
57
58
}
@@ -82,25 +83,25 @@ void TransposeSingleAxisOutwards(gsl::span<const size_t> permutations, const Ten
82
83
switch (bytes_per_write) {
83
84
case (sizeof (uint8_t )): {
84
85
SimpleTransposeSingleAxisOutwards (input_data, output_data, num_loops, num_writers, writes_per_loop,
85
- writes_per_writer_per_loop);
86
+ writes_per_writer_per_loop, tp );
86
87
break ;
87
88
}
88
89
case (sizeof (uint16_t )): {
89
90
SimpleTransposeSingleAxisOutwards (reinterpret_cast <const uint16_t *>(input_data),
90
91
reinterpret_cast <uint16_t *>(output_data), num_loops, num_writers,
91
- writes_per_loop, writes_per_writer_per_loop);
92
+ writes_per_loop, writes_per_writer_per_loop, tp );
92
93
break ;
93
94
}
94
95
case (sizeof (uint32_t )): {
95
96
SimpleTransposeSingleAxisOutwards (reinterpret_cast <const uint32_t *>(input_data),
96
97
reinterpret_cast <uint32_t *>(output_data), num_loops, num_writers,
97
- writes_per_loop, writes_per_writer_per_loop);
98
+ writes_per_loop, writes_per_writer_per_loop, tp );
98
99
break ;
99
100
}
100
101
case (sizeof (uint64_t )): {
101
102
SimpleTransposeSingleAxisOutwards (reinterpret_cast <const uint64_t *>(input_data),
102
103
reinterpret_cast <uint64_t *>(output_data), num_loops, num_writers,
103
- writes_per_loop, writes_per_writer_per_loop);
104
+ writes_per_loop, writes_per_writer_per_loop, tp );
104
105
break ;
105
106
}
106
107
default : {
@@ -125,7 +126,8 @@ void TransposeSingleAxisOutwards(gsl::span<const size_t> permutations, const Ten
125
126
template <typename T>
126
127
typename std::enable_if<!has_mlas_transpose<T>::value, void >::type SimpleTransposeSingleAxisInwards (
127
128
const T* input_data, T* output_data, int64_t num_loops, int64_t num_readers, int64_t reads_per_loop,
128
- int64_t reads_per_reader_per_loop) {
129
+ int64_t reads_per_reader_per_loop, concurrency::ThreadPool* tp = nullptr ) {
130
+ ORT_UNUSED_PARAMETER (tp);
129
131
T* end;
130
132
for (int64_t l = 0 ; l < num_loops; ++l) {
131
133
const T* input_for_first_reader = input_data;
@@ -150,10 +152,10 @@ typename std::enable_if<!has_mlas_transpose<T>::value, void>::type SimpleTranspo
150
152
template <typename T>
151
153
typename std::enable_if<has_mlas_transpose<T>::value, void >::type SimpleTransposeSingleAxisInwards (
152
154
const T* input_data, T* output_data, int64_t num_loops, int64_t num_readers, int64_t reads_per_loop,
153
- int64_t reads_per_reader_per_loop) {
155
+ int64_t reads_per_reader_per_loop, concurrency::ThreadPool* tp = nullptr ) {
154
156
for (int64_t l = 0 ; l < num_loops; ++l) {
155
157
MlasTranspose (input_data, output_data, static_cast <size_t >(num_readers),
156
- static_cast <size_t >(reads_per_reader_per_loop));
158
+ static_cast <size_t >(reads_per_reader_per_loop), tp );
157
159
input_data += reads_per_loop;
158
160
output_data += reads_per_loop;
159
161
}
@@ -162,7 +164,8 @@ typename std::enable_if<has_mlas_transpose<T>::value, void>::type SimpleTranspos
162
164
// moving a single axis inwards where the read/write size is a power of 2 and between 8 and 64 bits.
163
165
// `input_shape_override` overrides the shape of `input` for compute purposes.
164
166
void TransposeSingleAxisInwards (gsl::span<const size_t > permutations, const Tensor& input, Tensor& output,
165
- size_t from, size_t to, const TensorShape* input_shape_override = nullptr ) {
167
+ size_t from, size_t to, const TensorShape* input_shape_override = nullptr ,
168
+ concurrency::ThreadPool* tp = nullptr ) {
166
169
ORT_UNUSED_PARAMETER (permutations);
167
170
168
171
const auto & input_shape = input_shape_override ? *input_shape_override : input.Shape ();
@@ -184,25 +187,25 @@ void TransposeSingleAxisInwards(gsl::span<const size_t> permutations, const Tens
184
187
switch (bytes_per_read) {
185
188
case (sizeof (uint8_t )): {
186
189
SimpleTransposeSingleAxisInwards (input_data, output_data, num_loops, num_readers, reads_per_loop,
187
- reads_per_reader_per_loop);
190
+ reads_per_reader_per_loop, tp );
188
191
break ;
189
192
}
190
193
case (sizeof (uint16_t )): {
191
194
SimpleTransposeSingleAxisInwards (reinterpret_cast <const uint16_t *>(input_data),
192
195
reinterpret_cast <uint16_t *>(output_data), num_loops, num_readers, reads_per_loop,
193
- reads_per_reader_per_loop);
196
+ reads_per_reader_per_loop, tp );
194
197
break ;
195
198
}
196
199
case (sizeof (uint32_t )): {
197
200
SimpleTransposeSingleAxisInwards (reinterpret_cast <const uint32_t *>(input_data),
198
201
reinterpret_cast <uint32_t *>(output_data), num_loops, num_readers, reads_per_loop,
199
- reads_per_reader_per_loop);
202
+ reads_per_reader_per_loop, tp );
200
203
break ;
201
204
}
202
205
case (sizeof (uint64_t )): {
203
206
SimpleTransposeSingleAxisInwards (reinterpret_cast <const uint64_t *>(input_data),
204
207
reinterpret_cast <uint64_t *>(output_data), num_loops, num_readers, reads_per_loop,
205
- reads_per_reader_per_loop);
208
+ reads_per_reader_per_loop, tp );
206
209
break ;
207
210
}
208
211
default : {
@@ -236,7 +239,7 @@ void SingleAxisTranspose(gsl::span<const size_t> permutations, const Tensor& inp
236
239
if (from > to) {
237
240
TransposeSingleAxisOutwards (permutations, input, output, from, to, input_shape_override, tp);
238
241
} else {
239
- TransposeSingleAxisInwards (permutations, input, output, from, to, input_shape_override);
242
+ TransposeSingleAxisInwards (permutations, input, output, from, to, input_shape_override, tp );
240
243
}
241
244
}
242
245
0 commit comments