diff --git a/tokio-postgres/CHANGELOG.md b/tokio-postgres/CHANGELOG.md index 91e78b780..61f3e7117 100644 --- a/tokio-postgres/CHANGELOG.md +++ b/tokio-postgres/CHANGELOG.md @@ -1,5 +1,11 @@ # Change Log +## Unreleased + +### Added + +* Added `execute_prepared` functions. + ## v0.7.7 - 2022-08-21 ## Added diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index ad5aa2866..e8a5a2a26 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -6,7 +6,7 @@ use crate::connection::{Request, RequestMessages}; use crate::copy_out::CopyOutStream; #[cfg(feature = "runtime")] use crate::keepalive::KeepaliveConfig; -use crate::query::RowStream; +use crate::query::{Execute, RowStream}; use crate::simple_query::SimpleQueryStream; #[cfg(feature = "runtime")] use crate::tls::MakeTlsConnect; @@ -422,6 +422,26 @@ impl Client { query::execute(self.inner(), statement, params).await } + /// A version of [`execute_raw`] that does not borrow its arguments. + /// + /// This function is identical to [`execute_raw`] except that: + /// + /// 1. The returned future does not borrow the parameters or `self`. + /// 2. The type of the returned future does not depend on the parameters. + /// 3. If multiple such futures are being used concurrently, then they are executed on the server in the order + /// in which this function was called, regardless of the order in which the futures are polled. + /// 4. This function can only be used with prepared statements. + /// + /// [`execute_raw`]: #method.execute_raw + pub fn execute_prepared(&self, statement: &Statement, params: I) -> Execute + where + P: BorrowToSql, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + query::execute(self.inner(), statement.clone(), params) + } + /// Executes a `COPY FROM STDIN` statement, returning a sink used to write the copy data. /// /// PostgreSQL does not support parameters in `COPY` statements, so this method does not take any. The copy *must* diff --git a/tokio-postgres/src/generic_client.rs b/tokio-postgres/src/generic_client.rs index 50cff9712..d1040d42e 100644 --- a/tokio-postgres/src/generic_client.rs +++ b/tokio-postgres/src/generic_client.rs @@ -1,4 +1,4 @@ -use crate::query::RowStream; +use crate::query::{Execute, RowStream}; use crate::types::{BorrowToSql, ToSql, Type}; use crate::{Client, Error, Row, Statement, ToStatement, Transaction}; use async_trait::async_trait; @@ -25,6 +25,13 @@ pub trait GenericClient: private::Sealed { I: IntoIterator + Sync + Send, I::IntoIter: ExactSizeIterator; + /// Like `Client::execute_prepared`. + fn execute_prepared(&self, statement: &Statement, params: I) -> Execute + where + P: BorrowToSql, + I: IntoIterator, + I::IntoIter: ExactSizeIterator; + /// Like `Client::query`. async fn query(&self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result, Error> where @@ -97,6 +104,15 @@ impl GenericClient for Client { self.execute_raw(statement, params).await } + fn execute_prepared(&self, statement: &Statement, params: I) -> Execute + where + P: BorrowToSql, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + self.execute_prepared(statement, params) + } + async fn query(&self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result, Error> where T: ?Sized + ToStatement + Sync + Send, @@ -183,6 +199,15 @@ impl GenericClient for Transaction<'_> { self.execute_raw(statement, params).await } + fn execute_prepared(&self, statement: &Statement, params: I) -> Execute + where + P: BorrowToSql, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + self.execute_prepared(statement, params) + } + async fn query(&self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result, Error> where T: ?Sized + ToStatement + Sync + Send, diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index a9ecba4f1..fa2c9943d 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -129,6 +129,7 @@ use crate::error::DbError; pub use crate::error::Error; pub use crate::generic_client::GenericClient; pub use crate::portal::Portal; +pub use crate::query::Execute; pub use crate::query::RowStream; pub use crate::row::{Row, SimpleQueryRow}; pub use crate::simple_query::SimpleQueryStream; diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index 12176353b..f468fc660 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -10,6 +10,7 @@ use pin_project_lite::pin_project; use postgres_protocol::message::backend::{CommandCompleteBody, Message}; use postgres_protocol::message::frontend; use std::fmt; +use std::future::Future; use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; @@ -91,11 +92,68 @@ pub fn extract_row_affected(body: &CommandCompleteBody) -> Result { Ok(rows) } -pub async fn execute( - client: &InnerClient, - statement: Statement, - params: I, -) -> Result +/// A future that completes with the result of an `execute_prepared` function call. +// Once https://github.com/rust-lang/rust/issues/63063 is stable, we might want to replace this by +// type Execute = impl Future> + Sync + Send + Unpin + 'static +// and restore the simpler procedural logic instead of writing the state machine ourselves. +pub struct Execute { + responses: Result>, + bound: bool, + rows: u64, +} + +impl Future for Execute { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + use Poll::*; + let slf = self.get_mut(); + let responses = match &mut slf.responses { + Ok(r) => r, + Err(e) => match e.take() { + Some(e) => return Ready(Err(e)), + _ => panic!("Execute future polled after it has already completed"), + }, + }; + loop { + let message = match responses.poll_next(cx) { + Ready(Ok(msg)) => msg, + Ready(Err(e)) => { + slf.responses = Err(None); + return Ready(Err(e)); + } + Pending => return Pending, + }; + if !slf.bound { + match message { + Message::BindComplete => slf.bound = true, + _ => { + slf.responses = Err(None); + return Ready(Err(Error::unexpected_message())); + } + } + } else { + match message { + Message::DataRow(_) => {} + Message::CommandComplete(body) => { + slf.rows = extract_row_affected(&body)?; + } + Message::EmptyQueryResponse => slf.rows = 0, + Message::ReadyForQuery(_) => { + slf.responses = Err(None); + return Ready(Ok(slf.rows)); + } + _ => { + slf.responses = Err(None); + return Ready(Err(Error::unexpected_message())); + } + } + } + } + } +} + +pub fn execute(client: &InnerClient, statement: Statement, params: I) -> Execute where P: BorrowToSql, I: IntoIterator, @@ -108,23 +166,19 @@ where statement.name(), BorrowToSqlParamsDebug(params.as_slice()), ); - encode(client, &statement, params)? + encode(client, &statement, params) } else { - encode(client, &statement, params)? + encode(client, &statement, params) }; - let mut responses = start(client, buf).await?; - - let mut rows = 0; - loop { - match responses.next().await? { - Message::DataRow(_) => {} - Message::CommandComplete(body) => { - rows = extract_row_affected(&body)?; - } - Message::EmptyQueryResponse => rows = 0, - Message::ReadyForQuery(_) => return Ok(rows), - _ => return Err(Error::unexpected_message()), - } + + let responses = buf + .and_then(|buf| client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))) + .map_err(Some); + + Execute { + responses, + bound: false, + rows: 0, } } diff --git a/tokio-postgres/src/transaction.rs b/tokio-postgres/src/transaction.rs index 96a324652..54628216d 100644 --- a/tokio-postgres/src/transaction.rs +++ b/tokio-postgres/src/transaction.rs @@ -1,7 +1,7 @@ use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::copy_out::CopyOutStream; -use crate::query::RowStream; +use crate::query::{Execute, RowStream}; #[cfg(feature = "runtime")] use crate::tls::MakeTlsConnect; use crate::tls::TlsConnect; @@ -172,6 +172,16 @@ impl<'a> Transaction<'a> { self.client.execute_raw(statement, params).await } + /// Like `Client::execute_prepared`. + pub fn execute_prepared(&self, statement: &Statement, params: I) -> Execute + where + P: BorrowToSql, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + self.client.execute_prepared(statement, params) + } + /// Binds a statement to a set of parameters, creating a `Portal` which can be incrementally queried. /// /// Portals only last for the duration of the transaction in which they are created, and can only be used on the diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 0ab4a7bab..4d69d1ad2 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -6,7 +6,9 @@ use futures_util::{ future, join, pin_mut, stream, try_join, Future, FutureExt, SinkExt, StreamExt, TryStreamExt, }; use pin_project_lite::pin_project; +use postgres_types::ToSql; use std::fmt::Write; +use std::future::poll_fn; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; @@ -951,3 +953,90 @@ async fn deferred_constraint() { .await .unwrap_err(); } + +#[tokio::test] +async fn execute_prepared() { + let client = connect("user=postgres").await; + + client + .batch_execute( + " + CREATE TEMPORARY TABLE foo ( + id INT GENERATED ALWAYS AS IDENTITY, + a INT, + b INT + ); + ", + ) + .await + .unwrap(); + + let statement1 = client + .prepare("INSERT INTO foo (a, b) VALUES ($1, $2)") + .await + .unwrap(); + let statement2 = client + .prepare("INSERT INTO foo (a) VALUES ($1)") + .await + .unwrap(); + + let future1 = client.execute_prepared(&statement1, [&10 as &dyn ToSql, &11]); + let future2 = client.execute_prepared(&statement2, [&12 as &dyn ToSql]); + + fn same_type(_: &T, _: &T) {} + same_type(&future1, &future2); + + future2.await.unwrap(); + future1.await.unwrap(); + + let mut rows: Vec<(i32, i32, Option)> = client + .query("SELECT * FROM foo", &[]) + .await + .unwrap() + .into_iter() + .map(|row| (row.get("id"), row.get("a"), row.get("b"))) + .collect(); + + rows.sort_by_key(|row| row.0); + + assert_eq!(rows, vec![(1, 10, Some(11)), (2, 12, None)]); +} + +#[tokio::test] +async fn execute_error() { + let client = connect("user=postgres").await; + + client + .batch_execute( + " + CREATE TEMPORARY TABLE foo ( + id INT GENERATED ALWAYS AS IDENTITY, + a INT + ); + ", + ) + .await + .unwrap(); + + let statement = client + .prepare("INSERT INTO foo (a) VALUES ($1)") + .await + .unwrap(); + + let future = client.execute_prepared(&statement, [&"" as &dyn ToSql]); + + future.await.unwrap_err(); +} + +#[tokio::test] +#[should_panic] +async fn execute_poll_after_completion() { + let client = connect("user=postgres").await; + + let statement = client.prepare("SELECT 1").await.unwrap(); + + let mut future = client.execute_prepared::<&dyn ToSql, _>(&statement, []); + + poll_fn(|cx| Pin::new(&mut future).poll(cx)).await.unwrap(); + let _ = future.await; +}