From 415e0205892939564f6713755897b63ad1c1bdc8 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 3 Dec 2024 21:13:51 -0700 Subject: [PATCH 1/2] Add dask caching, avoid rechunk --- opendap_protocol/protocol.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/opendap_protocol/protocol.py b/opendap_protocol/protocol.py index 7cc9614..0a645d4 100644 --- a/opendap_protocol/protocol.py +++ b/opendap_protocol/protocol.py @@ -42,6 +42,7 @@ import re from dataclasses import dataclass +from dask.cache import Cache import dask.array as da import numpy as np @@ -53,6 +54,13 @@ @dataclass class Config: DASK_ENCODE_CHUNK_SIZE: int = 20e6 + DASK_CACHE_SIZE: int = 120 * 1024 * 1024 # 120MB + +# we load one `DASK_ENCODE_CHUNK_SIZE`-sized block of linearized data +# in to memory at one go. This may overlap with multiple dask chunks +# so lets cache those chunks since we might come back to them. +cache = Cache(Config.DASK_CACHE_SIZE) +cache.register() class DAPError(Exception): @@ -491,8 +499,9 @@ def dods_encode(data, dtype): if isinstance(data, da.Array): # Encode in chunks of a defined size if we work with dask.Array chunk_size = int(Config.DASK_ENCODE_CHUNK_SIZE / data.dtype.itemsize) - serialize_data = data.ravel().rechunk(chunk_size) - for block in serialize_data.blocks: + flat = data.ravel() + for start in range(0, data.size, chunk_size): + block = flat[slice(start, chunk_size)] yield block.astype(dtype.str).compute().tobytes() else: # Make sure we always encode an array or we will get wrong results From 370134e8390fa755f427b222a2523493a22c9ffd Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 4 Dec 2024 15:49:37 -0700 Subject: [PATCH 2/2] Always stream + support xarray Variables. --- opendap_protocol/protocol.py | 41 +++++++++++++++++++----------------- setup.py | 1 + tests/test_all.py | 11 +++++++--- 3 files changed, 31 insertions(+), 22 deletions(-) diff --git a/opendap_protocol/protocol.py b/opendap_protocol/protocol.py index 0a645d4..657395c 100644 --- a/opendap_protocol/protocol.py +++ b/opendap_protocol/protocol.py @@ -40,6 +40,7 @@ clients using the netCDF4 library. PyDAP client libraries are not supported. """ +import importlib import re from dataclasses import dataclass from dask.cache import Cache @@ -47,20 +48,18 @@ import dask.array as da import numpy as np +has_xarray = bool(importlib.util.find_spec("xarray")) + +if has_xarray: + from xarray import Variable + INDENT = ' ' SLICE_CONSTRAINT_RE = r'\[([\d,\W]+)\]$' @dataclass class Config: - DASK_ENCODE_CHUNK_SIZE: int = 20e6 - DASK_CACHE_SIZE: int = 120 * 1024 * 1024 # 120MB - -# we load one `DASK_ENCODE_CHUNK_SIZE`-sized block of linearized data -# in to memory at one go. This may overlap with multiple dask chunks -# so lets cache those chunks since we might come back to them. -cache = Cache(Config.DASK_CACHE_SIZE) -cache.register() + STREAMING_BLOCK_SIZE: int = 20e6 class DAPError(Exception): @@ -496,17 +495,21 @@ def dods_encode(data, dtype): yield packed_length + chunk_size = int(Config.STREAMING_BLOCK_SIZE / data.dtype.itemsize) if isinstance(data, da.Array): - # Encode in chunks of a defined size if we work with dask.Array - chunk_size = int(Config.DASK_ENCODE_CHUNK_SIZE / data.dtype.itemsize) - flat = data.ravel() - for start in range(0, data.size, chunk_size): - block = flat[slice(start, chunk_size)] - yield block.astype(dtype.str).compute().tobytes() - else: - # Make sure we always encode an array or we will get wrong results - data = np.asarray(data) - yield data.astype(dtype.str).tobytes() + data = data.ravel() + + for start in range(0, data.size, chunk_size): + end = start + chunk_size + if isinstance(data, da.Array): + block = data[slice(start, end)].compute() + elif has_xarray and isinstance(data, Variable): + npidxr = np.unravel_index(np.arange(start, min(end, data.size)), shape=data.shape) + xridxr = tuple(Variable(dims="__points__", data=idxr) for idxr in npidxr) + block = data[xridxr].to_numpy() + else: + block = np.asarray(data).ravel()[slice(start, end)] + yield block.astype(dtype.str).tobytes() def parse_slice_constraint(constraint): @@ -571,6 +574,6 @@ def set_dask_encoding_chunk_size(chunk_size: int): """ chunk_size = int(chunk_size) if chunk_size > 0: - Config.DASK_ENCODE_CHUNK_SIZE = chunk_size + Config.STREAMING_BLOCK_SIZE = chunk_size else: raise ValueError('Encoding chunk size needs to be greather than 0.') diff --git a/setup.py b/setup.py index 7105303..2585ff5 100644 --- a/setup.py +++ b/setup.py @@ -43,6 +43,7 @@ test_requirements = [ 'pytest', + 'xarray', ] extras = { diff --git a/tests/test_all.py b/tests/test_all.py index 726ad03..f4ac1a7 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -32,6 +32,7 @@ import dask.array as da import numpy as np import opendap_protocol as dap +import xarray as xr import pytest from opendap_protocol.protocol import dods_encode @@ -68,9 +69,13 @@ def test_dods_encode(): data_vals = da.from_array(np_data, chunks=(14, y_dim, 1, vertical_dim, 1, 1)) + variable = xr.Variable(dims=("x", "y", "time", "vertical", "real", "ref_time"), + data=np_data) + x = dap.dods_encode(data_vals, dap.Int32) y = dap.dods_encode(np_data, dap.Int32) - assert b''.join(x) == b''.join(y) + z = dap.dods_encode(variable, dap.Int32) + assert b''.join(x) == b''.join(y) == b''.join(z) int_arrdata = np.arange(0, 20, 2, dtype='