Skip to content

Commit 5cbf7c1

Browse files
committed
Share the statement cache with diesel
This commit refactors diesel-async to use the same statement cache implementation as diesel. That brings in all the optimisations done to the diesel statement cache.
1 parent e3beac6 commit 5cbf7c1

File tree

14 files changed

+294
-203
lines changed

14 files changed

+294
-203
lines changed

Cargo.toml

+23-5
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@ description = "An async extension for Diesel the safe, extensible ORM and Query
1313
rust-version = "1.78.0"
1414

1515
[dependencies]
16-
diesel = { version = "~2.2.0", default-features = false, features = [
17-
"i-implement-a-third-party-backend-and-opt-into-breaking-changes",
18-
] }
1916
async-trait = "0.1.66"
2017
futures-channel = { version = "0.3.17", default-features = false, features = [
2118
"std",
@@ -39,14 +36,35 @@ deadpool = { version = "0.12", optional = true, default-features = false, featur
3936
mobc = { version = ">=0.7,<0.10", optional = true }
4037
scoped-futures = { version = "0.1", features = ["std"] }
4138

39+
[dependencies.diesel]
40+
version = "~2.2.0"
41+
default-features = false
42+
features = [
43+
"i-implement-a-third-party-backend-and-opt-into-breaking-changes",
44+
]
45+
git = "https://github.com/weiznich/diesel"
46+
rev = "d3c67851"
47+
4248
[dev-dependencies]
4349
tokio = { version = "1.12.0", features = ["rt", "macros", "rt-multi-thread"] }
4450
cfg-if = "1"
4551
chrono = "0.4"
46-
diesel = { version = "2.2.0", default-features = false, features = ["chrono"] }
47-
diesel_migrations = "2.2.0"
4852
assert_matches = "1.0.1"
4953

54+
[dev-dependencies.diesel]
55+
version = "~2.2.0"
56+
default-features = false
57+
features = [
58+
"chrono"
59+
]
60+
git = "https://github.com/weiznich/diesel"
61+
rev = "d3c67851"
62+
63+
[dev-dependencies.diesel_migrations]
64+
version = "2.2.0"
65+
git = "https://github.com/weiznich/diesel"
66+
rev = "d3c67851"
67+
5068
[features]
5169
default = []
5270
mysql = [

examples/postgres/pooled-with-rustls/Cargo.toml

+7-1
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,17 @@ edition = "2021"
66
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
77

88
[dependencies]
9-
diesel = { version = "2.2.0", default-features = false, features = ["postgres"] }
109
diesel-async = { version = "0.5.0", path = "../../../", features = ["bb8", "postgres"] }
1110
futures-util = "0.3.21"
1211
rustls = "0.23.8"
1312
rustls-native-certs = "0.7.1"
1413
tokio = { version = "1.2.0", default-features = false, features = ["macros", "rt-multi-thread"] }
1514
tokio-postgres = "0.7.7"
1615
tokio-postgres-rustls = "0.12.0"
16+
17+
18+
[dependencies.diesel]
19+
version = "2.2.0"
20+
default-features = false
21+
git = "https://github.com/weiznich/diesel"
22+
rev = "d3c67851"

examples/postgres/run-pending-migrations-with-rustls/Cargo.toml

+12-2
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,22 @@ edition = "2021"
66
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
77

88
[dependencies]
9-
diesel = { version = "2.2.0", default-features = false, features = ["postgres"] }
109
diesel-async = { version = "0.5.0", path = "../../../", features = ["bb8", "postgres", "async-connection-wrapper"] }
11-
diesel_migrations = "2.2.0"
1210
futures-util = "0.3.21"
1311
rustls = "0.23.10"
1412
rustls-native-certs = "0.7.1"
1513
tokio = { version = "1.2.0", default-features = false, features = ["macros", "rt-multi-thread"] }
1614
tokio-postgres = "0.7.7"
1715
tokio-postgres-rustls = "0.12.0"
16+
17+
[dependencies.diesel]
18+
version = "2.2.0"
19+
default-features = false
20+
git = "https://github.com/weiznich/diesel"
21+
rev = "d3c67851"
22+
23+
[dependencies.diesel_migrations]
24+
version = "2.2.0"
25+
git = "https://github.com/weiznich/diesel"
26+
rev = "d3c67851"
27+

examples/sync-wrapper/Cargo.toml

+12-2
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,22 @@ edition = "2021"
66
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
77

88
[dependencies]
9-
diesel = { version = "2.2.0", default-features = false, features = ["returning_clauses_for_sqlite_3_35"] }
109
diesel-async = { version = "0.5.0", path = "../../", features = ["sync-connection-wrapper", "async-connection-wrapper"] }
11-
diesel_migrations = "2.2.0"
1210
futures-util = "0.3.21"
1311
tokio = { version = "1.2.0", default-features = false, features = ["macros", "rt-multi-thread"] }
1412

13+
[dependencies.diesel]
14+
version = "2.2.0"
15+
default-features = false
16+
features = ["returning_clauses_for_sqlite_3_35"]
17+
git = "https://github.com/weiznich/diesel"
18+
rev = "d3c67851"
19+
20+
[dependencies.diesel_migrations]
21+
version = "2.2.0"
22+
git = "https://github.com/weiznich/diesel"
23+
rev = "d3c67851"
24+
1525
[features]
1626
default = ["sqlite"]
1727
sqlite = ["diesel-async/sqlite"]

src/async_connection_wrapper.rs

+13-7
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ pub type AsyncConnectionWrapper<C, B = self::implementation::Tokio> =
100100
pub use self::implementation::AsyncConnectionWrapper;
101101

102102
mod implementation {
103-
use diesel::connection::{Instrumentation, SimpleConnection};
103+
use diesel::connection::{CacheSize, Instrumentation, SimpleConnection};
104104
use std::ops::{Deref, DerefMut};
105105

106106
use super::*;
@@ -187,20 +187,26 @@ mod implementation {
187187
fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
188188
self.inner.set_instrumentation(instrumentation);
189189
}
190+
191+
fn set_prepared_statement_cache_size(&mut self, size: CacheSize) {
192+
self.inner.set_prepared_statement_cache_size(size)
193+
}
190194
}
191195

192196
impl<C, B> diesel::connection::LoadConnection for AsyncConnectionWrapper<C, B>
193197
where
194198
C: crate::AsyncConnection,
195199
B: BlockOn + Send,
196200
{
197-
type Cursor<'conn, 'query> = AsyncCursorWrapper<'conn, C::Stream<'conn, 'query>, B>
198-
where
199-
Self: 'conn;
201+
type Cursor<'conn, 'query>
202+
= AsyncCursorWrapper<'conn, C::Stream<'conn, 'query>, B>
203+
where
204+
Self: 'conn;
200205

201-
type Row<'conn, 'query> = C::Row<'conn, 'query>
202-
where
203-
Self: 'conn;
206+
type Row<'conn, 'query>
207+
= C::Row<'conn, 'query>
208+
where
209+
Self: 'conn;
204210

205211
fn load<'conn, 'query, T>(
206212
&'conn mut self,

src/lib.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
)]
7575

7676
use diesel::backend::Backend;
77-
use diesel::connection::Instrumentation;
77+
use diesel::connection::{CacheSize, Instrumentation};
7878
use diesel::query_builder::{AsQuery, QueryFragment, QueryId};
7979
use diesel::result::Error;
8080
use diesel::row::Row;
@@ -354,4 +354,7 @@ pub trait AsyncConnection: SimpleAsyncConnection + Sized + Send {
354354

355355
/// Set a specific [`Instrumentation`] implementation for this connection
356356
fn set_instrumentation(&mut self, instrumentation: impl Instrumentation);
357+
358+
/// Set the prepared statement cache size to [`CacheSize`] for this connection
359+
fn set_prepared_statement_cache_size(&mut self, size: CacheSize);
357360
}

src/mysql/mod.rs

+62-53
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
use crate::stmt_cache::{PrepareCallback, StmtCache};
1+
use crate::stmt_cache::{CallbackHelper, QueryFragmentHelper};
22
use crate::{AnsiTransactionManager, AsyncConnection, SimpleAsyncConnection};
3-
use diesel::connection::statement_cache::{MaybeCached, StatementCacheKey};
4-
use diesel::connection::Instrumentation;
5-
use diesel::connection::InstrumentationEvent;
3+
use diesel::connection::statement_cache::{
4+
MaybeCached, QueryFragmentForCachedStatement, StatementCache,
5+
};
66
use diesel::connection::StrQueryHelper;
7+
use diesel::connection::{CacheSize, Instrumentation};
8+
use diesel::connection::{DynInstrumentation, InstrumentationEvent};
79
use diesel::mysql::{Mysql, MysqlQueryBuilder, MysqlType};
810
use diesel::query_builder::QueryBuilder;
911
use diesel::query_builder::{bind_collector::RawBytesBindCollector, QueryFragment, QueryId};
@@ -27,9 +29,9 @@ use self::serialize::ToSqlHelper;
2729
/// `mysql://[user[:password]@]host/database_name`
2830
pub struct AsyncMysqlConnection {
2931
conn: mysql_async::Conn,
30-
stmt_cache: StmtCache<Mysql, Statement>,
32+
stmt_cache: StatementCache<Mysql, Statement>,
3133
transaction_manager: AnsiTransactionManager,
32-
instrumentation: std::sync::Mutex<Option<Box<dyn Instrumentation>>>,
34+
instrumentation: DynInstrumentation,
3335
}
3436

3537
#[async_trait::async_trait]
@@ -72,7 +74,7 @@ impl AsyncConnection for AsyncMysqlConnection {
7274
type TransactionManager = AnsiTransactionManager;
7375

7476
async fn establish(database_url: &str) -> diesel::ConnectionResult<Self> {
75-
let mut instrumentation = diesel::connection::get_default_instrumentation();
77+
let mut instrumentation = DynInstrumentation::default_instrumentation();
7678
instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
7779
database_url,
7880
));
@@ -82,7 +84,7 @@ impl AsyncConnection for AsyncMysqlConnection {
8284
r.as_ref().err(),
8385
));
8486
let mut conn = r?;
85-
conn.instrumentation = std::sync::Mutex::new(instrumentation);
87+
conn.instrumentation = instrumentation;
8688
Ok(conn)
8789
}
8890

@@ -177,16 +179,15 @@ impl AsyncConnection for AsyncMysqlConnection {
177179
}
178180

179181
fn instrumentation(&mut self) -> &mut dyn Instrumentation {
180-
self.instrumentation
181-
.get_mut()
182-
.unwrap_or_else(|p| p.into_inner())
182+
&mut *self.instrumentation
183183
}
184184

185185
fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
186-
*self
187-
.instrumentation
188-
.get_mut()
189-
.unwrap_or_else(|p| p.into_inner()) = Some(Box::new(instrumentation));
186+
self.instrumentation = instrumentation.into();
187+
}
188+
189+
fn set_prepared_statement_cache_size(&mut self, size: CacheSize) {
190+
self.stmt_cache.set_cache_size(size);
190191
}
191192
}
192193

