Skip to content

Commit 4bdded0

Browse files
committed
wire chunktransform up to sharding
1 parent b387aeb commit 4bdded0

File tree

3 files changed

+116
-121
lines changed

3 files changed

+116
-121
lines changed

src/zarr/abc/codec.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -245,22 +245,6 @@ class PreparedWrite:
245245
class ArrayBytesCodec(BaseCodec[NDBuffer, Buffer]):
246246
"""Base class for array-to-bytes codecs."""
247247

248-
@property
249-
def inner_codec_chain(self) -> SupportsChunkCodec | None:
250-
"""The codec chain for decoding inner chunks after deserialization.
251-
252-
Returns ``None`` by default, meaning the pipeline should use its own
253-
codec chain. ``ShardingCodec`` overrides this to return its inner
254-
codec chain.
255-
256-
Returns
257-
-------
258-
SupportsChunkCodec or None
259-
A [`SupportsChunkCodec`][zarr.abc.codec.SupportsChunkCodec] instance,
260-
or ``None``.
261-
"""
262-
return None
263-
264248
def deserialize(
265249
self, raw: Buffer | None, chunk_spec: ArraySpec
266250
) -> dict[tuple[int, ...], Buffer | None]:

src/zarr/codecs/sharding.py

Lines changed: 71 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from collections.abc import Iterable, Mapping, MutableMapping
3+
from collections.abc import Iterable, Mapping
44
from dataclasses import dataclass, replace
55
from enum import Enum
66
from functools import lru_cache
@@ -15,11 +15,9 @@
1515
ArrayBytesCodecPartialDecodeMixin,
1616
ArrayBytesCodecPartialEncodeMixin,
1717
Codec,
18-
CodecPipeline,
1918
)
2019
from zarr.abc.store import (
2120
ByteGetter,
22-
ByteRequest,
2321
ByteSetter,
2422
RangeByteRequest,
2523
SuffixByteRequest,
@@ -35,6 +33,7 @@
3533
numpy_buffer_prototype,
3634
)
3735
from zarr.core.chunk_grids import ChunkGrid, RegularChunkGrid
36+
from zarr.core.codec_pipeline import ChunkTransform, fill_value_or_default
3837
from zarr.core.common import (
3938
ShapeLike,
4039
parse_enum,
@@ -54,7 +53,6 @@
5453
)
5554
from zarr.core.metadata.v3 import parse_codecs
5655
from zarr.registry import get_ndbuffer_class, get_pipeline_class
57-
from zarr.storage._utils import _normalize_byte_range_index
5856

5957
if TYPE_CHECKING:
6058
from collections.abc import Iterator
@@ -65,7 +63,6 @@
6563

6664
MAX_UINT_64 = 2**64 - 1
6765
ShardMapping = Mapping[tuple[int, ...], Buffer | None]
68-
ShardMutableMapping = MutableMapping[tuple[int, ...], Buffer | None]
6966

7067

7168
class ShardingCodecIndexLocation(Enum):
@@ -81,41 +78,6 @@ def parse_index_location(data: object) -> ShardingCodecIndexLocation:
8178
return parse_enum(data, ShardingCodecIndexLocation)
8279

8380

84-
@dataclass(frozen=True)
85-
class _ShardingByteGetter(ByteGetter):
86-
shard_dict: ShardMapping
87-
chunk_coords: tuple[int, ...]
88-
89-
async def get(
90-
self, prototype: BufferPrototype, byte_range: ByteRequest | None = None
91-
) -> Buffer | None:
92-
assert prototype == default_buffer_prototype(), (
93-
f"prototype is not supported within shards currently. diff: {prototype} != {default_buffer_prototype()}"
94-
)
95-
value = self.shard_dict.get(self.chunk_coords)
96-
if value is None:
97-
return None
98-
if byte_range is None:
99-
return value
100-
start, stop = _normalize_byte_range_index(value, byte_range)
101-
return value[start:stop]
102-
103-
104-
@dataclass(frozen=True)
105-
class _ShardingByteSetter(_ShardingByteGetter, ByteSetter):
106-
shard_dict: ShardMutableMapping
107-
108-
async def set(self, value: Buffer, byte_range: ByteRequest | None = None) -> None:
109-
assert byte_range is None, "byte_range is not supported within shards"
110-
self.shard_dict[self.chunk_coords] = value
111-
112-
async def delete(self) -> None:
113-
del self.shard_dict[self.chunk_coords]
114-
115-
async def set_if_not_exists(self, default: Buffer) -> None:
116-
self.shard_dict.setdefault(self.chunk_coords, default)
117-
118-
11981
class _ShardIndex(NamedTuple):
12082
# dtype uint64, shape (chunks_per_shard_0, chunks_per_shard_1, ..., 2)
12183
offsets_and_lengths: npt.NDArray[np.uint64]
@@ -354,9 +316,8 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
354316
_, configuration_parsed = parse_named_configuration(data, "sharding_indexed")
355317
return cls(**configuration_parsed) # type: ignore[arg-type]
356318

