Skip to content

Commit eb52f85

Browse files
committed
src, tests: align synchronous Eigen threadpool implementation
Aligns the implementation with the one from TensorFlow legacy runtime and adjusts the internals as they should be for truly synchronous threadpool. This action introduces a dependency on abseil-cpp in threadpool implementation.
1 parent 5f36333 commit eb52f85

File tree

10 files changed

+93
-32
lines changed

10 files changed

+93
-32
lines changed

src/common/dnnl_thread.hpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,6 @@
4444
// due to linker optimizations. The newer compiler and C++ standard, the less
4545
// binary size will be achieved.
4646

47-
#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
48-
#include "counting_barrier.hpp"
49-
#endif
50-
5147
#if defined(DNNL_ENABLE_ITT_TASKS)
5248
#include "common/ittnotify.hpp"
5349
#endif
@@ -282,10 +278,14 @@ static inline void parallel(int nthr, const std::function<void(int, int)> &f) {
282278
auto task_primitive_kind = itt::primitive_task_get_current_kind();
283279
bool itt_enable = itt::get_itt(itt::__itt_task_level_high);
284280
#endif
281+
#if DNNL_CPU_THREADING_RUNTIME != DNNL_RUNTIME_THREADPOOL
282+
// Tasks must be always submitted to a threadpool, it will handle them
283+
// properly.
285284
if (nthr == 1) {
286285
f(0, 1);
287286
return;
288287
}
288+
#endif
289289
#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP
290290
#pragma omp parallel num_threads(nthr)
291291
{
@@ -326,10 +326,6 @@ static inline void parallel(int nthr, const std::function<void(int, int)> &f) {
326326
}
327327
threadpool_utils::activate_threadpool(tp);
328328
} else {
329-
bool async = tp->get_flags()
330-
& dnnl::threadpool_interop::threadpool_iface::ASYNCHRONOUS;
331-
counting_barrier_t b;
332-
if (async) b.init(nthr);
333329
tp->parallel_for(nthr, [&, tp](int ithr, int nthr) {
334330
bool is_master = threadpool_utils::get_active_threadpool() == tp;
335331
if (!is_master) {
@@ -345,9 +341,7 @@ static inline void parallel(int nthr, const std::function<void(int, int)> &f) {
345341
#endif
346342
threadpool_utils::deactivate_threadpool();
347343
}
348-
if (async) b.notify();
349344
});
350-
if (async) b.wait();
351345
}
352346
#endif
353347
#endif

tests/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,15 @@ endif()
9393

9494
set(TEST_THREAD ${CMAKE_CURRENT_SOURCE_DIR}/test_thread.cpp)
9595

96+
set(LIBTHREADPOOL)
97+
if("${_DNNL_TEST_THREADPOOL_IMPL}" STREQUAL "EIGEN")
98+
find_package(absl REQUIRED CONFIG)
99+
if(absl_FOUND)
100+
list(APPEND LIBTHREADPOOL absl::synchronization)
101+
message(STATUS "Found abseil-cpp: ${PACKAGE_PREFIX_DIR}")
102+
endif()
103+
endif()
104+
96105
# Switch on threading layer for GPU only configurations to speed up testing.
97106
# For non-dppcp compilers OpenMP threading will be used which is handled in
98107
# OpenMP.cmake.

tests/benchdnn/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ elseif(QNXNTO)
4747
find_library(LIBREGEX regex)
4848
find_library(LIBSOCKET socket)
4949
endif()
50-
register_exe(benchdnn "${SOURCES}" "" "${LIBRT};${LIBREGEX};${LIBSOCKET}")
50+
register_exe(benchdnn "${SOURCES}" "" "${LIBRT};${LIBREGEX};${LIBSOCKET};${LIBTHREADPOOL}")
5151

5252
file(COPY inputs DESTINATION .)
5353

tests/gtests/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ function(register_gtest exe src)
180180

181181
add_executable(${exe} ${MAIN_SRC_GTEST} ${src})
182182
add_definitions_with_host_compiler(-DNOMINMAX) # to allow std::max on Windows with parentheses
183-
target_link_libraries(${exe} ${LIB_PACKAGE_NAME} dnnl_gtest ${EXTRA_SHARED_LIBS})
183+
target_link_libraries(${exe} ${LIB_PACKAGE_NAME} dnnl_gtest ${EXTRA_SHARED_LIBS};${LIBTHREADPOOL})
184184

185185
get_source_file_property(no_engine_param ${src} NO_ENGINE_PARAM)
186186

tests/gtests/api/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#===============================================================================
2-
# Copyright 2019-2022 Intel Corporation
2+
# Copyright 2019-2025 Intel Corporation
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -25,7 +25,7 @@ if(DNNL_USE_CLANG_SANITIZER)
2525
list(REMOVE_ITEM TEST_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/test_memory.cpp)
2626
endif()
2727

28-
register_exe(${TEST_EXE} "${TEST_SOURCES}" "test" "dnnl_gtest")
28+
register_exe(${TEST_EXE} "${TEST_SOURCES}" "test" "dnnl_gtest;${LIBTHREADPOOL}")
2929

3030
# Create DPC++ buffer target.
3131
if(DNNL_WITH_SYCL)

tests/gtests/graph/api/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#===============================================================================
2-
# Copyright 2021-2024 Intel Corporation
2+
# Copyright 2021-2025 Intel Corporation
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -78,6 +78,7 @@ foreach(TEST_FILE ${API_TEST_ENGINE_INDEPENDENT_SOURCES})
7878
dnnl_gtest
7979
${DNNL_LIBRARY_NAME}
8080
${EXTRA_SHARED_LIBS}
81+
${LIBTHREADPOOL}
8182
)
8283
register_graph_api_test_suite(${exe_name} ${exe_name})
8384
endforeach()
@@ -90,6 +91,7 @@ foreach(TEST_FILE ${API_TEST_ENGINE_DEPENDENT_SOURCES})
9091
dnnl_gtest
9192
${DNNL_LIBRARY_NAME}
9293
${EXTRA_SHARED_LIBS}
94+
${LIBTHREADPOOL}
9395
)
9496
register_graph_api_test_suite(${exe_name} ${exe_name})
9597
endforeach()

