Skip to content

Commit

Permalink
Initial support for multiple datasets in DatasetToChunks
Browse files Browse the repository at this point in the history
Here is an initial implementation of google#68.
  • Loading branch information
alxmrs committed Jan 25, 2023
1 parent 95ec55d commit 90aa3db
Showing 1 changed file with 87 additions and 47 deletions.
134 changes: 87 additions & 47 deletions xarray_beam/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@
import itertools
import math
from typing import (
AbstractSet,
Dict,
List,
Iterator,
Optional,
Mapping,
Sequence,
Tuple,
Union,
AbstractSet,
Dict,
List,
Iterator,
Optional,
Mapping,
Sequence,
Tuple,
TypeVar,
Union,
)

import apache_beam as beam
Expand Down Expand Up @@ -255,12 +256,26 @@ def normalize_expanded_chunks(
return result


def _all_equal(iterator):
"""Check if all values in a collection are equal."""
iterator = iter(iterator)
try:
first = next(iterator)
except StopIteration:
return True
return all(first == x for x in iterator)


T = TypeVar('T')
AtLeastOne = Union[T, Tuple[T]]


class DatasetToChunks(beam.PTransform):
"""Split an xarray.Dataset into keyed chunks."""
"""Split one or more xarray.Datasets into keyed chunks."""

def __init__(
self,
dataset: xarray.Dataset,
dataset: Union[xarray.Dataset, Tuple[xarray.Dataset]],
chunks: Optional[Mapping[str, Union[int, Tuple[int, ...]]]] = None,
split_vars: bool = False,
num_threads: Optional[int] = None,
Expand All @@ -269,13 +284,14 @@ def __init__(
"""Initialize DatasetToChunks.
Args:
dataset: dataset to split into (Key, xarray.Dataset) pairs.
dataset: dataset or datasets to split into (Key, xarray.Dataset) or
(Key, (xarray.Dataset, ...)) pairs.
chunks: optional chunking scheme. Required if the dataset is *not* already
chunked. If the dataset *is* already chunked with Dask, `chunks` takes
precedence over the existing chunks.
split_vars: whether to split the dataset into separate records for each
data variable or to keep all data variables together. This is
recommended if you don't need perform joint operations on different
recommended if you don't need to perform joint operations on different
dataset variables and individual variable chunks are sufficiently large.
num_threads: optional number of Dataset chunks to load in parallel per
worker. More threads can increase throughput, but also increases memory
Expand All @@ -287,18 +303,30 @@ def __init__(
rather than only on the host process. This is important for scaling
pipelines to millions of tasks.
"""
if type(dataset) is not tuple:
dataset = (dataset,)
elif not dataset:
raise ValueError('dataset tuple cannot be empty!')
if not _all_equal(ds.sizes for ds in dataset):
raise ValueError('all datasets must be the same size')
if split_vars and not _all_equal([(k, v.shape) for k, v in ds.items()]
for ds in dataset):
raise ValueError('when splitting variables, all datasets must have '
'the same data variables with equivalent shapes.')
if chunks is None:
chunks = dataset.chunks
if not _all_equal(ds.chunks for ds in dataset):
raise ValueError('all datasets must have the same chunks or chunks must be provided')
chunks = dataset[0].chunks
if chunks is None:
raise ValueError('dataset must be chunked or chunks must be provided')
expanded_chunks = normalize_expanded_chunks(chunks, dataset.sizes)
expanded_chunks = normalize_expanded_chunks(chunks, dataset[0].sizes)
self.dataset = dataset
self.expanded_chunks = expanded_chunks
self.split_vars = split_vars
self.num_threads = num_threads
self.shard_keys_threshold = shard_keys_threshold
# TODO(shoyer): consider recalculating these potentially large properties on
# each worker, rather than only once on the host.
# each worker, rather than only once on the host.
self.offsets = _chunks_to_offsets(expanded_chunks)
self.offset_index = compute_offset_index(self.offsets)
# We use the simple heuristic of only sharding inputs along the dimension
Expand All @@ -313,7 +341,7 @@ def _task_count(self) -> int:
if not self.split_vars:
return int(np.prod(list(counts.values())))
total = 0
for variable in self.dataset.values():
for variable in self._first.values():
count_list = [v for k, v in counts.items() if k in variable.dims]
total += int(np.prod(count_list))
return total
Expand All @@ -328,7 +356,7 @@ def _shard_count(self) -> Optional[int]:
return math.ceil(task_count / self.shard_keys_threshold)

var_count = sum(
self.sharded_dim in var.dims for var in self.dataset.values()
self.sharded_dim in var.dims for var in self._first.values()
)
return math.ceil(task_count / (var_count * self.shard_keys_threshold))

Expand All @@ -337,7 +365,7 @@ def _iter_all_keys(self) -> Iterator[Key]:
if not self.split_vars:
yield from iter_chunk_keys(self.offsets)
else:
for name, variable in self.dataset.items():
for name, variable in self._first.items():
relevant_offsets = {
k: v for k, v in self.offsets.items() if k in variable.dims
}
Expand All @@ -350,7 +378,7 @@ def _iter_shard_keys(
if var_name is None:
offsets = self.offsets
else:
offsets = {dim: self.offsets[dim] for dim in self.dataset[var_name].dims}
offsets = {dim: self.offsets[dim] for dim in self._first[var_name].dims}

if shard_id is None:
assert self.split_vars
Expand All @@ -370,26 +398,36 @@ def _shard_inputs(self) -> List[Tuple[Optional[int], Optional[str]]]:
return [(i, None) for i in range(self.shard_count)]

inputs = []
for name, variable in self.dataset.items():
for name, variable in self._first.items():
if self.sharded_dim in variable.dims:
inputs.extend([(i, name) for i in range(self.shard_count)])
else:
inputs.append((None, name))
return inputs

def _key_to_chunks(self, key: Key) -> Iterator[Tuple[Key, xarray.Dataset]]:
def _key_to_chunks(self, key: Key) -> Iterator[Tuple[Key, Union[xarray.Dataset, Tuple[xarray.Dataset, ...]]]]:
"""Convert a Key into an in-memory (Key, xarray.Dataset) pair."""
sizes = {
dim: self.expanded_chunks[dim][self.offset_index[dim][offset]]
for dim, offset in key.offsets.items()
}
slices = offsets_to_slices(key.offsets, sizes)
dataset = self.dataset if key.vars is None else self.dataset[list(key.vars)]
chunk = dataset.isel(slices)
# Load the data, using a separate thread for each variable
num_threads = len(self.dataset)
result = chunk.chunk().compute(num_workers=num_threads)
yield key, result
results = []
for ds in self.dataset:
dataset = ds if key.vars is None else ds[list(key.vars)]
chunk = dataset.isel(slices)
# Load the data, using a separate thread for each variable
num_threads = len(self.dataset)
result = chunk.chunk().compute(num_workers=num_threads)
results.append(result)
if len(results) == 1:
yield key, results[0]
else:
yield key, tuple(*results)

@property
def _first(self) -> xarray.Dataset:
return self.dataset[0]

def expand(self, pcoll):
if self.shard_count is None:
Expand All @@ -410,32 +448,34 @@ def expand(self, pcoll):
)


def validate_chunk(key: Key, dataset: xarray.Dataset) -> None:
"""Verify that keys correpond to Dataset properties."""
missing_keys = [repr(k) for k in key.offsets.keys() if k not in dataset.dims]
if missing_keys:
raise ValueError(
f"Key offset(s) {', '.join(missing_keys)} in {key} not found in Dataset"
f' dimensions: {dataset!r}'
)
def validate_chunk(key: Key, *datasets: xarray.Dataset) -> None:
"""Verify that keys correspond to Dataset properties."""
for dataset in datasets:
missing_keys = [repr(k) for k in key.offsets.keys() if k not in dataset.dims]
if missing_keys:
raise ValueError(
f"Key offset(s) {', '.join(missing_keys)} in {key} not found in Dataset"
f' dimensions: {dataset!r}'
)

if key.vars is None:
return
missing_vars = [repr(v) for v in key.vars if v not in dataset.data_vars]
if missing_vars:
raise ValueError(
f"Key var(s) {', '.join(missing_vars)} in {key} not found in Dataset"
f' data variables: {dataset!r}'
)
if key.vars is None:
return

missing_vars = [repr(v) for v in key.vars if v not in dataset.data_vars]
if missing_vars:
raise ValueError(
f"Key var(s) {', '.join(missing_vars)} in {key} not found in Dataset"
f' data variables: {dataset!r}'
)


class ValidateEachChunk(beam.PTransform):
"""Check that keys match the dataset for each key, dataset tuple."""

def _validate(self, key, dataset):
def _validate(self, key, *dataset):
# Other checks may come later...
validate_chunk(key, dataset)
return key, dataset
validate_chunk(key, *dataset)
return key, *dataset

def expand(self, pcoll):
return pcoll | beam.MapTuple(self._validate)

0 comments on commit 90aa3db

Please sign in to comment.