@@ -86,9 +86,8 @@ constexpr int kMaxBlock = 512;
8686
8787// get blockDim for reduceLastDim and reduceAny
8888static inline int GetBlockDim (int block_dim) {
89- return block_dim >= kMaxBlock
90- ? kMaxBlock
91- : (1 << static_cast <int >(std::log2 (block_dim)));
89+ return block_dim >= kMaxBlock ? kMaxBlock
90+ : (1 << static_cast <int >(std::log2 (block_dim)));
9291}
9392
9493// check reduce rand is valid
@@ -177,50 +176,71 @@ struct ReduceConfig {
177176 // --SetReduceDim--> x_dim = [8,6], reduce_dim = [0], left_dim = [1]
178177 void SetReduceDim () {
179178 std::set<int > reduce_set;
180-
181179 for (auto e : reduce_dims_origin) {
182180 auto pos = e >= 0 ? e : e + x_dim.size ();
183181 reduce_set.insert (pos);
184182 }
183+
185184 std::vector<int > reduce_dim_temp (reduce_set.begin (), reduce_set.end ());
186185 std::sort (reduce_dim_temp.begin (), reduce_dim_temp.end ());
187- // get reduce_dim
186+
187+ // update reduce_dim and x_dim
188+ std::vector<int > x_new_dim;
189+
190+ reduce_dim.push_back (reduce_dim_temp[0 ]);
191+ x_new_dim.push_back (x_dim[0 ]);
192+
193+ int idx_reduce = 1 ;
194+ int num = 0 ;
195+
188196 if (reduce_dim_temp.size () > 1 ) {
189- int num = 0 ; // for update axis
190- reduce_dim.push_back (reduce_dim_temp[0 ]);
191- for (int idx = 1 ; idx < reduce_dim_temp.size (); idx++) {
192- // update x_dim
193- if (reduce_dim_temp[idx] - reduce_dim_temp[idx - 1 ] == 1 ) {
194- x_dim[reduce_dim_temp[idx - 1 ]] *= x_dim[reduce_dim_temp[idx]];
195- x_dim.erase (x_dim.begin () + reduce_dim_temp[idx]);
196- num++;
197+ for (int i = 1 ; i < x_dim.size (); i++) {
198+ if (idx_reduce < reduce_dim_temp.size () &&
199+ i == reduce_dim_temp[idx_reduce]) {
200+ int result =
201+ reduce_dim_temp[idx_reduce] - reduce_dim[reduce_dim.size () - 1 ];
202+ bool is_equal = (result - num == 1 );
203+ if (is_equal) {
204+ x_new_dim[x_new_dim.size () - 1 ] *= x_dim[i];
205+ num++;
206+ } else {
207+ reduce_dim.push_back (reduce_dim_temp[idx_reduce] - num);
208+ x_new_dim.push_back (x_dim[i]);
209+ }
210+ idx_reduce++;
197211 } else {
198- reduce_dim .push_back (reduce_dim_temp[idx] - num );
212+ x_new_dim .push_back (x_dim[i] );
199213 }
200214 }
201215 } else {
202- reduce_dim = reduce_dim_temp ;
216+ x_new_dim = x_dim ;
203217 }
204218
205- // update new_x_dim and new_reduce_dim
206- std::vector<int > new_x_dim, new_reduce_dim_temp;
219+ // update x_dim
220+ x_dim = x_new_dim;
221+ std::vector<int >().swap (x_new_dim);
222+
223+ std::vector<int > reduce_dim_new;
207224 int is_reduced = 0 ;
208225 for (auto e : reduce_dim) {
226+ auto pos = e >= 0 ? e : e + x_dim.size ();
209227 is_reduced |= 1 << e;
210228 }
211229
230+ std::vector<int >().swap (reduce_dim);
231+
212232 for (int i = 0 ; i < x_dim.size (); i++) {
213233 if ((i == 0 ) || (((is_reduced >> i) ^ (is_reduced >> (i - 1 ))) & 1 )) {
214- new_x_dim .push_back (x_dim[i]);
234+ x_new_dim .push_back (x_dim[i]);
215235 if ((is_reduced >> i) & 1 )
216- new_reduce_dim_temp .push_back (new_x_dim .size () - 1 );
236+ reduce_dim_new .push_back (x_new_dim .size () - 1 );
217237 } else {
218- new_x_dim[new_x_dim .size () - 1 ] *= x_dim[i];
238+ x_new_dim[x_new_dim .size () - 1 ] *= x_dim[i];
219239 }
220240 }
221241
222- x_dim = new_x_dim ;
223- reduce_dim = new_reduce_dim_temp ;
242+ x_dim = x_new_dim ;
243+ reduce_dim = reduce_dim_new ;
224244
225245 int x_rank = static_cast <int >(x_dim.size ());
226246 std::set<int > left_set;
0 commit comments