Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cpp/src/arrow/flight/sql/odbc/odbc_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,7 @@ SQLRETURN SQLSetConnectAttr(SQLHDBC conn, SQLINTEGER attr, SQLPOINTER value_ptr,
// entries in the properties.
void LoadPropertiesFromDSN(const std::string& dsn,
Connection::ConnPropertyMap& properties) {
arrow::flight::sql::odbc::config::Configuration config;
config::Configuration config;
config.LoadDsn(dsn);
Connection::ConnPropertyMap dsn_properties = config.GetProperties();
for (auto& [key, value] : dsn_properties) {
Expand Down Expand Up @@ -796,7 +796,7 @@ SQLRETURN SQLDriverConnect(SQLHDBC conn, SQLHWND window_handle,
// Load the DSN window according to driver_completion
if (driver_completion == SQL_DRIVER_PROMPT) {
// Load DSN window before first attempt to connect
arrow::flight::sql::odbc::config::Configuration config;
config::Configuration config;
if (!DisplayConnectionWindow(window_handle, config, properties)) {
return static_cast<SQLRETURN>(SQL_NO_DATA);
}
Expand All @@ -809,7 +809,7 @@ SQLRETURN SQLDriverConnect(SQLHDBC conn, SQLHWND window_handle,
// If first connection fails due to missing attributes, load
// the DSN window and try to connect again
if (!missing_properties.empty()) {
arrow::flight::sql::odbc::config::Configuration config;
config::Configuration config;
missing_properties.clear();

if (!DisplayConnectionWindow(window_handle, config, properties)) {
Expand Down Expand Up @@ -855,7 +855,7 @@ SQLRETURN SQLConnect(SQLHDBC conn, SQLWCHAR* dsn_name, SQLSMALLINT dsn_name_len,
ODBCConnection* connection = reinterpret_cast<ODBCConnection*>(conn);
std::string dsn = SqlWcharToString(dsn_name, dsn_name_len);

Configuration config;
config::Configuration config;
config.LoadDsn(dsn);

if (user_name) {
Expand Down
4 changes: 3 additions & 1 deletion cpp/src/arrow/flight/sql/odbc/odbc_impl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,11 @@ add_arrow_test(odbc_spi_impl_test
accessors/time_array_accessor_test.cc
accessors/timestamp_array_accessor_test.cc
flight_sql_connection_test.cc
flight_sql_stream_chunk_buffer_test.cc
parse_table_types_test.cc
json_converter_test.cc
record_batch_transformer_test.cc
util_test.cc
EXTRA_LINK_LIBS
arrow_odbc_spi_impl)
arrow_odbc_spi_impl
arrow_flight_testing_shared)
34 changes: 18 additions & 16 deletions cpp/src/arrow/flight/sql/odbc/odbc_impl/attribute_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@
// GH-48083 TODO: replace `namespace ODBC` with `namespace arrow::flight::sql::odbc`
namespace ODBC {

using arrow::flight::sql::odbc::Diagnostics;
using arrow::flight::sql::odbc::DriverException;
using arrow::flight::sql::odbc::WcsToUtf8;

template <typename T, typename O>
inline void GetAttribute(T attribute_value, SQLPOINTER output, O output_size,
O* output_len_ptr) {
Expand Down Expand Up @@ -70,7 +66,7 @@ inline SQLRETURN GetAttributeUTF8(std::string_view attribute_value, SQLPOINTER o
template <typename O>
inline SQLRETURN GetAttributeUTF8(std::string_view attribute_value, SQLPOINTER output,
O output_size, O* output_len_ptr,
Diagnostics& diagnostics) {
arrow::flight::sql::odbc::Diagnostics& diagnostics) {
SQLRETURN result =
GetAttributeUTF8(attribute_value, output, output_size, output_len_ptr);
if (SQL_SUCCESS_WITH_INFO == result) {
Expand All @@ -85,10 +81,11 @@ inline SQLRETURN GetAttributeSQLWCHAR(std::string_view attribute_value,
O output_size, O* output_len_ptr) {
size_t length = ConvertToSqlWChar(
attribute_value, reinterpret_cast<SQLWCHAR*>(output),
is_length_in_bytes ? output_size : output_size * GetSqlWCharSize());
is_length_in_bytes ? output_size
: output_size * arrow::flight::sql::odbc::GetSqlWCharSize());

if (!is_length_in_bytes) {
length = length / GetSqlWCharSize();
length = length / arrow::flight::sql::odbc::GetSqlWCharSize();
}

if (output_len_ptr) {
Expand All @@ -97,17 +94,19 @@ inline SQLRETURN GetAttributeSQLWCHAR(std::string_view attribute_value,

if (output &&
output_size <
static_cast<O>(length + (is_length_in_bytes ? GetSqlWCharSize() : 1))) {
static_cast<O>(length + (is_length_in_bytes
? arrow::flight::sql::odbc::GetSqlWCharSize()
: 1))) {
return SQL_SUCCESS_WITH_INFO;
}
return SQL_SUCCESS;
}

template <typename O>
inline SQLRETURN GetAttributeSQLWCHAR(std::string_view attribute_value,
bool is_length_in_bytes, SQLPOINTER output,
O output_size, O* output_len_ptr,
Diagnostics& diagnostics) {
inline SQLRETURN GetAttributeSQLWCHAR(
const std::string& attribute_value, bool is_length_in_bytes, SQLPOINTER output,
O output_size, O* output_len_ptr,
arrow::flight::sql::odbc::Diagnostics& diagnostics) {
SQLRETURN result = GetAttributeSQLWCHAR(attribute_value, is_length_in_bytes, output,
output_size, output_len_ptr);
if (SQL_SUCCESS_WITH_INFO == result) {
Expand All @@ -120,7 +119,7 @@ template <typename O>
inline SQLRETURN GetStringAttribute(bool is_unicode, std::string_view attribute_value,
bool is_length_in_bytes, SQLPOINTER output,
O output_size, O* output_len_ptr,
Diagnostics& diagnostics) {
arrow::flight::sql::odbc::Diagnostics& diagnostics) {
SQLRETURN result = SQL_SUCCESS;
if (is_unicode) {
result = GetAttributeSQLWCHAR(attribute_value, is_length_in_bytes, output,
Expand Down Expand Up @@ -158,17 +157,20 @@ inline void SetAttributeSQLWCHAR(SQLPOINTER new_value, SQLINTEGER input_length_i
std::string& attribute_to_write) {
thread_local std::vector<uint8_t> utf8_str;
if (input_length_in_bytes == SQL_NTS) {
WcsToUtf8(new_value, &utf8_str);
arrow::flight::sql::odbc::WcsToUtf8(new_value, &utf8_str);
} else {
WcsToUtf8(new_value, input_length_in_bytes / GetSqlWCharSize(), &utf8_str);
arrow::flight::sql::odbc::WcsToUtf8(
new_value, input_length_in_bytes / arrow::flight::sql::odbc::GetSqlWCharSize(),
&utf8_str);
}
attribute_to_write.assign((char*)utf8_str.data());
}

template <typename T>
void CheckIfAttributeIsSetToOnlyValidValue(SQLPOINTER value, T allowed_value) {
if (static_cast<T>(reinterpret_cast<SQLULEN>(value)) != allowed_value) {
throw DriverException("Optional feature not implemented", "HYC00");
throw arrow::flight::sql::odbc::DriverException("Optional feature not implemented",
"HYC00");
}
}
} // namespace ODBC
27 changes: 13 additions & 14 deletions cpp/src/arrow/flight/sql/odbc/odbc_impl/encoding_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,12 @@

namespace ODBC {

using arrow::flight::sql::odbc::DriverException;
using arrow::flight::sql::odbc::GetSqlWCharSize;
using arrow::flight::sql::odbc::Utf8ToWcs;
using arrow::flight::sql::odbc::WcsToUtf8;

// Return the number of bytes required for the conversion.
template <typename CHAR_TYPE>
inline size_t ConvertToSqlWChar(std::string_view str, SQLWCHAR* buffer,
SQLLEN buffer_size_in_bytes) {
thread_local std::vector<uint8_t> wstr;
Utf8ToWcs<CHAR_TYPE>(str.data(), str.size(), &wstr);
arrow::flight::sql::odbc::Utf8ToWcs<CHAR_TYPE>(str.data(), str.size(), &wstr);
SQLLEN value_length_in_bytes = wstr.size();

if (buffer) {
Expand All @@ -51,11 +46,14 @@ inline size_t ConvertToSqlWChar(std::string_view str, SQLWCHAR* buffer,

// Write a NUL terminator
if (buffer_size_in_bytes >=
value_length_in_bytes + static_cast<SQLLEN>(GetSqlWCharSize())) {
reinterpret_cast<CHAR_TYPE*>(buffer)[value_length_in_bytes / GetSqlWCharSize()] =
value_length_in_bytes +
static_cast<SQLLEN>(arrow::flight::sql::odbc::GetSqlWCharSize())) {
reinterpret_cast<CHAR_TYPE*>(
buffer)[value_length_in_bytes / arrow::flight::sql::odbc::GetSqlWCharSize()] =
'\0';
} else {
SQLLEN num_chars_written = buffer_size_in_bytes / GetSqlWCharSize();
SQLLEN num_chars_written =
buffer_size_in_bytes / arrow::flight::sql::odbc::GetSqlWCharSize();
// If we failed to even write one char, the buffer is too small to hold a
// NUL-terminator.
if (num_chars_written > 0) {
Expand All @@ -68,15 +66,16 @@ inline size_t ConvertToSqlWChar(std::string_view str, SQLWCHAR* buffer,

inline size_t ConvertToSqlWChar(std::string_view str, SQLWCHAR* buffer,
SQLLEN buffer_size_in_bytes) {
switch (GetSqlWCharSize()) {
switch (arrow::flight::sql::odbc::GetSqlWCharSize()) {
case sizeof(char16_t):
return ConvertToSqlWChar<char16_t>(str, buffer, buffer_size_in_bytes);
case sizeof(char32_t):
return ConvertToSqlWChar<char32_t>(str, buffer, buffer_size_in_bytes);
default:
assert(false);
throw DriverException("Encoding is unsupported, SQLWCHAR size: " +
std::to_string(GetSqlWCharSize()));
throw arrow::flight::sql::odbc::DriverException(
"Encoding is unsupported, SQLWCHAR size: " +
std::to_string(arrow::flight::sql::odbc::GetSqlWCharSize()));
}
}

Expand All @@ -92,9 +91,9 @@ inline std::string SqlWcharToString(SQLWCHAR* wchar_msg, SQLINTEGER msg_len = SQ
thread_local std::vector<uint8_t> utf8_str;

if (msg_len == SQL_NTS) {
WcsToUtf8((void*)wchar_msg, &utf8_str);
arrow::flight::sql::odbc::WcsToUtf8((void*)wchar_msg, &utf8_str);
} else {
WcsToUtf8((void*)wchar_msg, msg_len, &utf8_str);
arrow::flight::sql::odbc::WcsToUtf8((void*)wchar_msg, msg_len, &utf8_str);
}

return std::string(utf8_str.begin(), utf8_str.end());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ class NoOpClientAuthHandler : public ClientAuthHandler {
NoOpClientAuthHandler() {}

Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming) override {
// Write a blank string. The server should ignore this and just accept any Handshake
// request.
// The server should ignore this and just accept any Handshake
// request. Some servers do not allow authentication with no handshakes.
return outgoing->Write(std::string());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,9 @@ inline std::string GetCerts() { return ""; }

#endif

// Case insensitive comparator that takes string_view
struct CaseInsensitiveComparatorStrView {
bool operator()(std::string_view s1, std::string_view s2) const {
return boost::lexicographical_compare(s1, s2, boost::is_iless());
}
};

const std::set<std::string_view, CaseInsensitiveComparatorStrView> BUILT_IN_PROPERTIES = {
const std::set<std::string_view, CaseInsensitiveComparator> BUILT_IN_PROPERTIES = {
FlightSqlConnection::DRIVER,
FlightSqlConnection::DSN,
FlightSqlConnection::HOST,
FlightSqlConnection::PORT,
FlightSqlConnection::USER,
Expand Down Expand Up @@ -160,14 +155,14 @@ void FlightSqlConnection::Connect(const ConnPropertyMap& properties,
auto flight_ssl_configs = LoadFlightSslConfigs(properties);

Location location = BuildLocation(properties, missing_attr, flight_ssl_configs);
FlightClientOptions client_options =
client_options_ =
BuildFlightClientOptions(properties, missing_attr, flight_ssl_configs);

const std::shared_ptr<ClientMiddlewareFactory>& cookie_factory = GetCookieFactory();
client_options.middleware.push_back(cookie_factory);
client_options_.middleware.push_back(cookie_factory);

std::unique_ptr<FlightClient> flight_client;
ThrowIfNotOK(FlightClient::Connect(location, client_options).Value(&flight_client));
ThrowIfNotOK(FlightClient::Connect(location, client_options_).Value(&flight_client));
PopulateMetadataSettings(properties);
PopulateCallOptions(properties);

Expand Down Expand Up @@ -370,7 +365,7 @@ void FlightSqlConnection::Close() {

std::shared_ptr<Statement> FlightSqlConnection::CreateStatement() {
return std::shared_ptr<Statement>(new FlightSqlStatement(
diagnostics_, *sql_client_, call_options_, metadata_settings_));
diagnostics_, *sql_client_, client_options_, call_options_, metadata_settings_));
}

bool FlightSqlConnection::SetAttribute(Connection::AttributeId attribute,
Expand Down Expand Up @@ -416,7 +411,7 @@ FlightSqlConnection::FlightSqlConnection(OdbcVersion odbc_version,
const std::string& driver_version)
: diagnostics_("Apache Arrow", "Flight SQL", odbc_version),
odbc_version_(odbc_version),
info_(call_options_, sql_client_, driver_version),
info_(client_options_, call_options_, sql_client_, driver_version),
closed_(true) {
attribute_[CONNECTION_DEAD] = static_cast<uint32_t>(SQL_TRUE);
attribute_[LOGIN_TIMEOUT] = static_cast<uint32_t>(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ GetTablesReader::GetTablesReader(std::shared_ptr<RecordBatch> record_batch)

bool GetTablesReader::Next() { return ++current_row_ < record_batch_->num_rows(); }

optional<std::string> GetTablesReader::GetCatalogName() {
std::optional<std::string> GetTablesReader::GetCatalogName() {
const auto& array = checked_pointer_cast<StringArray>(record_batch_->column(0));

if (array->IsNull(current_row_)) return nullopt;

return array->GetString(current_row_);
}

optional<std::string> GetTablesReader::GetDbSchemaName() {
std::optional<std::string> GetTablesReader::GetDbSchemaName() {
const auto& array = checked_pointer_cast<StringArray>(record_batch_->column(1));

if (array->IsNull(current_row_)) return nullopt;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

namespace arrow::flight::sql::odbc {

using std::optional;

class GetTablesReader {
private:
std::shared_ptr<RecordBatch> record_batch_;
Expand All @@ -32,9 +30,9 @@ class GetTablesReader {

bool Next();

optional<std::string> GetCatalogName();
std::optional<std::string> GetCatalogName();

optional<std::string> GetDbSchemaName();
std::optional<std::string> GetDbSchemaName();

std::string GetTableName();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ namespace arrow::flight::sql::odbc {

using arrow::internal::checked_pointer_cast;
using std::nullopt;
using std::optional;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not necessarily a blocker here but generally I think we avoid using types from std.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


GetTypeInfoReader::GetTypeInfoReader(std::shared_ptr<RecordBatch> record_batch)
: record_batch_(std::move(record_batch)), current_row_(-1) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

namespace arrow::flight::sql::odbc {

using std::optional;

class GetTypeInfoReader {
private:
std::shared_ptr<RecordBatch> record_batch_;
Expand All @@ -36,39 +34,39 @@ class GetTypeInfoReader {

int32_t GetDataType();

optional<int32_t> GetColumnSize();
std::optional<int32_t> GetColumnSize();

optional<std::string> GetLiteralPrefix();
std::optional<std::string> GetLiteralPrefix();

optional<std::string> GetLiteralSuffix();
std::optional<std::string> GetLiteralSuffix();

optional<std::vector<std::string>> GetCreateParams();
std::optional<std::vector<std::string>> GetCreateParams();

int32_t GetNullable();

bool GetCaseSensitive();

int32_t GetSearchable();

optional<bool> GetUnsignedAttribute();
std::optional<bool> GetUnsignedAttribute();

bool GetFixedPrecScale();

optional<bool> GetAutoIncrement();
std::optional<bool> GetAutoIncrement();

optional<std::string> GetLocalTypeName();
std::optional<std::string> GetLocalTypeName();

optional<int32_t> GetMinimumScale();
std::optional<int32_t> GetMinimumScale();

optional<int32_t> GetMaximumScale();
std::optional<int32_t> GetMaximumScale();

int32_t GetSqlDataType();

optional<int32_t> GetDatetimeSubcode();
std::optional<int32_t> GetDatetimeSubcode();

optional<int32_t> GetNumPrecRadix();
std::optional<int32_t> GetNumPrecRadix();

optional<int32_t> GetIntervalPrecision();
std::optional<int32_t> GetIntervalPrecision();
};

} // namespace arrow::flight::sql::odbc
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@
namespace arrow::flight::sql::odbc {

FlightSqlResultSet::FlightSqlResultSet(
FlightSqlClient& flight_sql_client, const FlightCallOptions& call_options,
const std::shared_ptr<FlightInfo>& flight_info,
FlightSqlClient& flight_sql_client, const FlightClientOptions& client_options,
const FlightCallOptions& call_options, const std::shared_ptr<FlightInfo>& flight_info,
const std::shared_ptr<RecordBatchTransformer>& transformer, Diagnostics& diagnostics,
const MetadataSettings& metadata_settings)
: metadata_settings_(metadata_settings),
chunk_buffer_(flight_sql_client, call_options, flight_info,
chunk_buffer_(flight_sql_client, client_options, call_options, flight_info,
metadata_settings_.chunk_buffer_capacity),
transformer_(transformer),
metadata_(transformer
Expand Down
Loading
Loading