Skip to content

Commit 4d01e30

Browse files
committed
Implement SimpleQuery + Tests
1 parent 9f84290 commit 4d01e30

File tree

9 files changed

+316
-31
lines changed

9 files changed

+316
-31
lines changed

Sources/PostgresNIO/Connection/PostgresConnection.swift

+41
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,47 @@ extension PostgresConnection {
438438
}
439439
}
440440

441+
/// Run a simple text-only query on the Postgres server the connection is connected to.
442+
/// WARNING: This functions is not yet API and is incomplete.
443+
/// The return type will change to another stream.
444+
///
445+
/// - Parameters:
446+
/// - query: The simple query to run
447+
/// - logger: The `Logger` to log into for the query
448+
/// - file: The file, the query was started in. Used for better error reporting.
449+
/// - line: The line, the query was started in. Used for better error reporting.
450+
/// - Returns: A ``PostgresRowSequence`` containing the rows the server sent as the query result.
451+
/// The sequence be discarded.
452+
@discardableResult
453+
public func __simpleQuery(
454+
_ query: String,
455+
logger: Logger,
456+
file: String = #fileID,
457+
line: Int = #line
458+
) async throws -> PostgresRowSequence {
459+
var logger = logger
460+
logger[postgresMetadataKey: .connectionID] = "\(self.id)"
461+
462+
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
463+
let context = ExtendedQueryContext(
464+
simpleQuery: query,
465+
logger: logger,
466+
promise: promise
467+
)
468+
469+
self.channel.write(HandlerTask.extendedQuery(context), promise: nil)
470+
471+
do {
472+
return try await promise.futureResult.map({ $0.asyncSequence() }).get()
473+
} catch var error as PSQLError {
474+
error.file = file
475+
error.line = line
476+
// FIXME: just pass the string as a simple query, instead of acting like this is a PostgresQuery.
477+
error.query = PostgresQuery(unsafeSQL: query)
478+
throw error // rethrow with more metadata
479+
}
480+
}
481+
441482
/// Start listening for a channel
442483
public func listen(_ channel: String) async throws -> PostgresNotificationSequence {
443484
let id = self.internalListenID.loadThenWrappingIncrement(ordering: .relaxed)

Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift

+7-3
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ struct ConnectionStateMachine {
8787
// --- general actions
8888
case sendParseDescribeBindExecuteSync(PostgresQuery)
8989
case sendBindExecuteSync(PSQLExecuteStatement)
90+
case sendQuery(String)
9091
case failQuery(EventLoopPromise<PSQLRowStream>, with: PSQLError, cleanupContext: CleanUpContext?)
9192
case succeedQuery(EventLoopPromise<PSQLRowStream>, with: QueryResult)
9293

@@ -537,7 +538,7 @@ struct ConnectionStateMachine {
537538

538539
self.state = .readyForQuery(connectionContext)
539540
return self.executeNextQueryFromQueue()
540-
541+
541542
default:
542543
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.readyForQuery(transactionState)))
543544
}
@@ -585,7 +586,7 @@ struct ConnectionStateMachine {
585586
switch task {
586587
case .extendedQuery(let queryContext):
587588
switch queryContext.query {
588-
case .executeStatement(_, let promise), .unnamed(_, let promise):
589+
case .executeStatement(_, let promise), .unnamed(_, let promise), .simpleQuery(_, let promise):
589590
return .failQuery(promise, with: psqlErrror, cleanupContext: nil)
590591
case .prepareStatement(_, _, _, let promise):
591592
return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil)
@@ -745,7 +746,7 @@ struct ConnectionStateMachine {
745746
guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else {
746747
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.commandComplete(commandTag)))
747748
}
748-
749+
749750
self.state = .modifying // avoid CoW
750751
let action = queryState.commandCompletedReceived(commandTag)
751752
self.state = .extendedQuery(queryState, connectionContext)
@@ -855,6 +856,7 @@ struct ConnectionStateMachine {
855856
case .sendParseDescribeBindExecuteSync,
856857
.sendParseDescribeSync,
857858
.sendBindExecuteSync,
859+
.sendQuery,
858860
.succeedQuery,
859861
.succeedPreparedStatementCreation,
860862
.forwardRows,
@@ -1035,6 +1037,8 @@ extension ConnectionStateMachine {
10351037
return .sendParseDescribeBindExecuteSync(query)
10361038
case .sendBindExecuteSync(let executeStatement):
10371039
return .sendBindExecuteSync(executeStatement)
1040+
case .sendQuery(let query):
1041+
return .sendQuery(query)
10381042
case .failQuery(let requestContext, with: let error):
10391043
let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error)
10401044
return .failQuery(requestContext, with: error, cleanupContext: cleanupContext)

Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift

+99-25
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ struct ExtendedQueryStateMachine {
2929
case sendParseDescribeBindExecuteSync(PostgresQuery)
3030
case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType])
3131
case sendBindExecuteSync(PSQLExecuteStatement)
32-
32+
case sendQuery(String)
33+
3334
// --- general actions
3435
case failQuery(EventLoopPromise<PSQLRowStream>, with: PSQLError)
3536
case succeedQuery(EventLoopPromise<PSQLRowStream>, with: QueryResult)
@@ -85,6 +86,12 @@ struct ExtendedQueryStateMachine {
8586
state = .messagesSent(queryContext)
8687
return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes)
8788
}
89+
90+
case .simpleQuery(let query, _):
91+
return self.avoidingStateMachineCoW { state -> Action in
92+
state = .messagesSent(queryContext)
93+
return .sendQuery(query)
94+
}
8895
}
8996
}
9097

