Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def configure(self):
self.options.rm_safe("fPIC")

def requirements(self):
self.requires("sparrow/1.0.0")
self.requires("sparrow/1.2.0", options={"json_reader": True})
self.requires(f"flatbuffers/{self._flatbuffers_version}")
self.requires("lz4/1.9.4")
self.requires("zstd/1.5.7")
Expand Down
86 changes: 86 additions & 0 deletions include/sparrow_ipc/deserialize_decimal_array.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#pragma once

#include <span>

#include <sparrow/arrow_interface/arrow_array_schema_proxy.hpp>
#include <sparrow/decimal_array.hpp>

#include "Message_generated.h"
#include "sparrow_ipc/arrow_interface/arrow_array.hpp"
#include "sparrow_ipc/arrow_interface/arrow_schema.hpp"
#include "sparrow_ipc/deserialize_utils.hpp"

namespace sparrow_ipc
{
template <sparrow::decimal_type T>
[[nodiscard]] sparrow::decimal_array<T> deserialize_non_owning_decimal(
const org::apache::arrow::flatbuf::RecordBatch& record_batch,
std::span<const uint8_t> body,
std::string_view name,
const std::optional<std::vector<sparrow::metadata_pair>>& metadata,
bool nullable,
size_t& buffer_index,
int32_t scale,
int32_t precision
)
{
constexpr std::size_t sizeof_decimal = sizeof(typename T::integer_type);
std::string format_str = "d:" + std::to_string(precision) + "," + std::to_string(scale);
if constexpr (sizeof_decimal != 16) // We don't need to specify the size for 128-bit
// decimals
{
format_str += "," + std::to_string(sizeof_decimal * 8);
}

// Set up flags based on nullable
std::optional<std::unordered_set<sparrow::ArrowFlag>> flags;
if (nullable)
{
flags = std::unordered_set<sparrow::ArrowFlag>{sparrow::ArrowFlag::NULLABLE};
}

ArrowSchema schema = make_non_owning_arrow_schema(
format_str,
name.data(),
metadata,
flags,
0,
nullptr,
nullptr
);

const auto compression = record_batch.compression();
std::vector<arrow_array_private_data::optionally_owned_buffer> buffers;

auto validity_buffer_span = utils::get_buffer(record_batch, body, buffer_index);
auto data_buffer_span = utils::get_buffer(record_batch, body, buffer_index);

if (compression)
{
buffers.push_back(utils::get_decompressed_buffer(validity_buffer_span, compression));
buffers.push_back(utils::get_decompressed_buffer(data_buffer_span, compression));
}
else
{
buffers.emplace_back(validity_buffer_span);
buffers.emplace_back(data_buffer_span);
}

const auto [bitmap_ptr, null_count] = utils::get_bitmap_pointer_and_null_count(
validity_buffer_span,
record_batch.length()
);

ArrowArray array = make_arrow_array<arrow_array_private_data>(
record_batch.length(),
null_count,
0,
0,
nullptr,
nullptr,
std::move(buffers)
);
sparrow::arrow_proxy ap{std::move(array), std::move(schema)};
return sparrow::decimal_array<T>(std::move(ap));
}
}
53 changes: 53 additions & 0 deletions include/sparrow_ipc/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <cstdint>
#include <optional>
#include <string_view>
#include <vector>

#include <sparrow/record_batch.hpp>

Expand All @@ -13,6 +14,39 @@ namespace sparrow_ipc::utils
// Aligns a value to the next multiple of 8, as required by the Arrow IPC format for message bodies
SPARROW_IPC_API size_t align_to_8(const size_t n);

/**
* @brief Extracts words after ':' separated by ',' from a string.
*
* This function finds the position of ':' in the input string and then
* splits the remaining part by ',' to extract individual words.
*
* @param str Input string to parse (e.g., "prefix:word1,word2,word3")
* @return std::vector<std::string_view> Vector of string views containing the extracted words
* Returns an empty vector if ':' is not found or if there are no words after it
*
* @example
* extract_words_after_colon("d:128,10") returns {"128", "10"}
* extract_words_after_colon("w:256") returns {"256"}
* extract_words_after_colon("no_colon") returns {}
*/
SPARROW_IPC_API std::vector<std::string_view> extract_words_after_colon(std::string_view str);

/**
* @brief Parse a string_view to int32_t using std::from_chars.
*
* This function converts a string view to a 32-bit integer using std::from_chars
* for efficient parsing.
*
* @param str The string view to parse
* @return std::optional<int32_t> The parsed integer value, or std::nullopt if parsing fails
*
* @example
* parse_to_int32("123") returns std::optional<int32_t>(123)
* parse_to_int32("abc") returns std::nullopt
* parse_to_int32("") returns std::nullopt
*/
SPARROW_IPC_API std::optional<int32_t> parse_to_int32(std::string_view str);

