Skip to content

Commit 72e8db0

Browse files
committed
Extract SQLExecDirect, SQLExecute, SQLPrepare implementation
Co-Authored-By: alinalibq <[email protected]>
1 parent 42f27ab commit 72e8db0

File tree

5 files changed

+165
-16
lines changed

5 files changed

+165
-16
lines changed

cpp/src/arrow/flight/sql/odbc/odbc_api.cc

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -803,22 +803,49 @@ SQLRETURN SQLExecDirect(SQLHSTMT stmt, SQLWCHAR* query_text, SQLINTEGER text_len
803803
ARROW_LOG(DEBUG) << "SQLExecDirectW called with stmt: " << stmt
804804
<< ", query_text: " << static_cast<const void*>(query_text)
805805
<< ", text_length: " << text_length;
806-
// GH-47711 TODO: Implement SQLExecDirect
807-
return SQL_INVALID_HANDLE;
806+
807+
using ODBC::ODBCStatement;
808+
// The driver is built to handle SELECT statements only.
809+
return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() {
810+
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(stmt);
811+
std::string query = ODBC::SqlWcharToString(query_text, text_length);
812+
813+
statement->Prepare(query);
814+
statement->ExecutePrepared();
815+
816+
return SQL_SUCCESS;
817+
});
808818
}
809819

810820
SQLRETURN SQLPrepare(SQLHSTMT stmt, SQLWCHAR* query_text, SQLINTEGER text_length) {
811821
ARROW_LOG(DEBUG) << "SQLPrepareW called with stmt: " << stmt
812822
<< ", query_text: " << static_cast<const void*>(query_text)
813823
<< ", text_length: " << text_length;
814-
// GH-47712 TODO: Implement SQLPrepare
815-
return SQL_INVALID_HANDLE;
824+
825+
using ODBC::ODBCStatement;
826+
// The driver is built to handle SELECT statements only.
827+
return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() {
828+
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(stmt);
829+
std::string query = ODBC::SqlWcharToString(query_text, text_length);
830+
831+
statement->Prepare(query);
832+
833+
return SQL_SUCCESS;
834+
});
816835
}
817836

818837
SQLRETURN SQLExecute(SQLHSTMT stmt) {
819838
ARROW_LOG(DEBUG) << "SQLExecute called with stmt: " << stmt;
820-
// GH-47712 TODO: Implement SQLExecute
821-
return SQL_INVALID_HANDLE;
839+
840+
using ODBC::ODBCStatement;
841+
// The driver is built to handle SELECT statements only.
842+
return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() {
843+
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(stmt);
844+
845+
statement->ExecutePrepared();
846+
847+
return SQL_SUCCESS;
848+
});
822849
}
823850

