Skip to content

Commit b069be4

Browse files
committed
fixed bug: mutex may be unlocked before holder object destroyed.
1 parent 6a7c396 commit b069be4

File tree

2 files changed

+47
-36
lines changed

2 files changed

+47
-36
lines changed

include/promise-cpp/promise.hpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,14 @@ struct Mutex {
8888
* Task state in TaskList always be kPending
8989
*/
9090
struct PromiseHolder {
91+
PROMISE_API PromiseHolder();
9192
PROMISE_API ~PromiseHolder();
9293
std::list<std::weak_ptr<SharedPromise>> owners_;
9394
std::list<std::shared_ptr<Task>> pendingTasks_;
9495
TaskState state_;
9596
any value_;
9697
#if PROMISE_MULTITHREAD
97-
Mutex mutex_;
98+
std::shared_ptr<Mutex> mutex_;
9899
#endif
99100

100101
PROMISE_API void dump() const;
@@ -112,7 +113,7 @@ struct SharedPromise {
112113
std::shared_ptr<PromiseHolder> promiseHolder_;
113114
PROMISE_API void dump() const;
114115
#if PROMISE_MULTITHREAD
115-
PROMISE_API Mutex &obtainLock() const;
116+
PROMISE_API std::shared_ptr<Mutex> obtainLock() const;
116117
#endif
117118
};
118119

include/promise-cpp/promise_inl.hpp

+44-34
Original file line numberDiff line numberDiff line change
@@ -125,15 +125,15 @@ static inline void join(const std::shared_ptr<PromiseHolder> &left, const std::s
125125
//Unlock and then lock
126126
#if PROMISE_MULTITHREAD
127127
struct unlock_guard_t {
128-
inline unlock_guard_t(Mutex *mutex)
128+
inline unlock_guard_t(std::shared_ptr<Mutex> mutex)
129129
: mutex_(mutex)
130130
, lock_count_(mutex->lock_count()) {
131131
mutex_->unlock(lock_count_);
132132
}
133133
inline ~unlock_guard_t() {
134134
mutex_->lock(lock_count_);
135135
}
136-
Mutex *mutex_;
136+
std::shared_ptr<Mutex> mutex_;
137137
size_t lock_count_;
138138
};
139139
#endif
@@ -147,8 +147,8 @@ static inline void call(std::shared_ptr<Task> task) {
147147
// lock for 1st stage
148148
{
149149
#if PROMISE_MULTITHREAD
150-
Mutex &mutex = promiseHolder->mutex_;
151-
std::lock_guard<Mutex> lock(mutex);
150+
std::shared_ptr<Mutex> mutex = promiseHolder->mutex_;
151+
std::lock_guard<Mutex> lock(*mutex);
152152
#endif
153153

154154
if (task->state_ != TaskState::kPending) return;
@@ -170,14 +170,14 @@ static inline void call(std::shared_ptr<Task> task) {
170170
}
171171
else {
172172
#if PROMISE_MULTITHREAD
173-
Mutex *mutex0 = nullptr;
173+
std::shared_ptr<Mutex> mutex0 = nullptr;
174174
auto call = [&]() -> any {
175-
unlock_guard_t lock(&mutex);
175+
unlock_guard_t lock(mutex);
176176
const any &value = task->onResolved_.call(promiseHolder->value_);
177177
// Make sure the returned promised is locked before than "mutex"
178178
if (value.type() == typeid(Promise)) {
179179
Promise &promise = value.cast<Promise &>();
180-
mutex0 = &promise.sharedPromise_->obtainLock();
180+
mutex0 = promise.sharedPromise_->obtainLock();
181181
}
182182
return value;
183183
};
@@ -220,14 +220,14 @@ static inline void call(std::shared_ptr<Task> task) {
220220
else {
221221
try {
222222
#if PROMISE_MULTITHREAD
223-
Mutex *mutex0 = nullptr;
223+
std::shared_ptr<Mutex> mutex0 = nullptr;
224224
auto call = [&]() -> any {
225-
unlock_guard_t lock(&mutex);
225+
unlock_guard_t lock(mutex);
226226
const any &value = task->onRejected_.call(promiseHolder->value_);
227227
// Make sure the returned promised is locked before than "mutex"
228228
if (value.type() == typeid(Promise)) {
229229
Promise &promise = value.cast<Promise &>();
230-
mutex0 = &promise.sharedPromise_->obtainLock();
230+
mutex0 = promise.sharedPromise_->obtainLock();
231231
}
232232
return value;
233233
};
@@ -279,8 +279,8 @@ static inline void call(std::shared_ptr<Task> task) {
279279
{
280280
// get next task
281281
#if PROMISE_MULTITHREAD
282-
Mutex &mutex = promiseHolder->mutex_;
283-
std::lock_guard<Mutex> lock(mutex);
282+
std::shared_ptr<Mutex> mutex = promiseHolder->mutex_;
283+
std::lock_guard<Mutex> lock(*mutex);
284284
#endif
285285
std::list<std::shared_ptr<Task>> &pendingTasks2 = promiseHolder->pendingTasks_;
286286
if (pendingTasks2.size() == 0) {
@@ -295,8 +295,8 @@ static inline void call(std::shared_ptr<Task> task) {
295295
Defer::Defer(const std::shared_ptr<Task> &task) {
296296
std::shared_ptr<SharedPromise> sharedPromise(new SharedPromise{ task->promiseHolder_.lock() });
297297
#if PROMISE_MULTITHREAD
298-
Mutex &mutex = sharedPromise->obtainLock();
299-
std::lock_guard<Mutex> lock(mutex, std::adopt_lock_t());
298+
std::shared_ptr<Mutex> mutex = sharedPromise->obtainLock();
299+
std::lock_guard<Mutex> lock(*mutex, std::adopt_lock_t());
300300
#endif
301301

302302
task_ = task;
@@ -306,8 +306,8 @@ Defer::Defer(const std::shared_ptr<Task> &task) {
306306

307307
void Defer::resolve(const any &arg) const {
308308
#if PROMISE_MULTITHREAD
309-
Mutex &mutex = this->sharedPromise_->obtainLock();
310-
std::lock_guard<Mutex> lock(mutex, std::adopt_lock_t());
309+
std::shared_ptr<Mutex> mutex = this->sharedPromise_->obtainLock();
310+
std::lock_guard<Mutex> lock(*mutex, std::adopt_lock_t());
311311
#endif
312312

313313
if (task_->state_ != TaskState::kPending) return;
@@ -319,8 +319,8 @@ void Defer::resolve(const any &arg) const {
319319

320320
void Defer::reject(const any &arg) const {
321321
#if PROMISE_MULTITHREAD
322-
Mutex &mutex = this->sharedPromise_->obtainLock();
323-
std::lock_guard<Mutex> lock(mutex, std::adopt_lock_t());
322+
std::shared_ptr<Mutex> mutex = this->sharedPromise_->obtainLock();
323+
std::lock_guard<Mutex> lock(*mutex, std::adopt_lock_t());
324324
#endif
325325

326326
if (task_->state_ != TaskState::kPending) return;
@@ -385,6 +385,17 @@ void Mutex::unlock(size_t lock_count) {
385385
}
386386
#endif
387387

388+
PromiseHolder::PromiseHolder()
389+
: owners_()
390+
, pendingTasks_()
391+
, state_(TaskState::kPending)
392+
, value_()
393+
#if PROMISE_MULTITHREAD
394+
, mutex_(std::make_shared<Mutex>())
395+
#endif
396+
{
397+
}
398+
388399
PromiseHolder::~PromiseHolder() {
389400
if (this->state_ == TaskState::kRejected) {
390401
PromiseHolder::onUncaughtException(this->value_);
@@ -418,19 +429,18 @@ void PromiseHolder::handleUncaughtException(const any &onUncaughtException) {
418429
}
419430

420431
#if PROMISE_MULTITHREAD
421-
Mutex &SharedPromise::obtainLock() const {
422-
Mutex *mutex = nullptr;
432+
std::shared_ptr<Mutex> SharedPromise::obtainLock() const {
423433
while (true) {
424-
mutex = &this->promiseHolder_->mutex_;
434+
std::shared_ptr<Mutex> mutex = this->promiseHolder_->mutex_;
425435
mutex->lock();
426436

427437
// pointer to mutex may be changed after locked,
428438
// in this case we should try to lock and test again
429-
if (mutex == &this->promiseHolder_->mutex_)
430-
break;
439+
if (mutex == this->promiseHolder_->mutex_)
440+
return mutex;
431441
mutex->unlock();
432442
}
433-
return *mutex;
443+
return nullptr;
434444
}
435445
#endif
436446

@@ -460,10 +470,10 @@ Promise &Promise::then(const any &deferOrPromiseOrOnResolved) {
460470
Promise &promise = deferOrPromiseOrOnResolved.cast<Promise &>();
461471

462472
#if PROMISE_MULTITHREAD
463-
Mutex &mutex0 = this->sharedPromise_->obtainLock();
464-
std::lock_guard<Mutex> lock0(mutex0, std::adopt_lock_t());
465-
Mutex &mutex1 = promise.sharedPromise_->obtainLock();
466-
std::lock_guard<Mutex> lock1(mutex1, std::adopt_lock_t());
473+
std::shared_ptr<Mutex> mutex0 = this->sharedPromise_->obtainLock();
474+
std::lock_guard<Mutex> lock0(*mutex0, std::adopt_lock_t());
475+
std::shared_ptr<Mutex> mutex1 = promise.sharedPromise_->obtainLock();
476+
std::lock_guard<Mutex> lock1(*mutex1, std::adopt_lock_t());
467477
#endif
468478

469479
if (promise.sharedPromise_ && promise.sharedPromise_->promiseHolder_) {
@@ -482,8 +492,8 @@ Promise &Promise::then(const any &deferOrPromiseOrOnResolved) {
482492

483493
Promise &Promise::then(const any &onResolved, const any &onRejected) {
484494
#if PROMISE_MULTITHREAD
485-
Mutex &mutex = this->sharedPromise_->obtainLock();
486-
std::lock_guard<Mutex> lock(mutex, std::adopt_lock_t());
495+
std::shared_ptr<Mutex> mutex = this->sharedPromise_->obtainLock();
496+
std::lock_guard<Mutex> lock(*mutex, std::adopt_lock_t());
487497
#endif
488498

489499
std::shared_ptr<Task> task = std::make_shared<Task>(Task {
@@ -528,8 +538,8 @@ Promise &Promise::finally(const any &onFinally) {
528538

529539
void Promise::resolve(const any &arg) const {
530540
#if PROMISE_MULTITHREAD
531-
Mutex &mutex = this->sharedPromise_->obtainLock();
532-
std::lock_guard<Mutex> lock(mutex, std::adopt_lock_t());
541+
std::shared_ptr<Mutex> mutex = this->sharedPromise_->obtainLock();
542+
std::lock_guard<Mutex> lock(*mutex, std::adopt_lock_t());
533543
#endif
534544

535545
std::list<std::shared_ptr<Task>> &pendingTasks_ = this->sharedPromise_->promiseHolder_->pendingTasks_;
@@ -542,8 +552,8 @@ void Promise::resolve(const any &arg) const {
542552

543553
void Promise::reject(const any &arg) const {
544554
#if PROMISE_MULTITHREAD
545-
Mutex &mutex = this->sharedPromise_->obtainLock();
546-
std::lock_guard<Mutex> lock(mutex, std::adopt_lock_t());
555+
std::shared_ptr<Mutex> mutex = this->sharedPromise_->obtainLock();
556+
std::lock_guard<Mutex> lock(*mutex, std::adopt_lock_t());
547557
#endif
548558

549559
std::list<std::shared_ptr<Task>> &pendingTasks_ = this->sharedPromise_->promiseHolder_->pendingTasks_;

0 commit comments

Comments
 (0)