Skip to content

Commit c494f27

Browse files
committed
src, tests: align synchronous Eigen threadpool implementation
1 parent 7b3f2af commit c494f27

File tree

4 files changed

+80
-25
lines changed

4 files changed

+80
-25
lines changed

cmake/Threadpool.cmake

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,16 @@ if("${DNNL_CPU_THREADING_RUNTIME}" STREQUAL "THREADPOOL")
3333
endif()
3434

3535
if("${_DNNL_TEST_THREADPOOL_IMPL}" STREQUAL "EIGEN")
36-
find_package(Eigen3 3.3...<5.1 REQUIRED NO_MODULE)
36+
find_package(Eigen3 5.0 REQUIRED NO_MODULE)
3737
if(Eigen3_FOUND)
3838
list(APPEND EXTRA_STATIC_LIBS Eigen3::Eigen)
39-
message(STATUS "Threadpool testing: Eigen (${PACKAGE_PREFIX_DIR})")
39+
message(STATUS "Found Eigen: ${PACKAGE_PREFIX_DIR}")
40+
endif()
41+
42+
find_package(absl REQUIRED CONFIG)
43+
if(absl_FOUND)
44+
list(APPEND EXTRA_STATIC_LIBS absl::synchronization)
45+
message(STATUS "Found abseil-cpp: ${PACKAGE_PREFIX_DIR}")
4046
endif()
4147
endif()
4248

cmake/utils.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#===============================================================================
2-
# Copyright 2018-2024 Intel Corporation
2+
# Copyright 2018-2025 Intel Corporation
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -47,7 +47,7 @@ endfunction()
4747
# arg4 -- (optional) list of extra library dependencies
4848
function(register_exe name srcs test)
4949
add_executable(${name} ${srcs})
50-
target_link_libraries(${name} ${LIB_PACKAGE_NAME} ${EXTRA_SHARED_LIBS} ${ARGV3})
50+
target_link_libraries(${name} ${LIB_PACKAGE_NAME} ${EXTRA_SHARED_LIBS} ${EXTRA_STATIC_LIBS} ${ARGV3})
5151
if("x${test}" STREQUAL "xtest")
5252
add_dnnl_test(${name} ${name})
5353
maybe_configure_windows_test(${name} TEST)

src/common/dnnl_thread.hpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,6 @@
4444
// due to linker optimizations. The newer compiler and C++ standard, the less
4545
// binary size will be achieved.
4646

