Skip to content

tokio-postgres: add execute_prepared function #991

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tokio-postgres/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Change Log

## Unreleased

### Added

* Added `execute_prepared` functions.

## v0.7.7 - 2022-08-21

## Added
Expand Down
22 changes: 21 additions & 1 deletion tokio-postgres/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::connection::{Request, RequestMessages};
use crate::copy_out::CopyOutStream;
#[cfg(feature = "runtime")]
use crate::keepalive::KeepaliveConfig;
use crate::query::RowStream;
use crate::query::{Execute, RowStream};
use crate::simple_query::SimpleQueryStream;
#[cfg(feature = "runtime")]
use crate::tls::MakeTlsConnect;
Expand Down Expand Up @@ -422,6 +422,26 @@ impl Client {
query::execute(self.inner(), statement, params).await
}

/// A version of [`execute_raw`] that does not borrow its arguments.
///
/// This function is identical to [`execute_raw`] except that:
///
/// 1. The returned future does not borrow the parameters or `self`.
/// 2. The type of the returned future does not depend on the parameters.
/// 3. If multiple such futures are being used concurrently, then they are executed on the server in the order
/// in which this function was called, regardless of the order in which the futures are polled.
/// 4. This function can only be used with prepared statements.
///
/// [`execute_raw`]: #method.execute_raw
pub fn execute_prepared<P, I>(&self, statement: &Statement, params: I) -> Execute
where
P: BorrowToSql,
I: IntoIterator<Item = P>,
I::IntoIter: ExactSizeIterator,
{
query::execute(self.inner(), statement.clone(), params)
}

/// Executes a `COPY FROM STDIN` statement, returning a sink used to write the copy data.
///
/// PostgreSQL does not support parameters in `COPY` statements, so this method does not take any. The copy *must*
Expand Down
27 changes: 26 additions & 1 deletion tokio-postgres/src/generic_client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::query::RowStream;
use crate::query::{Execute, RowStream};
use crate::types::{BorrowToSql, ToSql, Type};
use crate::{Client, Error, Row, Statement, ToStatement, Transaction};
use async_trait::async_trait;
Expand All @@ -25,6 +25,13 @@ pub trait GenericClient: private::Sealed {
I: IntoIterator<Item = P> + Sync + Send,
I::IntoIter: ExactSizeIterator;

/// Like `Client::execute_prepared`.
fn execute_prepared<P, I>(&self, statement: &Statement, params: I) -> Execute
where
P: BorrowToSql,
I: IntoIterator<Item = P>,
I::IntoIter: ExactSizeIterator;

/// Like `Client::query`.
async fn query<T>(&self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result<Vec<Row>, Error>
where
Expand Down Expand Up @@ -97,6 +104,15 @@ impl GenericClient for Client {
self.execute_raw(statement, params).await
}

fn execute_prepared<P, I>(&self, statement: &Statement, params: I) -> Execute
where
P: BorrowToSql,
I: IntoIterator<Item = P>,
I::IntoIter: ExactSizeIterator,
{
self.execute_prepared(statement, params)
}

async fn query<T>(&self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result<Vec<Row>, Error>
where
T: ?Sized + ToStatement + Sync + Send,
Expand Down Expand Up @@ -183,6 +199,15 @@ impl GenericClient for Transaction<'_> {
self.execute_raw(statement, params).await
}

fn execute_prepared<P, I>(&self, statement: &Statement, params: I) -> Execute
where
P: BorrowToSql,
I: IntoIterator<Item = P>,
I::IntoIter: ExactSizeIterator,
{
self.execute_prepared(statement, params)
}

async fn query<T>(&self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result<Vec<Row>, Error>
where
T: ?Sized + ToStatement + Sync + Send,
Expand Down
1 change: 1 addition & 0 deletions tokio-postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ use crate::error::DbError;
pub use crate::error::Error;
pub use crate::generic_client::GenericClient;
pub use crate::portal::Portal;
pub use crate::query::Execute;
pub use crate::query::RowStream;
pub use crate::row::{Row, SimpleQueryRow};
pub use crate::simple_query::SimpleQueryStream;
Expand Down
94 changes: 74 additions & 20 deletions tokio-postgres/src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use pin_project_lite::pin_project;
use postgres_protocol::message::backend::{CommandCompleteBody, Message};
use postgres_protocol::message::frontend;
use std::fmt;
use std::future::Future;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::task::{Context, Poll};
Expand Down Expand Up @@ -91,11 +92,68 @@ pub fn extract_row_affected(body: &CommandCompleteBody) -> Result<u64, Error> {
Ok(rows)
}

pub async fn execute<P, I>(
client: &InnerClient,
statement: Statement,
params: I,
) -> Result<u64, Error>
/// A future that completes with the result of an `execute_prepared` function call.
// Once https://github.com/rust-lang/rust/issues/63063 is stable, we might want to replace this by
// type Execute = impl Future<Output=Result<u64, Error>> + Sync + Send + Unpin + 'static
// and restore the simpler procedural logic instead of writing the state machine ourselves.
pub struct Execute {
responses: Result<Responses, Option<Error>>,
bound: bool,
rows: u64,
}

impl Future for Execute {
type Output = Result<u64, Error>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
use Poll::*;
let slf = self.get_mut();
let responses = match &mut slf.responses {
Ok(r) => r,
Err(e) => match e.take() {
Some(e) => return Ready(Err(e)),
_ => panic!("Execute future polled after it has already completed"),
},
};
loop {
let message = match responses.poll_next(cx) {
Ready(Ok(msg)) => msg,
Ready(Err(e)) => {
slf.responses = Err(None);
return Ready(Err(e));
}
Pending => return Pending,
};
if !slf.bound {
match message {
Message::BindComplete => slf.bound = true,
_ => {
slf.responses = Err(None);
return Ready(Err(Error::unexpected_message()));
}
}
} else {
match message {
Message::DataRow(_) => {}
Message::CommandComplete(body) => {
slf.rows = extract_row_affected(&body)?;
}
Message::EmptyQueryResponse => slf.rows = 0,
Message::ReadyForQuery(_) => {
slf.responses = Err(None);
return Ready(Ok(slf.rows));
}
_ => {
slf.responses = Err(None);
return Ready(Err(Error::unexpected_message()));
}
}
}
}
}
}

pub fn execute<P, I>(client: &InnerClient, statement: Statement, params: I) -> Execute
where
P: BorrowToSql,
I: IntoIterator<Item = P>,
Expand All @@ -108,23 +166,19 @@ where
statement.name(),
BorrowToSqlParamsDebug(params.as_slice()),
);
encode(client, &statement, params)?
encode(client, &statement, params)
} else {
encode(client, &statement, params)?
encode(client, &statement, params)
};
let mut responses = start(client, buf).await?;