tests/gtests/graph/unit/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ target_link_libraries(${BINARY_NAME}
6666
${EXTRA_STATIC_LIBS}
6767
${STATIC_LIB_DEPS}
6868
${SHARED_LIB_DEPS}
69+
${LIBTHREADPOOL}
6970
)
7071

7172
get_property(test_suite_names GLOBAL PROPERTY GRAPH_UNIT_TEST_SUITES)

tests/gtests/internals/CMakeLists.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#===============================================================================
2-
# Copyright 2020-2023 Intel Corporation
2+
# Copyright 2020-2025 Intel Corporation
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -43,11 +43,11 @@ endif()
4343
# per binary run.
4444
register_exe(${TEST_EXE}_env_vars_dnnl
4545
"${MAIN_SRC_GTEST};${CMAKE_CURRENT_SOURCE_DIR}/test_env_vars_dnnl.cpp"
46-
"test" "dnnl_gtest")
46+
"test" "dnnl_gtest;${LIBTHREADPOOL}")
4747
list(REMOVE_ITEM TEST_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/test_env_vars_dnnl.cpp)
4848
register_exe(${TEST_EXE}_env_vars_onednn
4949
"${MAIN_SRC_GTEST};${CMAKE_CURRENT_SOURCE_DIR}/test_env_vars_onednn.cpp"
50-
"test" "dnnl_gtest")
50+
"test" "dnnl_gtest;${LIBTHREADPOOL}")
5151
list(REMOVE_ITEM TEST_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/test_env_vars_onednn.cpp)
5252

53-
register_exe(${TEST_EXE} "${TEST_SOURCES}" "test" "dnnl_gtest")
53+
register_exe(${TEST_EXE} "${TEST_SOURCES}" "test" "dnnl_gtest;${LIBTHREADPOOL}")

tests/gtests/regression/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#===============================================================================
2-
# Copyright 2021 Intel Corporation
2+
# Copyright 2021-2025 Intel Corporation
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -19,4 +19,4 @@ set(TEST_EXE test_regression)
1919
file(GLOB TEST_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/test_*.cpp)
2020
list(APPEND TEST_SOURCES ${MAIN_SRC_GTEST})
2121

22-
register_exe(${TEST_EXE} "${TEST_SOURCES}" "test" "dnnl_gtest")
22+
register_exe(${TEST_EXE} "${TEST_SOURCES}" "test" "dnnl_gtest;${LIBTHREADPOOL}")