47-
#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
48-
#include "counting_barrier.hpp"
49-
#endif
50-
5147
#if defined(DNNL_ENABLE_ITT_TASKS)
5248
#include "common/ittnotify.hpp"
5349
#endif
@@ -282,10 +278,14 @@ static inline void parallel(int nthr, const std::function<void(int, int)> &f) {
282278
auto task_primitive_kind = itt::primitive_task_get_current_kind();
283279
bool itt_enable = itt::get_itt(itt::__itt_task_level_high);
284280
#endif
281+
#if DNNL_CPU_THREADING_RUNTIME != DNNL_RUNTIME_THREADPOOL
282+
// Tasks must be always submitted to a threadpool, it will handle them
283+
// properly.
285284
if (nthr == 1) {
286285
f(0, 1);
287286
return;
288287
}
288+
#endif
289289
#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP
290290
#pragma omp parallel num_threads(nthr)
291291
{
@@ -326,10 +326,6 @@ static inline void parallel(int nthr, const std::function<void(int, int)> &f) {
326326
}
327327
threadpool_utils::activate_threadpool(tp);
328328
} else {
329-
bool async = tp->get_flags()
330-
& dnnl::threadpool_interop::threadpool_iface::ASYNCHRONOUS;
331-
counting_barrier_t b;
332-
if (async) b.init(nthr);
333329
tp->parallel_for(nthr, [&, tp](int ithr, int nthr) {
334330
bool is_master = threadpool_utils::get_active_threadpool() == tp;
335331
if (!is_master) {
@@ -345,9 +341,7 @@ static inline void parallel(int nthr, const std::function<void(int, int)> &f) {
345341
#endif
346342
threadpool_utils::deactivate_threadpool();
347343
}
348-
if (async) b.notify();
349344
});
350-
if (async) b.wait();
351345
}
352346
#endif
353347
#endif

tests/test_thread.cpp

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ void *thr_ctx_t::get_interop_obj() const {
5858
#endif
5959

6060
#include "oneapi/dnnl/dnnl_threadpool_iface.hpp"
61-
#include "src/common/counting_barrier.hpp"
6261

6362
#if !defined(DNNL_TEST_THREADPOOL_USE_TBB)
6463

@@ -101,6 +100,10 @@ inline int read_num_threads_from_env() {
101100
#include "unsupported/Eigen/CXX11/Tensor"
102101
#include "unsupported/Eigen/CXX11/ThreadPool"
103102

103+
#include "absl/synchronization/blocking_counter.h"
104+
105+
#include "common/compiler_workarounds.hpp"
106+
104107
#include <memory>
105108

106109
namespace dnnl {
@@ -110,6 +113,30 @@ class threadpool_t : public dnnl::threadpool_interop::threadpool_iface {
110113
private:
111114
std::unique_ptr<Eigen::ThreadPool> tp_;
112115

116+
static void balance211(int n, int team, int tid, int *n_start, int *n_end) {
117+
if (team <= 1 || n == 0) {
118+
*n_start = 0;
119+
*n_end = n;
120+
return;
121+
}
122+
int min_per_team = n / team;
123+
int remainder = n - min_per_team * team; // i.e., n % teams.
124+
*n_start = tid * min_per_team + std::min(tid, remainder);
125+
*n_end = *n_start + min_per_team + (tid < remainder);
126+
}
127+
128+
static void run_jobs(bool balance, int i, int n, int njobs,
129+
const std::function<void(int, int)> &fn) {
130+
if (balance) {
131+
int start, end;
132+
balance211(n, njobs, i, &start, &end);
133+
for (int j = start; j < end; j++)
134+
fn(j, n);
135+
} else {
136+
fn(i, n);
137+
}
138+
}
139+
113140
public:
114141
explicit threadpool_t(int num_threads = 0) {
115142
if (num_threads <= 0) num_threads = read_num_threads_from_env();
@@ -119,19 +146,45 @@ class threadpool_t : public dnnl::threadpool_interop::threadpool_iface {
119146
bool get_in_parallel() const override {
120147
return tp_->CurrentThreadId() != -1;
121148
}
122-
uint64_t get_flags() const override { return ASYNCHRONOUS; }
149+
uint64_t get_flags() const override { return 0; }
123150
void parallel_for(int n, const std::function<void(int, int)> &fn) override {
124-
int nthr = get_num_threads();
125-
int njobs = std::min(n, nthr);
151+
// Should never happen.
152+
if (n == 0) { return; }
126153

127-
for (int i = 0; i < njobs; i++) {
128-
tp_->Schedule([i, n, njobs, fn]() {
129-
int start, end;
130-
impl::balance211(n, njobs, i, start, end);
131-
for (int j = start; j < end; j++)
132-
fn(j, n);
133-
});
154+
// Should never happen.
155+
if (n == 1) {
156+
fn(0, 1);
157+
return;
134158
}
159+
160+
int nthr = get_num_threads();
161+
int njobs = std::min(n, nthr);
162+
bool balance = (nthr < n);
163+
164+
absl::BlockingCounter counter(njobs);
165+
std::function<void(int, int)> handle_range
166+
= [= WA_THIS_COPY_CAPTURE, &handle_range, &counter](
167+
int first, int last) {
168+
while (last - first > 1) {
169+
const auto mid = first + (last - first) / 2;
170+
// Find something near the midpoint which is a
171+
// multiple of block size.
172+
tp_->ScheduleWithHint(
173+
[=]() { handle_range(mid, last); }, mid,
174+
mid + 1);
175+
last = mid;
176+
}
177+
run_jobs(balance, first, n, njobs, fn);
178+
counter.DecrementCount();
179+
};
180+
181+
// Eigen avoids a thread hop by running the root of the tree on the main
182+
// thread. We have disabled this because it actually slows things down
183+
// relative to base because base cheats and uses n threads while letting
184+
// main continue doing other work
185+
tp_->ScheduleWithHint([=]() { handle_range(0, njobs); }, 0, 1);
186+
187+
counter.Wait();
135188
};
136189
};
137190

@@ -164,6 +217,8 @@ class threadpool_t : public dnnl::threadpool_interop::threadpool_iface {
164217

165218
#else
166219

220+
#include "src/common/counting_barrier.hpp"
221+
167222
#include <atomic>
168223
#include <thread>
169224
#include <vector>

0 commit comments

Comments
 (0)