@@ -12,7 +12,7 @@ use futures::future;
12
12
use std:: future:: Future ;
13
13
use std:: mem;
14
14
use std:: os:: raw:: c_void;
15
- use std:: sync:: { Arc , Condvar , Mutex } ;
15
+ use std:: sync:: { Arc , Condvar , Mutex , MutexGuard } ;
16
16
use tokio:: task:: JoinHandle ;
17
17
use tokio:: time:: Duration ;
18
18
@@ -51,12 +51,113 @@ impl BoundCallback {
51
51
/// State of the execution of the [CassFuture],
52
52
/// together with a join handle of the tokio task that is executing it.
53
53
struct CassFutureState {
54
- value : Option < CassFutureResult > ,
55
- err_string : Option < String > ,
56
- callback : Option < BoundCallback > ,
54
+ execution_state : CassFutureExecution ,
55
+ /// Presence of this handle while `execution_state` is not `Completed` indicates
56
+ /// that no thread is currently blocked on the future. This means that it might
57
+ /// not be executed (especially in case of the current-thread executor).
58
+ /// Absence means that some thread has blocked on the future, so it is necessarily
59
+ /// being executed.
57
60
join_handle : Option < JoinHandle < ( ) > > ,
58
61
}
59
62
63
+ /// State of the execution of the [CassFuture].
64
+ enum CassFutureExecution {
65
+ RunningWithoutCallback ,
66
+ RunningWithCallback { callback : BoundCallback } ,
67
+ Completed ( CassFutureCompleted ) ,
68
+ }
69
+
70
+ impl CassFutureExecution {
71
+ fn completed ( & self ) -> bool {
72
+ match self {
73
+ Self :: Completed ( _) => true ,
74
+ Self :: RunningWithCallback { .. } | Self :: RunningWithoutCallback => false ,
75
+ }
76
+ }
77
+
78
+ /// Sets callback for the [CassFuture]. If the future has not completed yet,
79
+ /// the callback will be invoked once the future is completed, by the executor thread.
80
+ /// If the future has already completed, the callback will be invoked immediately.
81
+ unsafe fn set_callback (
82
+ mut state_lock : MutexGuard < CassFutureState > ,
83
+ fut_ptr : CassBorrowedSharedPtr < CassFuture , CMut > ,
84
+ cb : CassFutureCallback ,
85
+ data : * mut c_void ,
86
+ ) -> CassError {
87
+ let bound_cb = BoundCallback { cb, data } ;
88
+
89
+ match state_lock. execution_state {
90
+ Self :: RunningWithoutCallback => {
91
+ // Store the callback.
92
+ state_lock. execution_state = Self :: RunningWithCallback { callback : bound_cb } ;
93
+ CassError :: CASS_OK
94
+ }
95
+ Self :: RunningWithCallback { .. } =>
96
+ // Another callback has been already set.
97
+ {
98
+ return CassError :: CASS_ERROR_LIB_CALLBACK_ALREADY_SET ;
99
+ }
100
+ Self :: Completed { .. } => {
101
+ // The value is already available, we need to call the callback ourselves.
102
+ mem:: drop ( state_lock) ;
103
+ bound_cb. invoke ( fut_ptr) ;
104
+ return CassError :: CASS_OK ;
105
+ }
106
+ }
107
+ }
108
+
109
+ /// Sets the [CassFuture] as completed. This function is called by the executor thread
110
+ /// once it completes the underlying Rust future. If there's a callback set,
111
+ /// it will be invoked immediately.
112
+ fn complete (
113
+ mut state_lock : MutexGuard < CassFutureState > ,
114
+ value : CassFutureResult ,
115
+ cass_fut : & Arc < CassFuture > ,
116
+ ) {
117
+ let prev_state = mem:: replace (
118
+ & mut state_lock. execution_state ,
119
+ Self :: Completed ( CassFutureCompleted :: new ( value) ) ,
120
+ ) ;
121
+
122
+ // This is because we mustn't hold the lock while invoking the callback.
123
+ mem:: drop ( state_lock) ;
124
+
125
+ let maybe_cb = match prev_state {
126
+ Self :: RunningWithoutCallback => None ,
127
+ Self :: RunningWithCallback { callback } => Some ( callback) ,
128
+ Self :: Completed { .. } => unreachable ! (
129
+ "Exactly one dedicated tokio task is expected to execute and complete the CassFuture."
130
+ ) ,
131
+ } ;
132
+
133
+ if let Some ( bound_cb) = maybe_cb {
134
+ let fut_ptr = ArcFFI :: as_ptr :: < CMut > ( cass_fut) ;
135
+ // Safety: pointer is valid, because we get it from arc allocation.
136
+ bound_cb. invoke ( fut_ptr) ;
137
+ }
138
+ }
139
+ }
140
+
141
+ /// The result of a completed [CassFuture].
142
+ struct CassFutureCompleted {
143
+ /// The result of the future, either a value or an error.
144
+ value : CassFutureResult ,
145
+ /// Just a cache for the error message. Needed because the C API exposes a pointer to the
146
+ /// error message, and we need to keep it alive until the future is freed.
147
+ /// Initially, it's `None`, and it is set to `Some` when the error message is requested
148
+ /// by `cass_future_error_message()`.
149
+ cached_err_string : Option < String > ,
150
+ }
151
+
152
+ impl CassFutureCompleted {
153
+ fn new ( value : CassFutureResult ) -> Self {
154
+ Self {
155
+ value,
156
+ cached_err_string : None ,
157
+ }
158
+ }
159
+ }
160
+
60
161
/// The C-API representation of a future. Implemented as a wrapper around a Rust future
61
162
/// that can be awaited and has a callback mechanism. It's **eager** in a way that
62
163
/// its execution starts possibly immediately (unless the executor thread pool is nempty,
@@ -92,27 +193,17 @@ impl CassFuture {
92
193
) -> Arc < CassFuture > {
93
194
let cass_fut = Arc :: new ( CassFuture {
94
195
state : Mutex :: new ( CassFutureState {
95
- value : None ,
96
- err_string : None ,
97
- callback : None ,
98
196
join_handle : None ,
197
+ execution_state : CassFutureExecution :: RunningWithoutCallback ,
99
198
} ) ,
100
199
wait_for_value : Condvar :: new ( ) ,
101
200
} ) ;
102
201
let cass_fut_clone = Arc :: clone ( & cass_fut) ;
103
202
let join_handle = RUNTIME . spawn ( async move {
104
203
let r = fut. await ;
105
- let maybe_cb = {
106
- let mut guard = cass_fut_clone. state . lock ( ) . unwrap ( ) ;
107
- guard. value = Some ( r) ;
108
- // Take the callback and call it after releasing the lock
109
- guard. callback . take ( )
110
- } ;
111
- if let Some ( bound_cb) = maybe_cb {
112
- let fut_ptr = ArcFFI :: as_ptr :: < CMut > ( & cass_fut_clone) ;
113
- // Safety: pointer is valid, because we get it from arc allocation.
114
- bound_cb. invoke ( fut_ptr) ;
115
- }
204
+
205
+ let guard = cass_fut_clone. state . lock ( ) . unwrap ( ) ;
206
+ CassFutureExecution :: complete ( guard, r, & cass_fut_clone) ;
116
207
117
208
cass_fut_clone. wait_for_value . notify_all ( ) ;
118
209
} ) ;
@@ -126,17 +217,15 @@ impl CassFuture {
126
217
pub fn new_ready ( r : CassFutureResult ) -> Arc < Self > {
127
218
Arc :: new ( CassFuture {
128
219
state : Mutex :: new ( CassFutureState {
129
- value : Some ( r) ,
130
- err_string : None ,
131
- callback : None ,
132
220
join_handle : None ,
221
+ execution_state : CassFutureExecution :: Completed ( CassFutureCompleted :: new ( r) ) ,
133
222
} ) ,
134
223
wait_for_value : Condvar :: new ( ) ,
135
224
} )
136
225
}
137
226
138
227
pub fn with_waited_result < T > ( & self , f : impl FnOnce ( & mut CassFutureResult ) -> T ) -> T {
139
- self . with_waited_state ( |s| f ( s. value . as_mut ( ) . unwrap ( ) ) )
228
+ self . with_waited_state ( |s| f ( & mut s. value ) )
140
229
}
141
230
142
231
/// Awaits the future until completion.
@@ -152,7 +241,7 @@ impl CassFuture {
152
241
/// - JoinHandle is Some -> some other thread was working on the future, but
153
242
/// timed out (see [CassFuture::with_waited_state_timed]). We need to
154
243
/// take the ownership of the handle, and complete the work.
155
- fn with_waited_state < T > ( & self , f : impl FnOnce ( & mut CassFutureState ) -> T ) -> T {
244
+ fn with_waited_state < T > ( & self , f : impl FnOnce ( & mut CassFutureCompleted ) -> T ) -> T {
156
245
let mut guard = self . state . lock ( ) . unwrap ( ) ;
157
246
loop {
158
247
let handle = guard. join_handle . take ( ) ;
@@ -165,7 +254,7 @@ impl CassFuture {
165
254
guard = self
166
255
. wait_for_value
167
256
. wait_while ( guard, |state| {
168
- state. value . is_none ( ) && state. join_handle . is_none ( )
257
+ ! state. execution_state . completed ( ) && state. join_handle . is_none ( )
169
258
} )
170
259
// unwrap: Error appears only when mutex is poisoned.
171
260
. unwrap ( ) ;
@@ -177,7 +266,15 @@ impl CassFuture {
177
266
continue ;
178
267
}
179
268
}
180
- return f ( & mut guard) ;
269
+
270
+ // If we had ended up in either the handle's or with the condvar's `if` branch,
271
+ // we awaited the future and it is now completed.
272
+ let completed = match & mut guard. execution_state {
273
+ CassFutureExecution :: RunningWithoutCallback
274
+ | CassFutureExecution :: RunningWithCallback { .. } => unreachable ! ( ) ,
275
+ CassFutureExecution :: Completed ( completed) => completed,
276
+ } ;
277
+ return f ( completed) ;
181
278
}
182
279
}
183
280
@@ -186,7 +283,7 @@ impl CassFuture {
186
283
f : impl FnOnce ( & mut CassFutureResult ) -> T ,
187
284
timeout_duration : Duration ,
188
285
) -> Result < T , FutureError > {
189
- self . with_waited_state_timed ( |s| f ( s. value . as_mut ( ) . unwrap ( ) ) , timeout_duration)
286
+ self . with_waited_state_timed ( |s| f ( & mut s. value ) , timeout_duration)
190
287
}
191
288
192
289
/// Tries to await the future with a given timeout.
@@ -206,7 +303,7 @@ impl CassFuture {
206
303
/// take the ownership of the handle, and continue the work.
207
304
fn with_waited_state_timed < T > (
208
305
& self ,
209
- f : impl FnOnce ( & mut CassFutureState ) -> T ,
306
+ f : impl FnOnce ( & mut CassFutureCompleted ) -> T ,
210
307
timeout_duration : Duration ,
211
308
) -> Result < T , FutureError > {
212
309
let mut guard = self . state . lock ( ) . unwrap ( ) ;
@@ -254,7 +351,7 @@ impl CassFuture {
254
351
let ( guard_result, timeout_result) = self
255
352
. wait_for_value
256
353
. wait_timeout_while ( guard, remaining_timeout, |state| {
257
- state. value . is_none ( ) && state. join_handle . is_none ( )
354
+ ! state. execution_state . completed ( ) && state. join_handle . is_none ( )
258
355
} )
259
356
// unwrap: Error appears only when mutex is poisoned.
260
357
. unwrap ( ) ;
@@ -271,7 +368,14 @@ impl CassFuture {
271
368
}
272
369
}
273
370
274
- return Ok ( f ( & mut guard) ) ;
371
+ // If we had ended up in either the handle's or with the condvar's `if` branch
372
+ // and we didn't return `TimeoutError`, we awaited the future and it is now completed.
373
+ let completed = match & mut guard. execution_state {
374
+ CassFutureExecution :: RunningWithoutCallback
375
+ | CassFutureExecution :: RunningWithCallback { .. } => unreachable ! ( ) ,
376
+ CassFutureExecution :: Completed ( completed) => completed,
377
+ } ;
378
+ return Ok ( f ( completed) ) ;
275
379
}
276
380
}
277
381
@@ -281,21 +385,8 @@ impl CassFuture {
281
385
cb : CassFutureCallback ,
282
386
data : * mut c_void ,
283
387
) -> CassError {
284
- let mut lock = self . state . lock ( ) . unwrap ( ) ;
285
- if lock. callback . is_some ( ) {
286
- // Another callback has been already set
287
- return CassError :: CASS_ERROR_LIB_CALLBACK_ALREADY_SET ;
288
- }
289
- let bound_cb = BoundCallback { cb, data } ;
290
- if lock. value . is_some ( ) {
291
- // The value is already available, we need to call the callback ourselves
292
- mem:: drop ( lock) ;
293
- bound_cb. invoke ( self_ptr) ;
294
- return CassError :: CASS_OK ;
295
- }
296
- // Store the callback
297
- lock. callback = Some ( bound_cb) ;
298
- CassError :: CASS_OK
388
+ let lock = self . state . lock ( ) . unwrap ( ) ;
389
+ unsafe { CassFutureExecution :: set_callback ( lock, self_ptr, cb, data) }
299
390
}
300
391
301
392
fn into_raw ( self : Arc < Self > ) -> CassOwnedSharedPtr < Self , CMut > {
@@ -358,10 +449,7 @@ pub unsafe extern "C" fn cass_future_ready(
358
449
} ;
359
450
360
451
let state_guard = future. state . lock ( ) . unwrap ( ) ;
361
- match state_guard. value {
362
- None => cass_false,
363
- Some ( _) => cass_true,
364
- }
452
+ state_guard. execution_state . completed ( ) as cass_bool_t
365
453
}
366
454
367
455
#[ unsafe( no_mangle) ]
@@ -391,11 +479,10 @@ pub unsafe extern "C" fn cass_future_error_message(
391
479
return ;
392
480
} ;
393
481
394
- future. with_waited_state ( |state : & mut CassFutureState | {
395
- let value = & state. value ;
396
- let msg = state
397
- . err_string
398
- . get_or_insert_with ( || match value. as_ref ( ) . unwrap ( ) {
482
+ future. with_waited_state ( |completed : & mut CassFutureCompleted | {
483
+ let msg = completed
484
+ . cached_err_string
485
+ . get_or_insert_with ( || match & completed. value {
399
486
Ok ( CassResultValue :: QueryError ( err) ) => err. msg ( ) ,
400
487
Err ( ( _, s) ) => s. msg ( ) ,
401
488
_ => "" . to_string ( ) ,
0 commit comments