Skip to content

Commit 17a4099

Browse files
committed
Add read_parquet test
Edge cases still need sorted
1 parent e399465 commit 17a4099

File tree

2 files changed

+112
-0
lines changed

2 files changed

+112
-0
lines changed

cpp/src/streaming/cudf/parquet.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ Node read_parquet(
129129
auto options_num_rows =
130130
options.get_num_rows().value_or(std::numeric_limits<int64_t>::max());
131131
std::uint64_t sequence_number = 0;
132+
// TODO: Handle case where total num rows is zero and/or where we skip all the rows in
133+
// the file.
132134
for (file_offset = 0; file_offset < files_per_rank; file_offset += files_per_split) {
133135
if (options_num_rows == 0) {
134136
break;
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
import operator
7+
from typing import TYPE_CHECKING
8+
9+
import numpy as np
10+
import pytest
11+
12+
import pylibcudf as plc
13+
14+
from rapidsmpf.streaming.core.channel import Channel
15+
from rapidsmpf.streaming.core.leaf_node import pull_from_channel
16+
from rapidsmpf.streaming.core.node import run_streaming_pipeline
17+
from rapidsmpf.streaming.cudf.parquet import read_parquet
18+
from rapidsmpf.streaming.cudf.table_chunk import TableChunk
19+
20+
if TYPE_CHECKING:
21+
from rapidsmpf.streaming.core.context import Context
22+
23+
24+
@pytest.fixture(scope="module")
25+
def source_full_table(
26+
tmp_path_factory: pytest.TempPathFactory,
27+
) -> tuple[plc.io.SourceInfo, plc.Table]:
28+
path = tmp_path_factory.mktemp("read_parquet")
29+
30+
nrows = 10
31+
start = 0
32+
sources = []
33+
tables = []
34+
for i in range(10):
35+
table = plc.Table(
36+
[plc.Column.from_array(np.arange(start, start + nrows, dtype="int32"))]
37+
)
38+
tables.append(table)
39+
# gaps in the column numbering we produce
40+
start += nrows + nrows // 2
41+
filename = path / f"{i:3d}.pq"
42+
sink = plc.io.SinkInfo([filename])
43+
options = plc.io.parquet.ParquetWriterOptions.builder(sink, table).build()
44+
plc.io.parquet.write_parquet(options)
45+
sources.append(filename)
46+
return plc.io.SourceInfo(sources), plc.concatenate.concatenate(tables)
47+
48+
49+
@pytest.mark.parametrize("skip_rows", [None, 7, 19, 113], ids=lambda s: f"skip_{s}")
50+
@pytest.mark.parametrize("num_rows", [None, 0, 3, 31, 83], ids=lambda s: f"nrows_{s}")
51+
def test_read_parquet(
52+
context: Context,
53+
source_full_table: tuple[plc.io.SourceInfo, plc.Table],
54+
skip_rows: int | None,
55+
num_rows: int | None,
56+
) -> None:
57+
ch = Channel[TableChunk]()
58+
59+
source, expected = source_full_table
60+
options = plc.io.parquet.ParquetReaderOptions.builder(source).build()
61+
62+
if skip_rows is not None:
63+
options.set_skip_rows(skip_rows)
64+
(expected,) = plc.copying.slice(
65+
expected, [min(skip_rows, expected.num_rows()), expected.num_rows()]
66+
)
67+
if num_rows is not None:
68+
options.set_num_rows(num_rows)
69+
(expected,) = plc.copying.slice(
70+
expected, [0, min(num_rows, expected.num_rows())]
71+
)
72+
producer = read_parquet(context, ch, 4, options, 3)
73+
74+
consumer, messages = pull_from_channel(context, ch)
75+
76+
run_streaming_pipeline(nodes=[producer, consumer])
77+
78+
chunks = [TableChunk.from_message(m) for m in messages.release()]
79+
for chunk in chunks:
80+
chunk.stream.synchronize()
81+
82+
views = [
83+
chunk.table_view()
84+
for chunk in sorted(chunks, key=operator.attrgetter("sequence_number"))
85+
]
86+
if views:
87+
result = plc.concatenate.concatenate(views)
88+
else:
89+
result = plc.Table([])
90+
for chunk in chunks:
91+
chunk.stream.synchronize()
92+
93+
assert result.num_rows() == expected.num_rows()
94+
assert result.num_columns() == expected.num_columns()
95+
assert result.num_columns() == 1
96+
97+
(got,) = result.columns()
98+
(expect,) = expected.columns()
99+
100+
all_equal = plc.reduce.reduce(
101+
plc.binaryop.binary_operation(
102+
got,
103+
expect,
104+
plc.binaryop.BinaryOperator.EQUAL,
105+
plc.DataType(plc.TypeId.BOOL8),
106+
),
107+
plc.aggregation.all(),
108+
plc.DataType(plc.TypeId.BOOL8),
109+
)
110+
assert all_equal.to_py()

0 commit comments

Comments
 (0)