Skip to content

Commit 8625cfa

Browse files
committed
Extract SQLExecDirect, SQLExecute, SQLPrepare implementation
Co-Authored-By: alinalibq <[email protected]>
1 parent 46b033f commit 8625cfa

File tree

5 files changed

+168
-16
lines changed

5 files changed

+168
-16
lines changed

cpp/src/arrow/flight/sql/odbc/odbc_api.cc

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -940,22 +940,49 @@ SQLRETURN SQLExecDirect(SQLHSTMT stmt, SQLWCHAR* query_text, SQLINTEGER text_len
940940
ARROW_LOG(DEBUG) << "SQLExecDirectW called with stmt: " << stmt
941941
<< ", query_text: " << static_cast<const void*>(query_text)
942942
<< ", text_length: " << text_length;
943-
// GH-47711 TODO: Implement SQLExecDirect
944-
return SQL_INVALID_HANDLE;
943+
944+
using ODBC::ODBCStatement;
945+
// The driver is built to handle SELECT statements only.
946+
return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() {
947+
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(stmt);
948+
std::string query = ODBC::SqlWcharToString(query_text, text_length);
949+
950+
statement->Prepare(query);
951+
statement->ExecutePrepared();
952+
953+
return SQL_SUCCESS;
954+
});
945955
}
946956

947957
SQLRETURN SQLPrepare(SQLHSTMT stmt, SQLWCHAR* query_text, SQLINTEGER text_length) {
948958
ARROW_LOG(DEBUG) << "SQLPrepareW called with stmt: " << stmt
949959
<< ", query_text: " << static_cast<const void*>(query_text)
950960
<< ", text_length: " << text_length;
951-
// GH-47712 TODO: Implement SQLPrepare
952-
return SQL_INVALID_HANDLE;
961+
962+
using ODBC::ODBCStatement;
963+
// The driver is built to handle SELECT statements only.
964+
return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() {
965+
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(stmt);
966+
std::string query = ODBC::SqlWcharToString(query_text, text_length);
967+
968+
statement->Prepare(query);
969+
970+
return SQL_SUCCESS;
971+
});
953972
}
954973

955974
SQLRETURN SQLExecute(SQLHSTMT stmt) {
956975
ARROW_LOG(DEBUG) << "SQLExecute called with stmt: " << stmt;
957-
// GH-47712 TODO: Implement SQLExecute
958-
return SQL_INVALID_HANDLE;
976+
977+
using ODBC::ODBCStatement;
978+
// The driver is built to handle SELECT statements only.
979+
return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() {
980+
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(stmt);
981+
982+
statement->ExecutePrepared();
983+
984+
return SQL_SUCCESS;
985+
});
959986
}
960987