357-
@property
358-
def codec_pipeline(self) -> CodecPipeline:
359-
return get_pipeline_class().from_codecs(self.codecs)
319+
def _get_chunk_transform(self, chunk_spec: ArraySpec) -> ChunkTransform:
320+
return ChunkTransform(codecs=self.codecs, array_spec=chunk_spec)
360321

361322
def to_dict(self) -> dict[str, JSON]:
362323
return {
@@ -430,20 +391,15 @@ async def _decode_single(
430391
out.fill(shard_spec.fill_value)
431392
return out
432393

433-
# decoding chunks and writing them into the output buffer
434-
await self.codec_pipeline.read(
435-
[
436-
(
437-
_ShardingByteGetter(shard_dict, chunk_coords),
438-
chunk_spec,
439-
chunk_selection,
440-
out_selection,
441-
is_complete_shard,
442-
)
443-
for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer
444-
],
445-
out,
446-
)
394+
transform = self._get_chunk_transform(chunk_spec)
395+
fill_value = fill_value_or_default(chunk_spec)
396+
for chunk_coords, chunk_selection, out_selection, _ in indexer:
397+
chunk_bytes = shard_dict.get(chunk_coords)
398+
if chunk_bytes is not None:
399+
chunk_array = await transform.decode_chunk_async(chunk_bytes)
400+
out[out_selection] = chunk_array[chunk_selection]
401+
else:
402+
out[out_selection] = fill_value
447403

448404
return out
449405

@@ -502,20 +458,16 @@ async def _decode_partial_single(
502458
if chunk_bytes:
503459
shard_dict[chunk_coords] = chunk_bytes
504460

505-
# decoding chunks and writing them into the output buffer
506-
await self.codec_pipeline.read(
507-
[
508-
(
509-
_ShardingByteGetter(shard_dict, chunk_coords),
510-
chunk_spec,
511-
chunk_selection,
512-
out_selection,
513-
is_complete_shard,
514-
)
515-
for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer
516-
],
517-
out,
518-
)
461+
# decode chunks and write them into the output buffer
462+
transform = self._get_chunk_transform(chunk_spec)
463+
fill_value = fill_value_or_default(chunk_spec)
464+
for chunk_coords, chunk_selection, out_selection, _ in indexed_chunks:
465+
chunk_bytes = shard_dict.get(chunk_coords)
466+
if chunk_bytes is not None:
467+
chunk_array = await transform.decode_chunk_async(chunk_bytes)
468+
out[out_selection] = chunk_array[chunk_selection]
469+
else:
470+
out[out_selection] = fill_value
519471

520472
if hasattr(indexer, "sel_shape"):
521473
return out.reshape(indexer.sel_shape)
@@ -532,29 +484,23 @@ async def _encode_single(
532484
chunks_per_shard = self._get_chunks_per_shard(shard_spec)
533485
chunk_spec = self._get_chunk_spec(shard_spec)
534486

535-
indexer = list(
536-
BasicIndexer(
537-
tuple(slice(0, s) for s in shard_shape),
538-
shape=shard_shape,
539-
chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape),
540-
)
487+
indexer = BasicIndexer(
488+
tuple(slice(0, s) for s in shard_shape),
489+
shape=shard_shape,
490+
chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape),
541491
)
542492

543-
shard_builder = dict.fromkeys(morton_order_iter(chunks_per_shard))
493+
transform = self._get_chunk_transform(chunk_spec)
494+
fill_value = fill_value_or_default(chunk_spec)
495+
shard_builder: dict[tuple[int, ...], Buffer | None] = {}
544496

545-
await self.codec_pipeline.write(
546-
[
547-
(
548-
_ShardingByteSetter(shard_builder, chunk_coords),
549-
chunk_spec,
550-
chunk_selection,
551-
out_selection,
552-
is_complete_shard,
553-
)
554-
for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer
555-
],
556-
shard_array,
557-
)
497+
for chunk_coords, _, out_selection, _is_complete in indexer:
498+
chunk_array = shard_array[out_selection]
499+
if not chunk_spec.config.write_empty_chunks and chunk_array.all_equal(fill_value):
500+
continue
501+
encoded = await transform.encode_chunk_async(chunk_array)
502+
if encoded is not None:
503+
shard_builder[chunk_coords] = encoded
558504

559505
return await self._encode_shard_dict(
560506
shard_builder,
@@ -581,27 +527,47 @@ async def _encode_partial_single(
581527
)
582528
shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard)
583529
# Use vectorized lookup for better performance
584-
shard_dict = shard_reader.to_dict_vectorized(np.asarray(_morton_order(chunks_per_shard)))
530+
shard_dict: dict[tuple[int, ...], Buffer | None] = shard_reader.to_dict_vectorized(
531+
np.asarray(_morton_order(chunks_per_shard))
532+
)
585533

586534
indexer = list(
587535
get_indexer(
588536
selection, shape=shard_shape, chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape)
589537
)
590538
)
591539

592-
await self.codec_pipeline.write(
593-
[
594-
(
595-
_ShardingByteSetter(shard_dict, chunk_coords),
596-
chunk_spec,
597-
chunk_selection,
598-
out_selection,
599-
is_complete_shard,
600-
)
601-
for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer
602-
],
603-
shard_array,
604-
)
540+
transform = self._get_chunk_transform(chunk_spec)
541+
fill_value = fill_value_or_default(chunk_spec)
542+
543+
is_scalar = len(shard_array.shape) == 0
544+
for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer:
545+
value = shard_array if is_scalar else shard_array[out_selection]
546+
if is_complete_chunk and not is_scalar and value.shape == chunk_spec.shape:
547+
# Complete overwrite with matching shape — use value directly
548+
chunk_data = value
549+
else:
550+
# Read-modify-write: decode existing or create new, merge data
551+
if is_complete_chunk:
552+
existing_bytes = None
553+
else:
554+
existing_bytes = shard_dict.get(chunk_coords)
555+
if existing_bytes is not None:
556+
chunk_data = (await transform.decode_chunk_async(existing_bytes)).copy()
557+
else:
558+
chunk_data = chunk_spec.prototype.nd_buffer.create(
559+
shape=chunk_spec.shape,
560+
dtype=chunk_spec.dtype.to_native_dtype(),
561+
order=chunk_spec.order,
562+
fill_value=fill_value,
563+
)
564+
chunk_data[chunk_selection] = value
565+
566+
if not chunk_spec.config.write_empty_chunks and chunk_data.all_equal(fill_value):
567+
shard_dict[chunk_coords] = None
568+
else:
569+
shard_dict[chunk_coords] = await transform.encode_chunk_async(chunk_data)
570+
605571
buf = await self._encode_shard_dict(
606572
shard_dict,
607573
chunks_per_shard=chunks_per_shard,

src/zarr/core/codec_pipeline.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,51 @@ def encode_chunk(
168168

169169
return bb_out # type: ignore[no-any-return]
170170

171+
async def decode_chunk_async(
172+
self,
173+
chunk_bytes: Buffer,
174+
) -> NDBuffer:
175+
"""Decode a single chunk through the full codec chain, asynchronously.
176+
177+
Needed when the codec chain contains async-only codecs (e.g. nested sharding).
178+
"""
179+
bb_out: Any = chunk_bytes
180+
for bb_codec in reversed(self._bb_codecs):
181+
bb_out = await bb_codec._decode_single(bb_out, self._ab_spec)
182+
183+
ab_out: Any = await self._ab_codec._decode_single(bb_out, self._ab_spec)
184+
185+
for aa_codec, spec in reversed(self.layers):
186+
ab_out = await aa_codec._decode_single(ab_out, spec)
187+
188+
return ab_out # type: ignore[no-any-return]
189+
190+
async def encode_chunk_async(
191+
self,
192+
chunk_array: NDBuffer,
193+
) -> Buffer | None:
194+
"""Encode a single chunk through the full codec chain, asynchronously.
195+
196+
Needed when the codec chain contains async-only codecs (e.g. nested sharding).
197+
"""
198+
aa_out: Any = chunk_array
199+
200+
for aa_codec, spec in self.layers:
201+
if aa_out is None:
202+
return None
203+
aa_out = await aa_codec._encode_single(aa_out, spec)
204+
205+
if aa_out is None:
206+
return None
207+
bb_out: Any = await self._ab_codec._encode_single(aa_out, self._ab_spec)
208+
209+
for bb_codec in self._bb_codecs:
210+
if bb_out is None:
211+
return None
212+
bb_out = await bb_codec._encode_single(bb_out, self._ab_spec)
213+
214+
return bb_out # type: ignore[no-any-return]
215+
171216
def compute_encoded_size(self, byte_length: int, array_spec: ArraySpec) -> int:
172217
for codec in self.codecs:
173218
byte_length = codec.compute_encoded_size(byte_length, array_spec)

0 commit comments

Comments
 (0)