Skip to content

Commit 70ba716

Browse files
alinaliBQjusting-bq
andcommitted
Extract implementation of gh-46574
Address more feedback Avoid using "using" in Headers Add `server->Wait` call Co-Authored-By: justing-bq <[email protected]>
1 parent 5aa7dd1 commit 70ba716

26 files changed

+370
-187
lines changed

cpp/src/arrow/flight/sql/odbc/odbc_api.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,7 @@ SQLRETURN SQLSetConnectAttr(SQLHDBC conn, SQLINTEGER attr, SQLPOINTER value_ptr,
738738
// entries in the properties.
739739
void LoadPropertiesFromDSN(const std::string& dsn,
740740
Connection::ConnPropertyMap& properties) {
741-
arrow::flight::sql::odbc::config::Configuration config;
741+
config::Configuration config;
742742
config.LoadDsn(dsn);
743743
Connection::ConnPropertyMap dsn_properties = config.GetProperties();
744744
for (auto& [key, value] : dsn_properties) {
@@ -796,7 +796,7 @@ SQLRETURN SQLDriverConnect(SQLHDBC conn, SQLHWND window_handle,
796796
// Load the DSN window according to driver_completion
797797
if (driver_completion == SQL_DRIVER_PROMPT) {
798798
// Load DSN window before first attempt to connect
799-
arrow::flight::sql::odbc::config::Configuration config;
799+
config::Configuration config;
800800
if (!DisplayConnectionWindow(window_handle, config, properties)) {
801801
return static_cast<SQLRETURN>(SQL_NO_DATA);
802802
}
@@ -809,7 +809,7 @@ SQLRETURN SQLDriverConnect(SQLHDBC conn, SQLHWND window_handle,
809809
// If first connection fails due to missing attributes, load
810810
// the DSN window and try to connect again
811811
if (!missing_properties.empty()) {
812-
arrow::flight::sql::odbc::config::Configuration config;
812+
config::Configuration config;
813813
missing_properties.clear();
814814

815815
if (!DisplayConnectionWindow(window_handle, config, properties)) {
@@ -855,7 +855,7 @@ SQLRETURN SQLConnect(SQLHDBC conn, SQLWCHAR* dsn_name, SQLSMALLINT dsn_name_len,
855855
ODBCConnection* connection = reinterpret_cast<ODBCConnection*>(conn);
856856
std::string dsn = SqlWcharToString(dsn_name, dsn_name_len);
857857

858-
Configuration config;
858+
config::Configuration config;
859859
config.LoadDsn(dsn);
860860

861861
if (user_name) {

cpp/src/arrow/flight/sql/odbc/odbc_impl/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,11 @@ add_arrow_test(odbc_spi_impl_test
164164
accessors/time_array_accessor_test.cc
165165
accessors/timestamp_array_accessor_test.cc
166166
flight_sql_connection_test.cc
167+
flight_sql_stream_chunk_buffer_test.cc
167168
parse_table_types_test.cc
168169
json_converter_test.cc
169170
record_batch_transformer_test.cc
170171
util_test.cc
171172
EXTRA_LINK_LIBS
172-
arrow_odbc_spi_impl)
173+
arrow_odbc_spi_impl
174+
arrow_flight_testing_shared)

cpp/src/arrow/flight/sql/odbc/odbc_impl/attribute_utils.h

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,6 @@
3030
// GH-48083 TODO: replace `namespace ODBC` with `namespace arrow::flight::sql::odbc`
3131
namespace ODBC {
3232

33-
using arrow::flight::sql::odbc::Diagnostics;
34-
using arrow::flight::sql::odbc::DriverException;
35-
using arrow::flight::sql::odbc::WcsToUtf8;
36-
3733
template <typename T, typename O>
3834
inline void GetAttribute(T attribute_value, SQLPOINTER output, O output_size,
3935
O* output_len_ptr) {
@@ -70,7 +66,7 @@ inline SQLRETURN GetAttributeUTF8(std::string_view attribute_value, SQLPOINTER o
7066
template <typename O>
7167
inline SQLRETURN GetAttributeUTF8(std::string_view attribute_value, SQLPOINTER output,
7268
O output_size, O* output_len_ptr,
73-
Diagnostics& diagnostics) {
69+
arrow::flight::sql::odbc::Diagnostics& diagnostics) {
7470
SQLRETURN result =
7571
GetAttributeUTF8(attribute_value, output, output_size, output_len_ptr);
7672
if (SQL_SUCCESS_WITH_INFO == result) {
@@ -85,10 +81,11 @@ inline SQLRETURN GetAttributeSQLWCHAR(std::string_view attribute_value,
8581
O output_size, O* output_len_ptr) {
8682
size_t length = ConvertToSqlWChar(
8783
attribute_value, reinterpret_cast<SQLWCHAR*>(output),
88-
is_length_in_bytes ? output_size : output_size * GetSqlWCharSize());
84+
is_length_in_bytes ? output_size
85+
: output_size * arrow::flight::sql::odbc::GetSqlWCharSize());
8986

9087
if (!is_length_in_bytes) {
91-
length = length / GetSqlWCharSize();
88+
length = length / arrow::flight::sql::odbc::GetSqlWCharSize();
9289
}
9390

9491
if (output_len_ptr) {
@@ -97,17 +94,19 @@ inline SQLRETURN GetAttributeSQLWCHAR(std::string_view attribute_value,
9794

9895
if (output &&
9996
output_size <
100-
static_cast<O>(length + (is_length_in_bytes ? GetSqlWCharSize() : 1))) {
97+
static_cast<O>(length + (is_length_in_bytes
98+
? arrow::flight::sql::odbc::GetSqlWCharSize()
99+
: 1))) {
101100
return SQL_SUCCESS_WITH_INFO;
102101
}
103102
return SQL_SUCCESS;
104103
}
105104

106105
template <typename O>
107-
inline SQLRETURN GetAttributeSQLWCHAR(std::string_view attribute_value,
108-
bool is_length_in_bytes, SQLPOINTER output,
109-
O output_size, O* output_len_ptr,
110-
Diagnostics& diagnostics) {
106+
inline SQLRETURN GetAttributeSQLWCHAR(
107+
const std::string& attribute_value, bool is_length_in_bytes, SQLPOINTER output,
108+
O output_size, O* output_len_ptr,
109+
arrow::flight::sql::odbc::Diagnostics& diagnostics) {
111110
SQLRETURN result = GetAttributeSQLWCHAR(attribute_value, is_length_in_bytes, output,
112111
output_size, output_len_ptr);
113112
if (SQL_SUCCESS_WITH_INFO == result) {
@@ -120,7 +119,7 @@ template <typename O>
120119
inline SQLRETURN GetStringAttribute(bool is_unicode, std::string_view attribute_value,
121120
bool is_length_in_bytes, SQLPOINTER output,
122121
O output_size, O* output_len_ptr,
123-
Diagnostics& diagnostics) {
122+
arrow::flight::sql::odbc::Diagnostics& diagnostics) {
124123
SQLRETURN result = SQL_SUCCESS;
125124
if (is_unicode) {
126125
result = GetAttributeSQLWCHAR(attribute_value, is_length_in_bytes, output,
@@ -158,17 +157,20 @@ inline void SetAttributeSQLWCHAR(SQLPOINTER new_value, SQLINTEGER input_length_i
158157
std::string& attribute_to_write) {
159158
thread_local std::vector<uint8_t> utf8_str;
160159
if (input_length_in_bytes == SQL_NTS) {
161-
WcsToUtf8(new_value, &utf8_str);
160+
arrow::flight::sql::odbc::WcsToUtf8(new_value, &utf8_str);
162161
} else {
163-
WcsToUtf8(new_value, input_length_in_bytes / GetSqlWCharSize(), &utf8_str);
162+
arrow::flight::sql::odbc::WcsToUtf8(
163+
new_value, input_length_in_bytes / arrow::flight::sql::odbc::GetSqlWCharSize(),
164+
&utf8_str);
164165
}
165166
attribute_to_write.assign((char*)utf8_str.data());
166167
}
167168

168169
template <typename T>
169170
void CheckIfAttributeIsSetToOnlyValidValue(SQLPOINTER value, T allowed_value) {
170171
if (static_cast<T>(reinterpret_cast<SQLULEN>(value)) != allowed_value) {
171-
throw DriverException("Optional feature not implemented", "HYC00");
172+
throw arrow::flight::sql::odbc::DriverException("Optional feature not implemented",
173+
"HYC00");
172174
}
173175
}
174176
} // namespace ODBC

cpp/src/arrow/flight/sql/odbc/odbc_impl/encoding_utils.h

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,12 @@
3232

3333
namespace ODBC {
3434

35-
using arrow::flight::sql::odbc::DriverException;
36-
using arrow::flight::sql::odbc::GetSqlWCharSize;
37-
using arrow::flight::sql::odbc::Utf8ToWcs;
38-
using arrow::flight::sql::odbc::WcsToUtf8;
39-
4035
// Return the number of bytes required for the conversion.
4136
template <typename CHAR_TYPE>
4237
inline size_t ConvertToSqlWChar(std::string_view str, SQLWCHAR* buffer,
4338
SQLLEN buffer_size_in_bytes) {
4439
thread_local std::vector<uint8_t> wstr;
45-
Utf8ToWcs<CHAR_TYPE>(str.data(), str.size(), &wstr);
40+
arrow::flight::sql::odbc::Utf8ToWcs<CHAR_TYPE>(str.data(), str.size(), &wstr);
4641
SQLLEN value_length_in_bytes = wstr.size();
4742

4843
if (buffer) {
@@ -51,11 +46,14 @@ inline size_t ConvertToSqlWChar(std::string_view str, SQLWCHAR* buffer,
5146

5247
// Write a NUL terminator
5348
if (buffer_size_in_bytes >=
54-
value_length_in_bytes + static_cast<SQLLEN>(GetSqlWCharSize())) {
55-
reinterpret_cast<CHAR_TYPE*>(buffer)[value_length_in_bytes / GetSqlWCharSize()] =
49+
value_length_in_bytes +
50+
static_cast<SQLLEN>(arrow::flight::sql::odbc::GetSqlWCharSize())) {
51+
reinterpret_cast<CHAR_TYPE*>(
52+
buffer)[value_length_in_bytes / arrow::flight::sql::odbc::GetSqlWCharSize()] =
5653
'\0';
5754
} else {
58-
SQLLEN num_chars_written = buffer_size_in_bytes / GetSqlWCharSize();
55+
SQLLEN num_chars_written =
56+
buffer_size_in_bytes / arrow::flight::sql::odbc::GetSqlWCharSize();
5957
// If we failed to even write one char, the buffer is too small to hold a
6058
// NUL-terminator.
6159
if (num_chars_written > 0) {
@@ -68,15 +66,16 @@ inline size_t ConvertToSqlWChar(std::string_view str, SQLWCHAR* buffer,
6866

6967
inline size_t ConvertToSqlWChar(std::string_view str, SQLWCHAR* buffer,
7068
SQLLEN buffer_size_in_bytes) {
71-
switch (GetSqlWCharSize()) {
69+
switch (arrow::flight::sql::odbc::GetSqlWCharSize()) {
7270
case sizeof(char16_t):
7371
return ConvertToSqlWChar<char16_t>(str, buffer, buffer_size_in_bytes);
7472
case sizeof(char32_t):
7573
return ConvertToSqlWChar<char32_t>(str, buffer, buffer_size_in_bytes);
7674
default:
7775
assert(false);
78-
throw DriverException("Encoding is unsupported, SQLWCHAR size: " +
79-
std::to_string(GetSqlWCharSize()));
76+
throw arrow::flight::sql::odbc::DriverException(
77+
"Encoding is unsupported, SQLWCHAR size: " +
78+
std::to_string(arrow::flight::sql::odbc::GetSqlWCharSize()));
8079
}
8180
}
8281

@@ -92,9 +91,9 @@ inline std::string SqlWcharToString(SQLWCHAR* wchar_msg, SQLINTEGER msg_len = SQ
9291
thread_local std::vector<uint8_t> utf8_str;
9392

9493
if (msg_len == SQL_NTS) {
95-
WcsToUtf8((void*)wchar_msg, &utf8_str);
94+
arrow::flight::sql::odbc::WcsToUtf8((void*)wchar_msg, &utf8_str);
9695
} else {
97-
WcsToUtf8((void*)wchar_msg, msg_len, &utf8_str);
96+
arrow::flight::sql::odbc::WcsToUtf8((void*)wchar_msg, msg_len, &utf8_str);
9897
}
9998

10099
return std::string(utf8_str.begin(), utf8_str.end());

cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_auth_method.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ class NoOpClientAuthHandler : public ClientAuthHandler {
4444
NoOpClientAuthHandler() {}
4545

4646
Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming) override {
47-
// Write a blank string. The server should ignore this and just accept any Handshake
48-
// request.
47+
// The server should ignore this and just accept any Handshake
48+
// request. Some servers do not allow authentication with no handshakes.
4949
return outgoing->Write(std::string());
5050
}
5151

cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.cc

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,9 @@ inline std::string GetCerts() { return ""; }
9999

100100
#endif
101101

102-
// Case insensitive comparator that takes string_view
103-
struct CaseInsensitiveComparatorStrView {
104-
bool operator()(std::string_view s1, std::string_view s2) const {
105-
return boost::lexicographical_compare(s1, s2, boost::is_iless());
106-
}
107-
};
108-
109-
const std::set<std::string_view, CaseInsensitiveComparatorStrView> BUILT_IN_PROPERTIES = {
102+
const std::set<std::string_view, CaseInsensitiveComparator> BUILT_IN_PROPERTIES = {
103+
FlightSqlConnection::DRIVER,
104+
FlightSqlConnection::DSN,
110105
FlightSqlConnection::HOST,
111106
FlightSqlConnection::PORT,
112107
FlightSqlConnection::USER,
@@ -160,14 +155,14 @@ void FlightSqlConnection::Connect(const ConnPropertyMap& properties,
160155
auto flight_ssl_configs = LoadFlightSslConfigs(properties);
161156

162157
Location location = BuildLocation(properties, missing_attr, flight_ssl_configs);
163-
FlightClientOptions client_options =
158+
client_options_ =
164159
BuildFlightClientOptions(properties, missing_attr, flight_ssl_configs);
165160

166161
const std::shared_ptr<ClientMiddlewareFactory>& cookie_factory = GetCookieFactory();
167-
client_options.middleware.push_back(cookie_factory);
162+
client_options_.middleware.push_back(cookie_factory);
168163

169164
std::unique_ptr<FlightClient> flight_client;
170-
ThrowIfNotOK(FlightClient::Connect(location, client_options).Value(&flight_client));
165+
ThrowIfNotOK(FlightClient::Connect(location, client_options_).Value(&flight_client));
171166
PopulateMetadataSettings(properties);
172167
PopulateCallOptions(properties);
173168

@@ -370,7 +365,7 @@ void FlightSqlConnection::Close() {
370365

371366
std::shared_ptr<Statement> FlightSqlConnection::CreateStatement() {
372367
return std::shared_ptr<Statement>(new FlightSqlStatement(
373-
diagnostics_, *sql_client_, call_options_, metadata_settings_));
368+
diagnostics_, *sql_client_, client_options_, call_options_, metadata_settings_));
374369
}
375370

376371
bool FlightSqlConnection::SetAttribute(Connection::AttributeId attribute,
@@ -416,7 +411,7 @@ FlightSqlConnection::FlightSqlConnection(OdbcVersion odbc_version,
416411
const std::string& driver_version)
417412
: diagnostics_("Apache Arrow", "Flight SQL", odbc_version),
418413
odbc_version_(odbc_version),
419-
info_(call_options_, sql_client_, driver_version),
414+
info_(client_options_, call_options_, sql_client_, driver_version),
420415
closed_(true) {
421416
attribute_[CONNECTION_DEAD] = static_cast<uint32_t>(SQL_TRUE);
422417
attribute_[LOGIN_TIMEOUT] = static_cast<uint32_t>(0);

cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_tables_reader.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@ GetTablesReader::GetTablesReader(std::shared_ptr<RecordBatch> record_batch)
3636

3737
bool GetTablesReader::Next() { return ++current_row_ < record_batch_->num_rows(); }
3838

39-
optional<std::string> GetTablesReader::GetCatalogName() {
39+
std::optional<std::string> GetTablesReader::GetCatalogName() {
4040
const auto& array = checked_pointer_cast<StringArray>(record_batch_->column(0));
4141

4242
if (array->IsNull(current_row_)) return nullopt;
4343

4444
return array->GetString(current_row_);
4545
}
4646

47-
optional<std::string> GetTablesReader::GetDbSchemaName() {
47+
std::optional<std::string> GetTablesReader::GetDbSchemaName() {
4848
const auto& array = checked_pointer_cast<StringArray>(record_batch_->column(1));
4949

5050
if (array->IsNull(current_row_)) return nullopt;

cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_tables_reader.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020

2121
namespace arrow::flight::sql::odbc {
2222

23-
using std::optional;
24-
2523
class GetTablesReader {
2624
private:
2725
std::shared_ptr<RecordBatch> record_batch_;
@@ -32,9 +30,9 @@ class GetTablesReader {
3230

3331
bool Next();
3432

35-
optional<std::string> GetCatalogName();
33+
std::optional<std::string> GetCatalogName();
3634

37-
optional<std::string> GetDbSchemaName();
35+
std::optional<std::string> GetDbSchemaName();
3836

3937
std::string GetTableName();
4038

cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_type_info_reader.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ namespace arrow::flight::sql::odbc {
2828

2929
using arrow::internal::checked_pointer_cast;
3030
using std::nullopt;
31+
using std::optional;
3132

3233
GetTypeInfoReader::GetTypeInfoReader(std::shared_ptr<RecordBatch> record_batch)
3334
: record_batch_(std::move(record_batch)), current_row_(-1) {}

cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_get_type_info_reader.h

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020

2121
namespace arrow::flight::sql::odbc {
2222

23-
using std::optional;
24-
2523
class GetTypeInfoReader {
2624
private:
2725
std::shared_ptr<RecordBatch> record_batch_;
@@ -36,39 +34,39 @@ class GetTypeInfoReader {
3634

3735
int32_t GetDataType();
3836

39-
optional<int32_t> GetColumnSize();
37+
std::optional<int32_t> GetColumnSize();
4038

41-
optional<std::string> GetLiteralPrefix();
39+
std::optional<std::string> GetLiteralPrefix();
4240

43-
optional<std::string> GetLiteralSuffix();
41+
std::optional<std::string> GetLiteralSuffix();
4442

45-
optional<std::vector<std::string>> GetCreateParams();
43+
std::optional<std::vector<std::string>> GetCreateParams();
4644

4745
int32_t GetNullable();
4846

4947
bool GetCaseSensitive();
5048

5149
int32_t GetSearchable();
5250

53-
optional<bool> GetUnsignedAttribute();
51+
std::optional<bool> GetUnsignedAttribute();
5452

5553
bool GetFixedPrecScale();
5654

57-
optional<bool> GetAutoIncrement();
55+
std::optional<bool> GetAutoIncrement();
5856

59-
optional<std::string> GetLocalTypeName();
57+
std::optional<std::string> GetLocalTypeName();
6058

61-
optional<int32_t> GetMinimumScale();
59+
std::optional<int32_t> GetMinimumScale();
6260

63-
optional<int32_t> GetMaximumScale();
61+
std::optional<int32_t> GetMaximumScale();
6462

6563
int32_t GetSqlDataType();
6664

67-
optional<int32_t> GetDatetimeSubcode();
65+
std::optional<int32_t> GetDatetimeSubcode();
6866

69-
optional<int32_t> GetNumPrecRadix();
67+
std::optional<int32_t> GetNumPrecRadix();
7068

71-
optional<int32_t> GetIntervalPrecision();
69+
std::optional<int32_t> GetIntervalPrecision();
7270
};
7371

7472
} // namespace arrow::flight::sql::odbc

0 commit comments

Comments
 (0)