@@ -207,17 +208,24 @@ fn update_transaction_manager_status<T>(
207208
query_result
208209
}
209210

210-
#[async_trait::async_trait]
211-
impl PrepareCallback<Statement, MysqlType> for &'_ mut mysql_async::Conn {
212-
async fn prepare(
213-
self,
214-
sql: &str,
215-
_metadata: &[MysqlType],
216-
_is_for_cache: diesel::connection::statement_cache::PrepareForCache,
217-
) -> QueryResult<(Statement, Self)> {
218-
let s = self.prep(sql).await.map_err(ErrorHelper)?;
219-
Ok((s, self))
220-
}
211+
fn prepare_statement_helper<'a, 'b>(
212+
conn: &'a mut mysql_async::Conn,
213+
sql: &'b str,
214+
_is_for_cache: diesel::connection::statement_cache::PrepareForCache,
215+
_metadata: &[MysqlType],
216+
) -> CallbackHelper<impl Future<Output = QueryResult<(Statement, &'a mut mysql_async::Conn)>> + Send>
217+
{
218+
// ideally we wouldn't clone the SQL string here
219+
// but as we usually cache statements anyway
220+
// this is a fixed one time const
221+
//
222+
// The probleme with not cloning it is that we then cannot express
223+
// the right result lifetime anymore (at least not easily)
224+
let sql = sql.to_owned();
225+
CallbackHelper(async move {
226+
let s = conn.prep(sql).await.map_err(ErrorHelper)?;
227+
Ok((s, conn))
228+
})
221229
}
222230

223231
impl AsyncMysqlConnection {
@@ -229,11 +237,9 @@ impl AsyncMysqlConnection {
229237
use crate::run_query_dsl::RunQueryDsl;
230238
let mut conn = AsyncMysqlConnection {
231239
conn,
232-
stmt_cache: StmtCache::new(),
240+
stmt_cache: StatementCache::new(),
233241
transaction_manager: AnsiTransactionManager::default(),
234-
instrumentation: std::sync::Mutex::new(
235-
diesel::connection::get_default_instrumentation(),
236-
),
242+
instrumentation: DynInstrumentation::default_instrumentation(),
237243
};
238244

239245
for stmt in CONNECTION_SETUP_QUERIES {
@@ -286,36 +292,29 @@ impl AsyncMysqlConnection {
286292
} = bind_collector?;
287293
let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
288294
let sql = sql?;
295+
let helper = QueryFragmentHelper {
296+
sql,
297+
safe_to_cache: is_safe_to_cache_prepared,
298+
};
289299
let inner = async {
290-
let cache_key = if let Some(query_id) = query_id {
291-
StatementCacheKey::Type(query_id)
292-
} else {
293-
StatementCacheKey::Sql {
294-
sql: sql.clone(),
295-
bind_types: metadata.clone(),
296-
}
297-
};
298-
299300
let (stmt, conn) = stmt_cache
300-
.cached_prepared_statement(
301-
cache_key,
302-
sql.clone(),
303-
is_safe_to_cache_prepared,
301+
.cached_statement_non_generic(
302+
query_id,
303+
&helper,
304+
&Mysql,
304305
&metadata,
305306
conn,
306-
instrumentation,
307+
prepare_statement_helper,
308+
&mut **instrumentation,
307309
)
308310
.await?;
309311
callback(conn, stmt, ToSqlHelper { metadata, binds }).await
310312
};
311313
let r = update_transaction_manager_status(inner.await, transaction_manager);
312-
instrumentation
313-
.get_mut()
314-
.unwrap_or_else(|p| p.into_inner())
315-
.on_connection_event(InstrumentationEvent::finish_query(
316-
&StrQueryHelper::new(&sql),
317-
r.as_ref().err(),
318-
));
314+
instrumentation.on_connection_event(InstrumentationEvent::finish_query(
315+
&StrQueryHelper::new(&helper.sql),
316+
r.as_ref().err(),
317+
));
319318
r
320319
}
321320
.boxed()
@@ -370,9 +369,9 @@ impl AsyncMysqlConnection {
370369

371370
Ok(AsyncMysqlConnection {
372371
conn,
373-
stmt_cache: StmtCache::new(),
372+
stmt_cache: StatementCache::new(),
374373
transaction_manager: AnsiTransactionManager::default(),
375-
instrumentation: std::sync::Mutex::new(None),
374+
instrumentation: DynInstrumentation::none(),
376375
})
377376
}
378377
}
@@ -427,3 +426,13 @@ mod tests {
427426
}
428427
}
429428
}
429+
430+
impl QueryFragmentForCachedStatement<Mysql> for QueryFragmentHelper {
431+
fn construct_sql(&self, _backend: &Mysql) -> QueryResult<String> {
432+
Ok(self.sql.clone())
433+
}
434+
435+
fn is_safe_to_cache_prepared(&self, _backend: &Mysql) -> QueryResult<bool> {
436+
Ok(self.safe_to_cache)
437+
}
438+
}

src/mysql/row.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,11 @@ impl RowSealed for MysqlRow {}
3737

3838
impl<'a> diesel::row::Row<'a, Mysql> for MysqlRow {
3939
type InnerPartialRow = Self;
40-
type Field<'b> = MysqlField<'b> where Self: 'b, 'a: 'b;
40+
type Field<'b>
41+
= MysqlField<'b>
42+
where
43+
Self: 'b,
44+
'a: 'b;
4145

4246
fn field_count(&self) -> usize {
4347
self.0.columns_ref().len()

0 commit comments

Comments
 (0)