Skip to content

Cache prepared statements #258

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 93 additions & 9 deletions DuckDB.NET.Data/DuckDBCommand.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using DuckDB.NET.Native;
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Data;
using System.Data.Common;
Expand All @@ -12,6 +13,8 @@ public class DuckDBCommand : DbCommand
{
private DuckDBConnection? connection;
private readonly DuckDBParameterCollection parameters = new();
private bool prepared;
private readonly List<PreparedStatement.PreparedStatement> preparedStatements = new();

protected override DbTransaction? DbTransaction { get; set; }
protected override DbParameterCollection DbParameterCollection => parameters;
Expand All @@ -32,6 +35,8 @@ public class DuckDBCommand : DbCommand
/// </remarks>
public bool UseStreamingMode { get; set; } = false;

internal DuckDBDataReader? DataReader { get; set; }

private string commandText = string.Empty;

#if NET6_0_OR_GREATER
Expand All @@ -43,7 +48,13 @@ public override string CommandText
get => commandText;
set
{
// TODO: We shouldn't be able to change the CommandText when the command is in execution (requires CommandState implementation)
if (DataReader != null)
throw new InvalidOperationException("cannot change CommandText while a reader is open");

if (commandText == value)
return;

DisposePreparedStatements();
commandText = value ?? string.Empty;
}
}
Expand Down Expand Up @@ -80,14 +91,13 @@ public override int ExecuteNonQuery()
{
EnsureConnectionOpen();

var results = PreparedStatement.PreparedStatement.PrepareMultiple(connection!.NativeConnection, CommandText, parameters, UseStreamingMode);

var count = 0;

foreach (var result in results)
foreach (var statement in GetStatements())
{
var current = result;
var current = statement.Execute();
count += (int)NativeMethods.Query.DuckDBRowsChanged(ref current);
current.Dispose();
}

return count;
Expand All @@ -111,15 +121,28 @@ public override int ExecuteNonQuery()
return (DuckDBDataReader)base.ExecuteReader(behavior);
}

protected override void Dispose(bool disposing)
{
if (disposing)
{
DataReader?.Dispose();
}

DisposePreparedStatements();

base.Dispose(disposing);
}

protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior)
{
EnsureConnectionOpen();
if (DataReader != null)
throw new InvalidOperationException("cannot create a new reader while one is open");

var results = PreparedStatement.PreparedStatement.PrepareMultiple(connection!.NativeConnection, CommandText, parameters, UseStreamingMode);
EnsureConnectionOpen();

var reader = new DuckDBDataReader(this, results, behavior);
var closeConnection = behavior.HasFlag(CommandBehavior.CloseConnection);

return reader;
return new DuckDBDataReader(this, GetStatements(), closeConnection);
}

public override void Prepare() { }
Expand All @@ -128,11 +151,72 @@ public override void Prepare() { }

internal void CloseConnection() => Connection!.Close();

private void DisposePreparedStatements()
{
foreach (var statement in preparedStatements)
{
statement.Dispose();
}

preparedStatements.Clear();
prepared = false;
}

private void EnsureConnectionOpen([CallerMemberName] string operation = "")
{
if (Connection is null || Connection.State != ConnectionState.Open)
{
throw new InvalidOperationException($"{operation} requires an open connection");
}
}

private IEnumerable<PreparedStatement.PreparedStatement> GetStatements()
{
foreach (var statement in prepared
? preparedStatements
: PrepareAndEnumerateStatements())
{
statement.BindParameters(Parameters);
statement.UseStreamingMode = UseStreamingMode;
yield return statement;
}
}

