-
Notifications
You must be signed in to change notification settings - Fork 4
Add interval array #60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Alex-PLACET
wants to merge
3
commits into
QuantStack:main
Choose a base branch
from
Alex-PLACET:add_interval_array
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| #pragma once | ||
|
|
||
| #include <optional> | ||
| #include <vector> | ||
|
|
||
| #include <sparrow/arrow_interface/arrow_array_schema_proxy.hpp> | ||
| #include <sparrow/interval_array.hpp> | ||
|
|
||
| #include "Message_generated.h" | ||
| #include "sparrow_ipc/arrow_interface/arrow_array.hpp" | ||
| #include "sparrow_ipc/arrow_interface/arrow_schema.hpp" | ||
| #include "sparrow_ipc/deserialize_utils.hpp" | ||
|
|
||
| namespace sparrow_ipc | ||
| { | ||
| template <typename T> | ||
| [[nodiscard]] sparrow::interval_array<T> deserialize_non_owning_interval_array( | ||
| const org::apache::arrow::flatbuf::RecordBatch& record_batch, | ||
| std::span<const uint8_t> body, | ||
| std::string_view name, | ||
| const std::optional<std::vector<sparrow::metadata_pair>>& metadata, | ||
| bool nullable, | ||
| size_t& buffer_index | ||
| ) | ||
| { | ||
| const std::string_view format = data_type_to_format( | ||
| sparrow::detail::get_data_type_from_array<sparrow::interval_array<T>>::get() | ||
| ); | ||
|
|
||
| // Set up flags based on nullable | ||
| std::optional<std::unordered_set<sparrow::ArrowFlag>> flags; | ||
| if (nullable) | ||
| { | ||
| flags = std::unordered_set<sparrow::ArrowFlag>{sparrow::ArrowFlag::NULLABLE}; | ||
| } | ||
|
|
||
| ArrowSchema schema = make_non_owning_arrow_schema( | ||
| format, | ||
| name.data(), | ||
| metadata, | ||
| flags, | ||
| 0, | ||
| nullptr, | ||
| nullptr | ||
| ); | ||
|
|
||
| const auto compression = record_batch.compression(); | ||
| std::vector<arrow_array_private_data::optionally_owned_buffer> buffers; | ||
|
|
||
| auto validity_buffer_span = utils::get_buffer(record_batch, body, buffer_index); | ||
| auto data_buffer_span = utils::get_buffer(record_batch, body, buffer_index); | ||
|
|
||
| if (compression) | ||
| { | ||
| buffers.push_back(utils::get_decompressed_buffer(validity_buffer_span, compression)); | ||
| buffers.push_back(utils::get_decompressed_buffer(data_buffer_span, compression)); | ||
| } | ||
| else | ||
| { | ||
| buffers.emplace_back(validity_buffer_span); | ||
| buffers.emplace_back(data_buffer_span); | ||
| } | ||
|
|
||
| // TODO bitmap_ptr is not used anymore... Leave it for now, and remove later if no need confirmed | ||
| const auto [bitmap_ptr, null_count] = utils::get_bitmap_pointer_and_null_count(validity_buffer_span, record_batch.length()); | ||
|
|
||
| ArrowArray array = make_arrow_array<arrow_array_private_data>( | ||
| record_batch.length(), | ||
| null_count, | ||
| 0, | ||
| 0, | ||
| nullptr, | ||
| nullptr, | ||
| std::move(buffers) | ||
| ); | ||
|
|
||
| sparrow::arrow_proxy ap{std::move(array), std::move(schema)}; | ||
| return sparrow::interval_array<T>{std::move(ap)}; | ||
| } | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,7 @@ | |
| #include <sparrow/types/data_type.hpp> | ||
|
|
||
| #include "sparrow_ipc/deserialize_fixedsizebinary_array.hpp" | ||
| #include "sparrow_ipc/deserialize_interval_array.hpp" | ||
| #include "sparrow_ipc/deserialize_primitive_array.hpp" | ||
| #include "sparrow_ipc/deserialize_variable_size_binary_array.hpp" | ||
| #include "sparrow_ipc/encapsulated_message.hpp" | ||
|
|
@@ -11,11 +12,23 @@ | |
|
|
||
| namespace sparrow_ipc | ||
| { | ||
| namespace | ||
| { | ||
| // Integer bit width constants | ||
| constexpr int32_t BIT_WIDTH_8 = 8; | ||
| constexpr int32_t BIT_WIDTH_16 = 16; | ||
| constexpr int32_t BIT_WIDTH_32 = 32; | ||
| constexpr int32_t BIT_WIDTH_64 = 64; | ||
|
|
||
| // End-of-stream marker size in bytes | ||
| constexpr size_t END_OF_STREAM_MARKER_SIZE = 8; | ||
| } | ||
| const org::apache::arrow::flatbuf::RecordBatch* | ||
| deserialize_record_batch_message(std::span<const uint8_t> data, size_t& current_offset) | ||
| { | ||
| current_offset += sizeof(uint32_t); | ||
| const auto batch_message = org::apache::arrow::flatbuf::GetMessage(data.data() + current_offset); | ||
| const auto message_data = data.subspan(current_offset); | ||
| const auto* batch_message = org::apache::arrow::flatbuf::GetMessage(message_data.data()); | ||
| if (batch_message->header_type() != org::apache::arrow::flatbuf::MessageHeader::RecordBatch) | ||
| { | ||
| throw std::runtime_error("Expected RecordBatch message, but got a different type."); | ||
|
|
@@ -28,21 +41,21 @@ namespace sparrow_ipc | |
| * | ||
| * This function processes each field in the schema and deserializes the corresponding | ||
| * data from the RecordBatch into sparrow::array objects. It handles various Arrow data | ||
| * types including primitive types (bool, integers, floating point), binary data, and | ||
| * string data with their respective size variants. | ||
| * types including primitive types (bool, integers, floating point), binary data, string | ||
| * data, fixed-size binary data, and interval types. | ||
| * | ||
| * @param record_batch The Apache Arrow FlatBuffer RecordBatch containing the serialized data | ||
| * @param schema The Apache Arrow FlatBuffer Schema defining the structure and types of the data | ||
| * @param encapsulated_message The message containing the binary data buffers | ||
| * @param field_metadata Metadata for each field | ||
| * @param field_metadata Metadata associated with each field in the schema | ||
| * | ||
| * @return std::vector<sparrow::array> A vector of deserialized arrays, one for each field in the schema | ||
| * | ||
| * @throws std::runtime_error If an unsupported data type, integer bit width, or floating point precision | ||
| * is encountered | ||
| * @throws std::runtime_error If an unsupported data type, integer bit width, floating point precision, | ||
| * or interval unit is encountered | ||
| * | ||
| * The function maintains a buffer index that is incremented as it processes each field | ||
| * to correctly map data buffers to their corresponding arrays. | ||
| * @note The function maintains a buffer index that is incremented as it processes each field | ||
| * to correctly map data buffers to their corresponding arrays. | ||
| */ | ||
| std::vector<sparrow::array> get_arrays_from_record_batch( | ||
| const org::apache::arrow::flatbuf::RecordBatch& record_batch, | ||
|
|
@@ -64,7 +77,7 @@ namespace sparrow_ipc | |
| const std::string name = field->name() == nullptr ? "" : field->name()->str(); | ||
| const bool nullable = field->nullable(); | ||
| const auto field_type = field->type_type(); | ||
| // TODO rename all the deserialize_non_owning... fcts since this is not correct anymore | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why remove this? This is actually still relevant as we are using now |
||
|
|
||
| const auto deserialize_non_owning_primitive_array_lambda = [&]<typename T>() | ||
| { | ||
| return deserialize_non_owning_primitive_array<T>( | ||
|
|
@@ -85,7 +98,7 @@ namespace sparrow_ipc | |
| break; | ||
| case org::apache::arrow::flatbuf::Type::Int: | ||
| { | ||
| const auto int_type = field->type_as_Int(); | ||
| const auto* int_type = field->type_as_Int(); | ||
| const auto bit_width = int_type->bitWidth(); | ||
| const bool is_signed = int_type->is_signed(); | ||
|
|
||
|
|
@@ -94,11 +107,11 @@ namespace sparrow_ipc | |
| switch (bit_width) | ||
| { | ||
| // clang-format off | ||
| case 8: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int8_t>()); break; | ||
| case 16: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int16_t>()); break; | ||
| case 32: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int32_t>()); break; | ||
| case 64: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int64_t>()); break; | ||
| default: throw std::runtime_error("Unsupported integer bit width."); | ||
| case BIT_WIDTH_8: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int8_t>()); break; | ||
| case BIT_WIDTH_16: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int16_t>()); break; | ||
| case BIT_WIDTH_32: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int32_t>()); break; | ||
| case BIT_WIDTH_64: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int64_t>()); break; | ||
| default: throw std::runtime_error("Unsupported integer bit width: " + std::to_string(bit_width)); | ||
| // clang-format on | ||
| } | ||
| } | ||
|
|
@@ -107,19 +120,19 @@ namespace sparrow_ipc | |
| switch (bit_width) | ||
| { | ||
| // clang-format off | ||
| case 8: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint8_t>()); break; | ||
| case 16: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint16_t>()); break; | ||
| case 32: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint32_t>()); break; | ||
| case 64: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint64_t>()); break; | ||
| default: throw std::runtime_error("Unsupported integer bit width."); | ||
| case BIT_WIDTH_8: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint8_t>()); break; | ||
| case BIT_WIDTH_16: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint16_t>()); break; | ||
| case BIT_WIDTH_32: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint32_t>()); break; | ||
| case BIT_WIDTH_64: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint64_t>()); break; | ||
| default: throw std::runtime_error("Unsupported integer bit width: " + std::to_string(bit_width)); | ||
| // clang-format on | ||
| } | ||
| } | ||
| } | ||
| break; | ||
| case org::apache::arrow::flatbuf::Type::FloatingPoint: | ||
| { | ||
| const auto float_type = field->type_as_FloatingPoint(); | ||
| const auto* float_type = field->type_as_FloatingPoint(); | ||
| switch (float_type->precision()) | ||
| { | ||
| // clang-format off | ||
|
|
@@ -133,14 +146,17 @@ namespace sparrow_ipc | |
| arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<double>()); | ||
| break; | ||
| default: | ||
| throw std::runtime_error("Unsupported floating point precision."); | ||
| throw std::runtime_error( | ||
| "Unsupported floating point precision: " | ||
| + std::to_string(static_cast<int>(float_type->precision())) | ||
| ); | ||
| // clang-format on | ||
| } | ||
| break; | ||
| } | ||
| case org::apache::arrow::flatbuf::Type::FixedSizeBinary: | ||
| { | ||
| const auto fixed_size_binary_field = field->type_as_FixedSizeBinary(); | ||
| const auto* fixed_size_binary_field = field->type_as_FixedSizeBinary(); | ||
| arrays.emplace_back(deserialize_non_owning_fixedwidthbinary( | ||
| record_batch, | ||
| encapsulated_message.body(), | ||
|
|
@@ -200,8 +216,61 @@ namespace sparrow_ipc | |
| ) | ||
| ); | ||
| break; | ||
| case org::apache::arrow::flatbuf::Type::Interval: | ||
| { | ||
| const auto* interval_type = field->type_as_Interval(); | ||
| const org::apache::arrow::flatbuf::IntervalUnit interval_unit = interval_type->unit(); | ||
| switch (interval_unit) | ||
| { | ||
| case org::apache::arrow::flatbuf::IntervalUnit::YEAR_MONTH: | ||
| arrays.emplace_back( | ||
| deserialize_non_owning_interval_array<sparrow::chrono::months>( | ||
| record_batch, | ||
| encapsulated_message.body(), | ||
| name, | ||
| metadata, | ||
| nullable, | ||
| buffer_index | ||
| ) | ||
| ); | ||
| break; | ||
| case org::apache::arrow::flatbuf::IntervalUnit::DAY_TIME: | ||
| arrays.emplace_back( | ||
| deserialize_non_owning_interval_array<sparrow::days_time_interval>( | ||
| record_batch, | ||
| encapsulated_message.body(), | ||
| name, | ||
| metadata, | ||
| nullable, | ||
| buffer_index | ||
| ) | ||
| ); | ||
| break; | ||
| case org::apache::arrow::flatbuf::IntervalUnit::MONTH_DAY_NANO: | ||
| arrays.emplace_back( | ||
| deserialize_non_owning_interval_array<sparrow::month_day_nanoseconds_interval>( | ||
| record_batch, | ||
| encapsulated_message.body(), | ||
| name, | ||
| metadata, | ||
| nullable, | ||
| buffer_index | ||
| ) | ||
| ); | ||
| break; | ||
| default: | ||
| throw std::runtime_error( | ||
| "Unsupported interval unit: " | ||
| + std::to_string(static_cast<int>(interval_unit)) | ||
| ); | ||
| } | ||
| } | ||
| break; | ||
| default: | ||
| throw std::runtime_error("Unsupported type."); | ||
| throw std::runtime_error( | ||
| "Unsupported field type: " + std::to_string(static_cast<int>(field_type)) | ||
| + " for field '" + name + "'" | ||
| ); | ||
| } | ||
| } | ||
| return arrays; | ||
|
|
@@ -215,10 +284,12 @@ namespace sparrow_ipc | |
| std::vector<bool> fields_nullable; | ||
| std::vector<sparrow::data_type> field_types; | ||
| std::vector<std::optional<std::vector<sparrow::metadata_pair>>> fields_metadata; | ||
| do | ||
|
|
||
| while (!data.empty()) | ||
| { | ||
| // Check for end-of-stream marker here as data could contain only that (if no record batches present/written) | ||
| if (data.size() >= 8 && is_end_of_stream(data.subspan(0, 8))) | ||
| // Check for end-of-stream marker | ||
| if (data.size() >= END_OF_STREAM_MARKER_SIZE | ||
| && is_end_of_stream(data.subspan(0, END_OF_STREAM_MARKER_SIZE))) | ||
| { | ||
| break; | ||
| } | ||
|
|
@@ -243,11 +314,12 @@ namespace sparrow_ipc | |
|
|
||
| for (const auto field : *(schema->fields())) | ||
| { | ||
| if(field != nullptr && field->name() != nullptr) | ||
| if (field != nullptr && field->name() != nullptr) | ||
| { | ||
| field_names.emplace_back(field->name()->str()); | ||
| field_names.emplace_back(field->name()->str()); | ||
| } | ||
| else { | ||
| else | ||
| { | ||
| field_names.emplace_back("_unnamed_"); | ||
| } | ||
| fields_nullable.push_back(field->nullable()); | ||
|
|
@@ -265,33 +337,33 @@ namespace sparrow_ipc | |
| { | ||
| if (schema == nullptr) | ||
| { | ||
| throw std::runtime_error("Schema message is missing."); | ||
| throw std::runtime_error("RecordBatch encountered before Schema message."); | ||
| } | ||
| const auto record_batch = message->header_as_RecordBatch(); | ||
| const auto* record_batch = message->header_as_RecordBatch(); | ||
| if (record_batch == nullptr) | ||
| { | ||
| throw std::runtime_error("RecordBatch message is missing."); | ||
| throw std::runtime_error("RecordBatch message header is null."); | ||
| } | ||
| std::vector<sparrow::array> arrays = get_arrays_from_record_batch( | ||
| *record_batch, | ||
| *schema, | ||
| encapsulated_message, | ||
| fields_metadata | ||
| ); | ||
| auto names_copy = field_names; // TODO: Remove when issue with the to_vector of record_batch is fixed | ||
| auto names_copy = field_names; | ||
| sparrow::record_batch sp_record_batch(std::move(names_copy), std::move(arrays)); | ||
| record_batches.emplace_back(std::move(sp_record_batch)); | ||
| } | ||
| break; | ||
| case org::apache::arrow::flatbuf::MessageHeader::Tensor: | ||
| case org::apache::arrow::flatbuf::MessageHeader::DictionaryBatch: | ||
| case org::apache::arrow::flatbuf::MessageHeader::SparseTensor: | ||
| throw std::runtime_error("Not supported"); | ||
| throw std::runtime_error("Unsupported message type: Tensor, DictionaryBatch, or SparseTensor"); | ||
| default: | ||
| throw std::runtime_error("Unknown message header type."); | ||
| } | ||
| data = rest; | ||
| } while (!data.empty()); | ||
| } | ||
| return record_batches; | ||
| } | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems this function is exactly the same (except from data type) as the one used in
primitive_array, so we should probably refactor and use some generic function.