@@ -105,7 +112,7 @@ struct ExtendedQueryStateMachine {
105112

106113
self.isCancelled = true
107114
switch queryContext.query {
108-
case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise):
115+
case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise), .simpleQuery(_, let eventLoopPromise):
109116
return .failQuery(eventLoopPromise, with: .queryCancelled)
110117

111118
case .prepareStatement(_, _, _, let eventLoopPromise):
@@ -171,11 +178,19 @@ struct ExtendedQueryStateMachine {
171178
state = .noDataMessageReceived(queryContext)
172179
return .succeedPreparedStatementCreation(promise, with: nil)
173180
}
181+
182+
case .simpleQuery:
183+
return self.setAndFireError(.unexpectedBackendMessage(.noData))
174184
}
175185
}
176186

177187
mutating func rowDescriptionReceived(_ rowDescription: RowDescription) -> Action {
178-
guard case .parameterDescriptionReceived(let queryContext) = self.state else {
188+
let queryContext: ExtendedQueryContext
189+
switch self.state {
190+
case .messagesSent(let extendedQueryContext),
191+
.parameterDescriptionReceived(let extendedQueryContext):
192+
queryContext = extendedQueryContext
193+
default:
179194
return self.setAndFireError(.unexpectedBackendMessage(.rowDescription(rowDescription)))
180195
}
181196

@@ -198,7 +213,7 @@ struct ExtendedQueryStateMachine {
198213
}
199214

200215
switch queryContext.query {
201-
case .unnamed, .executeStatement:
216+
case .unnamed, .executeStatement, .simpleQuery:
202217
return .wait
203218

204219
case .prepareStatement(_, _, _, let eventLoopPromise):
@@ -219,6 +234,9 @@ struct ExtendedQueryStateMachine {
219234

220235
case .prepareStatement:
221236
return .evaluateErrorAtConnectionLevel(.unexpectedBackendMessage(.bindComplete))
237+
238+
case .simpleQuery:
239+
return self.setAndFireError(.unexpectedBackendMessage(.bindComplete))
222240
}
223241

224242
case .noDataMessageReceived(let queryContext):
@@ -258,20 +276,40 @@ struct ExtendedQueryStateMachine {
258276
return .wait
259277
}
260278

279+
case .rowDescriptionReceived(let queryContext, let columns):
280+
switch queryContext.query {
281+
case .simpleQuery(_, let eventLoopPromise):
282+
// When receiving a data row, we must ensure that the data row column count
283+
// matches the previously received row description column count.
284+
guard dataRow.columnCount == columns.count else {
285+
return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow)))
286+
}
287+
288+
return self.avoidingStateMachineCoW { state -> Action in
289+
var demandStateMachine = RowStreamStateMachine()
290+
demandStateMachine.receivedRow(dataRow)
291+
state = .streaming(columns, demandStateMachine)
292+
let result = QueryResult(value: .rowDescription(columns), logger: queryContext.logger)
293+
return .succeedQuery(eventLoopPromise, with: result)
294+
}
295+
296+
case .unnamed, .executeStatement, .prepareStatement:
297+
return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow)))
298+
}
299+
261300
case .drain(let columns):
262301
guard dataRow.columnCount == columns.count else {
263302
return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow)))
264303
}
265304
// we ignore all rows and wait for readyForQuery
266305
return .wait
267-
306+
268307
case .initialized,
269308
.messagesSent,
270309
.parseCompleteReceived,
271310
.parameterDescriptionReceived,
272311
.noDataMessageReceived,
273312
.emptyQueryResponseReceived,
274-
.rowDescriptionReceived,
275313
.bindCompleteReceived,
276314
.commandComplete,
277315
.error:
@@ -292,10 +330,36 @@ struct ExtendedQueryStateMachine {
292330
return .succeedQuery(eventLoopPromise, with: result)
293331
}
294332

