@@ -58,7 +58,6 @@ void *thr_ctx_t::get_interop_obj() const {
5858#endif
5959
6060#include " oneapi/dnnl/dnnl_threadpool_iface.hpp"
61- #include " src/common/counting_barrier.hpp"
6261
6362#if !defined(DNNL_TEST_THREADPOOL_USE_TBB)
6463
@@ -101,6 +100,10 @@ inline int read_num_threads_from_env() {
101100#include " unsupported/Eigen/CXX11/Tensor"
102101#include " unsupported/Eigen/CXX11/ThreadPool"
103102
103+ #include " absl/synchronization/blocking_counter.h"
104+
105+ #include " common/compiler_workarounds.hpp"
106+
104107#include < memory>
105108
106109namespace dnnl {
@@ -110,6 +113,30 @@ class threadpool_t : public dnnl::threadpool_interop::threadpool_iface {
110113private:
111114 std::unique_ptr<Eigen::ThreadPool> tp_;
112115
116+ static void balance211 (int n, int team, int tid, int *n_start, int *n_end) {
117+ if (team <= 1 || n == 0 ) {
118+ *n_start = 0 ;
119+ *n_end = n;
120+ return ;
121+ }
122+ int min_per_team = n / team;
123+ int remainder = n - min_per_team * team; // i.e., n % teams.
124+ *n_start = tid * min_per_team + std::min (tid, remainder);
125+ *n_end = *n_start + min_per_team + (tid < remainder);
126+ }
127+
128+ static void run_jobs (bool balance, int i, int n, int njobs,
129+ const std::function<void (int , int )> &fn) {
130+ if (balance) {
131+ int start, end;
132+ balance211 (n, njobs, i, &start, &end);
133+ for (int j = start; j < end; j++)
134+ fn (j, n);
135+ } else {
136+ fn (i, n);
137+ }
138+ }
139+
113140public:
114141 explicit threadpool_t (int num_threads = 0 ) {
115142 if (num_threads <= 0 ) num_threads = read_num_threads_from_env ();
@@ -119,19 +146,45 @@ class threadpool_t : public dnnl::threadpool_interop::threadpool_iface {
119146 bool get_in_parallel () const override {
120147 return tp_->CurrentThreadId () != -1 ;
121148 }
122- uint64_t get_flags () const override { return ASYNCHRONOUS ; }
149+ uint64_t get_flags () const override { return 0 ; }
123150 void parallel_for (int n, const std::function<void (int , int )> &fn) override {
124- int nthr = get_num_threads ();
125- int njobs = std::min (n, nthr);
151+ // Should never happen.
152+ if (n == 0 ) { return ; }
126153
127- for (int i = 0 ; i < njobs; i++) {
128- tp_->Schedule ([i, n, njobs, fn]() {
129- int start, end;
130- impl::balance211 (n, njobs, i, start, end);
131- for (int j = start; j < end; j++)
132- fn (j, n);
133- });
154+ // Should never happen.
155+ if (n == 1 ) {
156+ fn (0 , 1 );
157+ return ;
134158 }
159+
160+ int nthr = get_num_threads ();
161+ int njobs = std::min (n, nthr);
162+ bool balance = (nthr < n);
163+
164+ absl::BlockingCounter counter (njobs);
165+ std::function<void (int , int )> handle_range
166+ = [= WA_THIS_COPY_CAPTURE, &handle_range, &counter](
167+ int first, int last) {
168+ while (last - first > 1 ) {
169+ const auto mid = first + (last - first) / 2 ;
170+ // Find something near the midpoint which is a
171+ // multiple of block size.
172+ tp_->ScheduleWithHint (
173+ [=]() { handle_range (mid, last); }, mid,
174+ mid + 1 );
175+ last = mid;
176+ }
177+ run_jobs (balance, first, n, njobs, fn);
178+ counter.DecrementCount ();
179+ };
180+
181+ // Eigen avoids a thread hop by running the root of the tree on the main
182+ // thread. We have disabled this because it actually slows things down
183+ // relative to base because base cheats and uses n threads while letting
184+ // main continue doing other work
185+ tp_->ScheduleWithHint ([=]() { handle_range (0 , njobs); }, 0 , 1 );
186+
187+ counter.Wait ();
135188 };
136189};
137190
@@ -164,6 +217,8 @@ class threadpool_t : public dnnl::threadpool_interop::threadpool_iface {
164217
165218#else
166219
220+ #include " src/common/counting_barrier.hpp"
221+
167222#include < atomic>
168223#include < thread>
169224#include < vector>
0 commit comments