Skip to content

Commit d4851b3

Browse files
committed
cass_future: refactor future state management
The refactor involves employing an enum to represent the state of the future: `CassFutureExecution`. This enum has three variants: - `RunningWithoutCallback`: Indicates that the future is currently running without a callback, - `RunningWithCallback`: Indicates that the future is running with a callback, - `Completed`: Indicates that the future has completed, and it contains the result of the future. I hope this refactor makes the state transitions clearer and easier to understand, and, even more importantly, more bug-resistant.
1 parent 87c064e commit d4851b3

File tree

1 file changed

+140
-53
lines changed

1 file changed

+140
-53
lines changed

scylla-rust-wrapper/src/future.rs

Lines changed: 140 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use futures::future;
1212
use std::future::Future;
1313
use std::mem;
1414
use std::os::raw::c_void;
15-
use std::sync::{Arc, Condvar, Mutex};
15+
use std::sync::{Arc, Condvar, Mutex, MutexGuard};
1616
use tokio::task::JoinHandle;
1717
use tokio::time::Duration;
1818

@@ -51,12 +51,113 @@ impl BoundCallback {
5151
/// State of the execution of the [CassFuture],
5252
/// together with a join handle of the tokio task that is executing it.
5353
struct CassFutureState {
54-
value: Option<CassFutureResult>,
55-
err_string: Option<String>,
56-
callback: Option<BoundCallback>,
54+
execution_state: CassFutureExecution,
55+
/// Presence of this handle while `execution_state` is not `Completed` indicates
56+
/// that no thread is currently blocked on the future. This means that it might
57+
/// not be executed (especially in case of the current-thread executor).
58+
/// Absence means that some thread has blocked on the future, so it is necessarily
59+
/// being executed.
5760
join_handle: Option<JoinHandle<()>>,
5861
}
5962

63+
/// State of the execution of the [CassFuture].
64+
enum CassFutureExecution {
65+
RunningWithoutCallback,
66+
RunningWithCallback { callback: BoundCallback },
67+
Completed(CassFutureCompleted),
68+
}
69+
70+
impl CassFutureExecution {
71+
fn completed(&self) -> bool {
72+
match self {
73+
Self::Completed(_) => true,
74+
Self::RunningWithCallback { .. } | Self::RunningWithoutCallback => false,
75+
}
76+
}
77+
78+
/// Sets callback for the [CassFuture]. If the future has not completed yet,
79+
/// the callback will be invoked once the future is completed, by the executor thread.
80+
/// If the future has already completed, the callback will be invoked immediately.
81+
unsafe fn set_callback(
82+
mut state_lock: MutexGuard<CassFutureState>,
83+
fut_ptr: CassBorrowedSharedPtr<CassFuture, CMut>,
84+
cb: CassFutureCallback,
85+
data: *mut c_void,
86+
) -> CassError {
87+
let bound_cb = BoundCallback { cb, data };
88+
89+
match state_lock.execution_state {
90+
Self::RunningWithoutCallback => {
91+
// Store the callback.
92+
state_lock.execution_state = Self::RunningWithCallback { callback: bound_cb };
93+
CassError::CASS_OK
94+
}
95+
Self::RunningWithCallback { .. } =>
96+
// Another callback has been already set.
97+
{
98+
return CassError::CASS_ERROR_LIB_CALLBACK_ALREADY_SET;
99+
}
100+
Self::Completed { .. } => {
101+
// The value is already available, we need to call the callback ourselves.
102+
mem::drop(state_lock);
103+
bound_cb.invoke(fut_ptr);
104+
return CassError::CASS_OK;
105+
}
106+
}
107+
}
108+
109+
/// Sets the [CassFuture] as completed. This function is called by the executor thread
110+
/// once it completes the underlying Rust future. If there's a callback set,
111+
/// it will be invoked immediately.
112+
fn complete(
113+
mut state_lock: MutexGuard<CassFutureState>,
114+
value: CassFutureResult,
115+
cass_fut: &Arc<CassFuture>,
116+
) {
117+
let prev_state = mem::replace(
118+
&mut state_lock.execution_state,
119+
Self::Completed(CassFutureCompleted::new(value)),
120+
);
121+
122+
// This is because we mustn't hold the lock while invoking the callback.
123+
mem::drop(state_lock);
124+
125+
let maybe_cb = match prev_state {
126+
Self::RunningWithoutCallback => None,
127+
Self::RunningWithCallback { callback } => Some(callback),
128+
Self::Completed { .. } => unreachable!(
129+
"Exactly one dedicated tokio task is expected to execute and complete the CassFuture."
130+
),
131+
};
132+
133+
if let Some(bound_cb) = maybe_cb {
134+
let fut_ptr = ArcFFI::as_ptr::<CMut>(cass_fut);
135+
// Safety: pointer is valid, because we get it from arc allocation.
136+
bound_cb.invoke(fut_ptr);
137+
}
138+
}
139+
}
140+
141+
/// The result of a completed [CassFuture].
142+
struct CassFutureCompleted {
143+
/// The result of the future, either a value or an error.
144+
value: CassFutureResult,
145+
/// Just a cache for the error message. Needed because the C API exposes a pointer to the
146+
/// error message, and we need to keep it alive until the future is freed.
147+
/// Initially, it's `None`, and it is set to `Some` when the error message is requested
148+
/// by `cass_future_error_message()`.
149+
cached_err_string: Option<String>,
150+
}
151+
152+
impl CassFutureCompleted {
153+
fn new(value: CassFutureResult) -> Self {
154+
Self {
155+
value,
156+
cached_err_string: None,
157+
}
158+
}
159+
}
160+
60161
/// The C-API representation of a future. Implemented as a wrapper around a Rust future
61162
/// that can be awaited and has a callback mechanism. It's **eager** in a way that
62163
/// its execution starts possibly immediately (unless the executor thread pool is nempty,
@@ -92,27 +193,17 @@ impl CassFuture {
92193
) -> Arc<CassFuture> {
93194
let cass_fut = Arc::new(CassFuture {
94195
state: Mutex::new(CassFutureState {
95-
value: None,
96-
err_string: None,
97-
callback: None,
98196
join_handle: None,
197+
execution_state: CassFutureExecution::RunningWithoutCallback,
99198
}),
100199
wait_for_value: Condvar::new(),
101200
});
102201
let cass_fut_clone = Arc::clone(&cass_fut);
103202
let join_handle = RUNTIME.spawn(async move {
104203
let r = fut.await;
105-
let maybe_cb = {
106-
let mut guard = cass_fut_clone.state.lock().unwrap();
107-
guard.value = Some(r);
108-
// Take the callback and call it after releasing the lock
109-
guard.callback.take()
110-
};
111-
if let Some(bound_cb) = maybe_cb {
112-
let fut_ptr = ArcFFI::as_ptr::<CMut>(&cass_fut_clone);
113-
// Safety: pointer is valid, because we get it from arc allocation.
114-
bound_cb.invoke(fut_ptr);
115-
}
204+
205+
let guard = cass_fut_clone.state.lock().unwrap();
206+
CassFutureExecution::complete(guard, r, &cass_fut_clone);
116207

117208
cass_fut_clone.wait_for_value.notify_all();
118209
});
@@ -126,17 +217,15 @@ impl CassFuture {
126217
pub fn new_ready(r: CassFutureResult) -> Arc<Self> {
127218
Arc::new(CassFuture {
128219
state: Mutex::new(CassFutureState {
129-
value: Some(r),
130-
err_string: None,
131-
callback: None,
132220
join_handle: None,
221+
execution_state: CassFutureExecution::Completed(CassFutureCompleted::new(r)),
133222
}),
134223
wait_for_value: Condvar::new(),
135224
})
136225
}
137226

138227
pub fn with_waited_result<T>(&self, f: impl FnOnce(&mut CassFutureResult) -> T) -> T {
139-
self.with_waited_state(|s| f(s.value.as_mut().unwrap()))
228+
self.with_waited_state(|s| f(&mut s.value))
140229
}
141230

142231
/// Awaits the future until completion.
@@ -152,7 +241,7 @@ impl CassFuture {
152241
/// - JoinHandle is Some -> some other thread was working on the future, but
153242
/// timed out (see [CassFuture::with_waited_state_timed]). We need to
154243
/// take the ownership of the handle, and complete the work.
155-
fn with_waited_state<T>(&self, f: impl FnOnce(&mut CassFutureState) -> T) -> T {
244+
fn with_waited_state<T>(&self, f: impl FnOnce(&mut CassFutureCompleted) -> T) -> T {
156245
let mut guard = self.state.lock().unwrap();
157246
loop {
158247
let handle = guard.join_handle.take();
@@ -165,7 +254,7 @@ impl CassFuture {
165254
guard = self
166255
.wait_for_value
167256
.wait_while(guard, |state| {
168-
state.value.is_none() && state.join_handle.is_none()
257+
!state.execution_state.completed() && state.join_handle.is_none()
169258
})
170259
// unwrap: Error appears only when mutex is poisoned.
171260
.unwrap();
@@ -177,7 +266,15 @@ impl CassFuture {
177266
continue;
178267
}
179268
}
180-
return f(&mut guard);
269+
270+
// If we had ended up in either the handle's or with the condvar's `if` branch,
271+
// we awaited the future and it is now completed.
272+
let completed = match &mut guard.execution_state {
273+
CassFutureExecution::RunningWithoutCallback
274+
| CassFutureExecution::RunningWithCallback { .. } => unreachable!(),
275+
CassFutureExecution::Completed(completed) => completed,
276+
};
277+
return f(completed);
181278
}
182279
}
183280

@@ -186,7 +283,7 @@ impl CassFuture {
186283
f: impl FnOnce(&mut CassFutureResult) -> T,
187284
timeout_duration: Duration,
188285
) -> Result<T, FutureError> {
189-
self.with_waited_state_timed(|s| f(s.value.as_mut().unwrap()), timeout_duration)
286+
self.with_waited_state_timed(|s| f(&mut s.value), timeout_duration)
190287
}
191288

192289
/// Tries to await the future with a given timeout.
@@ -206,7 +303,7 @@ impl CassFuture {
206303
/// take the ownership of the handle, and continue the work.
207304
fn with_waited_state_timed<T>(
208305
&self,
209-
f: impl FnOnce(&mut CassFutureState) -> T,
306+
f: impl FnOnce(&mut CassFutureCompleted) -> T,
210307
timeout_duration: Duration,
211308
) -> Result<T, FutureError> {
212309
let mut guard = self.state.lock().unwrap();
@@ -254,7 +351,7 @@ impl CassFuture {
254351
let (guard_result, timeout_result) = self
255352
.wait_for_value
256353
.wait_timeout_while(guard, remaining_timeout, |state| {
257-
state.value.is_none() && state.join_handle.is_none()
354+
!state.execution_state.completed() && state.join_handle.is_none()
258355
})
259356
// unwrap: Error appears only when mutex is poisoned.
260357
.unwrap();
@@ -271,7 +368,14 @@ impl CassFuture {
271368
}
272369
}
273370

274-
return Ok(f(&mut guard));
371+
// If we had ended up in either the handle's or with the condvar's `if` branch
372+
// and we didn't return `TimeoutError`, we awaited the future and it is now completed.
373+
let completed = match &mut guard.execution_state {
374+
CassFutureExecution::RunningWithoutCallback
375+
| CassFutureExecution::RunningWithCallback { .. } => unreachable!(),
376+
CassFutureExecution::Completed(completed) => completed,
377+
};
378+
return Ok(f(completed));
275379
}
276380
}
277381

@@ -281,21 +385,8 @@ impl CassFuture {
281385
cb: CassFutureCallback,
282386
data: *mut c_void,
283387
) -> CassError {
284-
let mut lock = self.state.lock().unwrap();
285-
if lock.callback.is_some() {
286-
// Another callback has been already set
287-
return CassError::CASS_ERROR_LIB_CALLBACK_ALREADY_SET;
288-
}
289-
let bound_cb = BoundCallback { cb, data };
290-
if lock.value.is_some() {
291-
// The value is already available, we need to call the callback ourselves
292-
mem::drop(lock);
293-
bound_cb.invoke(self_ptr);
294-
return CassError::CASS_OK;
295-
}
296-
// Store the callback
297-
lock.callback = Some(bound_cb);
298-
CassError::CASS_OK
388+
let lock = self.state.lock().unwrap();
389+
unsafe { CassFutureExecution::set_callback(lock, self_ptr, cb, data) }
299390
}
300391

301392
fn into_raw(self: Arc<Self>) -> CassOwnedSharedPtr<Self, CMut> {
@@ -358,10 +449,7 @@ pub unsafe extern "C" fn cass_future_ready(
358449
};
359450

360451
let state_guard = future.state.lock().unwrap();
361-
match state_guard.value {
362-
None => cass_false,
363-
Some(_) => cass_true,
364-
}
452+
state_guard.execution_state.completed() as cass_bool_t
365453
}
366454

367455
#[unsafe(no_mangle)]
@@ -391,11 +479,10 @@ pub unsafe extern "C" fn cass_future_error_message(
391479
return;
392480
};
393481

394-
future.with_waited_state(|state: &mut CassFutureState| {
395-
let value = &state.value;
396-
let msg = state
397-
.err_string
398-
.get_or_insert_with(|| match value.as_ref().unwrap() {
482+
future.with_waited_state(|completed: &mut CassFutureCompleted| {
483+
let msg = completed
484+
.cached_err_string
485+
.get_or_insert_with(|| match &completed.value {
399486
Ok(CassResultValue::QueryError(err)) => err.msg(),
400487
Err((_, s)) => s.msg(),
401488
_ => "".to_string(),

0 commit comments

Comments
 (0)