295-
case .prepareStatement:
333+
case .prepareStatement, .simpleQuery:
296334
preconditionFailure("Invalid state: \(self.state)")
297335
}
298-
336+
337+
case .messagesSent(let context):
338+
switch context.query {
339+
case .simpleQuery(_, let eventLoopGroup):
340+
return self.avoidingStateMachineCoW { state -> Action in
341+
state = .commandComplete(commandTag: commandTag)
342+
let result = QueryResult(value: .noRows(.tag(commandTag)), logger: context.logger)
343+
return .succeedQuery(eventLoopGroup, with: result)
344+
}
345+
346+
case .unnamed, .executeStatement, .prepareStatement:
347+
return self.setAndFireError(.unexpectedBackendMessage(.commandComplete(commandTag)))
348+
}
349+
350+
case .rowDescriptionReceived(let context, _):
351+
switch context.query {
352+
case .simpleQuery(_, let eventLoopPromise):
353+
return self.avoidingStateMachineCoW { state -> Action in
354+
state = .commandComplete(commandTag: commandTag)
355+
let result = QueryResult(value: .noRows(.tag(commandTag)), logger: context.logger)
356+
return .succeedQuery(eventLoopPromise, with: result)
357+
}
358+
359+
case .unnamed, .executeStatement, .prepareStatement:
360+
return self.setAndFireError(.unexpectedBackendMessage(.commandComplete(commandTag)))
361+
}
362+
299363
case .streaming(_, var demandStateMachine):
300364
return self.avoidingStateMachineCoW { state -> Action in
301365
state = .commandComplete(commandTag: commandTag)
@@ -306,14 +370,12 @@ struct ExtendedQueryStateMachine {
306370
precondition(self.isCancelled)
307371
self.state = .commandComplete(commandTag: commandTag)
308372
return .wait
309-
373+
310374
case .initialized,
311-
.messagesSent,
312375
.parseCompleteReceived,
313376
.parameterDescriptionReceived,
314377
.noDataMessageReceived,
315378
.emptyQueryResponseReceived,
316-
.rowDescriptionReceived,
317379
.commandComplete,
318380
.error:
319381
return self.setAndFireError(.unexpectedBackendMessage(.commandComplete(commandTag)))
@@ -323,20 +385,32 @@ struct ExtendedQueryStateMachine {
323385
}
324386

325387
mutating func emptyQueryResponseReceived() -> Action {
326-
guard case .bindCompleteReceived(let queryContext) = self.state else {
327-
return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse))
328-
}
388+
switch self.state {
389+
case .bindCompleteReceived(let queryContext):
390+
switch queryContext.query {
391+
case .unnamed(_, let eventLoopPromise),
392+
.executeStatement(_, let eventLoopPromise):
393+
return self.avoidingStateMachineCoW { state -> Action in
394+
state = .emptyQueryResponseReceived
395+
let result = QueryResult(value: .noRows(.emptyResponse), logger: queryContext.logger)
396+
return .succeedQuery(eventLoopPromise, with: result)
397+
}
329398

330-
switch queryContext.query {
331-
case .unnamed(_, let eventLoopPromise),
332-
.executeStatement(_, let eventLoopPromise):
333-
return self.avoidingStateMachineCoW { state -> Action in
334-
state = .emptyQueryResponseReceived
335-
let result = QueryResult(value: .noRows(.emptyResponse), logger: queryContext.logger)
336-
return .succeedQuery(eventLoopPromise, with: result)
399+
case .prepareStatement, .simpleQuery:
400+
return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse))
337401
}
338-
339-
case .prepareStatement(_, _, _, _):
402+
case .messagesSent(let queryContext):
403+
switch queryContext.query {
404+
case .simpleQuery(_, let eventLoopPromise):
405+
return self.avoidingStateMachineCoW { state -> Action in
406+
state = .emptyQueryResponseReceived
407+
let result = QueryResult(value: .noRows(.emptyResponse), logger: queryContext.logger)
408+
return .succeedQuery(eventLoopPromise, with: result)
409+
}
410+
case .unnamed, .executeStatement, .prepareStatement:
411+
return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse))
412+
}
413+
default:
340414
return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse))
341415
}
342416
}
@@ -497,7 +571,7 @@ struct ExtendedQueryStateMachine {
497571
return .evaluateErrorAtConnectionLevel(error)
498572
} else {
499573
switch context.query {
500-
case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise):
574+
case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise), .simpleQuery(_, let eventLoopPromise):
501575
return .failQuery(eventLoopPromise, with: error)
502576
case .prepareStatement(_, _, _, let eventLoopPromise):
503577
return .failPreparedStatementCreation(eventLoopPromise, with: error)
@@ -536,7 +610,7 @@ struct ExtendedQueryStateMachine {
536610
switch context.query {
537611
case .prepareStatement:
538612
return true
539-
case .unnamed, .executeStatement:
613+
case .unnamed, .executeStatement, .simpleQuery:
540614
return false
541615
}
542616

Sources/PostgresNIO/New/PSQLTask.swift

+14-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ enum PSQLTask {
2323
eventLoopPromise.fail(error)
2424
case .prepareStatement(_, _, _, let eventLoopPromise):
2525
eventLoopPromise.fail(error)
26+
case .simpleQuery(_, let eventLoopPromise):
27+
eventLoopPromise.fail(error)
2628
}
2729

2830
case .closeCommand(let closeCommandContext):
@@ -31,16 +33,18 @@ enum PSQLTask {
3133
}
3234
}
3335

