|
3 | 3 | import logging |
4 | 4 | from typing import Any, Callable, Dict, Iterator, List, Optional, Union |
5 | 5 |
|
6 | | -import numpy as np |
7 | | -import pyarrow as pa |
8 | | - |
9 | 6 | # fs required to implicitly trigger S3 subsystem initialization |
10 | 7 | import pyarrow.fs # noqa: F401 pylint: disable=unused-import |
11 | | -import pyarrow.parquet as pq |
12 | | -from ray import cloudpickle # pylint: disable=wrong-import-order,ungrouped-imports |
| 8 | +import ray |
| 9 | +from ray.data._internal.output_buffer import BlockOutputBuffer |
| 10 | +from ray.data._internal.remote_fn import cached_remote_fn |
13 | 11 | from ray.data.block import Block, BlockAccessor, BlockMetadata |
14 | 12 | from ray.data.context import DatasetContext |
15 | | -from ray.data.datasource import BlockWritePathProvider, DefaultBlockWritePathProvider |
16 | | -from ray.data.datasource.datasource import ReadTask, WriteResult |
| 13 | +from ray.data.datasource import BlockWritePathProvider, DefaultBlockWritePathProvider, Reader |
| 14 | +from ray.data.datasource.datasource import WriteResult |
17 | 15 | from ray.data.datasource.file_based_datasource import ( |
18 | 16 | _resolve_paths_and_filesystem, |
19 | 17 | _S3FileSystemWrapper, |
20 | 18 | _wrap_s3_serialization_workaround, |
21 | 19 | ) |
22 | | -from ray.data.datasource.file_meta_provider import DefaultParquetMetadataProvider, ParquetMetadataProvider |
23 | 20 | from ray.data.datasource.parquet_datasource import ( |
24 | | - _deregister_parquet_file_fragment_serialization, |
25 | | - _register_parquet_file_fragment_serialization, |
| 21 | + PARQUET_READER_ROW_BATCH_SIZE, |
| 22 | + _deserialize_pieces_with_retry, |
| 23 | + _ParquetDatasourceReader, |
| 24 | + _SerializedPiece, |
26 | 25 | ) |
27 | | -from ray.data.impl.output_buffer import BlockOutputBuffer |
28 | | -from ray.data.impl.remote_fn import cached_remote_fn |
29 | 26 | from ray.types import ObjectRef |
30 | 27 |
|
31 | 28 | from awswrangler._arrow import _add_table_partitions |
32 | 29 |
|
33 | 30 | _logger: logging.Logger = logging.getLogger(__name__) |
34 | 31 |
|
35 | | -# The number of rows to read per batch. This is sized to generate 10MiB batches |
36 | | -# for rows about 1KiB in size. |
37 | | -PARQUET_READER_ROW_BATCH_SIZE = 100000 |
| 32 | + |
| 33 | +# Original implementation: |
| 34 | +# https://github.com/ray-project/ray/blob/releases/2.0.0/python/ray/data/datasource/parquet_datasource.py |
| 35 | +def _read_pieces( |
| 36 | + block_udf: Optional[Callable[[Block[Any]], Block[Any]]], |
| 37 | + reader_args: Any, |
| 38 | + columns: Optional[List[str]], |
| 39 | + schema: Optional[Union[type, "pyarrow.lib.Schema"]], |
| 40 | + serialized_pieces: List[_SerializedPiece], |
| 41 | +) -> Iterator["pyarrow.Table"]: |
| 42 | + # This import is necessary to load the tensor extension type. |
| 43 | + from ray.data.extensions.tensor_extension import ( # type: ignore # noqa: F401, E501 # pylint: disable=import-outside-toplevel, unused-import |
| 44 | + ArrowTensorType, |
| 45 | + ) |
| 46 | + |
| 47 | + # Deserialize after loading the filesystem class. |
| 48 | + pieces: List["pyarrow._dataset.ParquetFileFragment"] = _deserialize_pieces_with_retry(serialized_pieces) |
| 49 | + |
| 50 | + # Ensure that we're reading at least one dataset fragment. |
| 51 | + assert len(pieces) > 0 |
| 52 | + |
| 53 | + import pyarrow as pa # pylint: disable=import-outside-toplevel |
| 54 | + |
| 55 | + ctx = DatasetContext.get_current() |
| 56 | + output_buffer = BlockOutputBuffer( |
| 57 | + block_udf=block_udf, |
| 58 | + target_max_block_size=ctx.target_max_block_size, |
| 59 | + ) |
| 60 | + |
| 61 | + _logger.debug("Reading %s parquet pieces", len(pieces)) |
| 62 | + use_threads = reader_args.pop("use_threads", False) |
| 63 | + path_root = reader_args.pop("path_root", None) |
| 64 | + for piece in pieces: |
| 65 | + batches = piece.to_batches( |
| 66 | + use_threads=use_threads, |
| 67 | + columns=columns, |
| 68 | + schema=schema, |
| 69 | + batch_size=PARQUET_READER_ROW_BATCH_SIZE, |
| 70 | + **reader_args, |
| 71 | + ) |
| 72 | + for batch in batches: |
| 73 | + # Table creation is wrapped inside _add_table_partitions |
| 74 | + # to add columns with partition values when dataset=True |
| 75 | + # and cast them to categorical |
| 76 | + table = _add_table_partitions( |
| 77 | + table=pa.Table.from_batches([batch], schema=schema), |
| 78 | + path=f"s3://{piece.path}", |
| 79 | + path_root=path_root, |
| 80 | + ) |
| 81 | + # If the table is empty, drop it. |
| 82 | + if table.num_rows > 0: |
| 83 | + output_buffer.add_block(table) |
| 84 | + if output_buffer.has_next(): |
| 85 | + yield output_buffer.next() |
| 86 | + output_buffer.finalize() |
| 87 | + if output_buffer.has_next(): |
| 88 | + yield output_buffer.next() |
| 89 | + |
| 90 | + |
| 91 | +# Patch _read_pieces function |
| 92 | +ray.data.datasource.parquet_datasource._read_pieces = _read_pieces # pylint: disable=protected-access |
38 | 93 |
|
39 | 94 |
|
40 | 95 | class UserProvidedKeyBlockWritePathProvider(BlockWritePathProvider): |
@@ -62,109 +117,9 @@ class ParquetDatasource: |
62 | 117 | def __init__(self) -> None: |
63 | 118 | self._write_paths: List[str] = [] |
64 | 119 |
|
65 | | - # Original: https://github.com/ray-project/ray/blob/releases/1.13.0/python/ray/data/datasource/parquet_datasource.py |
66 | | - def prepare_read( |
67 | | - self, |
68 | | - parallelism: int, |
69 | | - use_threads: Union[bool, int], |
70 | | - paths: Union[str, List[str]], |
71 | | - schema: "pyarrow.lib.Schema", |
72 | | - columns: Optional[List[str]] = None, |
73 | | - coerce_int96_timestamp_unit: Optional[str] = None, |
74 | | - path_root: Optional[str] = None, |
75 | | - filesystem: Optional["pyarrow.fs.FileSystem"] = None, |
76 | | - meta_provider: ParquetMetadataProvider = DefaultParquetMetadataProvider(), |
77 | | - _block_udf: Optional[Callable[..., Any]] = None, |
78 | | - ) -> List[ReadTask]: |
79 | | - """Create and return read tasks for a Parquet file-based datasource.""" |
80 | | - paths, filesystem = _resolve_paths_and_filesystem(paths, filesystem) |
81 | | - |
82 | | - parquet_dataset = pq.ParquetDataset( |
83 | | - path_or_paths=paths, |
84 | | - filesystem=filesystem, |
85 | | - partitioning=None, |
86 | | - coerce_int96_timestamp_unit=coerce_int96_timestamp_unit, |
87 | | - use_legacy_dataset=False, |
88 | | - ) |
89 | | - |
90 | | - def read_pieces(serialized_pieces: str) -> Iterator[pa.Table]: |
91 | | - # Deserialize after loading the filesystem class. |
92 | | - try: |
93 | | - _register_parquet_file_fragment_serialization() # type: ignore |
94 | | - pieces = cloudpickle.loads(serialized_pieces) |
95 | | - finally: |
96 | | - _deregister_parquet_file_fragment_serialization() # type: ignore |
97 | | - |
98 | | - # Ensure that we're reading at least one dataset fragment. |
99 | | - assert len(pieces) > 0 |
100 | | - |
101 | | - ctx = DatasetContext.get_current() |
102 | | - output_buffer = BlockOutputBuffer(block_udf=_block_udf, target_max_block_size=ctx.target_max_block_size) |
103 | | - |
104 | | - _logger.debug("Reading %s parquet pieces", len(pieces)) |
105 | | - for piece in pieces: |
106 | | - batches = piece.to_batches( |
107 | | - use_threads=use_threads, |
108 | | - columns=columns, |
109 | | - schema=schema, |
110 | | - batch_size=PARQUET_READER_ROW_BATCH_SIZE, |
111 | | - ) |
112 | | - for batch in batches: |
113 | | - # Table creation is wrapped inside _add_table_partitions |
114 | | - # to add columns with partition values when dataset=True |
115 | | - table = _add_table_partitions( |
116 | | - table=pa.Table.from_batches([batch], schema=schema), |
117 | | - path=f"s3://{piece.path}", |
118 | | - path_root=path_root, |
119 | | - ) |
120 | | - # If the table is empty, drop it. |
121 | | - if table.num_rows > 0: |
122 | | - output_buffer.add_block(table) |
123 | | - if output_buffer.has_next(): |
124 | | - yield output_buffer.next() |
125 | | - |
126 | | - output_buffer.finalize() |
127 | | - if output_buffer.has_next(): |
128 | | - yield output_buffer.next() |
129 | | - |
130 | | - if _block_udf is not None: |
131 | | - # Try to infer dataset schema by passing dummy table through UDF. |
132 | | - dummy_table = schema.empty_table() |
133 | | - try: |
134 | | - inferred_schema = _block_udf(dummy_table).schema |
135 | | - inferred_schema = inferred_schema.with_metadata(schema.metadata) |
136 | | - except Exception: # pylint: disable=broad-except |
137 | | - _logger.debug( |
138 | | - "Failed to infer schema of dataset by passing dummy table " |
139 | | - "through UDF due to the following exception:", |
140 | | - exc_info=True, |
141 | | - ) |
142 | | - inferred_schema = schema |
143 | | - else: |
144 | | - inferred_schema = schema |
145 | | - read_tasks = [] |
146 | | - metadata = meta_provider.prefetch_file_metadata(parquet_dataset.pieces) or [] |
147 | | - try: |
148 | | - _register_parquet_file_fragment_serialization() # type: ignore |
149 | | - for pieces, metadata in zip( # type: ignore |
150 | | - np.array_split(parquet_dataset.pieces, parallelism), |
151 | | - np.array_split(metadata, parallelism), |
152 | | - ): |
153 | | - if len(pieces) <= 0: |
154 | | - continue |
155 | | - serialized_pieces = cloudpickle.dumps(pieces) # type: ignore |
156 | | - input_files = [p.path for p in pieces] |
157 | | - meta = meta_provider( |
158 | | - input_files, |
159 | | - inferred_schema, |
160 | | - pieces=pieces, |
161 | | - prefetched_metadata=metadata, |
162 | | - ) |
163 | | - read_tasks.append(ReadTask(lambda p=serialized_pieces: read_pieces(p), meta)) # type: ignore |
164 | | - finally: |
165 | | - _deregister_parquet_file_fragment_serialization() # type: ignore |
166 | | - |
167 | | - return read_tasks |
| 120 | + def create_reader(self, **kwargs: Dict[str, Any]) -> Reader[Any]: |
| 121 | + """Return a Reader for the given read arguments.""" |
| 122 | + return _ParquetDatasourceReader(**kwargs) # type: ignore |
168 | 123 |
|
169 | 124 | # Original implementation: |
170 | 125 | # https://github.com/ray-project/ray/blob/releases/1.13.0/python/ray/data/datasource/file_based_datasource.py |
|
0 commit comments