Skip to content

Commit 1b108d4

Browse files
committed
wip
1 parent 93994d3 commit 1b108d4

File tree

1 file changed

+59
-38
lines changed

1 file changed

+59
-38
lines changed

src/deserialize.cpp

Lines changed: 59 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,23 @@
1212

1313
namespace sparrow_ipc
1414
{
15+
namespace
16+
{
17+
// Integer bit width constants
18+
constexpr int32_t BIT_WIDTH_8 = 8;
19+
constexpr int32_t BIT_WIDTH_16 = 16;
20+
constexpr int32_t BIT_WIDTH_32 = 32;
21+
constexpr int32_t BIT_WIDTH_64 = 64;
22+
23+
// End-of-stream marker size in bytes
24+
constexpr size_t END_OF_STREAM_MARKER_SIZE = 8;
25+
}
1526
const org::apache::arrow::flatbuf::RecordBatch*
1627
deserialize_record_batch_message(std::span<const uint8_t> data, size_t& current_offset)
1728
{
1829
current_offset += sizeof(uint32_t);
19-
const auto batch_message = org::apache::arrow::flatbuf::GetMessage(data.data() + current_offset);
30+
const auto message_data = data.subspan(current_offset);
31+
const auto* batch_message = org::apache::arrow::flatbuf::GetMessage(message_data.data());
2032
if (batch_message->header_type() != org::apache::arrow::flatbuf::MessageHeader::RecordBatch)
2133
{
2234
throw std::runtime_error("Expected RecordBatch message, but got a different type.");
@@ -29,21 +41,21 @@ namespace sparrow_ipc
2941
*
3042
* This function processes each field in the schema and deserializes the corresponding
3143
* data from the RecordBatch into sparrow::array objects. It handles various Arrow data
32-
* types including primitive types (bool, integers, floating point), binary data, and
33-
* string data with their respective size variants.
44+
* types including primitive types (bool, integers, floating point), binary data, string
45+
* data, fixed-size binary data, and interval types.
3446
*
3547
* @param record_batch The Apache Arrow FlatBuffer RecordBatch containing the serialized data
3648
* @param schema The Apache Arrow FlatBuffer Schema defining the structure and types of the data
3749
* @param encapsulated_message The message containing the binary data buffers
38-
* @param field_metadata Metadata for each field
50+
* @param field_metadata Metadata associated with each field in the schema
3951
*
4052
* @return std::vector<sparrow::array> A vector of deserialized arrays, one for each field in the schema
4153
*
42-
* @throws std::runtime_error If an unsupported data type, integer bit width, or floating point precision
43-
* is encountered
54+
* @throws std::runtime_error If an unsupported data type, integer bit width, floating point precision,
55+
* or interval unit is encountered
4456
*
45-
* The function maintains a buffer index that is incremented as it processes each field
46-
* to correctly map data buffers to their corresponding arrays.
57+
* @note The function maintains a buffer index that is incremented as it processes each field
58+
* to correctly map data buffers to their corresponding arrays.
4759
*/
4860
std::vector<sparrow::array> get_arrays_from_record_batch(
4961
const org::apache::arrow::flatbuf::RecordBatch& record_batch,
@@ -65,7 +77,7 @@ namespace sparrow_ipc
6577
const std::string name = field->name() == nullptr ? "" : field->name()->str();
6678
const bool nullable = field->nullable();
6779
const auto field_type = field->type_type();
68-
// TODO rename all the deserialize_non_owning... fcts since this is not correct anymore
80+
6981
const auto deserialize_non_owning_primitive_array_lambda = [&]<typename T>()
7082
{
7183
return deserialize_non_owning_primitive_array<T>(
@@ -86,7 +98,7 @@ namespace sparrow_ipc
8698
break;
8799
case org::apache::arrow::flatbuf::Type::Int:
88100
{
89-
const auto int_type = field->type_as_Int();
101+
const auto* int_type = field->type_as_Int();
90102
const auto bit_width = int_type->bitWidth();
91103
const bool is_signed = int_type->is_signed();
92104

@@ -95,11 +107,11 @@ namespace sparrow_ipc
95107
switch (bit_width)
96108
{
97109
// clang-format off
98-
case 8: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int8_t>()); break;
99-
case 16: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int16_t>()); break;
100-
case 32: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int32_t>()); break;
101-
case 64: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int64_t>()); break;
102-
default: throw std::runtime_error("Unsupported integer bit width.");
110+
case BIT_WIDTH_8: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int8_t>()); break;
111+
case BIT_WIDTH_16: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int16_t>()); break;
112+
case BIT_WIDTH_32: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int32_t>()); break;
113+
case BIT_WIDTH_64: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int64_t>()); break;
114+
default: throw std::runtime_error("Unsupported integer bit width: " + std::to_string(bit_width));
103115
// clang-format on
104116
}
105117
}
@@ -108,19 +120,19 @@ namespace sparrow_ipc
108120
switch (bit_width)
109121
{
110122
// clang-format off
111-
case 8: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint8_t>()); break;
112-
case 16: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint16_t>()); break;
113-
case 32: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint32_t>()); break;
114-
case 64: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint64_t>()); break;
115-
default: throw std::runtime_error("Unsupported integer bit width.");
123+
case BIT_WIDTH_8: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint8_t>()); break;
124+
case BIT_WIDTH_16: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint16_t>()); break;
125+
case BIT_WIDTH_32: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint32_t>()); break;
126+
case BIT_WIDTH_64: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint64_t>()); break;
127+
default: throw std::runtime_error("Unsupported integer bit width: " + std::to_string(bit_width));
116128
// clang-format on
117129
}
118130
}
119131
}
120132
break;
121133
case org::apache::arrow::flatbuf::Type::FloatingPoint:
122134
{
123-
const auto float_type = field->type_as_FloatingPoint();
135+
const auto* float_type = field->type_as_FloatingPoint();
124136
switch (float_type->precision())
125137
{
126138
// clang-format off
@@ -134,14 +146,17 @@ namespace sparrow_ipc
134146
arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<double>());
135147
break;
136148
default:
137-
throw std::runtime_error("Unsupported floating point precision.");
149+
throw std::runtime_error(
150+
"Unsupported floating point precision: "
151+
+ std::to_string(static_cast<int>(float_type->precision()))
152+
);
138153
// clang-format on
139154
}
140155
break;
141156
}
142157
case org::apache::arrow::flatbuf::Type::FixedSizeBinary:
143158
{
144-
const auto fixed_size_binary_field = field->type_as_FixedSizeBinary();
159+
const auto* fixed_size_binary_field = field->type_as_FixedSizeBinary();
145160
arrays.emplace_back(deserialize_non_owning_fixedwidthbinary(
146161
record_batch,
147162
encapsulated_message.body(),
@@ -203,8 +218,8 @@ namespace sparrow_ipc
203218
break;
204219
case org::apache::arrow::flatbuf::Type::Interval:
205220
{
206-
const auto interval_type = field->type_as_Interval();
207-
org::apache::arrow::flatbuf::IntervalUnit interval_unit = interval_type->unit();
221+
const auto* interval_type = field->type_as_Interval();
222+
const org::apache::arrow::flatbuf::IntervalUnit interval_unit = interval_type->unit();
208223
switch (interval_unit)
209224
{
210225
case org::apache::arrow::flatbuf::IntervalUnit::YEAR_MONTH:
@@ -241,12 +256,18 @@ namespace sparrow_ipc
241256
);
242257
break;
243258
default:
244-
throw std::runtime_error("Unsupported interval unit.");
259+
throw std::runtime_error(
260+
"Unsupported interval unit: "
261+
+ std::to_string(static_cast<int>(interval_unit))
262+
);
245263
}
246264
}
247265
break;
248266
default:
249-
throw std::runtime_error("Unsupported type.");
267+
throw std::runtime_error(
268+
"Unsupported field type: " + std::to_string(static_cast<int>(field_type))
269+
+ " for field '" + name + "'"
270+
);
250271
}
251272
}
252273
return arrays;
@@ -260,11 +281,12 @@ namespace sparrow_ipc
260281
std::vector<bool> fields_nullable;
261282
std::vector<sparrow::data_type> field_types;
262283
std::vector<std::optional<std::vector<sparrow::metadata_pair>>> fields_metadata;
263-
do
284+
285+
while (!data.empty())
264286
{
265-
// Check for end-of-stream marker here as data could contain only that (if no record batches
266-
// present/written)
267-
if (data.size() >= 8 && is_end_of_stream(data.subspan(0, 8)))
287+
// Check for end-of-stream marker
288+
if (data.size() >= END_OF_STREAM_MARKER_SIZE
289+
&& is_end_of_stream(data.subspan(0, END_OF_STREAM_MARKER_SIZE)))
268290
{
269291
break;
270292
}
@@ -312,34 +334,33 @@ namespace sparrow_ipc
312334
{
313335
if (schema == nullptr)
314336
{
315-
throw std::runtime_error("Schema message is missing.");
337+
throw std::runtime_error("RecordBatch encountered before Schema message.");
316338
}
317-
const auto record_batch = message->header_as_RecordBatch();
339+
const auto* record_batch = message->header_as_RecordBatch();
318340
if (record_batch == nullptr)
319341
{
320-
throw std::runtime_error("RecordBatch message is missing.");
342+
throw std::runtime_error("RecordBatch message header is null.");
321343
}
322344
std::vector<sparrow::array> arrays = get_arrays_from_record_batch(
323345
*record_batch,
324346
*schema,
325347
encapsulated_message,
326348
fields_metadata
327349
);
328-
auto names_copy = field_names; // TODO: Remove when issue with the to_vector of
329-
// record_batch is fixed
350+
auto names_copy = field_names;
330351
sparrow::record_batch sp_record_batch(std::move(names_copy), std::move(arrays));
331352
record_batches.emplace_back(std::move(sp_record_batch));
332353
}
333354
break;
334355
case org::apache::arrow::flatbuf::MessageHeader::Tensor:
335356
case org::apache::arrow::flatbuf::MessageHeader::DictionaryBatch:
336357
case org::apache::arrow::flatbuf::MessageHeader::SparseTensor:
337-
throw std::runtime_error("Not supported");
358+
throw std::runtime_error("Unsupported message type: Tensor, DictionaryBatch, or SparseTensor");
338359
default:
339360
throw std::runtime_error("Unknown message header type.");
340361
}
341362
data = rest;
342-
} while (!data.empty());
363+
}
343364
return record_batches;
344365
}
345366
}

0 commit comments

Comments
 (0)