@@ -8,6 +8,8 @@ use diesel::QueryResult;
8
8
use scoped_futures:: ScopedBoxFuture ;
9
9
use std:: borrow:: Cow ;
10
10
use std:: num:: NonZeroU32 ;
11
+ use std:: sync:: atomic:: { AtomicBool , Ordering } ;
12
+ use std:: sync:: Arc ;
11
13
12
14
use crate :: AsyncConnection ;
13
15
// TODO: refactor this to share more code with diesel
@@ -88,24 +90,31 @@ pub trait TransactionManager<Conn: AsyncConnection>: Send {
88
90
/// in an error state.
89
91
#[ doc( hidden) ]
90
92
fn is_broken_transaction_manager ( conn : & mut Conn ) -> bool {
91
- match Self :: transaction_manager_status_mut ( conn) . transaction_state ( ) {
92
- // all transactions are closed
93
- // so we don't consider this connection broken
94
- Ok ( ValidTransactionManagerStatus {
95
- in_transaction : None ,
96
- ..
97
- } ) => false ,
98
- // The transaction manager is in an error state
99
- // Therefore we consider this connection broken
100
- Err ( _) => true ,
101
- // The transaction manager contains a open transaction
102
- // we do consider this connection broken
103
- // if that transaction was not opened by `begin_test_transaction`
104
- Ok ( ValidTransactionManagerStatus {
105
- in_transaction : Some ( s) ,
106
- ..
107
- } ) => !s. test_transaction ,
108
- }
93
+ check_broken_transaction_state ( conn)
94
+ }
95
+ }
96
+
97
+ fn check_broken_transaction_state < Conn > ( conn : & mut Conn ) -> bool
98
+ where
99
+ Conn : AsyncConnection ,
100
+ {
101
+ match Conn :: TransactionManager :: transaction_manager_status_mut ( conn) . transaction_state ( ) {
102
+ // all transactions are closed
103
+ // so we don't consider this connection broken
104
+ Ok ( ValidTransactionManagerStatus {
105
+ in_transaction : None ,
106
+ ..
107
+ } ) => false ,
108
+ // The transaction manager is in an error state
109
+ // Therefore we consider this connection broken
110
+ Err ( _) => true ,
111
+ // The transaction manager contains a open transaction
112
+ // we do consider this connection broken
113
+ // if that transaction was not opened by `begin_test_transaction`
114
+ Ok ( ValidTransactionManagerStatus {
115
+ in_transaction : Some ( s) ,
116
+ ..
117
+ } ) => !s. test_transaction ,
109
118
}
110
119
}
111
120
@@ -114,147 +123,23 @@ pub trait TransactionManager<Conn: AsyncConnection>: Send {
114
123
#[ derive( Default , Debug ) ]
115
124
pub struct AnsiTransactionManager {
116
125
pub ( crate ) status : TransactionManagerStatus ,
126
+ // this boolean flag tracks whether we are currently in the process
127
+ // of executing any transaction releated SQL (BEGIN, COMMIT, ROLLBACK)
128
+ // if we ever encounter a situation where this flag is set
129
+ // while the connection is returned to a pool
130
+ // that means the connection is broken as someone dropped the
131
+ // transaction future while these commands where executed
132
+ // and we cannot know the connection state anymore
133
+ //
134
+ // We ensure this by wrapping all calls to `.await`
135
+ // into `AnsiTransactionManager::critical_transaction_block`
136
+ // below
137
+ //
138
+ // See https://github.com/weiznich/diesel_async/issues/198 for
139
+ // details
140
+ pub ( crate ) is_broken : Arc < AtomicBool > ,
117
141
}
118
142
119
- // /// Status of the transaction manager
120
- // #[derive(Debug)]
121
- // pub enum TransactionManagerStatus {
122
- // /// Valid status, the manager can run operations
123
- // Valid(ValidTransactionManagerStatus),
124
- // /// Error status, probably following a broken connection. The manager will no longer run operations
125
- // InError,
126
- // }
127
-
128
- // impl Default for TransactionManagerStatus {
129
- // fn default() -> Self {
130
- // TransactionManagerStatus::Valid(ValidTransactionManagerStatus::default())
131
- // }
132
- // }
133
-
134
- // impl TransactionManagerStatus {
135
- // /// Returns the transaction depth if the transaction manager's status is valid, or returns
136
- // /// [`Error::BrokenTransactionManager`] if the transaction manager is in error.
137
- // pub fn transaction_depth(&self) -> QueryResult<Option<NonZeroU32>> {
138
- // match self {
139
- // TransactionManagerStatus::Valid(valid_status) => Ok(valid_status.transaction_depth()),
140
- // TransactionManagerStatus::InError => Err(Error::BrokenTransactionManager),
141
- // }
142
- // }
143
-
144
- // /// If in transaction and transaction manager is not broken, registers that the
145
- // /// connection can not be used anymore until top-level transaction is rolled back
146
- // pub(crate) fn set_top_level_transaction_requires_rollback(&mut self) {
147
- // if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
148
- // in_transaction:
149
- // Some(InTransactionStatus {
150
- // top_level_transaction_requires_rollback,
151
- // ..
152
- // }),
153
- // }) = self
154
- // {
155
- // *top_level_transaction_requires_rollback = true;
156
- // }
157
- // }
158
-
159
- // /// Sets the transaction manager status to InError
160
- // ///
161
- // /// Subsequent attempts to use transaction-related features will result in a
162
- // /// [`Error::BrokenTransactionManager`] error
163
- // pub fn set_in_error(&mut self) {
164
- // *self = TransactionManagerStatus::InError
165
- // }
166
-
167
- // fn transaction_state(&mut self) -> QueryResult<&mut ValidTransactionManagerStatus> {
168
- // match self {
169
- // TransactionManagerStatus::Valid(valid_status) => Ok(valid_status),
170
- // TransactionManagerStatus::InError => Err(Error::BrokenTransactionManager),
171
- // }
172
- // }
173
-
174
- // pub(crate) fn set_test_transaction_flag(&mut self) {
175
- // if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
176
- // in_transaction: Some(s),
177
- // }) = self
178
- // {
179
- // s.test_transaction = true;
180
- // }
181
- // }
182
- // }
183
-
184
- // /// Valid transaction status for the manager. Can return the current transaction depth
185
- // #[allow(missing_copy_implementations)]
186
- // #[derive(Debug, Default)]
187
- // pub struct ValidTransactionManagerStatus {
188
- // in_transaction: Option<InTransactionStatus>,
189
- // }
190
-
191
- // #[allow(missing_copy_implementations)]
192
- // #[derive(Debug)]
193
- // struct InTransactionStatus {
194
- // transaction_depth: NonZeroU32,
195
- // top_level_transaction_requires_rollback: bool,
196
- // test_transaction: bool,
197
- // }
198
-
199
- // impl ValidTransactionManagerStatus {
200
- // /// Return the current transaction depth
201
- // ///
202
- // /// This value is `None` if no current transaction is running
203
- // /// otherwise the number of nested transactions is returned.
204
- // pub fn transaction_depth(&self) -> Option<NonZeroU32> {
205
- // self.in_transaction.as_ref().map(|it| it.transaction_depth)
206
- // }
207
-
208
- // /// Update the transaction depth by adding the value of the `transaction_depth_change` parameter if the `query` is
209
- // /// `Ok(())`
210
- // pub fn change_transaction_depth(
211
- // &mut self,
212
- // transaction_depth_change: TransactionDepthChange,
213
- // ) -> QueryResult<()> {
214
- // match (&mut self.in_transaction, transaction_depth_change) {
215
- // (Some(in_transaction), TransactionDepthChange::IncreaseDepth) => {
216
- // // Can be replaced with saturating_add directly on NonZeroU32 once
217
- // // <https://github.com/rust-lang/rust/issues/84186> is stable
218
- // in_transaction.transaction_depth =
219
- // NonZeroU32::new(in_transaction.transaction_depth.get().saturating_add(1))
220
- // .expect("nz + nz is always non-zero");
221
- // Ok(())
222
- // }
223
- // (Some(in_transaction), TransactionDepthChange::DecreaseDepth) => {
224
- // // This sets `transaction_depth` to `None` as soon as we reach zero
225
- // match NonZeroU32::new(in_transaction.transaction_depth.get() - 1) {
226
- // Some(depth) => in_transaction.transaction_depth = depth,
227
- // None => self.in_transaction = None,
228
- // }
229
- // Ok(())
230
- // }
231
- // (None, TransactionDepthChange::IncreaseDepth) => {
232
- // self.in_transaction = Some(InTransactionStatus {
233
- // transaction_depth: NonZeroU32::new(1).expect("1 is non-zero"),
234
- // top_level_transaction_requires_rollback: false,
235
- // test_transaction: false,
236
- // });
237
- // Ok(())
238
- // }
239
- // (None, TransactionDepthChange::DecreaseDepth) => {
240
- // // We screwed up something somewhere
241
- // // we cannot decrease the transaction count if
242
- // // we are not inside a transaction
243
- // Err(Error::NotInTransaction)
244
- // }
245
- // }
246
- // }
247
- // }
248
-
249
- // /// Represents a change to apply to the depth of a transaction
250
- // #[derive(Debug, Clone, Copy)]
251
- // pub enum TransactionDepthChange {
252
- // /// Increase the depth of the transaction (corresponds to `BEGIN` or `SAVEPOINT`)
253
- // IncreaseDepth,
254
- // /// Decreases the depth of the transaction (corresponds to `COMMIT`/`RELEASE SAVEPOINT` or `ROLLBACK`)
255
- // DecreaseDepth,
256
- // }
257
-
258
143
impl AnsiTransactionManager {
259
144
fn get_transaction_state < Conn > (
260
145
conn : & mut Conn ,
@@ -274,17 +159,38 @@ impl AnsiTransactionManager {
274
159
where
275
160
Conn : AsyncConnection < TransactionManager = Self > ,
276
161
{
162
+ let is_broken = conn. transaction_state ( ) . is_broken . clone ( ) ;
277
163
let state = Self :: get_transaction_state ( conn) ?;
278
164
match state. transaction_depth ( ) {
279
165
None => {
280
- conn. batch_execute ( sql) . await ?;
166
+ Self :: critical_transaction_block ( & is_broken , conn. batch_execute ( sql) ) . await ?;
281
167
Self :: get_transaction_state ( conn) ?
282
168
. change_transaction_depth ( TransactionDepthChange :: IncreaseDepth ) ?;
283
169
Ok ( ( ) )
284
170
}
285
171
Some ( _depth) => Err ( Error :: AlreadyInTransaction ) ,
286
172
}
287
173
}
174
+
175
+ // This function should be used to await any connection
176
+ // related future in our transaction manager implementation
177
+ //
178
+ // It takes care of tracking entering and exiting executing the future
179
+ // which in turn is used to determine if it's safe to still use
180
+ // the connection in the event of a canceled transaction execution
181
+ async fn critical_transaction_block < F > ( is_broken : & AtomicBool , f : F ) -> F :: Output
182
+ where
183
+ F : std:: future:: Future ,
184
+ {
185
+ let was_broken = is_broken. swap ( true , Ordering :: Relaxed ) ;
186
+ debug_assert ! (
187
+ !was_broken,
188
+ "Tried to execute a transaction SQL on transaction manager that was previously cancled"
189
+ ) ;
190
+ let res = f. await ;
191
+ is_broken. store ( false , Ordering :: Relaxed ) ;
192
+ res
193
+ }
288
194
}
289
195
290
196
#[ async_trait:: async_trait]
@@ -308,7 +214,11 @@ where
308
214
. unwrap_or ( NonZeroU32 :: new ( 1 ) . expect ( "It's not 0" ) ) ;
309
215
conn. instrumentation ( )
310
216
. on_connection_event ( InstrumentationEvent :: begin_transaction ( depth) ) ;
311
- conn. batch_execute ( & start_transaction_sql) . await ?;
217
+ Self :: critical_transaction_block (
218
+ & conn. transaction_state ( ) . is_broken . clone ( ) ,
219
+ conn. batch_execute ( & start_transaction_sql) ,
220
+ )
221
+ . await ?;
312
222
Self :: get_transaction_state ( conn) ?
313
223
. change_transaction_depth ( TransactionDepthChange :: IncreaseDepth ) ?;
314
224
@@ -344,7 +254,10 @@ where
344
254
conn. instrumentation ( )
345
255
. on_connection_event ( InstrumentationEvent :: rollback_transaction ( depth) ) ;
346
256
347
- match conn. batch_execute ( & rollback_sql) . await {
257
+ let is_broken = conn. transaction_state ( ) . is_broken . clone ( ) ;
258
+
259
+ match Self :: critical_transaction_block ( & is_broken, conn. batch_execute ( & rollback_sql) ) . await
260
+ {
348
261
Ok ( ( ) ) => {
349
262
match Self :: get_transaction_state ( conn) ?
350
263
. change_transaction_depth ( TransactionDepthChange :: DecreaseDepth )
@@ -429,7 +342,9 @@ where
429
342
conn. instrumentation ( )
430
343
. on_connection_event ( InstrumentationEvent :: commit_transaction ( depth) ) ;
431
344
432
- match conn. batch_execute ( & commit_sql) . await {
345
+ let is_broken = conn. transaction_state ( ) . is_broken . clone ( ) ;
346
+
347
+ match Self :: critical_transaction_block ( & is_broken, conn. batch_execute ( & commit_sql) ) . await {
433
348
Ok ( ( ) ) => {
434
349
match Self :: get_transaction_state ( conn) ?
435
350
. change_transaction_depth ( TransactionDepthChange :: DecreaseDepth )
@@ -453,7 +368,12 @@ where
453
368
..
454
369
} ) = conn. transaction_state ( ) . status
455
370
{
456
- match Self :: rollback_transaction ( conn) . await {
371
+ match Self :: critical_transaction_block (
372
+ & is_broken,
373
+ Self :: rollback_transaction ( conn) ,
374
+ )
375
+ . await
376
+ {
457
377
Ok ( ( ) ) => { }
458
378
Err ( rollback_error) => {
459
379
conn. transaction_state ( ) . status . set_in_error ( ) ;
@@ -472,4 +392,9 @@ where
472
392
fn transaction_manager_status_mut ( conn : & mut Conn ) -> & mut TransactionManagerStatus {
473
393
& mut conn. transaction_state ( ) . status
474
394
}
395
+
396
+ fn is_broken_transaction_manager ( conn : & mut Conn ) -> bool {
397
+ conn. transaction_state ( ) . is_broken . load ( Ordering :: Relaxed )
398
+ || check_broken_transaction_state ( conn)
399
+ }
475
400
}
0 commit comments