Skip to content

Commit a60b90a

Browse files
committed
fix a bug in reduce_Op.cuh
1 parent 2e8ad8f commit a60b90a

File tree

6 files changed

+47
-27
lines changed

6 files changed

+47
-27
lines changed

paddle/fluid/operators/reduce_ops/reduce_all_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.

paddle/fluid/operators/reduce_ops/reduce_any_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2021 PaddlePaddle Authors. Any Rights Reserved.
1+
// Copyright (c) 2018 PaddlePaddle Authors. Any Rights Reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.

paddle/fluid/operators/reduce_ops/reduce_max_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.

paddle/fluid/operators/reduce_ops/reduce_min_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.

paddle/fluid/operators/reduce_ops/reduce_op.cuh

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,8 @@ constexpr int kMaxBlock = 512;
8686

8787
// get blockDim for reduceLastDim and reduceAny
8888
static inline int GetBlockDim(int block_dim) {
89-
return block_dim >= kMaxBlock
90-
? kMaxBlock
91-
: (1 << static_cast<int>(std::log2(block_dim)));
89+
return block_dim >= kMaxBlock ? kMaxBlock
90+
: (1 << static_cast<int>(std::log2(block_dim)));
9291
}
9392

9493
// check reduce rand is valid
@@ -177,50 +176,71 @@ struct ReduceConfig {
177176
// --SetReduceDim--> x_dim = [8,6], reduce_dim = [0], left_dim = [1]
178177
void SetReduceDim() {
179178
std::set<int> reduce_set;
180-
181179
for (auto e : reduce_dims_origin) {
182180
auto pos = e >= 0 ? e : e + x_dim.size();
183181
reduce_set.insert(pos);
184182
}
183+
185184
std::vector<int> reduce_dim_temp(reduce_set.begin(), reduce_set.end());
186185
std::sort(reduce_dim_temp.begin(), reduce_dim_temp.end());
187-
// get reduce_dim
186+
187+
// update reduce_dim and x_dim
188+
std::vector<int> x_new_dim;
189+
190+
reduce_dim.push_back(reduce_dim_temp[0]);
191+
x_new_dim.push_back(x_dim[0]);
192+
193+
int idx_reduce = 1;
194+
int num = 0;
195+
188196
if (reduce_dim_temp.size() > 1) {
189-
int num = 0; // for update axis
190-
reduce_dim.push_back(reduce_dim_temp[0]);
191-
for (int idx = 1; idx < reduce_dim_temp.size(); idx++) {
192-
// update x_dim
193-
if (reduce_dim_temp[idx] - reduce_dim_temp[idx - 1] == 1) {
194-
x_dim[reduce_dim_temp[idx - 1]] *= x_dim[reduce_dim_temp[idx]];
195-
x_dim.erase(x_dim.begin() + reduce_dim_temp[idx]);
196-
num++;
197+
for (int i = 1; i < x_dim.size(); i++) {
198+
if (idx_reduce < reduce_dim_temp.size() &&
199+
i == reduce_dim_temp[idx_reduce]) {
200+
int result =
201+
reduce_dim_temp[idx_reduce] - reduce_dim[reduce_dim.size() - 1];
202+
bool is_equal = (result - num == 1);
203+
if (is_equal) {
204+
x_new_dim[x_new_dim.size() - 1] *= x_dim[i];
205+
num++;
206+
} else {
207+
reduce_dim.push_back(reduce_dim_temp[idx_reduce] - num);
208+
x_new_dim.push_back(x_dim[i]);
209+
}
210+
idx_reduce++;
197211
} else {
198-
reduce_dim.push_back(reduce_dim_temp[idx] - num);
212+
x_new_dim.push_back(x_dim[i]);
199213
}
200214
}
201215
} else {
202-
reduce_dim = reduce_dim_temp;
216+
x_new_dim = x_dim;
203217
}
204218

205-
// update new_x_dim and new_reduce_dim
206-
std::vector<int> new_x_dim, new_reduce_dim_temp;
219+
// update x_dim
220+
x_dim = x_new_dim;
221+
std::vector<int>().swap(x_new_dim);
222+
223+
std::vector<int> reduce_dim_new;
207224
int is_reduced = 0;
208225
for (auto e : reduce_dim) {
226+
auto pos = e >= 0 ? e : e + x_dim.size();
209227
is_reduced |= 1 << e;
210228
}
211229

230+
std::vector<int>().swap(reduce_dim);
231+
212232
for (int i = 0; i < x_dim.size(); i++) {
213233
if ((i == 0) || (((is_reduced >> i) ^ (is_reduced >> (i - 1))) & 1)) {
214-
new_x_dim.push_back(x_dim[i]);
234+
x_new_dim.push_back(x_dim[i]);
215235
if ((is_reduced >> i) & 1)
216-
new_reduce_dim_temp.push_back(new_x_dim.size() - 1);
236+
reduce_dim_new.push_back(x_new_dim.size() - 1);
217237
} else {
218-
new_x_dim[new_x_dim.size() - 1] *= x_dim[i];
238+
x_new_dim[x_new_dim.size() - 1] *= x_dim[i];
219239
}
220240
}
221241

222-
x_dim = new_x_dim;
223-
reduce_dim = new_reduce_dim_temp;
242+
x_dim = x_new_dim;
243+
reduce_dim = reduce_dim_new;
224244

225245
int x_rank = static_cast<int>(x_dim.size());
226246
std::set<int> left_set;

paddle/fluid/operators/reduce_ops/reduce_prod_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.

0 commit comments

Comments
 (0)