tests/test_thread.cpp

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ void *thr_ctx_t::get_interop_obj() const {
5858
#endif
5959

6060
#include "oneapi/dnnl/dnnl_threadpool_iface.hpp"
61-
#include "src/common/counting_barrier.hpp"
6261

6362
#if !defined(DNNL_TEST_THREADPOOL_USE_TBB)
6463

@@ -101,6 +100,10 @@ inline int read_num_threads_from_env() {
101100
#include "unsupported/Eigen/CXX11/Tensor"
102101
#include "unsupported/Eigen/CXX11/ThreadPool"
103102

103+
#include "absl/synchronization/blocking_counter.h"
104+
105+
#include "common/compiler_workarounds.hpp"
106+
104107
#include <memory>
105108

106109
namespace dnnl {
@@ -110,6 +113,30 @@ class threadpool_t : public dnnl::threadpool_interop::threadpool_iface {
110113
private:
111114
std::unique_ptr<Eigen::ThreadPool> tp_;
112115

116+
static void balance211(int n, int team, int tid, int *n_start, int *n_end) {
117+
if (team <= 1 || n == 0) {
118+
*n_start = 0;
119+
*n_end = n;
120+
return;
121+
}
122+
int min_per_team = n / team;
123+
int remainder = n - min_per_team * team; // i.e., n % teams.
124+
*n_start = tid * min_per_team + std::min(tid, remainder);
125+
*n_end = *n_start + min_per_team + (tid < remainder);
126+
}
127+
128+
static void run_jobs(bool balance, int i, int n, int njobs,
129+
const std::function<void(int, int)> &fn) {
130+
if (balance) {
131+
int start, end;
132+
balance211(n, njobs, i, &start, &end);
133+
for (int j = start; j < end; j++)
134+
fn(j, n);
135+
} else {
136+
fn(i, n);
137+
}
138+
}
139+
113140
public:
114141
explicit threadpool_t(int num_threads = 0) {
115142
if (num_threads <= 0) num_threads = read_num_threads_from_env();
@@ -119,19 +146,45 @@ class threadpool_t : public dnnl::threadpool_interop::threadpool_iface {
119146
bool get_in_parallel() const override {
120147
return tp_->CurrentThreadId() != -1;
121148
}
122-
uint64_t get_flags() const override { return ASYNCHRONOUS; }
149+
uint64_t get_flags() const override { return 0; }
123150
void parallel_for(int n, const std::function<void(int, int)> &fn) override {
124-
int nthr = get_num_threads();
125-
int njobs = std::min(n, nthr);
151+
// Should never happen.
152+
if (n == 0) { return; }
126153

127-
for (int i = 0; i < njobs; i++) {
128-
tp_->Schedule([i, n, njobs, fn]() {
129-
int start, end;
130-
impl::balance211(n, njobs, i, start, end);
131-
for (int j = start; j < end; j++)
132-
fn(j, n);
133-
});
154+
// Should never happen.
155+
if (n == 1) {
156+
fn(0, 1);
157+
return;
134158
}
159+
160+
int nthr = get_num_threads();
161+
int njobs = std::min(n, nthr);
162+
bool balance = (nthr < n);
163+
164+
absl::BlockingCounter counter(njobs);
165+
std::function<void(int, int)> handle_range
166+
= [= WA_THIS_COPY_CAPTURE, &handle_range, &counter](
167+
int first, int last) {
168+
while (last - first > 1) {
169+
const auto mid = first + (last - first) / 2;
170+
// Find something near the midpoint which is a
171+
// multiple of block size.
172+
tp_->ScheduleWithHint(
173+
[=]() { handle_range(mid, last); }, mid,
174+
mid + 1);
175+
last = mid;
176+
}
177+
run_jobs(balance, first, n, njobs, fn);
178+
counter.DecrementCount();
179+
};
180+
181+
// Eigen avoids a thread hop by running the root of the tree on the main
182+
// thread. We have disabled this because it actually slows things down
183+
// relative to base because base cheats and uses n threads while letting
184+
// main continue doing other work
185+
tp_->ScheduleWithHint([=]() { handle_range(0, njobs); }, 0, 1);
186+
187+
counter.Wait();
135188
};
136189
};
137190

@@ -164,6 +217,8 @@ class threadpool_t : public dnnl::threadpool_interop::threadpool_iface {
164217

165218
#else
166219

220+
#include "src/common/counting_barrier.hpp"
221+
167222
#include <atomic>
168223
#include <thread>
169224
#include <vector>

0 commit comments

Comments
 (0)