1
+ #include < iostream>
2
+ #include < mutex>
3
+ #include < condition_variable>
4
+ #include < atomic>
5
+ #include < queue>
6
+ #include < functional>
7
+ #include < thread>
8
+ #include < vector>
9
+ #include < future>
10
+ #include < memory>
11
+ #include < syncstream>
12
+ using namespace std ::chrono_literals;
13
+
14
+ inline std::size_t default_thread_pool_size () noexcept {
15
+ std::size_t num_threads = std::thread::hardware_concurrency ();
16
+ num_threads = num_threads == 0 ? 2 : num_threads; // 防止无法检测当前硬件,让我们线程池至少有 2 个线程
17
+ return num_threads;
18
+ }
19
+
20
+ class ThreadPool {
21
+ public:
22
+ using Task = std::packaged_task<void ()>;
23
+
24
+ ThreadPool (const ThreadPool&) = delete ;
25
+ ThreadPool& operator =(const ThreadPool&) = delete ;
26
+
27
+ ThreadPool (std::size_t num_thread = default_thread_pool_size()) :
28
+ stop_{ false }, num_thread_{ num_thread }
29
+ {
30
+ start ();
31
+ }
32
+ ~ThreadPool (){
33
+ stop ();
34
+ }
35
+
36
+ void stop (){
37
+ stop_ = true ;
38
+ cv_.notify_all ();
39
+ for (auto & thread : pool_){
40
+ if (thread.joinable ())
41
+ thread.join ();
42
+ }
43
+ pool_.clear ();
44
+ }
45
+
46
+ template <typename F, typename ...Args>
47
+ std::future<std::invoke_result_t <std::decay_t <F>, std::decay_t <Args>...>> submit (F&& f, Args&&...args){
48
+ using RetType = std::invoke_result_t <std::decay_t <F>, std::decay_t <Args>...>;
49
+ if (stop_){
50
+ throw std::runtime_error (" ThreadPool is stopped" );
51
+ }
52
+ auto task = std::make_shared<std::packaged_task<RetType ()>>(std::bind (std::forward<F>(f), std::forward<Args>(args)...));
53
+
54
+ std::future<RetType> ret = task->get_future ();
55
+
56
+ {
57
+ std::lock_guard<std::mutex> lc{ mutex_ };
58
+ tasks_.emplace ([task] {(*task)(); });
59
+ }
60
+ cv_.notify_one ();
61
+
62
+ return ret;
63
+ }
64
+
65
+ void start (){
66
+ for (std::size_t i = 0 ; i < num_thread_; ++i){
67
+ pool_.emplace_back ([this ]{
68
+ while (!stop_) {
69
+ Task task;
70
+ {
71
+ std::unique_lock<std::mutex> lock{ mutex_ };
72
+ cv_.wait (lock, [this ] {return stop_ || !tasks_.empty (); });
73
+ if (tasks_.empty ()) return ;
74
+ task = std::move (tasks_.front ());
75
+ tasks_.pop ();
76
+ }
77
+ task ();
78
+ }
79
+ });
80
+ }
81
+ }
82
+
83
+ private:
84
+ std::mutex mutex_;
85
+ std::condition_variable cv_;
86
+ std::atomic<bool > stop_;
87
+ std::atomic<std::size_t > num_thread_;
88
+ std::queue<Task> tasks_;
89
+ std::vector<std::thread> pool_;
90
+ };
91
+
92
+ int print_task (int n) {
93
+ std::osyncstream{ std::cout } << " Task " << n << " is running on thr: " <<
94
+ std::this_thread::get_id () << ' \n ' ;
95
+ return n;
96
+ }
97
+ int print_task2 (int n) {
98
+ std::osyncstream{ std::cout } << " 🐢🐢🐢 " << n << " 🐉🐉🐉" << std::endl;
99
+ return n;
100
+ }
101
+
102
+ struct X {
103
+ void f (const int & n) const {
104
+ std::osyncstream{ std::cout } << &n << ' \n ' ;
105
+ }
106
+ };
107
+
108
+ int main () {
109
+ ThreadPool pool{ 4 }; // 创建一个有 4 个线程的线程池
110
+
111
+ X x;
112
+ int n = 6 ;
113
+ std::cout << &n << ' \n ' ;
114
+ auto t = pool.submit (&X::f, &x, n); // 默认复制,地址不同
115
+ auto t2 = pool.submit (&X::f, &x, std::ref (n));
116
+ t.wait ();
117
+ t2.wait ();
118
+ } // 析构自动 stop()自动 stop()
0 commit comments