Skip to content

Commit e1edb18

Browse files
committed
Expose read_parquet node to python
1 parent 679c4f4 commit e1edb18

File tree

4 files changed

+186
-1
lines changed

4 files changed

+186
-1
lines changed

python/rapidsmpf/rapidsmpf/streaming/cudf/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# cmake-format: on
66
# =================================================================================
77

8-
set(cython_modules partition.pyx table_chunk.pyx)
8+
set(cython_modules parquet.pyx partition.pyx table_chunk.pyx)
99

1010
rapids_cython_create_modules(
1111
CXX
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from pylibcudf.io.parquet import ParquetReaderOptions
5+
6+
from rapidsmpf.streaming.core.channel import Channel
7+
from rapidsmpf.streaming.core.context import Context
8+
from rapidsmpf.streaming.core.node import CppNode
9+
from rapidsmpf.streaming.cudf.table_chunk import TableChunk
10+
11+
def read_parquet(
12+
ctx: Context,
13+
ch_out: Channel[TableChunk],
14+
num_producers: int,
15+
options: ParquetReaderOptions,
16+
num_rows_per_chunk: int,
17+
) -> CppNode: ...
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from libc.stddef cimport size_t
5+
from libcpp.memory cimport make_unique, shared_ptr
6+
from libcpp.utility cimport move
7+
from pylibcudf.io.parquet cimport ParquetReaderOptions
8+
from pylibcudf.libcudf.io.parquet cimport parquet_reader_options
9+
from pylibcudf.libcudf.types cimport size_type
10+
11+
from rapidsmpf.streaming.core.channel cimport Channel, cpp_Channel
12+
from rapidsmpf.streaming.core.context cimport Context, cpp_Context
13+
from rapidsmpf.streaming.core.node cimport CppNode, cpp_Node
14+
15+
16+
cdef extern from "<rapidsmpf/streaming/cudf/parquet.hpp>" nogil:
17+
cdef cpp_Node cpp_read_parquet \
18+
"rapidsmpf::streaming::node::read_parquet"(
19+
shared_ptr[cpp_Context] ctx,
20+
shared_ptr[cpp_Channel] ch_out,
21+
size_t num_producers,
22+
parquet_reader_options options,
23+
size_type num_rows_per_chunk,
24+
) except +
25+
26+
27+
def read_parquet(
28+
Context ctx not None,
29+
Channel ch_out not None,
30+
size_t num_producers,
31+
ParquetReaderOptions options not None,
32+
size_type num_rows_per_chunk
33+
):
34+
"""
35+
Create a streaming node to read from parquet.
36+
37+
Parameters
38+
----------
39+
ctx
40+
Streaming execution context
41+
ch_out
42+
Output channel to receive the TableChunks.
43+
num_producers
44+
Number of concurrent producers of output chunks.
45+
options
46+
Reader options
47+
num_rows_per_chunk
48+
Target (maximum) number of rows per output chunk.
49+
50+
Notes
51+
-----
52+
This is a collective operation, all ranks participating via the
53+
execution context's communicator must call it with the same options.
54+
"""
55+
cdef cpp_Node _ret
56+
with nogil:
57+
_ret = cpp_read_parquet(
58+
ctx._handle,
59+
ch_out._handle,
60+
num_producers,
61+
options.c_obj,
62+
num_rows_per_chunk,
63+
)
64+
return CppNode.from_handle(
65+
make_unique[cpp_Node](move(_ret)), owner=None
66+
)
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
import itertools
7+
from typing import TYPE_CHECKING
8+
9+
import numpy as np
10+
import pytest
11+
12+
import pylibcudf as plc
13+
14+
from rapidsmpf.streaming.core.channel import Channel
15+
from rapidsmpf.streaming.core.leaf_node import pull_from_channel
16+
from rapidsmpf.streaming.core.node import run_streaming_pipeline
17+
from rapidsmpf.streaming.cudf.parquet import read_parquet
18+
from rapidsmpf.streaming.cudf.table_chunk import TableChunk
19+
20+
if TYPE_CHECKING:
21+
from typing import Literal
22+
23+
from rapidsmpf.streaming.core.context import Context
24+
25+
26+
@pytest.fixture(scope="module")
27+
def source(
28+
tmp_path_factory: pytest.TempPathFactory,
29+
) -> plc.io.SourceInfo:
30+
path = tmp_path_factory.mktemp("read_parquet")
31+
32+
nrows = 10
33+
start = 0
34+
sources = []
35+
for i in range(10):
36+
table = plc.Table(
37+
[plc.Column.from_array(np.arange(start, start + nrows, dtype="int32"))]
38+
)
39+
# gaps in the column numbering we produce
40+
start += nrows + nrows // 2
41+
filename = path / f"{i:3d}.pq"
42+
sink = plc.io.SinkInfo([filename])
43+
options = plc.io.parquet.ParquetWriterOptions.builder(sink, table).build()
44+
plc.io.parquet.write_parquet(options)
45+
sources.append(filename)
46+
return plc.io.SourceInfo(sources)
47+
48+
49+
@pytest.mark.parametrize(
50+
"skip_rows", ["none", 7, 19, 113], ids=lambda s: f"skip_rows_{s}"
51+
)
52+
@pytest.mark.parametrize("num_rows", ["all", 0, 3, 31, 83], ids=lambda s: f"nrows_{s}")
53+
def test_read_parquet(
54+
context: Context,
55+
source: plc.io.SourceInfo,
56+
skip_rows: int | Literal["none"],
57+
num_rows: int | Literal["all"],
58+
) -> None:
59+
ch = Channel[TableChunk]()
60+
61+
options = plc.io.parquet.ParquetReaderOptions.builder(source).build()
62+
63+
if skip_rows != "none":
64+
options.set_skip_rows(skip_rows)
65+
if num_rows != "all":
66+
options.set_num_rows(num_rows)
67+
expected = plc.io.parquet.read_parquet(options).tbl
68+
69+
producer = read_parquet(context, ch, 4, options, 3)
70+
71+
consumer, deferred_messages = pull_from_channel(context, ch)
72+
73+
run_streaming_pipeline(nodes=[producer, consumer])
74+
75+
messages = deferred_messages.release()
76+
assert all(
77+
m1.sequence_number < m2.sequence_number
78+
for m1, m2 in itertools.pairwise(messages)
79+
)
80+
chunks = [TableChunk.from_message(m) for m in messages]
81+
for chunk in chunks:
82+
chunk.stream.synchronize()
83+
84+
got = plc.concatenate.concatenate([chunk.table_view() for chunk in chunks])
85+
for chunk in chunks:
86+
chunk.stream.synchronize()
87+
88+
assert got.num_rows() == expected.num_rows()
89+
assert got.num_columns() == expected.num_columns()
90+
assert got.num_columns() == 1
91+
92+
all_equal = plc.reduce.reduce(
93+
plc.binaryop.binary_operation(
94+
got.columns()[0],
95+
expected.columns()[0],
96+
plc.binaryop.BinaryOperator.EQUAL,
97+
plc.DataType(plc.TypeId.BOOL8),
98+
),
99+
plc.aggregation.all(),
100+
plc.DataType(plc.TypeId.BOOL8),
101+
)
102+
assert all_equal.to_py()

0 commit comments

Comments
 (0)