private IEnumerable<PreparedStatement.PreparedStatement> PrepareAndEnumerateStatements()
{
DisposePreparedStatements();

using var unmanagedQuery = CommandText.ToUnmanagedString();

var statementCount = NativeMethods.ExtractStatements.DuckDBExtractStatements(connection!.NativeConnection, unmanagedQuery, out var extractedStatements);

using (extractedStatements)
{
if (statementCount <= 0)
{
var error = NativeMethods.ExtractStatements.DuckDBExtractStatementsError(extractedStatements);
throw new DuckDBException(error.ToManagedString(false));
}

for (int index = 0; index < statementCount; index++)
{
var status = NativeMethods.ExtractStatements.DuckDBPrepareExtractedStatement(connection!.NativeConnection, extractedStatements, index, out var unmanagedStatement);

if (status.IsSuccess())
{
var statement = new PreparedStatement.PreparedStatement(unmanagedStatement);
preparedStatements.Add(statement);
yield return statement;
}
else
{
var errorMessage = NativeMethods.PreparedStatements.DuckDBPrepareError(unmanagedStatement).ToManagedString(false);

throw new DuckDBException(string.IsNullOrEmpty(errorMessage) ? "DuckDBQuery failed" : errorMessage);
}
}
}

prepared = true;
}
}
81 changes: 59 additions & 22 deletions DuckDB.NET.Data/DuckDBDataReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ namespace DuckDB.NET.Data;
public class DuckDBDataReader : DbDataReader
{
private readonly DuckDBCommand command;
private readonly CommandBehavior behavior;
private readonly bool closeConnection;

private DuckDBResult currentResult;
private DuckDBResult? currentResult;
private DuckDBDataChunk? currentChunk;

private int fieldCount;
Expand All @@ -27,35 +27,51 @@ public class DuckDBDataReader : DbDataReader
private bool streamingResult;
private long currentChunkIndex;

private readonly IEnumerator<DuckDBResult> resultEnumerator;
private readonly IEnumerator<PreparedStatement.PreparedStatement> statementEnumerator;
private VectorDataReaderBase[] vectorReaders = [];
private Dictionary<string, int> columnMapping = [];

internal DuckDBDataReader(DuckDBCommand command, IEnumerable<DuckDBResult> queryResults, CommandBehavior behavior)
internal DuckDBDataReader(DuckDBCommand command, IEnumerable<PreparedStatement.PreparedStatement> statements, bool closeConnection)
{
this.command = command;
this.behavior = behavior;
resultEnumerator = queryResults.GetEnumerator();
this.closeConnection = closeConnection;
statementEnumerator = statements.GetEnumerator();

InitNextReader();

// Do not modify the command's state when InitNextReader() throws an exception.
command.DataReader = this;
}

private bool InitNextReader()
{
while (resultEnumerator.MoveNext())
while (statementEnumerator.MoveNext())
{
if (NativeMethods.Query.DuckDBResultReturnType(resultEnumerator.Current) == DuckDBResultType.QueryResult)
currentResult?.Dispose();
currentResult = null; // Prevent double disposal.

try
{
currentChunkIndex = 0;
currentResult = resultEnumerator.Current;
var current = statementEnumerator.Current.Execute();
currentResult = current;

columnMapping = [];
fieldCount = (int)NativeMethods.Query.DuckDBColumnCount(ref currentResult);
streamingResult = NativeMethods.Types.DuckDBResultIsStreaming(currentResult) > 0;
if (NativeMethods.Query.DuckDBResultReturnType(current) == DuckDBResultType.QueryResult)
{
currentChunkIndex = 0;

hasRows = InitChunkData();
columnMapping = [];
fieldCount = (int)NativeMethods.Query.DuckDBColumnCount(ref current);
streamingResult = NativeMethods.Types.DuckDBResultIsStreaming(current) > 0;

return true;
hasRows = InitChunkData();

return true;
}
}
catch
{
Dispose();
throw;
}
}

Expand All @@ -69,8 +85,10 @@ private bool InitChunkData()
reader.Dispose();
}

var current = currentResult!.Value;

currentChunk?.Dispose();
currentChunk = streamingResult ? NativeMethods.StreamingResult.DuckDBStreamFetchChunk(currentResult) : NativeMethods.Types.DuckDBResultGetChunk(currentResult, currentChunkIndex);
currentChunk = streamingResult ? NativeMethods.StreamingResult.DuckDBStreamFetchChunk(current) : NativeMethods.Types.DuckDBResultGetChunk(current, currentChunkIndex);

rowsReadFromCurrentChunk = 0;

Expand All @@ -85,9 +103,9 @@ private bool InitChunkData()
{
var vector = NativeMethods.DataChunks.DuckDBDataChunkGetVector(currentChunk, index);

using var logicalType = NativeMethods.Query.DuckDBColumnLogicalType(ref currentResult, index);
using var logicalType = NativeMethods.Query.DuckDBColumnLogicalType(ref current, index);

var columnName = vectorReaders[index]?.ColumnName ?? NativeMethods.Query.DuckDBColumnName(ref currentResult, index).ToManagedString(false);
var columnName = vectorReaders[index]?.ColumnName ?? NativeMethods.Query.DuckDBColumnName(ref current, index).ToManagedString(false);
vectorReaders[index] = VectorDataReaderFactory.CreateReader(vector, logicalType, columnName);
}

Expand Down Expand Up @@ -291,7 +309,7 @@ public override bool Read()

public override IEnumerator GetEnumerator()
{
return new DbEnumerator(this, behavior == CommandBehavior.CloseConnection);
return new DbEnumerator(this, closeConnection);
}

public override DataTable GetSchemaTable()
Expand Down Expand Up @@ -340,20 +358,39 @@ public override void Close()
{
if (closed) return;

command.DataReader = null;

foreach (var reader in vectorReaders)
{
reader.Dispose();
}

currentResult?.Dispose();
currentResult = null; // Prevent double disposal.

currentChunk?.Dispose();

if (behavior == CommandBehavior.CloseConnection)
try
{
command.CloseConnection();
// Try to consume the enumerator to ensure that all statements are prepared.
while (statementEnumerator.MoveNext())
{
// No-op.
}
}
catch
{
// Dispose() must not throw exceptions.
}

statementEnumerator.Dispose();

closed = true;
resultEnumerator.Dispose();

if (closeConnection)
{
command.CloseConnection();
}
}

private void CheckRowRead()
Expand Down
Loading
Loading