let mut rows = 0;
loop {
match responses.next().await? {
Message::DataRow(_) => {}
Message::CommandComplete(body) => {
rows = extract_row_affected(&body)?;
}
Message::EmptyQueryResponse => rows = 0,
Message::ReadyForQuery(_) => return Ok(rows),
_ => return Err(Error::unexpected_message()),
}

let responses = buf
.and_then(|buf| client.send(RequestMessages::Single(FrontendMessage::Raw(buf))))
.map_err(Some);

Execute {
responses,
bound: false,
rows: 0,
}
}

Expand Down
12 changes: 11 additions & 1 deletion tokio-postgres/src/transaction.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::copy_out::CopyOutStream;
use crate::query::RowStream;
use crate::query::{Execute, RowStream};
#[cfg(feature = "runtime")]
use crate::tls::MakeTlsConnect;
use crate::tls::TlsConnect;
Expand Down Expand Up @@ -172,6 +172,16 @@ impl<'a> Transaction<'a> {
self.client.execute_raw(statement, params).await
}

/// Like `Client::execute_prepared`.
pub fn execute_prepared<P, I>(&self, statement: &Statement, params: I) -> Execute
where
P: BorrowToSql,
I: IntoIterator<Item = P>,
I::IntoIter: ExactSizeIterator,
{
self.client.execute_prepared(statement, params)
}

/// Binds a statement to a set of parameters, creating a `Portal` which can be incrementally queried.
///
/// Portals only last for the duration of the transaction in which they are created, and can only be used on the
Expand Down
89 changes: 89 additions & 0 deletions tokio-postgres/tests/test/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ use futures_util::{
future, join, pin_mut, stream, try_join, Future, FutureExt, SinkExt, StreamExt, TryStreamExt,
};
use pin_project_lite::pin_project;
use postgres_types::ToSql;
use std::fmt::Write;
use std::future::poll_fn;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
Expand Down Expand Up @@ -951,3 +953,90 @@ async fn deferred_constraint() {
.await
.unwrap_err();
}

#[tokio::test]
async fn execute_prepared() {
let client = connect("user=postgres").await;

client
.batch_execute(
"
CREATE TEMPORARY TABLE foo (
id INT GENERATED ALWAYS AS IDENTITY,
a INT,
b INT
);
",
)
.await
.unwrap();

let statement1 = client
.prepare("INSERT INTO foo (a, b) VALUES ($1, $2)")
.await
.unwrap();
let statement2 = client
.prepare("INSERT INTO foo (a) VALUES ($1)")
.await
.unwrap();

let future1 = client.execute_prepared(&statement1, [&10 as &dyn ToSql, &11]);
let future2 = client.execute_prepared(&statement2, [&12 as &dyn ToSql]);

fn same_type<T>(_: &T, _: &T) {}
same_type(&future1, &future2);

future2.await.unwrap();
future1.await.unwrap();

let mut rows: Vec<(i32, i32, Option<i32>)> = client
.query("SELECT * FROM foo", &[])
.await
.unwrap()
.into_iter()
.map(|row| (row.get("id"), row.get("a"), row.get("b")))
.collect();

rows.sort_by_key(|row| row.0);

assert_eq!(rows, vec![(1, 10, Some(11)), (2, 12, None)]);
}

#[tokio::test]
async fn execute_error() {
let client = connect("user=postgres").await;

client
.batch_execute(
"
CREATE TEMPORARY TABLE foo (
id INT GENERATED ALWAYS AS IDENTITY,
a INT
);
",
)
.await
.unwrap();

let statement = client
.prepare("INSERT INTO foo (a) VALUES ($1)")
.await
.unwrap();

let future = client.execute_prepared(&statement, [&"" as &dyn ToSql]);

future.await.unwrap_err();
}

#[tokio::test]
#[should_panic]
async fn execute_poll_after_completion() {
let client = connect("user=postgres").await;

let statement = client.prepare("SELECT 1").await.unwrap();

let mut future = client.execute_prepared::<&dyn ToSql, _>(&statement, []);

poll_fn(|cx| Pin::new(&mut future).poll(cx)).await.unwrap();
let _ = future.await;
}