11from __future__ import annotations
22
3- from collections .abc import Iterable , Mapping , MutableMapping
3+ from collections .abc import Iterable , Mapping
44from dataclasses import dataclass , replace
55from enum import Enum
66from functools import lru_cache
1515 ArrayBytesCodecPartialDecodeMixin ,
1616 ArrayBytesCodecPartialEncodeMixin ,
1717 Codec ,
18- CodecPipeline ,
1918)
2019from zarr .abc .store import (
2120 ByteGetter ,
22- ByteRequest ,
2321 ByteSetter ,
2422 RangeByteRequest ,
2523 SuffixByteRequest ,
3533 numpy_buffer_prototype ,
3634)
3735from zarr .core .chunk_grids import ChunkGrid , RegularChunkGrid
36+ from zarr .core .codec_pipeline import ChunkTransform , fill_value_or_default
3837from zarr .core .common import (
3938 ShapeLike ,
4039 parse_enum ,
5453)
5554from zarr .core .metadata .v3 import parse_codecs
5655from zarr .registry import get_ndbuffer_class , get_pipeline_class
57- from zarr .storage ._utils import _normalize_byte_range_index
5856
5957if TYPE_CHECKING :
6058 from collections .abc import Iterator
6563
6664MAX_UINT_64 = 2 ** 64 - 1
6765ShardMapping = Mapping [tuple [int , ...], Buffer | None ]
68- ShardMutableMapping = MutableMapping [tuple [int , ...], Buffer | None ]
6966
7067
7168class ShardingCodecIndexLocation (Enum ):
@@ -81,41 +78,6 @@ def parse_index_location(data: object) -> ShardingCodecIndexLocation:
8178 return parse_enum (data , ShardingCodecIndexLocation )
8279
8380
84- @dataclass (frozen = True )
85- class _ShardingByteGetter (ByteGetter ):
86- shard_dict : ShardMapping
87- chunk_coords : tuple [int , ...]
88-
89- async def get (
90- self , prototype : BufferPrototype , byte_range : ByteRequest | None = None
91- ) -> Buffer | None :
92- assert prototype == default_buffer_prototype (), (
93- f"prototype is not supported within shards currently. diff: { prototype } != { default_buffer_prototype ()} "
94- )
95- value = self .shard_dict .get (self .chunk_coords )
96- if value is None :
97- return None
98- if byte_range is None :
99- return value
100- start , stop = _normalize_byte_range_index (value , byte_range )
101- return value [start :stop ]
102-
103-
104- @dataclass (frozen = True )
105- class _ShardingByteSetter (_ShardingByteGetter , ByteSetter ):
106- shard_dict : ShardMutableMapping
107-
108- async def set (self , value : Buffer , byte_range : ByteRequest | None = None ) -> None :
109- assert byte_range is None , "byte_range is not supported within shards"
110- self .shard_dict [self .chunk_coords ] = value
111-
112- async def delete (self ) -> None :
113- del self .shard_dict [self .chunk_coords ]
114-
115- async def set_if_not_exists (self , default : Buffer ) -> None :
116- self .shard_dict .setdefault (self .chunk_coords , default )
117-
118-
11981class _ShardIndex (NamedTuple ):
12082 # dtype uint64, shape (chunks_per_shard_0, chunks_per_shard_1, ..., 2)
12183 offsets_and_lengths : npt .NDArray [np .uint64 ]
@@ -354,9 +316,8 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
354316 _ , configuration_parsed = parse_named_configuration (data , "sharding_indexed" )
355317 return cls (** configuration_parsed ) # type: ignore[arg-type]
356318
357- @property
358- def codec_pipeline (self ) -> CodecPipeline :
359- return get_pipeline_class ().from_codecs (self .codecs )
319+ def _get_chunk_transform (self , chunk_spec : ArraySpec ) -> ChunkTransform :
320+ return ChunkTransform (codecs = self .codecs , array_spec = chunk_spec )
360321
361322 def to_dict (self ) -> dict [str , JSON ]:
362323 return {
@@ -430,20 +391,15 @@ async def _decode_single(
430391 out .fill (shard_spec .fill_value )
431392 return out
432393
433- # decoding chunks and writing them into the output buffer
434- await self .codec_pipeline .read (
435- [
436- (
437- _ShardingByteGetter (shard_dict , chunk_coords ),
438- chunk_spec ,
439- chunk_selection ,
440- out_selection ,
441- is_complete_shard ,
442- )
443- for chunk_coords , chunk_selection , out_selection , is_complete_shard in indexer
444- ],
445- out ,
446- )
394+ transform = self ._get_chunk_transform (chunk_spec )
395+ fill_value = fill_value_or_default (chunk_spec )
396+ for chunk_coords , chunk_selection , out_selection , _ in indexer :
397+ chunk_bytes = shard_dict .get (chunk_coords )
398+ if chunk_bytes is not None :
399+ chunk_array = await transform .decode_chunk_async (chunk_bytes )
400+ out [out_selection ] = chunk_array [chunk_selection ]
401+ else :
402+ out [out_selection ] = fill_value
447403
448404 return out
449405
@@ -502,20 +458,16 @@ async def _decode_partial_single(
502458 if chunk_bytes :
503459 shard_dict [chunk_coords ] = chunk_bytes
504460
505- # decoding chunks and writing them into the output buffer
506- await self .codec_pipeline .read (
507- [
508- (
509- _ShardingByteGetter (shard_dict , chunk_coords ),
510- chunk_spec ,
511- chunk_selection ,
512- out_selection ,
513- is_complete_shard ,
514- )
515- for chunk_coords , chunk_selection , out_selection , is_complete_shard in indexer
516- ],
517- out ,
518- )
461+ # decode chunks and write them into the output buffer
462+ transform = self ._get_chunk_transform (chunk_spec )
463+ fill_value = fill_value_or_default (chunk_spec )
464+ for chunk_coords , chunk_selection , out_selection , _ in indexed_chunks :
465+ chunk_bytes = shard_dict .get (chunk_coords )
466+ if chunk_bytes is not None :
467+ chunk_array = await transform .decode_chunk_async (chunk_bytes )
468+ out [out_selection ] = chunk_array [chunk_selection ]
469+ else :
470+ out [out_selection ] = fill_value
519471
520472 if hasattr (indexer , "sel_shape" ):
521473 return out .reshape (indexer .sel_shape )
@@ -532,29 +484,23 @@ async def _encode_single(
532484 chunks_per_shard = self ._get_chunks_per_shard (shard_spec )
533485 chunk_spec = self ._get_chunk_spec (shard_spec )
534486
535- indexer = list (
536- BasicIndexer (
537- tuple (slice (0 , s ) for s in shard_shape ),
538- shape = shard_shape ,
539- chunk_grid = RegularChunkGrid (chunk_shape = chunk_shape ),
540- )
487+ indexer = BasicIndexer (
488+ tuple (slice (0 , s ) for s in shard_shape ),
489+ shape = shard_shape ,
490+ chunk_grid = RegularChunkGrid (chunk_shape = chunk_shape ),
541491 )
542492
543- shard_builder = dict .fromkeys (morton_order_iter (chunks_per_shard ))
493+ transform = self ._get_chunk_transform (chunk_spec )
494+ fill_value = fill_value_or_default (chunk_spec )
495+ shard_builder : dict [tuple [int , ...], Buffer | None ] = {}
544496
545- await self .codec_pipeline .write (
546- [
547- (
548- _ShardingByteSetter (shard_builder , chunk_coords ),
549- chunk_spec ,
550- chunk_selection ,
551- out_selection ,
552- is_complete_shard ,
553- )
554- for chunk_coords , chunk_selection , out_selection , is_complete_shard in indexer
555- ],
556- shard_array ,
557- )
497+ for chunk_coords , _ , out_selection , _is_complete in indexer :
498+ chunk_array = shard_array [out_selection ]
499+ if not chunk_spec .config .write_empty_chunks and chunk_array .all_equal (fill_value ):
500+ continue
501+ encoded = await transform .encode_chunk_async (chunk_array )
502+ if encoded is not None :
503+ shard_builder [chunk_coords ] = encoded
558504
559505 return await self ._encode_shard_dict (
560506 shard_builder ,
@@ -581,27 +527,47 @@ async def _encode_partial_single(
581527 )
582528 shard_reader = shard_reader or _ShardReader .create_empty (chunks_per_shard )
583529 # Use vectorized lookup for better performance
584- shard_dict = shard_reader .to_dict_vectorized (np .asarray (_morton_order (chunks_per_shard )))
530+ shard_dict : dict [tuple [int , ...], Buffer | None ] = shard_reader .to_dict_vectorized (
531+ np .asarray (_morton_order (chunks_per_shard ))
532+ )
585533
586534 indexer = list (
587535 get_indexer (
588536 selection , shape = shard_shape , chunk_grid = RegularChunkGrid (chunk_shape = chunk_shape )
589537 )
590538 )
591539
592- await self .codec_pipeline .write (
593- [
594- (
595- _ShardingByteSetter (shard_dict , chunk_coords ),
596- chunk_spec ,
597- chunk_selection ,
598- out_selection ,
599- is_complete_shard ,
600- )
601- for chunk_coords , chunk_selection , out_selection , is_complete_shard in indexer
602- ],
603- shard_array ,
604- )
540+ transform = self ._get_chunk_transform (chunk_spec )
541+ fill_value = fill_value_or_default (chunk_spec )
542+
543+ is_scalar = len (shard_array .shape ) == 0
544+ for chunk_coords , chunk_selection , out_selection , is_complete_chunk in indexer :
545+ value = shard_array if is_scalar else shard_array [out_selection ]
546+ if is_complete_chunk and not is_scalar and value .shape == chunk_spec .shape :
547+ # Complete overwrite with matching shape — use value directly
548+ chunk_data = value
549+ else :
550+ # Read-modify-write: decode existing or create new, merge data
551+ if is_complete_chunk :
552+ existing_bytes = None
553+ else :
554+ existing_bytes = shard_dict .get (chunk_coords )
555+ if existing_bytes is not None :
556+ chunk_data = (await transform .decode_chunk_async (existing_bytes )).copy ()
557+ else :
558+ chunk_data = chunk_spec .prototype .nd_buffer .create (
559+ shape = chunk_spec .shape ,
560+ dtype = chunk_spec .dtype .to_native_dtype (),
561+ order = chunk_spec .order ,
562+ fill_value = fill_value ,
563+ )
564+ chunk_data [chunk_selection ] = value
565+
566+ if not chunk_spec .config .write_empty_chunks and chunk_data .all_equal (fill_value ):
567+ shard_dict [chunk_coords ] = None
568+ else :
569+ shard_dict [chunk_coords ] = await transform .encode_chunk_async (chunk_data )
570+
605571 buf = await self ._encode_shard_dict (
606572 shard_dict ,
607573 chunks_per_shard = chunks_per_shard ,
0 commit comments