Skip to content

Commit c3a4fbd

Browse files
committed
Add read_parquet test
1 parent 45fc307 commit c3a4fbd

File tree

1 file changed

+102
-0
lines changed

1 file changed

+102
-0
lines changed
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
import operator
7+
from typing import TYPE_CHECKING
8+
9+
import numpy as np
10+
import pytest
11+
12+
import pylibcudf as plc
13+
14+
from rapidsmpf.streaming.core.channel import Channel
15+
from rapidsmpf.streaming.core.leaf_node import pull_from_channel
16+
from rapidsmpf.streaming.core.node import run_streaming_pipeline
17+
from rapidsmpf.streaming.cudf.parquet import read_parquet
18+
from rapidsmpf.streaming.cudf.table_chunk import TableChunk
19+
20+
if TYPE_CHECKING:
21+
from typing import Literal
22+
23+
from rapidsmpf.streaming.core.context import Context
24+
25+
26+
@pytest.fixture(scope="module")
27+
def source(
28+
tmp_path_factory: pytest.TempPathFactory,
29+
) -> plc.io.SourceInfo:
30+
path = tmp_path_factory.mktemp("read_parquet")
31+
32+
nrows = 10
33+
start = 0
34+
sources = []
35+
for i in range(10):
36+
table = plc.Table(
37+
[plc.Column.from_array(np.arange(start, start + nrows, dtype="int32"))]
38+
)
39+
# gaps in the column numbering we produce
40+
start += nrows + nrows // 2
41+
filename = path / f"{i:3d}.pq"
42+
sink = plc.io.SinkInfo([filename])
43+
options = plc.io.parquet.ParquetWriterOptions.builder(sink, table).build()
44+
plc.io.parquet.write_parquet(options)
45+
sources.append(filename)
46+
return plc.io.SourceInfo(sources)
47+
48+
49+
@pytest.mark.parametrize(
50+
"skip_rows", ["none", 7, 19, 113], ids=lambda s: f"skip_rows_{s}"
51+
)
52+
@pytest.mark.parametrize("num_rows", ["all", 0, 3, 31, 83], ids=lambda s: f"nrows_{s}")
53+
def test_read_parquet(
54+
context: Context,
55+
source: plc.io.SourceInfo,
56+
skip_rows: int | Literal["none"],
57+
num_rows: int | Literal["all"],
58+
) -> None:
59+
ch = Channel[TableChunk]()
60+
61+
options = plc.io.parquet.ParquetReaderOptions.builder(source).build()
62+
63+
if skip_rows != "none":
64+
options.set_skip_rows(skip_rows)
65+
if num_rows != "all":
66+
options.set_num_rows(num_rows)
67+
expected = plc.io.parquet.read_parquet(options).tbl
68+
69+
producer = read_parquet(context, ch, 4, options, 3)
70+
71+
consumer, messages = pull_from_channel(context, ch)
72+
73+
run_streaming_pipeline(nodes=[producer, consumer])
74+
75+
chunks = [TableChunk.from_message(m) for m in messages.release()]
76+
for chunk in chunks:
77+
chunk.stream.synchronize()
78+
79+
got = plc.concatenate.concatenate(
80+
[
81+
chunk.table_view()
82+
for chunk in sorted(chunks, key=operator.attrgetter("sequence_number"))
83+
]
84+
)
85+
for chunk in chunks:
86+
chunk.stream.synchronize()
87+
88+
assert got.num_rows() == expected.num_rows()
89+
assert got.num_columns() == expected.num_columns()
90+
assert got.num_columns() == 1
91+
92+
all_equal = plc.reduce.reduce(
93+
plc.binaryop.binary_operation(
94+
got.columns()[0],
95+
expected.columns()[0],
96+
plc.binaryop.BinaryOperator.EQUAL,
97+
plc.DataType(plc.TypeId.BOOL8),
98+
),
99+
plc.aggregation.all(),
100+
plc.DataType(plc.TypeId.BOOL8),
101+
)
102+
assert all_equal.to_py()

0 commit comments

Comments
 (0)