36+
// FIXME: Either rename all these `ExtendedQuery`s to just like `Query` or pull out `simpleQuery`
3437
final class ExtendedQueryContext {
3538
enum Query {
3639
case unnamed(PostgresQuery, EventLoopPromise<PSQLRowStream>)
3740
case executeStatement(PSQLExecuteStatement, EventLoopPromise<PSQLRowStream>)
3841
case prepareStatement(name: String, query: String, bindingDataTypes: [PostgresDataType], EventLoopPromise<RowDescription?>)
42+
case simpleQuery(String, EventLoopPromise<PSQLRowStream>)
3943
}
4044

4145
let query: Query
4246
let logger: Logger
43-
47+
4448
init(
4549
query: PostgresQuery,
4650
logger: Logger,
@@ -69,6 +73,15 @@ final class ExtendedQueryContext {
6973
self.query = .prepareStatement(name: name, query: query, bindingDataTypes: bindingDataTypes, promise)
7074
self.logger = logger
7175
}
76+
77+
init(
78+
simpleQuery: String,
79+
logger: Logger,
80+
promise: EventLoopPromise<PSQLRowStream>
81+
) {
82+
self.query = .simpleQuery(simpleQuery, promise)
83+
self.logger = logger
84+
}
7285
}
7386

7487
final class PreparedStatementContext: Sendable {

0 commit comments

Comments
 (0)