Skip to content

Commit 287e310

Browse files
committed
MlasTranspose multi-threads support.
1 parent f1d790c commit 287e310

File tree

10 files changed

+285
-168
lines changed

10 files changed

+285
-168
lines changed

onnxruntime/core/framework/transpose_helper.cc

+19-16
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ struct has_mlas_transpose<uint32_t> : std::true_type {};
2222
template <typename T>
2323
typename std::enable_if<!has_mlas_transpose<T>::value, void>::type SimpleTransposeSingleAxisOutwards(
2424
const T* input_data, T* output_data, int64_t num_loops, int64_t num_writers, int64_t writes_per_loop,
25-
int64_t writes_per_writer_per_loop) {
25+
int64_t writes_per_writer_per_loop, concurrency::ThreadPool* tp = nullptr) {
26+
ORT_UNUSED_PARAMETER(tp);
2627
const T* end;
2728
for (int64_t l = 0; l < num_loops; ++l) {
2829
T* output_for_first_writer = output_data;
@@ -48,10 +49,10 @@ typename std::enable_if<!has_mlas_transpose<T>::value, void>::type SimpleTranspo
4849
template <typename T>
4950
typename std::enable_if<has_mlas_transpose<T>::value, void>::type SimpleTransposeSingleAxisOutwards(
5051
const T* input_data, T* output_data, int64_t num_loops, int64_t num_writers, int64_t writes_per_loop,
51-
int64_t writes_per_writer_per_loop) {
52+
int64_t writes_per_writer_per_loop, concurrency::ThreadPool* tp = nullptr) {
5253
for (int64_t l = 0; l < num_loops; ++l) {
5354
MlasTranspose(input_data, output_data, static_cast<size_t>(writes_per_writer_per_loop),
54-
static_cast<size_t>(num_writers));
55+
static_cast<size_t>(num_writers), tp);
5556
input_data += writes_per_loop;
5657
output_data += writes_per_loop;
5758
}
@@ -82,25 +83,25 @@ void TransposeSingleAxisOutwards(gsl::span<const size_t> permutations, const Ten
8283
switch (bytes_per_write) {
8384
case (sizeof(uint8_t)): {
8485
SimpleTransposeSingleAxisOutwards(input_data, output_data, num_loops, num_writers, writes_per_loop,
85-
writes_per_writer_per_loop);
86+
writes_per_writer_per_loop, tp);
8687
break;
8788
}
8889
case (sizeof(uint16_t)): {
8990
SimpleTransposeSingleAxisOutwards(reinterpret_cast<const uint16_t*>(input_data),
9091
reinterpret_cast<uint16_t*>(output_data), num_loops, num_writers,
91-
writes_per_loop, writes_per_writer_per_loop);
92+
writes_per_loop, writes_per_writer_per_loop, tp);
9293
break;
9394
}
9495
case (sizeof(uint32_t)): {
9596
SimpleTransposeSingleAxisOutwards(reinterpret_cast<const uint32_t*>(input_data),
9697
reinterpret_cast<uint32_t*>(output_data), num_loops, num_writers,
97-
writes_per_loop, writes_per_writer_per_loop);
98+
writes_per_loop, writes_per_writer_per_loop, tp);
9899
break;
99100
}
100101
case (sizeof(uint64_t)): {
101102
SimpleTransposeSingleAxisOutwards(reinterpret_cast<const uint64_t*>(input_data),
102103
reinterpret_cast<uint64_t*>(output_data), num_loops, num_writers,
103-
writes_per_loop, writes_per_writer_per_loop);
104+
writes_per_loop, writes_per_writer_per_loop, tp);
104105
break;
105106
}
106107
default: {
@@ -125,7 +126,8 @@ void TransposeSingleAxisOutwards(gsl::span<const size_t> permutations, const Ten
125126
template <typename T>
126127
typename std::enable_if<!has_mlas_transpose<T>::value, void>::type SimpleTransposeSingleAxisInwards(
127128
const T* input_data, T* output_data, int64_t num_loops, int64_t num_readers, int64_t reads_per_loop,
128-
int64_t reads_per_reader_per_loop) {
129+
int64_t reads_per_reader_per_loop, concurrency::ThreadPool* tp = nullptr) {
130+
ORT_UNUSED_PARAMETER(tp);
129131
T* end;
130132
for (int64_t l = 0; l < num_loops; ++l) {
131133
const T* input_for_first_reader = input_data;
@@ -150,10 +152,10 @@ typename std::enable_if<!has_mlas_transpose<T>::value, void>::type SimpleTranspo
150152
template <typename T>
151153
typename std::enable_if<has_mlas_transpose<T>::value, void>::type SimpleTransposeSingleAxisInwards(
152154
const T* input_data, T* output_data, int64_t num_loops, int64_t num_readers, int64_t reads_per_loop,
153-
int64_t reads_per_reader_per_loop) {
155+
int64_t reads_per_reader_per_loop, concurrency::ThreadPool* tp = nullptr) {
154156
for (int64_t l = 0; l < num_loops; ++l) {
155157
MlasTranspose(input_data, output_data, static_cast<size_t>(num_readers),
156-
static_cast<size_t>(reads_per_reader_per_loop));
158+
static_cast<size_t>(reads_per_reader_per_loop), tp);
157159
input_data += reads_per_loop;
158160
output_data += reads_per_loop;
159161
}
@@ -162,7 +164,8 @@ typename std::enable_if<has_mlas_transpose<T>::value, void>::type SimpleTranspos
162164
// moving a single axis inwards where the read/write size is a power of 2 and between 8 and 64 bits.
163165
// `input_shape_override` overrides the shape of `input` for compute purposes.
164166
void TransposeSingleAxisInwards(gsl::span<const size_t> permutations, const Tensor& input, Tensor& output,
165-
size_t from, size_t to, const TensorShape* input_shape_override = nullptr) {
167+
size_t from, size_t to, const TensorShape* input_shape_override = nullptr,
168+
concurrency::ThreadPool* tp = nullptr) {
166169
ORT_UNUSED_PARAMETER(permutations);
167170

168171
const auto& input_shape = input_shape_override ? *input_shape_override : input.Shape();
@@ -184,25 +187,25 @@ void TransposeSingleAxisInwards(gsl::span<const size_t> permutations, const Tens
184187
switch (bytes_per_read) {
185188
case (sizeof(uint8_t)): {
186189
SimpleTransposeSingleAxisInwards(input_data, output_data, num_loops, num_readers, reads_per_loop,
187-
reads_per_reader_per_loop);
190+
reads_per_reader_per_loop, tp);
188191
break;
189192
}
190193
case (sizeof(uint16_t)): {
191194
SimpleTransposeSingleAxisInwards(reinterpret_cast<const uint16_t*>(input_data),
192195
reinterpret_cast<uint16_t*>(output_data), num_loops, num_readers, reads_per_loop,
193-
reads_per_reader_per_loop);
196+
reads_per_reader_per_loop, tp);
194197
break;
195198
}
196199
case (sizeof(uint32_t)): {
197200
SimpleTransposeSingleAxisInwards(reinterpret_cast<const uint32_t*>(input_data),
198201
reinterpret_cast<uint32_t*>(output_data), num_loops, num_readers, reads_per_loop,
199-
reads_per_reader_per_loop);
202+
reads_per_reader_per_loop, tp);
200203
break;
201204
}
202205
case (sizeof(uint64_t)): {
203206
SimpleTransposeSingleAxisInwards(reinterpret_cast<const uint64_t*>(input_data),
204207
reinterpret_cast<uint64_t*>(output_data), num_loops, num_readers, reads_per_loop,
205-
reads_per_reader_per_loop);
208+
reads_per_reader_per_loop, tp);
206209
break;
207210
}
208211
default: {
@@ -236,7 +239,7 @@ void SingleAxisTranspose(gsl::span<const size_t> permutations, const Tensor& inp
236239
if (from > to) {
237240
TransposeSingleAxisOutwards(permutations, input, output, from, to, input_shape_override, tp);
238241
} else {
239-
TransposeSingleAxisInwards(permutations, input, output, from, to, input_shape_override);
242+
TransposeSingleAxisInwards(permutations, input, output, from, to, input_shape_override, tp);
240243
}
241244
}
242245

onnxruntime/core/mlas/inc/mlas.h

+10-42
Original file line numberDiff line numberDiff line change
@@ -1053,49 +1053,15 @@ MlasComputeTanh(
10531053
// Transpose routines.
10541054
//
10551055

1056+
template<typename DataType>
10561057
void
10571058
MLASCALL
10581059
MlasTranspose(
1059-
const uint8_t* Input,
1060-
uint8_t* Output,
1061-
size_t M,
1062-
size_t N
1063-
);
1064-
1065-
void
1066-
MLASCALL
1067-
MlasTranspose(
1068-
const int8_t* Input,
1069-
int8_t* Output,
1070-
size_t M,
1071-
size_t N
1072-
);
1073-
1074-
void
1075-
MLASCALL
1076-
MlasTranspose(
1077-
const uint16_t* Input,
1078-
uint16_t* Output,
1079-
size_t M,
1080-
size_t N
1081-
);
1082-
1083-
void
1084-
MLASCALL
1085-
MlasTranspose(
1086-
const uint32_t* Input,
1087-
uint32_t* Output,
1060+
const DataType* Input,
1061+
DataType* Output,
10881062
size_t M,
1089-
size_t N
1090-
);
1091-
1092-
void
1093-
MLASCALL
1094-
MlasTranspose(
1095-
const float* Input,
1096-
float* Output,
1097-
size_t M,
1098-
size_t N
1063+
size_t N,
1064+
MLAS_THREADPOOL* ThreadPool
10991065
);
11001066

11011067
//
@@ -1937,20 +1903,22 @@ MlasConvDepthwise(
19371903
MLAS_HALF_GEMM_POSTPROCESSOR* PostProc
19381904
);
19391905

1940-
19411906
inline
19421907
void
19431908
MlasTranspose(
19441909
const MLAS_FP16* Input,
19451910
MLAS_FP16* Output,
19461911
size_t M,
1947-
size_t N
1912+
size_t N,
1913+
MLAS_THREADPOOL* ThreadPool
19481914
)
19491915
{
19501916
MlasTranspose(
19511917
reinterpret_cast<const uint16_t*>(Input),
19521918
reinterpret_cast<uint16_t*>(Output),
1953-
M, N);
1919+
M,
1920+
N,
1921+
ThreadPool);
19541922
}
19551923

19561924

0 commit comments

Comments
 (0)