diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4891e09b..20ada727 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -6,6 +6,7 @@ add_library( postgres_ext_library OBJECT postgres_attach.cpp postgres_binary_copy.cpp + postgres_binary_reader.cpp postgres_connection.cpp postgres_copy_from.cpp postgres_copy_to.cpp @@ -15,6 +16,7 @@ add_library( postgres_query.cpp postgres_scanner.cpp postgres_storage.cpp + postgres_text_reader.cpp postgres_utils.cpp) set(ALL_OBJECT_FILES ${ALL_OBJECT_FILES} $ diff --git a/src/include/postgres_binary_reader.hpp b/src/include/postgres_binary_reader.hpp index 1988662a..aadc8d0a 100644 --- a/src/include/postgres_binary_reader.hpp +++ b/src/include/postgres_binary_reader.hpp @@ -8,88 +8,30 @@ #pragma once -#include "duckdb.hpp" -#include "duckdb/common/types/interval.hpp" -#include "postgres_conversion.hpp" +#include "postgres_result_reader.hpp" +#include "postgres_connection.hpp" namespace duckdb { -struct PostgresBinaryReader { - explicit PostgresBinaryReader(PostgresConnection &con_p) : con(con_p) { - } - ~PostgresBinaryReader() { - Reset(); - } - PostgresConnection &GetConn() { - return con; - } +struct PostgresBinaryReader : public PostgresResultReader { + explicit PostgresBinaryReader(PostgresConnection &con, const vector &column_ids, + const PostgresBindData &bind_data); + ~PostgresBinaryReader() override; - bool Next() { - Reset(); - char *out_buffer; - int len = PQgetCopyData(con.GetConn(), &out_buffer, 0); - auto new_buffer = data_ptr_cast(out_buffer); - - // len -1 signals end - if (len == -1) { - auto final_result = PQgetResult(con.GetConn()); - if (!final_result || PQresultStatus(final_result) != PGRES_COMMAND_OK) { - throw IOException("Failed to fetch header for COPY: %s", string(PQresultErrorMessage(final_result))); - } - return false; - } - - // len -2 is error - // we expect at least 2 bytes in each message for the tuple count - if (!new_buffer || len < sizeof(int16_t)) { - throw IOException("Unable to read binary COPY data from Postgres: %s", - string(PQerrorMessage(con.GetConn()))); - } - buffer = new_buffer; - buffer_ptr = buffer; - end = buffer + len; - return true; - } +public: + void BeginCopy(const string &sql) override; + PostgresReadResult Read(DataChunk &result) override; - void CheckResult() { - auto result = PQgetResult(con.GetConn()); - if (!result || PQresultStatus(result) != PGRES_COMMAND_OK) { - throw std::runtime_error("Failed to execute COPY: " + string(PQresultErrorMessage(result))); - } - } +protected: + bool Next(); + void CheckResult(); - void Reset() { - if (buffer) { - PQfreemem(buffer); - } - buffer = nullptr; - buffer_ptr = nullptr; - end = nullptr; - } - bool Ready() { - return buffer_ptr != nullptr; - } + void Reset(); + bool Ready(); - void CheckHeader() { - auto magic_len = PostgresConversion::COPY_HEADER_LENGTH; - auto flags_len = 8; - auto header_len = magic_len + flags_len; + void CheckHeader(); - if (!buffer_ptr) { - throw IOException("buffer_ptr not set in CheckHeader"); - } - if (buffer_ptr + header_len >= end) { - throw IOException("Unable to read binary COPY data from Postgres, invalid header"); - } - if (memcmp(buffer_ptr, PostgresConversion::COPY_HEADER, magic_len) != 0) { - throw IOException("Expected Postgres binary COPY header, got something else"); - } - buffer_ptr += header_len; - // as far as i can tell the "Flags field" and the "Header - // extension area length" do not contain anything interesting - } - -public: +protected: template inline T ReadIntegerUnchecked() { T val = Load(buffer_ptr); @@ -192,21 +134,7 @@ struct PostgresBinaryReader { return result; } - PostgresDecimalConfig ReadDecimalConfig() { - PostgresDecimalConfig config; - config.ndigits = ReadInteger(); - config.weight = ReadInteger(); - auto sign = ReadInteger(); - - if (!(sign == NUMERIC_POS || sign == NUMERIC_NAN || sign == NUMERIC_PINF || sign == NUMERIC_NINF || - sign == NUMERIC_NEG)) { - throw NotImplementedException("Postgres numeric NA/Inf"); - } - config.is_negative = sign == NUMERIC_NEG; - config.scale = ReadInteger(); - - return config; - } + PostgresDecimalConfig ReadDecimalConfig(); template T ReadDecimal() { @@ -270,307 +198,17 @@ struct PostgresBinaryReader { return (config.is_negative ? -base_res : base_res); } - void ReadGeometry(const LogicalType &type, const PostgresType &postgres_type, Vector &out_vec, - idx_t output_offset) { - idx_t element_count = 0; - switch (postgres_type.info) { - case PostgresTypeAnnotation::GEOM_LINE: - case PostgresTypeAnnotation::GEOM_CIRCLE: - element_count = 3; - break; - case PostgresTypeAnnotation::GEOM_LINE_SEGMENT: - case PostgresTypeAnnotation::GEOM_BOX: - element_count = 4; - break; - case PostgresTypeAnnotation::GEOM_PATH: { - // variable number of elements - auto path_is_closed = ReadBoolean(); // ignored for now - element_count = 2 * ReadInteger(); - break; - } - case PostgresTypeAnnotation::GEOM_POLYGON: - // variable number of elements - element_count = 2 * ReadInteger(); - break; - default: - throw InternalException("Unsupported type for ReadGeometry"); - } - auto list_entries = FlatVector::GetData(out_vec); - auto child_offset = ListVector::GetListSize(out_vec); - ListVector::Reserve(out_vec, child_offset + element_count); - list_entries[output_offset].offset = child_offset; - list_entries[output_offset].length = element_count; - auto &child_vector = ListVector::GetEntry(out_vec); - auto child_data = FlatVector::GetData(child_vector); - for (idx_t i = 0; i < element_count; i++) { - child_data[child_offset + i] = ReadDouble(); - } - ListVector::SetListSize(out_vec, child_offset + element_count); - } + void ReadGeometry(const LogicalType &type, const PostgresType &postgres_type, Vector &out_vec, idx_t output_offset); void ReadArray(const LogicalType &type, const PostgresType &postgres_type, Vector &out_vec, idx_t output_offset, - uint32_t current_count, uint32_t dimensions[], uint32_t ndim) { - auto list_entries = FlatVector::GetData(out_vec); - auto child_offset = ListVector::GetListSize(out_vec); - auto child_dimension = dimensions[0]; - auto child_count = current_count * child_dimension; - // set up the list entries for this dimension - auto current_offset = child_offset; - for (idx_t c = 0; c < current_count; c++) { - auto &list_entry = list_entries[output_offset + c]; - list_entry.offset = current_offset; - list_entry.length = child_dimension; - current_offset += child_dimension; - } - ListVector::Reserve(out_vec, child_offset + child_count); - auto &child_vec = ListVector::GetEntry(out_vec); - auto &child_type = ListType::GetChildType(type); - auto &child_pg_type = postgres_type.children[0]; - if (ndim > 1) { - // there are more dimensions to read - recurse into child list - ReadArray(child_type, child_pg_type, child_vec, child_offset, child_count, dimensions + 1, ndim - 1); - } else { - // this is the last level - read the actual values - for (idx_t child_idx = 0; child_idx < child_count; child_idx++) { - ReadValue(child_type, child_pg_type, child_vec, child_offset + child_idx); - } - } - ListVector::SetListSize(out_vec, child_offset + child_count); - } + uint32_t current_count, uint32_t dimensions[], uint32_t ndim); - void ReadValue(const LogicalType &type, const PostgresType &postgres_type, Vector &out_vec, idx_t output_offset) { - auto value_len = ReadInteger(); - if (value_len == -1) { // NULL - FlatVector::SetNull(out_vec, output_offset, true); - return; - } - switch (type.id()) { - case LogicalTypeId::SMALLINT: - D_ASSERT(value_len == sizeof(int16_t)); - FlatVector::GetData(out_vec)[output_offset] = ReadInteger(); - break; - case LogicalTypeId::INTEGER: - D_ASSERT(value_len == sizeof(int32_t)); - FlatVector::GetData(out_vec)[output_offset] = ReadInteger(); - break; - case LogicalTypeId::UINTEGER: - D_ASSERT(value_len == sizeof(uint32_t)); - FlatVector::GetData(out_vec)[output_offset] = ReadInteger(); - break; - case LogicalTypeId::BIGINT: - if (postgres_type.info == PostgresTypeAnnotation::CTID) { - D_ASSERT(value_len == 6); - int64_t page_index = ReadInteger(); - int64_t row_in_page = ReadInteger(); - FlatVector::GetData(out_vec)[output_offset] = (page_index << 16LL) + row_in_page; - return; - } - D_ASSERT(value_len == sizeof(int64_t)); - FlatVector::GetData(out_vec)[output_offset] = ReadInteger(); - break; - case LogicalTypeId::FLOAT: - D_ASSERT(value_len == sizeof(float)); - FlatVector::GetData(out_vec)[output_offset] = ReadFloat(); - break; - case LogicalTypeId::DOUBLE: { - // this was an unbounded decimal, read params from value and cast to double - if (postgres_type.info == PostgresTypeAnnotation::NUMERIC_AS_DOUBLE) { - FlatVector::GetData(out_vec)[output_offset] = ReadDecimal(); - break; - } - D_ASSERT(value_len == sizeof(double)); - FlatVector::GetData(out_vec)[output_offset] = ReadDouble(); - break; - } - - case LogicalTypeId::BLOB: - case LogicalTypeId::VARCHAR: { - if (postgres_type.info == PostgresTypeAnnotation::JSONB) { - auto version = ReadInteger(); - value_len--; - if (version != 1) { - throw NotImplementedException("JSONB version number mismatch, expected 1, got %d", version); - } - } - auto str = ReadString(value_len); - if (postgres_type.info == PostgresTypeAnnotation::FIXED_LENGTH_CHAR) { - // CHAR column - remove trailing spaces - while (value_len > 0 && str[value_len - 1] == ' ') { - value_len--; - } - } - FlatVector::GetData(out_vec)[output_offset] = - StringVector::AddStringOrBlob(out_vec, str, value_len); - break; - } - case LogicalTypeId::BOOLEAN: - D_ASSERT(value_len == sizeof(bool)); - FlatVector::GetData(out_vec)[output_offset] = ReadBoolean(); - break; - case LogicalTypeId::DECIMAL: { - if (value_len < sizeof(uint16_t) * 4) { - throw InvalidInputException("Need at least 8 bytes to read a Postgres decimal. Got %d", value_len); - } - switch (type.InternalType()) { - case PhysicalType::INT16: - FlatVector::GetData(out_vec)[output_offset] = ReadDecimal(); - break; - case PhysicalType::INT32: - FlatVector::GetData(out_vec)[output_offset] = ReadDecimal(); - break; - case PhysicalType::INT64: - FlatVector::GetData(out_vec)[output_offset] = ReadDecimal(); - break; - case PhysicalType::INT128: - FlatVector::GetData(out_vec)[output_offset] = - ReadDecimal(); - break; - default: - throw InvalidInputException("Unsupported decimal storage type"); - } - break; - } - - case LogicalTypeId::DATE: { - D_ASSERT(value_len == sizeof(int32_t)); - auto out_ptr = FlatVector::GetData(out_vec); - out_ptr[output_offset] = ReadDate(); - break; - } - case LogicalTypeId::TIME: { - D_ASSERT(value_len == sizeof(int64_t)); - FlatVector::GetData(out_vec)[output_offset] = ReadTime(); - break; - } - case LogicalTypeId::TIME_TZ: { - D_ASSERT(value_len == sizeof(int64_t) + sizeof(int32_t)); - FlatVector::GetData(out_vec)[output_offset] = ReadTimeTZ(); - break; - } - case LogicalTypeId::TIMESTAMP_TZ: - case LogicalTypeId::TIMESTAMP: { - D_ASSERT(value_len == sizeof(int64_t)); - FlatVector::GetData(out_vec)[output_offset] = ReadTimestamp(); - break; - } - case LogicalTypeId::ENUM: { - auto enum_val = string(ReadString(value_len), value_len); - auto offset = EnumType::GetPos(type, enum_val); - if (offset < 0) { - throw IOException("Could not map ENUM value %s", enum_val); - } - switch (type.InternalType()) { - case PhysicalType::UINT8: - FlatVector::GetData(out_vec)[output_offset] = (uint8_t)offset; - break; - case PhysicalType::UINT16: - FlatVector::GetData(out_vec)[output_offset] = (uint16_t)offset; - break; - - case PhysicalType::UINT32: - FlatVector::GetData(out_vec)[output_offset] = (uint32_t)offset; - break; - - default: - throw InternalException("ENUM can only have unsigned integers (except " - "UINT64) as physical types, got %s", - TypeIdToString(type.InternalType())); - } - break; - } - case LogicalTypeId::INTERVAL: { - FlatVector::GetData(out_vec)[output_offset] = ReadInterval(); - break; - } - case LogicalTypeId::UUID: { - D_ASSERT(value_len == 2 * sizeof(int64_t)); - FlatVector::GetData(out_vec)[output_offset] = ReadUUID(); - break; - } - case LogicalTypeId::LIST: { - auto &list_entry = FlatVector::GetData(out_vec)[output_offset]; - auto child_offset = ListVector::GetListSize(out_vec); - - if (value_len < 1) { - list_entry.offset = child_offset; - list_entry.length = 0; - break; - } - switch (postgres_type.info) { - case PostgresTypeAnnotation::GEOM_LINE: - case PostgresTypeAnnotation::GEOM_LINE_SEGMENT: - case PostgresTypeAnnotation::GEOM_BOX: - case PostgresTypeAnnotation::GEOM_PATH: - case PostgresTypeAnnotation::GEOM_POLYGON: - case PostgresTypeAnnotation::GEOM_CIRCLE: - ReadGeometry(type, postgres_type, out_vec, output_offset); - return; - default: - break; - } - D_ASSERT(value_len >= 3 * sizeof(uint32_t)); - auto array_dim = ReadInteger(); - auto array_has_null = ReadInteger(); // whether or not the array has nulls - ignore - auto value_oid = ReadInteger(); // value_oid - not necessary - if (array_dim == 0) { - list_entry.offset = child_offset; - list_entry.length = 0; - return; - } - // verify the number of dimensions matches the expected number of dimensions - idx_t expected_dimensions = 0; - const_reference current_type = type; - while (current_type.get().id() == LogicalTypeId::LIST) { - current_type = ListType::GetChildType(current_type.get()); - expected_dimensions++; - } - if (expected_dimensions != array_dim) { - throw InvalidInputException( - "Expected an array with %llu dimensions, but this array has %llu dimensions. The array stored in " - "Postgres does not match the schema. Postgres does not enforce that arrays match the provided " - "schema but DuckDB requires this.\nSet pg_array_as_varchar=true to read the array as a varchar " - "instead.", - expected_dimensions, array_dim); - } - auto dimensions = unique_ptr(new uint32_t[array_dim]); - for (idx_t d = 0; d < array_dim; d++) { - dimensions[d] = ReadInteger(); - auto lb = ReadInteger(); // index lower bounds for each dimension -- we don't need them - } - // read the arrays recursively - ReadArray(type, postgres_type, out_vec, output_offset, 1, dimensions.get(), array_dim); - break; - } - case LogicalTypeId::STRUCT: { - auto &child_entries = StructVector::GetEntries(out_vec); - if (postgres_type.info == PostgresTypeAnnotation::GEOM_POINT) { - D_ASSERT(value_len == sizeof(double) * 2); - FlatVector::GetData(*child_entries[0])[output_offset] = ReadDouble(); - FlatVector::GetData(*child_entries[1])[output_offset] = ReadDouble(); - break; - } - auto entry_count = ReadInteger(); - if (entry_count != child_entries.size()) { - throw InternalException("Mismatch in entry count: expected %d but got %d", child_entries.size(), - entry_count); - } - for (idx_t c = 0; c < entry_count; c++) { - auto &child = *child_entries[c]; - auto value_oid = ReadInteger(); - ReadValue(child.GetType(), postgres_type.children[c], child, output_offset); - } - break; - } - default: - throw InternalException("Unsupported Type %s", type.ToString()); - } - } + void ReadValue(const LogicalType &type, const PostgresType &postgres_type, Vector &out_vec, idx_t output_offset); private: data_ptr_t buffer = nullptr; data_ptr_t buffer_ptr = nullptr; data_ptr_t end = nullptr; - PostgresConnection &con; }; } // namespace duckdb diff --git a/src/include/postgres_connection.hpp b/src/include/postgres_connection.hpp index d13b834c..48035a11 100644 --- a/src/include/postgres_connection.hpp +++ b/src/include/postgres_connection.hpp @@ -61,7 +61,7 @@ class PostgresConnection { void CopyChunk(ClientContext &context, PostgresCopyState &state, DataChunk &chunk, DataChunk &varchar_chunk); void FinishCopyTo(PostgresCopyState &state); - void BeginCopyFrom(PostgresBinaryReader &reader, const string &query); + void BeginCopyFrom(const string &query, ExecStatusType expected_result); bool IsOpen(); void Close(); diff --git a/src/include/postgres_result.hpp b/src/include/postgres_result.hpp index db5efac5..262adfef 100644 --- a/src/include/postgres_result.hpp +++ b/src/include/postgres_result.hpp @@ -28,6 +28,9 @@ class PostgresResult { D_ASSERT(res); return string(GetValueInternal(row, col)); } + string_t GetStringRef(idx_t row, idx_t col) { + return string_t(GetValueInternal(row, col), PQgetlength(res, row, col)); + } int32_t GetInt32(idx_t row, idx_t col) { return atoi(GetValueInternal(row, col)); diff --git a/src/include/postgres_result_reader.hpp b/src/include/postgres_result_reader.hpp new file mode 100644 index 00000000..2b452ec4 --- /dev/null +++ b/src/include/postgres_result_reader.hpp @@ -0,0 +1,43 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// postgres_result_reader.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb.hpp" +#include "duckdb/common/types/interval.hpp" +#include "postgres_conversion.hpp" +#include "postgres_utils.hpp" + +namespace duckdb { +class PostgresConnection; +struct PostgresBindData; + +enum class PostgresReadResult { FINISHED, HAVE_MORE_TUPLES }; + +struct PostgresResultReader { + explicit PostgresResultReader(PostgresConnection &con_p, const vector &column_ids, + const PostgresBindData &bind_data) + : con(con_p), column_ids(column_ids), bind_data(bind_data) { + } + virtual ~PostgresResultReader() = default; + + PostgresConnection &GetConn() { + return con; + } + +public: + virtual void BeginCopy(const string &sql) = 0; + virtual PostgresReadResult Read(DataChunk &result) = 0; + +protected: + PostgresConnection &con; + const vector &column_ids; + const PostgresBindData &bind_data; +}; + +} // namespace duckdb diff --git a/src/include/postgres_scanner.hpp b/src/include/postgres_scanner.hpp index 2df90d32..7ff4878d 100644 --- a/src/include/postgres_scanner.hpp +++ b/src/include/postgres_scanner.hpp @@ -41,6 +41,7 @@ struct PostgresBindData : public FunctionData { bool read_only = true; bool emit_ctid = false; bool use_transaction = true; + bool use_text_protocol = false; idx_t max_threads = 1; public: diff --git a/src/include/postgres_text_reader.hpp b/src/include/postgres_text_reader.hpp new file mode 100644 index 00000000..46aabb22 --- /dev/null +++ b/src/include/postgres_text_reader.hpp @@ -0,0 +1,41 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// postgres_text_reader.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "postgres_result_reader.hpp" +#include "postgres_connection.hpp" +#include "postgres_result.hpp" + +namespace duckdb { + +struct PostgresTextReader : public PostgresResultReader { + explicit PostgresTextReader(ClientContext &context, PostgresConnection &con, const vector &column_ids, + const PostgresBindData &bind_data); + ~PostgresTextReader() override; + +public: + void BeginCopy(const string &sql) override; + PostgresReadResult Read(DataChunk &result) override; + +private: + void Reset(); + void ConvertVector(Vector &source, Vector &target, const PostgresType &postgres_type, idx_t count); + void ConvertList(Vector &source, Vector &target, const PostgresType &postgres_type, idx_t count); + void ConvertStruct(Vector &source, Vector &target, const PostgresType &postgres_type, idx_t count); + void ConvertCTID(Vector &source, Vector &target, idx_t count); + void ConvertBlob(Vector &source, Vector &target, idx_t count); + +private: + ClientContext &context; + DataChunk scan_chunk; + unique_ptr result; + idx_t row_offset = 0; +}; + +} // namespace duckdb diff --git a/src/postgres_binary_reader.cpp b/src/postgres_binary_reader.cpp new file mode 100644 index 00000000..e204325b --- /dev/null +++ b/src/postgres_binary_reader.cpp @@ -0,0 +1,441 @@ +#include "postgres_binary_reader.hpp" +#include "postgres_scanner.hpp" + +namespace duckdb { + +PostgresBinaryReader::PostgresBinaryReader(PostgresConnection &con_p, const vector &column_ids, + const PostgresBindData &bind_data) + : PostgresResultReader(con_p, column_ids, bind_data) { +} + +PostgresBinaryReader::~PostgresBinaryReader() { + Reset(); +} + +void PostgresBinaryReader::BeginCopy(const string &sql) { + con.BeginCopyFrom(sql, PGRES_COPY_OUT); + if (!Next()) { + throw IOException("Failed to fetch header for COPY \"%s\"", sql); + } + CheckHeader(); +} + +PostgresReadResult PostgresBinaryReader::Read(DataChunk &output) { + while (output.size() < STANDARD_VECTOR_SIZE) { + while (!Ready()) { + if (!Next()) { + // finished this batch + CheckResult(); + return PostgresReadResult::FINISHED; + } + } + + // read a row + auto tuple_count = ReadInteger(); + if (tuple_count <= 0) { // done here, lets try to get more + Reset(); + return PostgresReadResult::FINISHED; + } + + D_ASSERT(tuple_count == column_ids.size()); + + idx_t output_offset = output.size(); + for (idx_t output_idx = 0; output_idx < output.ColumnCount(); output_idx++) { + auto col_idx = column_ids[output_idx]; + auto &out_vec = output.data[output_idx]; + if (col_idx == COLUMN_IDENTIFIER_ROW_ID) { + // row id + // ctid in postgres are a composite type of (page_index, tuple_in_page) + // the page index is a 4-byte integer, the tuple_in_page a 2-byte integer + PostgresType ctid_type; + ctid_type.info = PostgresTypeAnnotation::CTID; + ReadValue(LogicalType::BIGINT, ctid_type, out_vec, output_offset); + } else { + ReadValue(bind_data.types[col_idx], bind_data.postgres_types[col_idx], out_vec, output_offset); + } + } + Reset(); + output.SetCardinality(output_offset + 1); + } + // we filled a chunk + return PostgresReadResult::HAVE_MORE_TUPLES; +} + +bool PostgresBinaryReader::Next() { + Reset(); + char *out_buffer; + int len = PQgetCopyData(con.GetConn(), &out_buffer, 0); + auto new_buffer = data_ptr_cast(out_buffer); + + // len -1 signals end + if (len == -1) { + auto final_result = PQgetResult(con.GetConn()); + if (!final_result || PQresultStatus(final_result) != PGRES_COMMAND_OK) { + throw IOException("Failed to fetch header for COPY: %s", string(PQresultErrorMessage(final_result))); + } + return false; + } + + // len -2 is error + // we expect at least 2 bytes in each message for the tuple count + if (!new_buffer || len < sizeof(int16_t)) { + throw IOException("Unable to read binary COPY data from Postgres: %s", string(PQerrorMessage(con.GetConn()))); + } + buffer = new_buffer; + buffer_ptr = buffer; + end = buffer + len; + return true; +} + +void PostgresBinaryReader::CheckResult() { + auto result = PQgetResult(con.GetConn()); + if (!result || PQresultStatus(result) != PGRES_COMMAND_OK) { + throw std::runtime_error("Failed to execute COPY: " + string(PQresultErrorMessage(result))); + } +} + +void PostgresBinaryReader::Reset() { + if (buffer) { + PQfreemem(buffer); + } + buffer = nullptr; + buffer_ptr = nullptr; + end = nullptr; +} + +bool PostgresBinaryReader::Ready() { + return buffer_ptr != nullptr; +} + +void PostgresBinaryReader::CheckHeader() { + auto magic_len = PostgresConversion::COPY_HEADER_LENGTH; + auto flags_len = 8; + auto header_len = magic_len + flags_len; + + if (!buffer_ptr) { + throw IOException("buffer_ptr not set in CheckHeader"); + } + if (buffer_ptr + header_len >= end) { + throw IOException("Unable to read binary COPY data from Postgres, invalid header"); + } + if (memcmp(buffer_ptr, PostgresConversion::COPY_HEADER, magic_len) != 0) { + throw IOException("Expected Postgres binary COPY header, got something else"); + } + buffer_ptr += header_len; + // as far as i can tell the "Flags field" and the "Header + // extension area length" do not contain anything interesting +} + +PostgresDecimalConfig PostgresBinaryReader::ReadDecimalConfig() { + PostgresDecimalConfig config; + config.ndigits = ReadInteger(); + config.weight = ReadInteger(); + auto sign = ReadInteger(); + + if (!(sign == NUMERIC_POS || sign == NUMERIC_NAN || sign == NUMERIC_PINF || sign == NUMERIC_NINF || + sign == NUMERIC_NEG)) { + throw NotImplementedException("Postgres numeric NA/Inf"); + } + config.is_negative = sign == NUMERIC_NEG; + config.scale = ReadInteger(); + + return config; +} + +void PostgresBinaryReader::ReadGeometry(const LogicalType &type, const PostgresType &postgres_type, Vector &out_vec, + idx_t output_offset) { + idx_t element_count = 0; + switch (postgres_type.info) { + case PostgresTypeAnnotation::GEOM_LINE: + case PostgresTypeAnnotation::GEOM_CIRCLE: + element_count = 3; + break; + case PostgresTypeAnnotation::GEOM_LINE_SEGMENT: + case PostgresTypeAnnotation::GEOM_BOX: + element_count = 4; + break; + case PostgresTypeAnnotation::GEOM_PATH: { + // variable number of elements + auto path_is_closed = ReadBoolean(); // ignored for now + element_count = 2 * ReadInteger(); + break; + } + case PostgresTypeAnnotation::GEOM_POLYGON: + // variable number of elements + element_count = 2 * ReadInteger(); + break; + default: + throw InternalException("Unsupported type for ReadGeometry"); + } + auto list_entries = FlatVector::GetData(out_vec); + auto child_offset = ListVector::GetListSize(out_vec); + ListVector::Reserve(out_vec, child_offset + element_count); + list_entries[output_offset].offset = child_offset; + list_entries[output_offset].length = element_count; + auto &child_vector = ListVector::GetEntry(out_vec); + auto child_data = FlatVector::GetData(child_vector); + for (idx_t i = 0; i < element_count; i++) { + child_data[child_offset + i] = ReadDouble(); + } + ListVector::SetListSize(out_vec, child_offset + element_count); +} + +void PostgresBinaryReader::ReadArray(const LogicalType &type, const PostgresType &postgres_type, Vector &out_vec, + idx_t output_offset, uint32_t current_count, uint32_t dimensions[], + uint32_t ndim) { + auto list_entries = FlatVector::GetData(out_vec); + auto child_offset = ListVector::GetListSize(out_vec); + auto child_dimension = dimensions[0]; + auto child_count = current_count * child_dimension; + // set up the list entries for this dimension + auto current_offset = child_offset; + for (idx_t c = 0; c < current_count; c++) { + auto &list_entry = list_entries[output_offset + c]; + list_entry.offset = current_offset; + list_entry.length = child_dimension; + current_offset += child_dimension; + } + ListVector::Reserve(out_vec, child_offset + child_count); + auto &child_vec = ListVector::GetEntry(out_vec); + auto &child_type = ListType::GetChildType(type); + auto &child_pg_type = postgres_type.children[0]; + if (ndim > 1) { + // there are more dimensions to read - recurse into child list + ReadArray(child_type, child_pg_type, child_vec, child_offset, child_count, dimensions + 1, ndim - 1); + } else { + // this is the last level - read the actual values + for (idx_t child_idx = 0; child_idx < child_count; child_idx++) { + ReadValue(child_type, child_pg_type, child_vec, child_offset + child_idx); + } + } + ListVector::SetListSize(out_vec, child_offset + child_count); +} + +void PostgresBinaryReader::ReadValue(const LogicalType &type, const PostgresType &postgres_type, Vector &out_vec, + idx_t output_offset) { + auto value_len = ReadInteger(); + if (value_len == -1) { // NULL + FlatVector::SetNull(out_vec, output_offset, true); + return; + } + switch (type.id()) { + case LogicalTypeId::SMALLINT: + D_ASSERT(value_len == sizeof(int16_t)); + FlatVector::GetData(out_vec)[output_offset] = ReadInteger(); + break; + case LogicalTypeId::INTEGER: + D_ASSERT(value_len == sizeof(int32_t)); + FlatVector::GetData(out_vec)[output_offset] = ReadInteger(); + break; + case LogicalTypeId::UINTEGER: + D_ASSERT(value_len == sizeof(uint32_t)); + FlatVector::GetData(out_vec)[output_offset] = ReadInteger(); + break; + case LogicalTypeId::BIGINT: + if (postgres_type.info == PostgresTypeAnnotation::CTID) { + D_ASSERT(value_len == 6); + int64_t page_index = ReadInteger(); + int64_t row_in_page = ReadInteger(); + FlatVector::GetData(out_vec)[output_offset] = (page_index << 16LL) + row_in_page; + return; + } + D_ASSERT(value_len == sizeof(int64_t)); + FlatVector::GetData(out_vec)[output_offset] = ReadInteger(); + break; + case LogicalTypeId::FLOAT: + D_ASSERT(value_len == sizeof(float)); + FlatVector::GetData(out_vec)[output_offset] = ReadFloat(); + break; + case LogicalTypeId::DOUBLE: { + // this was an unbounded decimal, read params from value and cast to double + if (postgres_type.info == PostgresTypeAnnotation::NUMERIC_AS_DOUBLE) { + FlatVector::GetData(out_vec)[output_offset] = ReadDecimal(); + break; + } + D_ASSERT(value_len == sizeof(double)); + FlatVector::GetData(out_vec)[output_offset] = ReadDouble(); + break; + } + + case LogicalTypeId::BLOB: + case LogicalTypeId::VARCHAR: { + if (postgres_type.info == PostgresTypeAnnotation::JSONB) { + auto version = ReadInteger(); + value_len--; + if (version != 1) { + throw NotImplementedException("JSONB version number mismatch, expected 1, got %d", version); + } + } + auto str = ReadString(value_len); + if (postgres_type.info == PostgresTypeAnnotation::FIXED_LENGTH_CHAR) { + // CHAR column - remove trailing spaces + while (value_len > 0 && str[value_len - 1] == ' ') { + value_len--; + } + } + FlatVector::GetData(out_vec)[output_offset] = StringVector::AddStringOrBlob(out_vec, str, value_len); + break; + } + case LogicalTypeId::BOOLEAN: + D_ASSERT(value_len == sizeof(bool)); + FlatVector::GetData(out_vec)[output_offset] = ReadBoolean(); + break; + case LogicalTypeId::DECIMAL: { + if (value_len < sizeof(uint16_t) * 4) { + throw InvalidInputException("Need at least 8 bytes to read a Postgres decimal. Got %d", value_len); + } + switch (type.InternalType()) { + case PhysicalType::INT16: + FlatVector::GetData(out_vec)[output_offset] = ReadDecimal(); + break; + case PhysicalType::INT32: + FlatVector::GetData(out_vec)[output_offset] = ReadDecimal(); + break; + case PhysicalType::INT64: + FlatVector::GetData(out_vec)[output_offset] = ReadDecimal(); + break; + case PhysicalType::INT128: + FlatVector::GetData(out_vec)[output_offset] = ReadDecimal(); + break; + default: + throw InvalidInputException("Unsupported decimal storage type"); + } + break; + } + + case LogicalTypeId::DATE: { + D_ASSERT(value_len == sizeof(int32_t)); + auto out_ptr = FlatVector::GetData(out_vec); + out_ptr[output_offset] = ReadDate(); + break; + } + case LogicalTypeId::TIME: { + D_ASSERT(value_len == sizeof(int64_t)); + FlatVector::GetData(out_vec)[output_offset] = ReadTime(); + break; + } + case LogicalTypeId::TIME_TZ: { + D_ASSERT(value_len == sizeof(int64_t) + sizeof(int32_t)); + FlatVector::GetData(out_vec)[output_offset] = ReadTimeTZ(); + break; + } + case LogicalTypeId::TIMESTAMP_TZ: + case LogicalTypeId::TIMESTAMP: { + D_ASSERT(value_len == sizeof(int64_t)); + FlatVector::GetData(out_vec)[output_offset] = ReadTimestamp(); + break; + } + case LogicalTypeId::ENUM: { + auto enum_val = string(ReadString(value_len), value_len); + auto offset = EnumType::GetPos(type, enum_val); + if (offset < 0) { + throw IOException("Could not map ENUM value %s", enum_val); + } + switch (type.InternalType()) { + case PhysicalType::UINT8: + FlatVector::GetData(out_vec)[output_offset] = (uint8_t)offset; + break; + case PhysicalType::UINT16: + FlatVector::GetData(out_vec)[output_offset] = (uint16_t)offset; + break; + + case PhysicalType::UINT32: + FlatVector::GetData(out_vec)[output_offset] = (uint32_t)offset; + break; + + default: + throw InternalException("ENUM can only have unsigned integers (except " + "UINT64) as physical types, got %s", + TypeIdToString(type.InternalType())); + } + break; + } + case LogicalTypeId::INTERVAL: { + FlatVector::GetData(out_vec)[output_offset] = ReadInterval(); + break; + } + case LogicalTypeId::UUID: { + D_ASSERT(value_len == 2 * sizeof(int64_t)); + FlatVector::GetData(out_vec)[output_offset] = ReadUUID(); + break; + } + case LogicalTypeId::LIST: { + auto &list_entry = FlatVector::GetData(out_vec)[output_offset]; + auto child_offset = ListVector::GetListSize(out_vec); + + if (value_len < 1) { + list_entry.offset = child_offset; + list_entry.length = 0; + break; + } + switch (postgres_type.info) { + case PostgresTypeAnnotation::GEOM_LINE: + case PostgresTypeAnnotation::GEOM_LINE_SEGMENT: + case PostgresTypeAnnotation::GEOM_BOX: + case PostgresTypeAnnotation::GEOM_PATH: + case PostgresTypeAnnotation::GEOM_POLYGON: + case PostgresTypeAnnotation::GEOM_CIRCLE: + ReadGeometry(type, postgres_type, out_vec, output_offset); + return; + default: + break; + } + D_ASSERT(value_len >= 3 * sizeof(uint32_t)); + auto array_dim = ReadInteger(); + auto array_has_null = ReadInteger(); // whether or not the array has nulls - ignore + auto value_oid = ReadInteger(); // value_oid - not necessary + if (array_dim == 0) { + list_entry.offset = child_offset; + list_entry.length = 0; + return; + } + // verify the number of dimensions matches the expected number of dimensions + idx_t expected_dimensions = 0; + const_reference current_type = type; + while (current_type.get().id() == LogicalTypeId::LIST) { + current_type = ListType::GetChildType(current_type.get()); + expected_dimensions++; + } + if (expected_dimensions != array_dim) { + throw InvalidInputException( + "Expected an array with %llu dimensions, but this array has %llu dimensions. The array stored in " + "Postgres does not match the schema. Postgres does not enforce that arrays match the provided " + "schema but DuckDB requires this.\nSet pg_array_as_varchar=true to read the array as a varchar " + "instead.", + expected_dimensions, array_dim); + } + auto dimensions = unique_ptr(new uint32_t[array_dim]); + for (idx_t d = 0; d < array_dim; d++) { + dimensions[d] = ReadInteger(); + auto lb = ReadInteger(); // index lower bounds for each dimension -- we don't need them + } + // read the arrays recursively + ReadArray(type, postgres_type, out_vec, output_offset, 1, dimensions.get(), array_dim); + break; + } + case LogicalTypeId::STRUCT: { + auto &child_entries = StructVector::GetEntries(out_vec); + if (postgres_type.info == PostgresTypeAnnotation::GEOM_POINT) { + D_ASSERT(value_len == sizeof(double) * 2); + FlatVector::GetData(*child_entries[0])[output_offset] = ReadDouble(); + FlatVector::GetData(*child_entries[1])[output_offset] = ReadDouble(); + break; + } + auto entry_count = ReadInteger(); + if (entry_count != child_entries.size()) { + throw InternalException("Mismatch in entry count: expected %d but got %d", child_entries.size(), + entry_count); + } + for (idx_t c = 0; c < entry_count; c++) { + auto &child = *child_entries[c]; + auto value_oid = ReadInteger(); + ReadValue(child.GetType(), postgres_type.children[c], child, output_offset); + } + break; + } + default: + throw InternalException("Unsupported Type %s", type.ToString()); + } +} + +} // namespace duckdb diff --git a/src/postgres_copy_from.cpp b/src/postgres_copy_from.cpp index 2fdbc215..fa2a7193 100644 --- a/src/postgres_copy_from.cpp +++ b/src/postgres_copy_from.cpp @@ -3,15 +3,11 @@ namespace duckdb { -void PostgresConnection::BeginCopyFrom(PostgresBinaryReader &reader, const string &query) { +void PostgresConnection::BeginCopyFrom(const string &query, ExecStatusType expected_result) { auto result = PQExecute(query.c_str()); - if (!result || PQresultStatus(result) != PGRES_COPY_OUT) { + if (!result || PQresultStatus(result) != expected_result) { throw std::runtime_error("Failed to prepare COPY \"" + query + "\": " + string(PQresultErrorMessage(result))); } - if (!reader.Next()) { - throw IOException("Failed to fetch header for COPY \"%s\"", query); - } - reader.CheckHeader(); } } // namespace duckdb diff --git a/src/postgres_extension.cpp b/src/postgres_extension.cpp index 5e54b69e..62d24ed1 100644 --- a/src/postgres_extension.cpp +++ b/src/postgres_extension.cpp @@ -180,6 +180,10 @@ static void LoadInternal(DatabaseInstance &db) { LogicalType::VARCHAR, Value(), SetPostgresNullByteReplacement); config.AddExtensionOption("pg_debug_show_queries", "DEBUG SETTING: print all queries sent to Postgres to stdout", LogicalType::BOOLEAN, Value::BOOLEAN(false), SetPostgresDebugQueryPrint); + config.AddExtensionOption("pg_use_text_protocol", + "Whether or not to use TEXT protocol to read data. This is slower, but provides better " + "compatibility with non-Postgres systems", + LogicalType::BOOLEAN, Value::BOOLEAN(false)); OptimizerExtension postgres_optimizer; postgres_optimizer.optimize_function = PostgresOptimizer::Optimize; diff --git a/src/postgres_scanner.cpp b/src/postgres_scanner.cpp index 4c4d4805..db3007b0 100644 --- a/src/postgres_scanner.cpp +++ b/src/postgres_scanner.cpp @@ -10,6 +10,7 @@ #include "postgres_scanner.hpp" #include "postgres_result.hpp" #include "postgres_binary_reader.hpp" +#include "postgres_text_reader.hpp" #include "storage/postgres_catalog.hpp" #include "storage/postgres_transaction.hpp" #include "storage/postgres_table_set.hpp" @@ -31,6 +32,7 @@ struct PostgresLocalState : public LocalTableFunctionState { PostgresConnection connection; idx_t batch_idx = 0; PostgresPoolConnection pool_connection; + unique_ptr reader; void ScanChunk(ClientContext &context, const PostgresBindData &bind_data, PostgresGlobalState &gstate, DataChunk &output); @@ -113,6 +115,14 @@ void PostgresScanFunction::PrepareBind(PostgresVersion version, ClientContext &c if (context.TryGetCurrentSetting("pg_use_ctid_scan", pg_use_ctid_scan)) { use_ctid_scan = BooleanValue::Get(pg_use_ctid_scan); } + Value use_text_protocol; + if (context.TryGetCurrentSetting("pg_use_text_protocol", use_text_protocol)) { + if (BooleanValue::Get(use_text_protocol)) { + bind_data.use_text_protocol = true; + use_ctid_scan = false; + } + } + if (version.major_v < 14) { // Disable parallel CTID scan on older Postgres versions since it is not efficient // see https://github.com/duckdb/postgres_scanner/issues/186 @@ -127,7 +137,7 @@ void PostgresScanFunction::PrepareBind(PostgresVersion version, ClientContext &c void PostgresBindData::SetTablePages(idx_t approx_num_pages) { this->pages_approx = approx_num_pages; - if (!read_only) { + if (!read_only || use_text_protocol) { max_threads = 1; } else { max_threads = MaxValue(pages_approx / pages_per_task, 1); @@ -238,6 +248,9 @@ static void PostgresInitInternal(ClientContext &context, const PostgresBindData PostgresFilterPushdown::TransformFilters(lstate.column_ids, lstate.filters, bind_data->names); string filter; + + lstate.exec = false; + lstate.done = false; if (bind_data->pages_approx > 0) { filter = StringUtil::Format("WHERE ctid BETWEEN '(%d,0)'::tid AND '(%d,0)'::tid", task_min, task_max); } @@ -249,20 +262,23 @@ static void PostgresInitInternal(ClientContext &context, const PostgresBindData } filter += filter_string; } + string query; if (bind_data->table_name.empty()) { D_ASSERT(!bind_data->sql.empty()); - lstate.sql = - StringUtil::Format(R"(COPY (SELECT %s FROM (%s) AS __unnamed_subquery %s%s) TO STDOUT (FORMAT "binary");)", - col_names, bind_data->sql, filter, bind_data->limit); + query = StringUtil::Format(R"(SELECT %s FROM (%s) AS __unnamed_subquery %s%s)", col_names, bind_data->sql, + filter, bind_data->limit); } else { - lstate.sql = - StringUtil::Format(R"(COPY (SELECT %s FROM %s.%s %s%s) TO STDOUT (FORMAT "binary");)", col_names, - KeywordHelper::WriteQuoted(bind_data->schema_name, '"'), - KeywordHelper::WriteQuoted(bind_data->table_name, '"'), filter, bind_data->limit); + query = StringUtil::Format(R"(SELECT %s FROM %s.%s %s%s)", col_names, + KeywordHelper::WriteQuoted(bind_data->schema_name, '"'), + KeywordHelper::WriteQuoted(bind_data->table_name, '"'), filter, bind_data->limit); } - lstate.exec = false; - lstate.done = false; + if (!bind_data->use_text_protocol) { + query = StringUtil::Format(R"(COPY (%s) TO STDOUT (FORMAT "binary");)", query); + } else { + query += ";"; + } + lstate.sql = std::move(query); } static idx_t PostgresMaxThreads(ClientContext &context, const FunctionData *bind_data_p) { @@ -420,55 +436,29 @@ static unique_ptr PostgresInitLocalState(ExecutionConte void PostgresLocalState::ScanChunk(ClientContext &context, const PostgresBindData &bind_data, PostgresGlobalState &gstate, DataChunk &output) { idx_t output_offset = 0; - PostgresBinaryReader reader(connection); + if (!reader) { + if (bind_data.use_text_protocol) { + reader = make_uniq(context, connection, column_ids, bind_data); + } else { + reader = make_uniq(connection, column_ids, bind_data); + } + } while (true) { if (done && !PostgresParallelStateNext(context, &bind_data, *this, gstate)) { return; } if (!exec) { - connection.BeginCopyFrom(reader, sql); + reader->BeginCopy(sql); exec = true; } - - output.SetCardinality(output_offset); - if (output_offset == STANDARD_VECTOR_SIZE) { - return; - } - - while (!reader.Ready()) { - if (!reader.Next()) { - // finished this batch - reader.CheckResult(); - done = true; - continue; - } - } - - auto tuple_count = reader.ReadInteger(); - if (tuple_count <= 0) { // done here, lets try to get more - reader.Reset(); + auto read_result = reader->Read(output); + if (read_result == PostgresReadResult::FINISHED) { done = true; continue; } - - D_ASSERT(tuple_count == column_ids.size()); - - for (idx_t output_idx = 0; output_idx < output.ColumnCount(); output_idx++) { - auto col_idx = column_ids[output_idx]; - auto &out_vec = output.data[output_idx]; - if (col_idx == COLUMN_IDENTIFIER_ROW_ID) { - // row id - // ctid in postgres are a composite type of (page_index, tuple_in_page) - // the page index is a 4-byte integer, the tuple_in_page a 2-byte integer - PostgresType ctid_type; - ctid_type.info = PostgresTypeAnnotation::CTID; - reader.ReadValue(LogicalType::BIGINT, ctid_type, out_vec, output_offset); - } else { - reader.ReadValue(bind_data.types[col_idx], bind_data.postgres_types[col_idx], out_vec, output_offset); - } + if (output.size() == STANDARD_VECTOR_SIZE) { + return; } - reader.Reset(); - output_offset++; } } diff --git a/src/postgres_text_reader.cpp b/src/postgres_text_reader.cpp new file mode 100644 index 00000000..96154cc1 --- /dev/null +++ b/src/postgres_text_reader.cpp @@ -0,0 +1,376 @@ +#include "postgres_text_reader.hpp" +#include "postgres_scanner.hpp" +#include "duckdb/common/types/blob.hpp" + +namespace duckdb { + +PostgresTextReader::PostgresTextReader(ClientContext &context, PostgresConnection &con_p, + const vector &column_ids, const PostgresBindData &bind_data) + : PostgresResultReader(con_p, column_ids, bind_data), context(context) { +} + +PostgresTextReader::~PostgresTextReader() { + Reset(); +} + +void PostgresTextReader::BeginCopy(const string &sql) { + result = con.Query(sql); + row_offset = 0; +} + +struct PostgresListParser { + PostgresListParser() : capacity(STANDARD_VECTOR_SIZE), size(0), vector(LogicalType::VARCHAR, capacity) { + } + + void Initialize() { + } + + void AddString(const string &str, bool "ed) { + if (size >= capacity) { + vector.Resize(capacity, capacity * 2); + capacity *= 2; + } + if (!quoted && str == "NULL") { + FlatVector::SetNull(vector, size, true); + } else { + FlatVector::GetData(vector)[size] = StringVector::AddStringOrBlob(vector, str); + } + size++; + quoted = false; + } + + void Finish() { + } + + idx_t capacity; + idx_t size; + Vector vector; +}; + +struct PostgresStructParser { + PostgresStructParser(ClientContext &context, idx_t child_count, idx_t row_count) { + vector child_varchar_types; + for (idx_t c = 0; c < child_count; c++) { + child_varchar_types.push_back(LogicalType::VARCHAR); + } + + data.Initialize(context, child_varchar_types, row_count); + } + + void Initialize() { + column_offset = 0; + } + + void AddString(const string &str, bool "ed) { + if (column_offset >= data.ColumnCount()) { + throw InvalidInputException("Too many columns in data for parsing struct - string %s - expected %d", str, + data.ColumnCount()); + } + auto &col = data.data[column_offset]; + if (!quoted && str == "NULL") { + FlatVector::SetNull(col, row_offset, true); + } else { + FlatVector::GetData(col)[row_offset] = StringVector::AddStringOrBlob(col, str); + } + column_offset++; + } + + void Finish() { + if (column_offset != data.ColumnCount()) { + throw InvalidInputException("Missing columns in data for parsing struct - expected %d but got %d", + data.ColumnCount(), column_offset); + } + row_offset++; + } + + DataChunk data; + idx_t column_offset = 0; + idx_t row_offset = 0; +}; + +struct PostgresCTIDParser { + PostgresCTIDParser() { + } + + void Initialize() { + } + + void AddString(const string &str, bool "ed) { + values.push_back(StringUtil::ToUnsigned(str)); + } + + void Finish() { + if (values.size() != 2) { + throw InvalidInputException("CTID mismatch - expected (page_index, row_in_page)"); + } + } + + vector values; +}; + +template +void ParsePostgresNested(T &parser, string_t list, char start, char end) { + auto str = list.GetData(); + auto size = list.GetSize(); + if (size == 0 || str[0] != start || str[size - 1] != end) { + throw InvalidInputException("Invalid Postgres list - expected %s...%s - got %s", string(1, start), + string(1, end), list.GetString()); + } + parser.Initialize(); + bool quoted = false; + bool was_quoted = false; + vector delims; + string current_string; + for (idx_t i = 1; i < size - 1; i++) { + auto c = str[i]; + if (quoted) { + switch (c) { + case '"': + quoted = false; + break; + case '\\': + // escape - directly add the next character to the string + if (i + 1 < size) { + current_string += str[i + 1]; + } + // skip the next character + i++; + break; + default: + current_string += c; + break; + } + continue; + } + switch (c) { + case '{': + delims.push_back('}'); + current_string += c; + break; + case '(': + delims.push_back(')'); + current_string += c; + break; + case '}': + case ')': + if (delims.empty() || delims.back() != c) { + throw InvalidInputException("Failed to convert list %s - mismatch in brackets", list.GetString()); + } + delims.pop_back(); + current_string += c; + break; + case '"': + quoted = true; + was_quoted = true; + break; + case ',': + if (!delims.empty()) { + // in a nested struct + current_string += c; + break; + } + // next element + if (!current_string.empty() || was_quoted) { + parser.AddString(current_string, was_quoted); + } + current_string = string(); + break; + default: + current_string += c; + } + } + if (!current_string.empty() || was_quoted) { + parser.AddString(current_string, was_quoted); + } + parser.Finish(); +} + +void ParsePostgresList(PostgresListParser &list_parser, string_t list) { + ParsePostgresNested(list_parser, list, '{', '}'); +} + +void ParsePostgresStruct(PostgresStructParser &struct_parser, string_t list) { + ParsePostgresNested(struct_parser, list, '(', ')'); +} + +void ParsePostgresCTID(PostgresCTIDParser &ctid_parser, string_t list) { + ParsePostgresNested(ctid_parser, list, '(', ')'); +} + +void PostgresTextReader::ConvertList(Vector &source, Vector &target, const PostgresType &postgres_type, idx_t count) { + // lists have the format {1, 2, 3} + UnifiedVectorFormat vdata; + source.ToUnifiedFormat(count, vdata); + + auto strings = UnifiedVectorFormat::GetData(vdata); + auto list_data = FlatVector::GetData(target); + + PostgresListParser list_parser; + for (idx_t i = 0; i < count; i++) { + if (!vdata.validity.RowIsValid(i)) { + // NULL value - skip + FlatVector::SetNull(target, i, true); + continue; + } + list_data[i].offset = list_parser.size; + ParsePostgresList(list_parser, strings[i]); + list_data[i].length = list_parser.size - list_data[i].offset; + } + if (list_parser.size > 0) { + auto &target_child = ListVector::GetEntry(target); + ListVector::Reserve(target, list_parser.size); + ConvertVector(list_parser.vector, target_child, + postgres_type.children.empty() ? PostgresType() : postgres_type.children[0], list_parser.size); + } + ListVector::SetListSize(target, list_parser.size); +} + +void PostgresTextReader::ConvertStruct(Vector &source, Vector &target, const PostgresType &postgres_type, idx_t count) { + // structs have the format (1, 2, 3) + UnifiedVectorFormat vdata; + source.ToUnifiedFormat(count, vdata); + auto strings = UnifiedVectorFormat::GetData(vdata); + auto &children = StructVector::GetEntries(target); + + PostgresStructParser struct_parser(context, children.size(), count); + for (idx_t i = 0; i < count; i++) { + if (!vdata.validity.RowIsValid(i)) { + // NULL value - skip + FlatVector::SetNull(target, i, true); + for (idx_t c = 0; c < children.size(); c++) { + FlatVector::SetNull(struct_parser.data.data[c], i, true); + } + continue; + } + ParsePostgresStruct(struct_parser, strings[i]); + } + for (idx_t c = 0; c < children.size(); c++) { + ConvertVector(struct_parser.data.data[c], *children[c], + c >= postgres_type.children.size() ? PostgresType() : postgres_type.children[c], count); + } +} + +void PostgresTextReader::ConvertCTID(Vector &source, Vector &target, idx_t count) { + // ctids have the format (page_index, row_in_page) + UnifiedVectorFormat vdata; + source.ToUnifiedFormat(count, vdata); + auto strings = UnifiedVectorFormat::GetData(vdata); + auto result = FlatVector::GetData(target); + + for (idx_t i = 0; i < count; i++) { + if (!vdata.validity.RowIsValid(i)) { + // NULL value - skip + FlatVector::SetNull(target, i, true); + continue; + } + PostgresCTIDParser ctid_parser; + ParsePostgresCTID(ctid_parser, strings[i]); + auto page_index = ctid_parser.values[0]; + auto row_in_page = ctid_parser.values[1]; + result[i] = NumericCast((page_index << 16LL) + row_in_page); + } +} + +void PostgresTextReader::ConvertBlob(Vector &source, Vector &target, idx_t count) { + // ctids have the format (page_index, row_in_page) + UnifiedVectorFormat vdata; + source.ToUnifiedFormat(count, vdata); + auto strings = UnifiedVectorFormat::GetData(vdata); + auto result = FlatVector::GetData(target); + + for (idx_t i = 0; i < count; i++) { + if (!vdata.validity.RowIsValid(i)) { + // NULL value - skip + FlatVector::SetNull(target, i, true); + continue; + } + auto blob_str = strings[i]; + auto str = blob_str.GetData(); + auto size = blob_str.GetSize(); + if (size < 2 || str[0] != '\\' || str[1] != 'x') { + throw InvalidInputException("Incorrect blob format - expected \\x... for blob"); + } + if (size % 2 != 0) { + throw InvalidInputException("Blob size must be modulo 2 (\\xAA)"); + } + string result_blob; + for (idx_t i = 2; i < size; i += 2) { + int byte_a = Blob::HEX_MAP[static_cast(str[i])]; + int byte_b = Blob::HEX_MAP[static_cast(str[i + 1])]; + result_blob += UnsafeNumericCast((byte_a << 4) + byte_b); + } + result[i] = StringVector::AddStringOrBlob(target, result_blob); + } +} + +void PostgresTextReader::ConvertVector(Vector &source, Vector &target, const PostgresType &postgres_type, idx_t count) { + if (source.GetType().id() != LogicalTypeId::VARCHAR) { + throw InternalException("Source needs to be VARCHAR"); + } + if (postgres_type.info == PostgresTypeAnnotation::CTID) { + ConvertCTID(source, target, count); + return; + } + switch (target.GetType().id()) { + case LogicalTypeId::LIST: + ConvertList(source, target, postgres_type, count); + break; + case LogicalTypeId::STRUCT: + ConvertStruct(source, target, postgres_type, count); + break; + case LogicalTypeId::BLOB: + ConvertBlob(source, target, count); + break; + default: + VectorOperations::Cast(context, source, target, count); + } +} + +PostgresReadResult PostgresTextReader::Read(DataChunk &output) { + if (!result) { + return PostgresReadResult::FINISHED; + } + if (scan_chunk.data.empty()) { + // initialize the scan chunk + vector types; + for (idx_t i = 0; i < output.ColumnCount(); i++) { + types.push_back(LogicalType::VARCHAR); + } + scan_chunk.Initialize(context, types); + } + scan_chunk.Reset(); + for (; scan_chunk.size() < STANDARD_VECTOR_SIZE && row_offset < result->Count(); row_offset++) { + idx_t output_offset = scan_chunk.size(); + for (idx_t output_idx = 0; output_idx < output.ColumnCount(); output_idx++) { + auto col_idx = column_ids[output_idx]; + auto &out_vec = scan_chunk.data[output_idx]; + if (result->IsNull(row_offset, output_idx)) { + FlatVector::SetNull(out_vec, output_offset, true); + continue; + } + auto col_data = FlatVector::GetData(out_vec); + col_data[output_offset] = + StringVector::AddStringOrBlob(out_vec, result->GetStringRef(row_offset, output_idx)); + } + scan_chunk.SetCardinality(scan_chunk.size() + 1); + } + for (idx_t c = 0; c < output.ColumnCount(); c++) { + auto col_idx = column_ids[c]; + if (col_idx == COLUMN_IDENTIFIER_ROW_ID) { + PostgresType ctid_type; + ctid_type.info = PostgresTypeAnnotation::CTID; + ConvertVector(scan_chunk.data[c], output.data[c], ctid_type, scan_chunk.size()); + } else { + ConvertVector(scan_chunk.data[c], output.data[c], bind_data.postgres_types[c], scan_chunk.size()); + } + } + output.SetCardinality(scan_chunk.size()); + return row_offset < result->Count() ? PostgresReadResult::HAVE_MORE_TUPLES : PostgresReadResult::FINISHED; +} + +void PostgresTextReader::Reset() { + result.reset(); + row_offset = 0; +} + +} // namespace duckdb diff --git a/test/sql/misc/postgres_binary.test b/test/sql/misc/postgres_binary.test index d0141588..98926a29 100644 --- a/test/sql/misc/postgres_binary.test +++ b/test/sql/misc/postgres_binary.test @@ -52,6 +52,7 @@ query I nosort all_types FROM all_types_tbl ---- + query I nosort all_types SELECT * FROM s.binary_copy_test ---- diff --git a/test/sql/storage/attach_types.test b/test/sql/storage/attach_types.test index 7819c04f..7be6677e 100644 --- a/test/sql/storage/attach_types.test +++ b/test/sql/storage/attach_types.test @@ -82,3 +82,14 @@ SELECT ANY_VALUE(${column_name})=getvariable('minimum_value') FROM s.all_types W true endloop + +# text protocol +statement ok +SET pg_use_text_protocol=true; + +query IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII +SELECT COLUMNS(*)::VARCHAR FROM s.all_types +---- +false -128 -32768 -2147483648 -9223372036854775808 0 0 0 2000-01-01 00:00:00 2000-01-01 01:02:03 2000-01-01 01:02:03 2000-01-01 01:02:03 2000-01-01 01:02:03 00:00:00+15:00 2000-01-01 01:02:03 -999.9 -99999.9999 -999999999999.999999 -9999999999999999999999999999.9999999999 00000000-0000-0000-0000-000000000000 00:00:00 🦆🦆🦆🦆🦆🦆 thisisalongblob\x00withnullbytes 0010001001011100010101011010111 DUCK_DUCK_ENUM enum_0 enum_0 [] [] [] [] [] [🦆🦆🦆🦆🦆🦆, goose, NULL] +true 127 32767 2147483647 9223372036854775807 255 65535 4294967295 2000-01-01 24:00:00 2000-01-01 01:02:03 2000-01-01 01:02:03 2000-01-01 01:02:03 2000-01-01 01:02:03 00:00:00+15:00 2000-01-01 01:02:03 999.9 99999.9999 999999999999.999999 9999999999999999999999999999.9999999999 ffffffff-ffff-ffff-ffff-ffffffffffff 83 years 3 months 999 days 00:16:39.999999 goo se \x00\x00\x00a 10101 GOOSE enum_299 enum_69999 [42, 999, NULL, NULL, -42] [42.0, nan, inf, -inf, NULL, -42.0] [1970-01-01, infinity, -infinity, NULL, 2022-05-12] ['1970-01-01 00:00:00', infinity, -infinity, NULL, '2022-05-12 16:23:45'] ['1970-01-01 00:00:00+00', infinity, -infinity, NULL, '2022-05-12 23:23:45+00'] [🦆🦆🦆🦆🦆🦆, goose, NULL] +NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL