Skip to content

Commit e9588a9

Browse files
committed
Implement ReentrantMutex based on spin
1 parent f24f3c7 commit e9588a9

File tree

2 files changed

+216
-0
lines changed

2 files changed

+216
-0
lines changed

sgx_trts/src/sync/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ mod lazy;
1919
mod lock_api;
2020
mod mutex;
2121
mod once;
22+
mod remutex;
2223
mod rwlock;
2324

2425
#[allow(unused_imports)]
@@ -27,4 +28,5 @@ pub(crate) use once::Once;
2728

2829
pub use lock_api::{RawMutex, RawRwLock};
2930
pub use mutex::{SpinMutex, SpinMutexGuard};
31+
pub use remutex::{SpinReentrantMutex, SpinReentrantMutexGuard};
3032
pub use rwlock::{SpinRwLock, SpinRwLockReadGuard, SpinRwLockWriteGuard};

sgx_trts/src/sync/remutex.rs

+214
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License..
17+
18+
use crate::sync::lock_api::RawMutex;
19+
use crate::tcs;
20+
use core::cell::UnsafeCell;
21+
use core::convert::From;
22+
use core::fmt;
23+
use core::hint;
24+
use core::mem;
25+
use core::ops::{Deref, DerefMut};
26+
use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
27+
use sgx_types::marker::ContiguousMemory;
28+
29+
pub struct SpinReentrantMutex<T: ?Sized> {
30+
lock: AtomicBool,
31+
owner: AtomicUsize, // tcs id
32+
count: UnsafeCell<u32>,
33+
data: UnsafeCell<T>,
34+
}
35+
36+
unsafe impl<T: ContiguousMemory> ContiguousMemory for SpinReentrantMutex<T> {}
37+
38+
unsafe impl<T: ?Sized + Sync> Sync for SpinReentrantMutex<T> {}
39+
unsafe impl<T: ?Sized + Send> Send for SpinReentrantMutex<T> {}
40+
41+
impl<T> SpinReentrantMutex<T> {
42+
pub const fn new(data: T) -> Self {
43+
Self {
44+
lock: AtomicBool::new(false),
45+
owner: AtomicUsize::new(0),
46+
count: UnsafeCell::new(0),
47+
data: UnsafeCell::new(data),
48+
}
49+
}
50+
51+
#[inline]
52+
pub fn into_inner(self) -> T {
53+
let SpinReentrantMutex { data, .. } = self;
54+
data.into_inner()
55+
}
56+
}
57+
58+
impl<T: ?Sized> SpinReentrantMutex<T> {
59+
#[inline]
60+
pub fn lock(&self) -> SpinReentrantMutexGuard<'_, T> {
61+
let current_thread = tcs::current().id().as_usize();
62+
if self.owner.load(Ordering::Relaxed) == current_thread {
63+
self.increment_count()
64+
} else {
65+
self.acquire_lock();
66+
self.owner.store(current_thread, Ordering::Relaxed);
67+
unsafe {
68+
assert_eq!(*self.count.get(), 0);
69+
*self.count.get() = 1;
70+
}
71+
}
72+
73+
SpinReentrantMutexGuard { lock: self }
74+
}
75+
76+
#[inline]
77+
pub fn try_lock(&self) -> Option<SpinReentrantMutexGuard<'_, T>> {
78+
if self.try_acquire_lock() {
79+
Some(SpinReentrantMutexGuard { lock: self })
80+
} else {
81+
None
82+
}
83+
}
84+
85+
#[inline]
86+
pub fn unlock(guard: SpinReentrantMutexGuard<'_, T>) {
87+
drop(guard);
88+
}
89+
90+
#[inline]
91+
pub unsafe fn force_unlock(&self) {
92+
self.lock.store(false, Ordering::Release);
93+
}
94+
95+
#[inline]
96+
pub fn get_mut(&mut self) -> &mut T {
97+
unsafe { &mut *self.data.get() }
98+
}
99+
100+
#[inline]
101+
pub fn is_locked(&self) -> bool {
102+
self.lock.load(Ordering::Relaxed)
103+
}
104+
105+
#[inline]
106+
fn increment_count(&self) {
107+
unsafe {
108+
*self.count.get() = (*self.count.get())
109+
.checked_add(1)
110+
.expect("lock count overflow in reentrant mutex");
111+
}
112+
}
113+
114+
#[inline]
115+
fn acquire_lock(&self) {
116+
while self
117+
.lock
118+
.compare_exchange_weak(false, true, Ordering::Acquire, Ordering::Relaxed)
119+
.is_err()
120+
{
121+
while self.is_locked() {
122+
hint::spin_loop();
123+
}
124+
}
125+
}
126+
127+
#[inline]
128+
fn try_acquire_lock(&self) -> bool {
129+
self.lock
130+
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
131+
.is_ok()
132+
}
133+
}
134+
135+
impl<T: ?Sized + fmt::Debug> fmt::Debug for SpinReentrantMutex<T> {
136+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
137+
match self.try_lock() {
138+
Some(guard) => write!(f, "SpinReentrantMutex {{ value: ")
139+
.and_then(|()| (*guard).fmt(f))
140+
.and_then(|()| write!(f, "}}")),
141+
None => write!(f, "SpinReentrantMutex {{ <locked> }}"),
142+
}
143+
}
144+
}
145+
146+
impl<T: ?Sized + Default> Default for SpinReentrantMutex<T> {
147+
fn default() -> SpinReentrantMutex<T> {
148+
SpinReentrantMutex::new(Default::default())
149+
}
150+
}
151+
152+
impl<T> From<T> for SpinReentrantMutex<T> {
153+
fn from(value: T) -> SpinReentrantMutex<T> {
154+
SpinReentrantMutex::new(value)
155+
}
156+
}
157+
158+
impl<T> RawMutex for SpinReentrantMutex<T> {
159+
#[inline]
160+
fn lock(&self) {
161+
mem::forget(SpinReentrantMutex::lock(self));
162+
}
163+
164+
#[inline]
165+
fn try_lock(&self) -> bool {
166+
SpinReentrantMutex::try_lock(self)
167+
.map(mem::forget)
168+
.is_some()
169+
}
170+
171+
#[inline]
172+
unsafe fn unlock(&self) {
173+
self.force_unlock();
174+
}
175+
}
176+
177+
pub struct SpinReentrantMutexGuard<'a, T: 'a + ?Sized> {
178+
lock: &'a SpinReentrantMutex<T>,
179+
}
180+
181+
impl<T: ?Sized> !Send for SpinReentrantMutexGuard<'_, T> {}
182+
unsafe impl<T: ?Sized + Sync> Sync for SpinReentrantMutexGuard<'_, T> {}
183+
184+
impl<T: ?Sized> Drop for SpinReentrantMutexGuard<'_, T> {
185+
fn drop(&mut self) {
186+
let remutx = self.lock;
187+
unsafe {
188+
*remutx.count.get() -= 1;
189+
if *remutx.count.get() == 0 {
190+
remutx.owner.store(0, Ordering::Relaxed);
191+
remutx.lock.store(false, Ordering::Release);
192+
}
193+
}
194+
}
195+
}
196+
197+
impl<T: ?Sized> Deref for SpinReentrantMutexGuard<'_, T> {
198+
type Target = T;
199+
fn deref(&self) -> &Self::Target {
200+
unsafe { &*self.lock.data.get() }
201+
}
202+
}
203+
204+
impl<T: ?Sized> DerefMut for SpinReentrantMutexGuard<'_, T> {
205+
fn deref_mut(&mut self) -> &mut Self::Target {
206+
unsafe { &mut *self.lock.data.get() }
207+
}
208+
}
209+
210+
impl<T: ?Sized + fmt::Debug> fmt::Debug for SpinReentrantMutexGuard<'_, T> {
211+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
212+
fmt::Debug::fmt(&**self, f)
213+
}
214+
}

0 commit comments

Comments
 (0)