/**
* @brief Checks if all record batches in a collection have consistent structure.
*
Expand Down Expand Up @@ -63,5 +97,24 @@ namespace sparrow_ipc::utils
// Parse the format string
// The format string is expected to be "w:size", "+w:size", "d:precision,scale", etc
std::optional<int32_t> parse_format(std::string_view format_str, std::string_view sep);

/**
* @brief Parse decimal format strings.
*
* This function parses decimal format strings which can be in two formats:
* - "d:precision,scale" (e.g., "d:19,10")
* - "d:precision,scale,bitWidth" (e.g., "d:19,10,128")
*
* @param format_str The format string to parse
* @return std::optional<std::tuple<int32_t, int32_t, std::optional<int32_t>>>
* A tuple containing (precision, scale, optional bitWidth), or std::nullopt if parsing fails
*
* @example
* parse_decimal_format("d:19,10") returns std::optional{std::tuple{19, 10, std::nullopt}}
* parse_decimal_format("d:19,10,128") returns std::optional{std::tuple{19, 10, std::optional{128}}}
* parse_decimal_format("invalid") returns std::nullopt
*/
SPARROW_IPC_API std::optional<std::tuple<int32_t, int32_t, std::optional<int32_t>>> parse_decimal_format(std::string_view format_str);

// size_t calculate_output_serialized_size(const sparrow::record_batch& record_batch);
}
68 changes: 68 additions & 0 deletions src/deserialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <sparrow/types/data_type.hpp>

#include "sparrow_ipc/deserialize_decimal_array.hpp"
#include "sparrow_ipc/deserialize_fixedsizebinary_array.hpp"
#include "sparrow_ipc/deserialize_primitive_array.hpp"
#include "sparrow_ipc/deserialize_variable_size_binary_array.hpp"
Expand Down Expand Up @@ -205,6 +206,73 @@ namespace sparrow_ipc
)
);
break;
case org::apache::arrow::flatbuf::Type::Decimal:
{
const auto decimal_field = field->type_as_Decimal();
const auto scale = decimal_field->scale();
const auto precision = decimal_field->precision();
if (decimal_field->bitWidth() == 32)
{
arrays.emplace_back(
deserialize_non_owning_decimal<sparrow::decimal<int32_t>>(
record_batch,
encapsulated_message.body(),
name,
metadata,
nullable,
buffer_index,
scale,
precision
)
);
}
else if (decimal_field->bitWidth() == 64)
{
arrays.emplace_back(
deserialize_non_owning_decimal<sparrow::decimal<int64_t>>(
record_batch,
encapsulated_message.body(),
name,
metadata,
nullable,
buffer_index,
scale,
precision
)
);
}
else if (decimal_field->bitWidth() == 128)
{
arrays.emplace_back(
deserialize_non_owning_decimal<sparrow::decimal<sparrow::int128_t>>(
record_batch,
encapsulated_message.body(),
name,
metadata,
nullable,
buffer_index,
scale,
precision
)
);
}
else if (decimal_field->bitWidth() == 256)
{
arrays.emplace_back(
deserialize_non_owning_decimal<sparrow::decimal<sparrow::int256_t>>(
record_batch,
encapsulated_message.body(),
name,
metadata,
nullable,
buffer_index,
scale,
precision
)
);
}
break;
}
default:
throw std::runtime_error("Unsupported type.");
}
Expand Down
31 changes: 14 additions & 17 deletions src/flatbuffer_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,37 +370,34 @@ namespace sparrow_ipc
}

// Creates a Flatbuffers Decimal type from a format string
// The format string is expected to be in the format "d:precision,scale"
// The format string is expected to be in the format "d:precision,scale" or "d:precision,scale,bitWidth"
std::pair<org::apache::arrow::flatbuf::Type, flatbuffers::Offset<void>> get_flatbuffer_decimal_type(
flatbuffers::FlatBufferBuilder& builder,
std::string_view format_str,
const int32_t bitWidth
)
{
// Decimal requires precision and scale. We need to parse the format_str.
// Format: "d:precision,scale"
const auto scale = utils::parse_format(format_str, ",");
if (!scale.has_value())
// Format: "d:precision,scale" or "d:precision,scale,bitWidth"
const auto parsed = utils::parse_decimal_format(format_str);
if (!parsed.has_value())
{
throw std::runtime_error(
"Failed to parse Decimal " + std::to_string(bitWidth)
+ " scale from format string: " + std::string(format_str)
);
}
const size_t comma_pos = format_str.find(',');
const auto precision = utils::parse_format(format_str.substr(0, comma_pos), ":");
if (!precision.has_value())
{
throw std::runtime_error(
"Failed to parse Decimal " + std::to_string(bitWidth)
+ " precision from format string: " + std::string(format_str)
+ " format string: " + std::string(format_str)
);
}

const auto& [precision, scale, parsed_bitwidth] = parsed.value();

// Use the bitWidth from the format string if provided, otherwise use the parameter
const int32_t actual_bitwidth = parsed_bitwidth.value_or(bitWidth);

const auto decimal_type = org::apache::arrow::flatbuf::CreateDecimal(
builder,
precision.value(),
scale.value(),
bitWidth
precision,
scale,
actual_bitwidth
);
return {org::apache::arrow::flatbuf::Type::Decimal, decimal_type.Union()};
}
Expand Down
Loading