Skip to content

Commit

Permalink
Skip rechunking if source and target chunks are the same
Browse files Browse the repository at this point in the history
This avoids an unnecessary shuffle.

PiperOrigin-RevId: 697040647
  • Loading branch information
shoyer authored and Xarray-Beam authors committed Nov 18, 2024
1 parent 762228b commit 63b6b85
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 16 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

setuptools.setup(
name='xarray-beam',
version='0.6.3',
version='0.6.4', # keep in sync with __init__.py
license='Apache 2.0',
author='Google LLC',
author_email='[email protected]',
Expand Down
2 changes: 1 addition & 1 deletion xarray_beam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@
DatasetToZarr,
)

__version__ = '0.6.3'
__version__ = '0.6.4' # keep in sync with setup.py
6 changes: 6 additions & 0 deletions xarray_beam/_src/rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,12 @@ def __init__(
self.dim_sizes = dim_sizes
self.source_chunks = normalize_chunks(source_chunks, dim_sizes)
self.target_chunks = normalize_chunks(target_chunks, dim_sizes)

if self.source_chunks == self.target_chunks:
self.stage_in = self.stage_out = []
logging.info(f'Rechunk with chunks {self.source_chunks} is a no-op')
return

plan = rechunking_plan(
dim_sizes,
self.source_chunks,
Expand Down
37 changes: 23 additions & 14 deletions xarray_beam/_src/rechunk_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_normalize_chunks_errors(self):

def test_rechunking_plan(self):
# this trivial case fits entirely into memory
plan, = rechunk.rechunking_plan(
(plan,) = rechunk.rechunking_plan(
dim_sizes={'x': 10, 'y': 20},
source_chunks={'x': 1, 'y': 20},
target_chunks={'x': 10, 'y': 1},
Expand All @@ -75,7 +75,7 @@ def test_rechunking_plan(self):
self.assertEqual(plan, expected)

# this harder case doesn't
(read_chunks, _, write_chunks), = rechunk.rechunking_plan(
((read_chunks, _, write_chunks),) = rechunk.rechunking_plan(
dim_sizes={'t': 1000, 'x': 200, 'y': 300},
source_chunks={'t': 1, 'x': 200, 'y': 300},
target_chunks={'t': 1000, 'x': 20, 'y': 20},
Expand Down Expand Up @@ -361,15 +361,11 @@ def test_consolidate_with_unchunked_vars(self):
]
with self.assertRaisesRegex(
ValueError,
re.escape(
textwrap.dedent(
"""
re.escape(textwrap.dedent("""
combining nested dataset chunks for vars=None with offsets={'x': [0, 10]} failed.
Leading datasets along dimension 'x':
<xarray.Dataset>
"""
).strip()
),
""").strip()),
):
inconsistent_inputs | xbeam.ConsolidateChunks({'x': -1})

Expand Down Expand Up @@ -449,14 +445,10 @@ def test_consolidate_variables_merge_fails(self):
]
with self.assertRaisesRegex(
ValueError,
re.escape(
textwrap.dedent(
"""
re.escape(textwrap.dedent("""
merging dataset chunks with variables [{'foo'}, {'bar'}] failed.
<xarray.Dataset>
"""
).strip()
),
""").strip()),
):
inputs | xbeam.ConsolidateVariables()

Expand Down Expand Up @@ -816,6 +808,23 @@ def test_rechunk_inconsistent_dimensions(self):
)
self.assertIdenticalChunks(actual, expected)

def test_rechunk_same_source_and_target_chunks(self):
rs = np.random.RandomState(0)
ds = xarray.Dataset({'foo': (('x', 'y'), rs.rand(2, 3))})
p = test_util.EagerPipeline()
inputs = p | xbeam.DatasetToChunks(ds, {'x': 1}, split_vars=True)
rechunk_transform = xbeam.Rechunk(
dim_sizes=ds.sizes,
source_chunks={'x': 1},
target_chunks={'x': 1},
itemsize=8,
)
# no rechunk stages
self.assertEqual(rechunk_transform.stage_in, [])
self.assertEqual(rechunk_transform.stage_out, [])
outputs = inputs | rechunk_transform
self.assertIdenticalChunks(outputs, inputs)


if __name__ == '__main__':
absltest.main()

0 comments on commit 63b6b85

Please sign in to comment.