824851
SQLRETURN SQLFetch(SQLHSTMT stmt) {

cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ using util::ThrowIfNotOK;
4141

4242
namespace {
4343

44-
void ClosePreparedStatementIfAny(std::shared_ptr<PreparedStatement>& prepared_statement) {
44+
void ClosePreparedStatementIfAny(std::shared_ptr<PreparedStatement>& prepared_statement,
45+
const FlightCallOptions& options) {
4546
if (prepared_statement != nullptr) {
46-
ThrowIfNotOK(prepared_statement->Close());
47+
ThrowIfNotOK(prepared_statement->Close(options));
4748
prepared_statement.reset();
4849
}
4950
}
@@ -66,6 +67,10 @@ FlightSqlStatement::FlightSqlStatement(const Diagnostics& diagnostics,
6667
call_options_.timeout = TimeoutDuration{-1};
6768
}
6869

70+
FlightSqlStatement::~FlightSqlStatement() {
71+
ClosePreparedStatementIfAny(prepared_statement_, call_options_);
72+
}
73+
6974
bool FlightSqlStatement::SetAttribute(StatementAttributeId attribute,
7075
const Attribute& value) {
7176
switch (attribute) {
@@ -97,7 +102,7 @@ boost::optional<Statement::Attribute> FlightSqlStatement::GetAttribute(
97102

98103
boost::optional<std::shared_ptr<ResultSetMetadata>> FlightSqlStatement::Prepare(
99104
const std::string& query) {
100-
ClosePreparedStatementIfAny(prepared_statement_);
105+
ClosePreparedStatementIfAny(prepared_statement_, call_options_);
101106

102107
Result<std::shared_ptr<PreparedStatement>> result =
103108
sql_client_.Prepare(call_options_, query);
@@ -113,7 +118,8 @@ boost::optional<std::shared_ptr<ResultSetMetadata>> FlightSqlStatement::Prepare(
113118
bool FlightSqlStatement::ExecutePrepared() {
114119
assert(prepared_statement_.get() != nullptr);
115120

116-
Result<std::shared_ptr<FlightInfo>> result = prepared_statement_->Execute();
121+
Result<std::shared_ptr<FlightInfo>> result =
122+
prepared_statement_->Execute(call_options_);
117123
ThrowIfNotOK(result.status());
118124

119125
current_result_set_ = std::make_shared<FlightSqlResultSet>(
@@ -124,7 +130,7 @@ bool FlightSqlStatement::ExecutePrepared() {
124130
}
125131

126132
bool FlightSqlStatement::Execute(const std::string& query) {
127-
ClosePreparedStatementIfAny(prepared_statement_);
133+
ClosePreparedStatementIfAny(prepared_statement_, call_options_);
128134

129135
Result<std::shared_ptr<FlightInfo>> result = sql_client_.Execute(call_options_, query);
130136
ThrowIfNotOK(result.status());
@@ -146,7 +152,7 @@ std::shared_ptr<ResultSet> FlightSqlStatement::GetTables(
146152
const std::string* catalog_name, const std::string* schema_name,
147153
const std::string* table_name, const std::string* table_type,
148154
const ColumnNames& column_names) {
149-
ClosePreparedStatementIfAny(prepared_statement_);
155+
ClosePreparedStatementIfAny(prepared_statement_, call_options_);
150156

151157
std::vector<std::string> table_types;
152158

@@ -199,7 +205,7 @@ std::shared_ptr<ResultSet> FlightSqlStatement::GetTables_V3(
199205
std::shared_ptr<ResultSet> FlightSqlStatement::GetColumns_V2(
200206
const std::string* catalog_name, const std::string* schema_name,
201207
const std::string* table_name, const std::string* column_name) {
202-
ClosePreparedStatementIfAny(prepared_statement_);
208+
ClosePreparedStatementIfAny(prepared_statement_, call_options_);
203209

204210
Result<std::shared_ptr<FlightInfo>> result = sql_client_.GetTables(
205211
call_options_, catalog_name, schema_name, table_name, true, nullptr);
@@ -220,7 +226,7 @@ std::shared_ptr<ResultSet> FlightSqlStatement::GetColumns_V2(
220226
std::shared_ptr<ResultSet> FlightSqlStatement::GetColumns_V3(
221227
const std::string* catalog_name, const std::string* schema_name,
222228
const std::string* table_name, const std::string* column_name) {
223-
ClosePreparedStatementIfAny(prepared_statement_);
229+
ClosePreparedStatementIfAny(prepared_statement_, call_options_);
224230

225231
Result<std::shared_ptr<FlightInfo>> result = sql_client_.GetTables(
226232
call_options_, catalog_name, schema_name, table_name, true, nullptr);
@@ -239,7 +245,7 @@ std::shared_ptr<ResultSet> FlightSqlStatement::GetColumns_V3(
239245
}
240246

241247
std::shared_ptr<ResultSet> FlightSqlStatement::GetTypeInfo_V2(int16_t data_type) {
242-
ClosePreparedStatementIfAny(prepared_statement_);
248+
ClosePreparedStatementIfAny(prepared_statement_, call_options_);
243249

244250
Result<std::shared_ptr<FlightInfo>> result = sql_client_.GetXdbcTypeInfo(call_options_);
245251
ThrowIfNotOK(result.status());
@@ -257,7 +263,7 @@ std::shared_ptr<ResultSet> FlightSqlStatement::GetTypeInfo_V2(int16_t data_type)
257263
}
258264

259265
std::shared_ptr<ResultSet> FlightSqlStatement::GetTypeInfo_V3(int16_t data_type) {
260-
ClosePreparedStatementIfAny(prepared_statement_);
266+
ClosePreparedStatementIfAny(prepared_statement_, call_options_);
261267

262268
Result<std::shared_ptr<FlightInfo>> result = sql_client_.GetXdbcTypeInfo(call_options_);
263269
ThrowIfNotOK(result.status());

cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class FlightSqlStatement : public Statement {
4848
FlightSqlStatement(const Diagnostics& diagnostics, FlightSqlClient& sql_client,
4949
FlightCallOptions call_options,
5050
const MetadataSettings& metadata_settings);
51+
~FlightSqlStatement();
5152

5253
bool SetAttribute(StatementAttributeId attribute, const Attribute& value) override;
5354

cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ add_arrow_test(flight_sql_odbc_test
3535
odbc_test_suite.cc
3636
odbc_test_suite.h
3737
connection_test.cc
38+
statement_test.cc
3839
# Enable Protobuf cleanup after test execution
3940
# GH-46889: move protobuf_test_util to a more common location
4041
../../../../engine/substrait/protobuf_test_util.cc
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h"
18+
19+
#include "arrow/flight/sql/odbc/odbc_impl/platform.h"
20+
21+
#include <sql.h>
22+
#include <sqltypes.h>
23+
#include <sqlucode.h>
24+
25+
#include <limits>
26+
27+
#include <gmock/gmock.h>
28+
#include <gtest/gtest.h>
29+
30+
namespace arrow::flight::sql::odbc {
31+
32+
template <typename T>
33+
class StatementTest : public T {};
34+
35+
class StatementMockTest : public FlightSQLODBCMockTestBase {};
36+
class StatementRemoteTest : public FlightSQLODBCRemoteTestBase {};
37+
using TestTypes = ::testing::Types<StatementMockTest, StatementRemoteTest>;
38+
TYPED_TEST_SUITE(StatementTest, TestTypes);
39+
40+
TYPED_TEST(StatementTest, TestSQLExecDirectSimpleQuery) {
41+
std::wstring wsql = L"SELECT 1;";
42+
std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());
43+
44+
ASSERT_EQ(SQL_SUCCESS,
45+
SQLExecDirect(this->stmt, &sql0[0], static_cast<SQLINTEGER>(sql0.size())));
46+
47+
ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt));
48+
49+
SQLINTEGER val;
50+
51+
ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0));
52+
// Verify 1 is returned
53+
EXPECT_EQ(1, val);
54+
55+
ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt));
56+
57+
ASSERT_EQ(SQL_ERROR, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0));
58+
// Invalid cursor state
59+
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState24000);
60+
}
61+
62+
TYPED_TEST(StatementTest, TestSQLExecDirectInvalidQuery) {
63+
std::wstring wsql = L"SELECT;";
64+
std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());
65+
66+
ASSERT_EQ(SQL_ERROR,
67+
SQLExecDirect(this->stmt, &sql0[0], static_cast<SQLINTEGER>(sql0.size())));
68+
// ODBC provides generic error code HY000 to all statement errors
69+
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY000);
70+
}
71+
72+
TYPED_TEST(StatementTest, TestSQLExecuteSimpleQuery) {
73+
std::wstring wsql = L"SELECT 1;";
74+
std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());
75+
76+
ASSERT_EQ(SQL_SUCCESS,
77+
SQLPrepare(this->stmt, &sql0[0], static_cast<SQLINTEGER>(sql0.size())));
78+
79+
ASSERT_EQ(SQL_SUCCESS, SQLExecute(this->stmt));
80+
81+
// GH-47713 TODO: Uncomment call to SQLFetch SQLGetData after implementation
82+
/*
83+
// Fetch data
84+
ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt));
85+
86+
SQLINTEGER val;
87+
ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0));
88+
89+
// Verify 1 is returned
90+
EXPECT_EQ(1, val);
91+
92+
ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt));
93+
94+
ASSERT_EQ(SQL_ERROR, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0));
95+
// Invalid cursor state
96+
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState24000);
97+
*/
98+
}
99+
100+
TYPED_TEST(StatementTest, TestSQLPrepareInvalidQuery) {
101+
std::wstring wsql = L"SELECT;";
102+
std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());
103+
104+
ASSERT_EQ(SQL_ERROR,
105+
SQLPrepare(this->stmt, &sql0[0], static_cast<SQLINTEGER>(sql0.size())));
106+
// ODBC provides generic error code HY000 to all statement errors
107+
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY000);
108+
109+
ASSERT_EQ(SQL_ERROR, SQLExecute(this->stmt));
110+
// Verify function sequence error state is returned
111+
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY010);
112+
}
113+
114+
} // namespace arrow::flight::sql::odbc

0 commit comments

Comments
 (0)