961988
SQLRETURN SQLFetch(SQLHSTMT stmt) {

cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ using util::ThrowIfNotOK;
4141

4242
namespace {
4343

44-
void ClosePreparedStatementIfAny(std::shared_ptr<PreparedStatement>& prepared_statement) {
44+
void ClosePreparedStatementIfAny(std::shared_ptr<PreparedStatement>& prepared_statement,
45+
const FlightCallOptions& options) {
4546
if (prepared_statement != nullptr) {
46-
ThrowIfNotOK(prepared_statement->Close());
47+
ThrowIfNotOK(prepared_statement->Close(options));
4748
prepared_statement.reset();
4849
}
4950
}
@@ -66,6 +67,10 @@ FlightSqlStatement::FlightSqlStatement(const Diagnostics& diagnostics,
6667
call_options_.timeout = TimeoutDuration{-1};
6768
}
6869

70+
FlightSqlStatement::~FlightSqlStatement() {
71+
ClosePreparedStatementIfAny(prepared_statement_, call_options_);
72+
}
73+
6974
bool FlightSqlStatement::SetAttribute(StatementAttributeId attribute,
7075
const Attribute& value) {
7176
switch (attribute) {
@@ -97,7 +102,7 @@ boost::optional<Statement::Attribute> FlightSqlStatement::GetAttribute(
97102

98103
boost::optional<std::shared_ptr<ResultSetMetadata>> FlightSqlStatement::Prepare(
99104
const std::string& query) {
100-
ClosePreparedStatementIfAny(prepared_statement_);
105+
ClosePreparedStatementIfAny(prepared_statement_, call_options_);
101106

102107
Result<std::shared_ptr<PreparedStatement>> result =
103108
sql_client_.Prepare(call_options_, query);
@@ -113,7 +118,8 @@ boost::optional<std::shared_ptr<ResultSetMetadata>> FlightSqlStatement::Prepare(
113118
bool FlightSqlStatement::ExecutePrepared() {
114119
assert(prepared_statement_.get() != nullptr);
115120

116-
Result<std::shared_ptr<FlightInfo>> result = prepared_statement_->Execute();
121+
Result<std::shared_ptr<FlightInfo>> result =
122+
prepared_statement_->Execute(call_options_);
117123
ThrowIfNotOK(result.status());
118124

119125
current_result_set_ = std::make_shared<FlightSqlResultSet>(
@@ -124,7 +130,7 @@ bool FlightSqlStatement::ExecutePrepared() {
124130
}
125131

126132
bool FlightSqlStatement::Execute(const std::string& query) {
127-
ClosePreparedStatementIfAny(prepared_statement_);
133+
ClosePreparedStatementIfAny(prepared_statement_, call_options_);
128134

129135
Result<std::shared_ptr<FlightInfo>> result = sql_client_.Execute(call_options_, query);
130136
ThrowIfNotOK(result.status());
@@ -146,7 +152,7 @@ std::shared_ptr<ResultSet> FlightSqlStatement::GetTables(
146152
const std::string* catalog_name, const std::string* schema_name,
147153
const std::string* table_name, const std::string* table_type,
148154
const ColumnNames& column_names) {
149-
ClosePreparedStatementIfAny(prepared_statement_);
155+
ClosePreparedStatementIfAny(prepared_statement_, call_options_);
150156

151157
std::vector<std::string> table_types;
152158

@@ -199,7 +205,7 @@ std::shared_ptr<ResultSet> FlightSqlStatement::GetTables_V3(
199205
std::shared_ptr<ResultSet> FlightSqlStatement::GetColumns_V2(
200206
const std::string* catalog_name, const std::string* schema_name,
201207
const std::string* table_name, const std::string* column_name) {
202-
ClosePreparedStatementIfAny(prepared_statement_);
208+
ClosePreparedStatementIfAny(prepared_statement_, call_options_);
203209

204210
Result<std::shared_ptr<FlightInfo>> result = sql_client_.GetTables(
205211
call_options_, catalog_name, schema_name, table_name, true, nullptr);
@@ -220,7 +226,7 @@ std::shared_ptr<ResultSet> FlightSqlStatement::GetColumns_V2(
220226
std::shared_ptr<ResultSet> FlightSqlStatement::GetColumns_V3(
221227
const std::string* catalog_name, const std::string* schema_name,
222228
const std::string* table_name, const std::string* column_name) {
223-
ClosePreparedStatementIfAny(prepared_statement_);
229+
ClosePreparedStatementIfAny(prepared_statement_, call_options_);
224230

225231
Result<std::shared_ptr<FlightInfo>> result = sql_client_.GetTables(
226232
call_options_, catalog_name, schema_name, table_name, true, nullptr);
@@ -239,7 +245,7 @@ std::shared_ptr<ResultSet> FlightSqlStatement::GetColumns_V3(
239245
}
240246

241247
std::shared_ptr<ResultSet> FlightSqlStatement::GetTypeInfo_V2(int16_t data_type) {
242-
ClosePreparedStatementIfAny(prepared_statement_);
248+
ClosePreparedStatementIfAny(prepared_statement_, call_options_);
243249

244250
Result<std::shared_ptr<FlightInfo>> result = sql_client_.GetXdbcTypeInfo(call_options_);
245251
ThrowIfNotOK(result.status());
@@ -257,7 +263,7 @@ std::shared_ptr<ResultSet> FlightSqlStatement::GetTypeInfo_V2(int16_t data_type)
257263
}
258264

259265
std::shared_ptr<ResultSet> FlightSqlStatement::GetTypeInfo_V3(int16_t data_type) {
260-
ClosePreparedStatementIfAny(prepared_statement_);
266+
ClosePreparedStatementIfAny(prepared_statement_, call_options_);
261267

262268
Result<std::shared_ptr<FlightInfo>> result = sql_client_.GetXdbcTypeInfo(call_options_);
263269
ThrowIfNotOK(result.status());

cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class FlightSqlStatement : public Statement {
4848
FlightSqlStatement(const Diagnostics& diagnostics, FlightSqlClient& sql_client,
4949
FlightCallOptions call_options,
5050
const MetadataSettings& metadata_settings);
51+
~FlightSqlStatement();
5152

5253
bool SetAttribute(StatementAttributeId attribute, const Attribute& value) override;
5354

cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ add_arrow_test(flight_sql_odbc_test
3535
odbc_test_suite.cc
3636
odbc_test_suite.h
3737
connection_test.cc
38+
statement_test.cc
3839
# Enable Protobuf cleanup after test execution
3940
# GH-46889: move protobuf_test_util to a more common location
4041
../../../../engine/substrait/protobuf_test_util.cc
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h"
18+
19+
#include "arrow/flight/sql/odbc/odbc_impl/platform.h"
20+
21+
#include <sql.h>
22+
#include <sqltypes.h>
23+
#include <sqlucode.h>
24+
25+
#include <limits>
26+
27+
#include <gmock/gmock.h>
28+
#include <gtest/gtest.h>
29+
30+
namespace arrow::flight::sql::odbc {
31+
32+
template <typename T>
33+
class StatementTest : public T {};
34+
35+
class StatementMockTest : public FlightSQLODBCMockTestBase {};
36+
class StatementRemoteTest : public FlightSQLODBCRemoteTestBase {};
37+
using TestTypes = ::testing::Types<StatementMockTest, StatementRemoteTest>;
38+
TYPED_TEST_SUITE(StatementTest, TestTypes);
39+
40+
TYPED_TEST(StatementTest, TestSQLExecDirectSimpleQuery) {
41+
std::wstring wsql = L"SELECT 1;";
42+
std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());
43+
44+
ASSERT_EQ(SQL_SUCCESS,
45+
SQLExecDirect(this->stmt, &sql0[0], static_cast<SQLINTEGER>(sql0.size())));
46+
47+
// GH-47713 TODO: Uncomment call to SQLFetch SQLGetData after implementation
48+
/*
49+
ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt));
50+
51+
SQLINTEGER val;
52+
53+
ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0));
54+
// Verify 1 is returned
55+
EXPECT_EQ(1, val);
56+
57+
ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt));
58+
59+
ASSERT_EQ(SQL_ERROR, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0));
60+
// Invalid cursor state
61+
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState24000);
62+
*/
63+
}
64+
65+
TYPED_TEST(StatementTest, TestSQLExecDirectInvalidQuery) {
66+
std::wstring wsql = L"SELECT;";
67+
std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());
68+
69+
ASSERT_EQ(SQL_ERROR,
70+
SQLExecDirect(this->stmt, &sql0[0], static_cast<SQLINTEGER>(sql0.size())));
71+
// ODBC provides generic error code HY000 to all statement errors
72+
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY000);
73+
}
74+
75+
TYPED_TEST(StatementTest, TestSQLExecuteSimpleQuery) {
76+
std::wstring wsql = L"SELECT 1;";
77+
std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());
78+
79+
ASSERT_EQ(SQL_SUCCESS,
80+
SQLPrepare(this->stmt, &sql0[0], static_cast<SQLINTEGER>(sql0.size())));
81+
82+
ASSERT_EQ(SQL_SUCCESS, SQLExecute(this->stmt));
83+
84+
// GH-47713 TODO: Uncomment call to SQLFetch SQLGetData after implementation
85+
/*
86+
// Fetch data
87+
ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt));
88+
89+
SQLINTEGER val;
90+
ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0));
91+
92+
// Verify 1 is returned
93+
EXPECT_EQ(1, val);
94+
95+
ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt));
96+
97+
ASSERT_EQ(SQL_ERROR, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0));
98+
// Invalid cursor state
99+
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState24000);
100+
*/
101+
}
102+
103+
TYPED_TEST(StatementTest, TestSQLPrepareInvalidQuery) {
104+
std::wstring wsql = L"SELECT;";
105+
std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());
106+
107+
ASSERT_EQ(SQL_ERROR,
108+
SQLPrepare(this->stmt, &sql0[0], static_cast<SQLINTEGER>(sql0.size())));
109+
// ODBC provides generic error code HY000 to all statement errors
110+
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY000);
111+
112+
ASSERT_EQ(SQL_ERROR, SQLExecute(this->stmt));
113+
// Verify function sequence error state is returned
114+
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY010);
115+
}
116+
117+
} // namespace arrow::flight::sql::odbc

0 commit comments

Comments
 (0)