From 63b6b854bca6eb4712911b88ce88e587bb90343a Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Fri, 15 Nov 2024 17:04:19 -0800 Subject: [PATCH] Skip rechunking if source and target chunks are the same This avoids an unnecessary shuffle. PiperOrigin-RevId: 697040647 --- setup.py | 2 +- xarray_beam/__init__.py | 2 +- xarray_beam/_src/rechunk.py | 6 ++++++ xarray_beam/_src/rechunk_test.py | 37 ++++++++++++++++++++------------ 4 files changed, 31 insertions(+), 16 deletions(-) diff --git a/setup.py b/setup.py index 01154be..823e6f3 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ setuptools.setup( name='xarray-beam', - version='0.6.3', + version='0.6.4', # keep in sync with __init__.py license='Apache 2.0', author='Google LLC', author_email='noreply@google.com', diff --git a/xarray_beam/__init__.py b/xarray_beam/__init__.py index 791fee1..b2452b9 100644 --- a/xarray_beam/__init__.py +++ b/xarray_beam/__init__.py @@ -51,4 +51,4 @@ DatasetToZarr, ) -__version__ = '0.6.3' +__version__ = '0.6.4' # keep in sync with setup.py diff --git a/xarray_beam/_src/rechunk.py b/xarray_beam/_src/rechunk.py index fcee0bc..cb2eaf2 100644 --- a/xarray_beam/_src/rechunk.py +++ b/xarray_beam/_src/rechunk.py @@ -547,6 +547,12 @@ def __init__( self.dim_sizes = dim_sizes self.source_chunks = normalize_chunks(source_chunks, dim_sizes) self.target_chunks = normalize_chunks(target_chunks, dim_sizes) + + if self.source_chunks == self.target_chunks: + self.stage_in = self.stage_out = [] + logging.info(f'Rechunk with chunks {self.source_chunks} is a no-op') + return + plan = rechunking_plan( dim_sizes, self.source_chunks, diff --git a/xarray_beam/_src/rechunk_test.py b/xarray_beam/_src/rechunk_test.py index 07ca684..6409301 100644 --- a/xarray_beam/_src/rechunk_test.py +++ b/xarray_beam/_src/rechunk_test.py @@ -63,7 +63,7 @@ def test_normalize_chunks_errors(self): def test_rechunking_plan(self): # this trivial case fits entirely into memory - plan, = rechunk.rechunking_plan( + (plan,) = rechunk.rechunking_plan( dim_sizes={'x': 10, 'y': 20}, source_chunks={'x': 1, 'y': 20}, target_chunks={'x': 10, 'y': 1}, @@ -75,7 +75,7 @@ def test_rechunking_plan(self): self.assertEqual(plan, expected) # this harder case doesn't - (read_chunks, _, write_chunks), = rechunk.rechunking_plan( + ((read_chunks, _, write_chunks),) = rechunk.rechunking_plan( dim_sizes={'t': 1000, 'x': 200, 'y': 300}, source_chunks={'t': 1, 'x': 200, 'y': 300}, target_chunks={'t': 1000, 'x': 20, 'y': 20}, @@ -361,15 +361,11 @@ def test_consolidate_with_unchunked_vars(self): ] with self.assertRaisesRegex( ValueError, - re.escape( - textwrap.dedent( - """ + re.escape(textwrap.dedent(""" combining nested dataset chunks for vars=None with offsets={'x': [0, 10]} failed. Leading datasets along dimension 'x': - """ - ).strip() - ), + """).strip()), ): inconsistent_inputs | xbeam.ConsolidateChunks({'x': -1}) @@ -449,14 +445,10 @@ def test_consolidate_variables_merge_fails(self): ] with self.assertRaisesRegex( ValueError, - re.escape( - textwrap.dedent( - """ + re.escape(textwrap.dedent(""" merging dataset chunks with variables [{'foo'}, {'bar'}] failed. - """ - ).strip() - ), + """).strip()), ): inputs | xbeam.ConsolidateVariables() @@ -816,6 +808,23 @@ def test_rechunk_inconsistent_dimensions(self): ) self.assertIdenticalChunks(actual, expected) + def test_rechunk_same_source_and_target_chunks(self): + rs = np.random.RandomState(0) + ds = xarray.Dataset({'foo': (('x', 'y'), rs.rand(2, 3))}) + p = test_util.EagerPipeline() + inputs = p | xbeam.DatasetToChunks(ds, {'x': 1}, split_vars=True) + rechunk_transform = xbeam.Rechunk( + dim_sizes=ds.sizes, + source_chunks={'x': 1}, + target_chunks={'x': 1}, + itemsize=8, + ) + # no rechunk stages + self.assertEqual(rechunk_transform.stage_in, []) + self.assertEqual(rechunk_transform.stage_out, []) + outputs = inputs | rechunk_transform + self.assertIdenticalChunks(outputs, inputs) + if __name__ == '__main__': absltest.main()