Skip to content

Commit f388f5e

Browse files
committed
Implement PostgresConnection.query and .execute with metadata
1 parent 5d817be commit f388f5e

File tree

9 files changed

+569
-25
lines changed

9 files changed

+569
-25
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
.DS_Store
22
/.build
3+
/.index-build
34
/Packages
45
/*.xcodeproj
56
DerivedData

Sources/PostgresNIO/Connection/PostgresConnection.swift

+96
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,56 @@ extension PostgresConnection {
438438
}
439439
}
440440

441+
/// Run a query on the Postgres server the connection is connected to, returning the metadata.
442+
///
443+
/// - Parameters:
444+
/// - query: The ``PostgresQuery`` to run
445+
/// - logger: The `Logger` to log into for the query
446+
/// - file: The file, the query was started in. Used for better error reporting.
447+
/// - line: The line, the query was started in. Used for better error reporting.
448+
/// - consume: The closure to consume the ``PostgresRowSequence``.
449+
/// DO NOT escape the row-sequence out of the closure.
450+
/// - Returns: The result of the `consume` closure as well as the query metadata.
451+
public func query<Result>(
452+
_ query: PostgresQuery,
453+
logger: Logger,
454+
file: String = #fileID,
455+
line: Int = #line,
456+
_ consume: (PostgresRowSequence) async throws -> Result
457+
) async throws -> (Result, PostgresQueryMetadata) {
458+
var logger = logger
459+
logger[postgresMetadataKey: .connectionID] = "\(self.id)"
460+
461+
guard query.binds.count <= Int(UInt16.max) else {
462+
throw PSQLError(code: .tooManyParameters, query: query, file: file, line: line)
463+
}
464+
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
465+
let context = ExtendedQueryContext(
466+
query: query,
467+
logger: logger,
468+
promise: promise
469+
)
470+
471+
self.channel.write(HandlerTask.extendedQuery(context), promise: nil)
472+
473+
do {
474+
let (rowStream, rowSequence) = try await promise.futureResult.map { rowStream in
475+
(rowStream, rowStream.asyncSequence())
476+
}.get()
477+
let result = try await consume(rowSequence)
478+
try await rowStream.drain().get()
479+
guard let metadata = PostgresQueryMetadata(string: rowStream.commandTag) else {
480+
throw PSQLError.invalidCommandTag(rowStream.commandTag)
481+
}
482+
return (result, metadata)
483+
} catch var error as PSQLError {
484+
error.file = file
485+
error.line = line
486+
error.query = query
487+
throw error // rethrow with more metadata
488+
}
489+
}
490+
441491
/// Start listening for a channel
442492
public func listen(_ channel: String) async throws -> PostgresNotificationSequence {
443493
let id = self.internalListenID.loadThenWrappingIncrement(ordering: .relaxed)
@@ -531,6 +581,52 @@ extension PostgresConnection {
531581
}
532582
}
533583

584+
/// Execute a statement on the Postgres server the connection is connected to,
585+
/// returning the metadata.
586+
///
587+
/// - Parameters:
588+
/// - query: The ``PostgresQuery`` to run
589+
/// - logger: The `Logger` to log into for the query
590+
/// - file: The file, the query was started in. Used for better error reporting.
591+
/// - line: The line, the query was started in. Used for better error reporting.
592+
/// - Returns: The query metadata.
593+
@discardableResult
594+
public func execute(
595+
_ query: PostgresQuery,
596+
logger: Logger,
597+
file: String = #fileID,
598+
line: Int = #line
599+
) async throws -> PostgresQueryMetadata {
600+
var logger = logger
601+
logger[postgresMetadataKey: .connectionID] = "\(self.id)"
602+
603+
guard query.binds.count <= Int(UInt16.max) else {
604+
throw PSQLError(code: .tooManyParameters, query: query, file: file, line: line)
605+
}
606+
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
607+
let context = ExtendedQueryContext(
608+
query: query,
609+
logger: logger,
610+
promise: promise
611+
)
612+
613+
self.channel.write(HandlerTask.extendedQuery(context), promise: nil)
614+
615+
do {
616+
let rowStream = try await promise.futureResult.get()
617+
try await rowStream.drain().get()
618+
guard let metadata = PostgresQueryMetadata(string: rowStream.commandTag) else {
619+
throw PSQLError.invalidCommandTag(rowStream.commandTag)
620+
}
621+
return metadata
622+
} catch var error as PSQLError {
623+
error.file = file
624+
error.line = line
625+
error.query = query
626+
throw error // rethrow with more metadata
627+
}
628+
}
629+
534630
#if compiler(>=6.0)
535631
/// Puts the connection into an open transaction state, for the provided `closure`'s lifetime.
536632
///

Sources/PostgresNIO/New/PSQLRowStream.swift

+64-1
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,70 @@ final class PSQLRowStream: @unchecked Sendable {
276276
return self.eventLoop.makeFailedFuture(error)
277277
}
278278
}
279-
279+
280+
// MARK: Drain on EventLoop
281+
282+
func drain() -> EventLoopFuture<Void> {
283+
if self.eventLoop.inEventLoop {
284+
return self.drain0()
285+
} else {
286+
return self.eventLoop.flatSubmit {
287+
self.drain0()
288+
}
289+
}
290+
}
291+
292+
private func drain0() -> EventLoopFuture<Void> {
293+
self.eventLoop.preconditionInEventLoop()
294+
295+
switch self.downstreamState {
296+
case .waitingForConsumer(let bufferState):
297+
switch bufferState {
298+
case .streaming(var buffer, let dataSource):
299+
let promise = self.eventLoop.makePromise(of: [PostgresRow].self)
300+
301+
buffer.removeAll()
302+
self.downstreamState = .waitingForAll([], promise, dataSource)
303+
// immediately request more
304+
dataSource.request(for: self)
305+
306+
return promise.futureResult.map { _ in }
307+
308+
case .finished(_, let summary):
309+
self.downstreamState = .consumed(.success(summary))
310+
return self.eventLoop.makeSucceededVoidFuture()
311+
312+
case .failure(let error):
313+
self.downstreamState = .consumed(.failure(error))
314+
return self.eventLoop.makeFailedFuture(error)
315+
}
316+
case .asyncSequence(let consumer, let dataSource, let onFinish):
317+
consumer.finish()
318+
onFinish()
319+
320+
let promise = self.eventLoop.makePromise(of: [PostgresRow].self)
321+
322+
self.downstreamState = .waitingForAll([], promise, dataSource)
323+
// immediately request more
324+
dataSource.request(for: self)
325+
326+
return promise.futureResult.map { _ in }
327+
case .consumed(.success):
328+
// already drained
329+
return self.eventLoop.makeSucceededVoidFuture()
330+
case .consumed(let .failure(error)):
331+
return self.eventLoop.makeFailedFuture(error)
332+
case .waitingForAll(let rows, let promise, let dataSource):
333+
self.downstreamState = .waitingForAll(rows, promise, dataSource)
334+
// immediately request more
335+
dataSource.request(for: self)
336+
337+
return promise.futureResult.map { _ in }
338+
default:
339+
preconditionFailure("Invalid state: \(self.downstreamState)")
340+
}
341+
}
342+
280343
internal func noticeReceived(_ notice: PostgresBackendMessage.NoticeResponse) {
281344
self.logger.debug("Notice Received", metadata: [
282345
.notice: "\(notice)"

Sources/PostgresNIO/New/PostgresRowSequence.swift

+2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ extension PostgresRowSequence {
6060
extension PostgresRowSequence.AsyncIterator: Sendable {}
6161

6262
extension PostgresRowSequence {
63+
/// Collects all rows into an array.
64+
/// - Returns: The rows.
6365
public func collect() async throws -> [PostgresRow] {
6466
var result = [PostgresRow]()
6567
for try await row in self {

Sources/PostgresNIO/Pool/PostgresClient.swift

+55
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,61 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service {
435435
}
436436
}
437437

438+
/// Run a query on the Postgres server the connection is connected to, returning the metadata.
439+
///
440+
/// - Parameters:
441+
/// - query: The ``PostgresQuery`` to run
442+
/// - logger: The `Logger` to log into for the query
443+
/// - file: The file, the query was started in. Used for better error reporting.
444+
/// - line: The line, the query was started in. Used for better error reporting.
445+
/// - consume: The closure to consume the ``PostgresRowSequence``.
446+
/// DO NOT escape the row-sequence out of the closure.
447+
/// - Returns: The result of the `consume` closure as well as the query metadata.
448+
public func query<Result>(
449+
_ query: PostgresQuery,
450+
logger: Logger? = nil,
451+
file: String = #fileID,
452+
line: Int = #line,
453+
_ consume: (PostgresRowSequence) async throws -> Result
454+
) async throws -> (Result, PostgresQueryMetadata) {
455+
let logger = logger ?? Self.loggingDisabled
456+
457+
do {
458+
guard query.binds.count <= Int(UInt16.max) else {
459+
throw PSQLError(code: .tooManyParameters, query: query, file: file, line: line)
460+
}
461+
462+
let connection = try await self.leaseConnection()
463+
464+
var logger = logger
465+
logger[postgresMetadataKey: .connectionID] = "\(connection.id)"
466+
467+
let promise = connection.channel.eventLoop.makePromise(of: PSQLRowStream.self)
468+
let context = ExtendedQueryContext(
469+
query: query,
470+
logger: logger,
471+
promise: promise
472+
)
473+
474+
connection.channel.write(HandlerTask.extendedQuery(context), promise: nil)
475+
476+
let (rowStream, rowSequence) = try await promise.futureResult.map { rowStream in
477+
(rowStream, rowStream.asyncSequence(onFinish: { self.pool.releaseConnection(connection) }))
478+
}.get()
479+
let result = try await consume(rowSequence)
480+
try await rowStream.drain().get()
481+
guard let metadata = PostgresQueryMetadata(string: rowStream.commandTag) else {
482+
throw PSQLError.invalidCommandTag(rowStream.commandTag)
483+
}
484+
return (result, metadata)
485+
} catch var error as PSQLError {
486+
error.file = file
487+
error.line = line
488+
error.query = query
489+
throw error // rethrow with more metadata
490+
}
491+
}
492+
438493
/// Execute a prepared statement, taking care of the preparation when necessary
439494
public func execute<Statement: PostgresPreparedStatement, Row>(
440495
_ preparedStatement: Statement,

Tests/IntegrationTests/AsyncTests.swift

+95-8
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,98 @@ final class AsyncPostgresConnectionTests: XCTestCase {
4646
}
4747
}
4848

49+
func testSelect10kRowsWithMetadata() async throws {
50+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
51+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
52+
let eventLoop = eventLoopGroup.next()
53+
54+
let start = 1
55+
let end = 10000
56+
57+
try await withTestConnection(on: eventLoop) { connection in
58+
let (result, metadata) = try await connection.query(
59+
"SELECT generate_series(\(start), \(end));",
60+
logger: .psqlTest
61+
) { rows in
62+
var counter = 0
63+
for try await row in rows {
64+
let element = try row.decode(Int.self)
65+
XCTAssertEqual(element, counter + 1)
66+
counter += 1
67+
}
68+
return counter
69+
}
70+
71+
XCTAssertEqual(metadata.command, "SELECT")
72+
XCTAssertEqual(metadata.oid, nil)
73+
XCTAssertEqual(metadata.rows, end)
74+
75+
XCTAssertEqual(result, end)
76+
}
77+
}
78+
79+
func testSelectRowsWithMetadataNotConsumedAtAll() async throws {
80+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
81+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
82+
let eventLoop = eventLoopGroup.next()
83+
84+
let start = 1
85+
let end = 10000
86+
87+
try await withTestConnection(on: eventLoop) { connection in
88+
let (_, metadata) = try await connection.query(
89+
"SELECT generate_series(\(start), \(end));",
90+
logger: .psqlTest
91+
) { _ in }
92+
93+
XCTAssertEqual(metadata.command, "SELECT")
94+
XCTAssertEqual(metadata.oid, nil)
95+
XCTAssertEqual(metadata.rows, end)
96+
}
97+
}
98+
99+
func testSelectRowsWithMetadataNotFullyConsumed() async throws {
100+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
101+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
102+
let eventLoop = eventLoopGroup.next()
103+
104+
try await withTestConnection(on: eventLoop) { connection in
105+
do {
106+
_ = try await connection.query(
107+
"SELECT generate_series(1, 10000);",
108+
logger: .psqlTest
109+
) { rows in
110+
for try await _ in rows { break }
111+
}
112+
// This path is also fine
113+
} catch is CancellationError {
114+
// Expected
115+
} catch {
116+
XCTFail("Expected 'CancellationError', got: \(String(reflecting: error))")
117+
}
118+
}
119+
}
120+
121+
func testExecuteRowsWithMetadata() async throws {
122+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
123+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
124+
let eventLoop = eventLoopGroup.next()
125+
126+
let start = 1
127+
let end = 10000
128+
129+
try await withTestConnection(on: eventLoop) { connection in
130+
let metadata = try await connection.execute(
131+
"SELECT generate_series(\(start), \(end));",
132+
logger: .psqlTest
133+
)
134+
135+
XCTAssertEqual(metadata.command, "SELECT")
136+
XCTAssertEqual(metadata.oid, nil)
137+
XCTAssertEqual(metadata.rows, end)
138+
}
139+
}
140+
49141
func testSelectActiveConnection() async throws {
50142
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
51143
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
@@ -207,7 +299,7 @@ final class AsyncPostgresConnectionTests: XCTestCase {
207299

208300
try await withTestConnection(on: eventLoop) { connection in
209301
// Max binds limit is UInt16.max which is 65535 which is 3 * 5 * 17 * 257
210-
// Max columns limit is 1664, so we will only make 5 * 257 columns which is less
302+
// Max columns limit appears to be ~1600, so we will only make 5 * 257 columns which is less
211303
// Then we will insert 3 * 17 rows
212304
// In the insertion, there will be a total of 3 * 17 * 5 * 257 == UInt16.max bindings
213305
// If the test is successful, it means Postgres supports UInt16.max bindings
@@ -241,13 +333,8 @@ final class AsyncPostgresConnectionTests: XCTestCase {
241333
unsafeSQL: "INSERT INTO table1 VALUES \(insertionValues)",
242334
binds: binds
243335
)
244-
try await connection.query(insertionQuery, logger: .psqlTest)
245-
246-
let countQuery = PostgresQuery(unsafeSQL: "SELECT COUNT(*) FROM table1")
247-
let countRows = try await connection.query(countQuery, logger: .psqlTest)
248-
var countIterator = countRows.makeAsyncIterator()
249-
let insertedRowsCount = try await countIterator.next()?.decode(Int.self, context: .default)
250-
XCTAssertEqual(rowsCount, insertedRowsCount)
336+
let metadata = try await connection.execute(insertionQuery, logger: .psqlTest)
337+
XCTAssertEqual(metadata.rows, rowsCount)
251338

252339
let dropQuery = PostgresQuery(unsafeSQL: "DROP TABLE table1")
253340
try await connection.query(dropQuery, logger: .psqlTest)

0 commit comments

Comments
 (0)