Skip to content

Commit dcd09a0

Browse files
jhammanandersy005pre-commit-ci[bot]
authored
Refactor Pytorch dataset (#202)
Co-authored-by: Anderson Banihirwe <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Anderson Banihirwe <[email protected]>
1 parent f86f47b commit dcd09a0

File tree

2 files changed

+128
-25
lines changed

2 files changed

+128
-25
lines changed

xbatcher/loaders/torch.py

+49-17
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1+
from __future__ import annotations
2+
13
from collections.abc import Callable
2-
from typing import Any
4+
from types import ModuleType
5+
6+
import xarray as xr
7+
8+
from xbatcher import BatchGenerator
39

410
try:
511
import torch
@@ -9,6 +15,13 @@
915
'install PyTorch to proceed.'
1016
) from exc
1117

18+
try:
19+
import dask
20+
except ImportError:
21+
dask: ModuleType | None = None # type: ignore[no-redef]
22+
23+
T_DataArrayOrSet = xr.DataArray | xr.Dataset
24+
1225
# Notes:
1326
# This module includes two PyTorch datasets.
1427
# - The MapDataset provides an indexable interface
@@ -20,13 +33,22 @@
2033
# - need to test with additional dataset parameters (e.g. transforms)
2134

2235

36+
def to_tensor(xr_obj: T_DataArrayOrSet) -> torch.Tensor:
37+
"""Convert this DataArray or Dataset to a torch.Tensor"""
38+
if isinstance(xr_obj, xr.Dataset):
39+
xr_obj = xr_obj.to_array().squeeze(dim='variable')
40+
if isinstance(xr_obj, xr.DataArray):
41+
xr_obj = xr_obj.data
42+
return torch.tensor(xr_obj)
43+
44+
2345
class MapDataset(torch.utils.data.Dataset):
2446
def __init__(
2547
self,
26-
X_generator,
27-
y_generator,
28-
transform: Callable | None = None,
29-
target_transform: Callable | None = None,
48+
X_generator: BatchGenerator,
49+
y_generator: BatchGenerator | None = None,
50+
transform: Callable[[T_DataArrayOrSet], torch.Tensor] = to_tensor,
51+
target_transform: Callable[[T_DataArrayOrSet], torch.Tensor] = to_tensor,
3052
) -> None:
3153
"""
3254
PyTorch Dataset adapter for Xbatcher
@@ -35,10 +57,8 @@ def __init__(
3557
----------
3658
X_generator : xbatcher.BatchGenerator
3759
y_generator : xbatcher.BatchGenerator
38-
transform : callable, optional
39-
A function/transform that takes in an array and returns a transformed version.
40-
target_transform : callable, optional
41-
A function/transform that takes in the target and transforms it.
60+
transform, target_transform : callable, optional
61+
A function/transform that takes in an Xarray object and returns a transformed version in the form of a torch.Tensor.
4262
"""
4363
self.X_generator = X_generator
4464
self.y_generator = y_generator
@@ -48,7 +68,7 @@ def __init__(
4868
def __len__(self) -> int:
4969
return len(self.X_generator)
5070

51-
def __getitem__(self, idx) -> tuple[Any, Any]:
71+
def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor:
5272
if torch.is_tensor(idx):
5373
idx = idx.tolist()
5474
if len(idx) == 1:
@@ -58,15 +78,27 @@ def __getitem__(self, idx) -> tuple[Any, Any]:
5878
f'{type(self).__name__}.__getitem__ currently requires a single integer key'
5979
)
6080

61-
X_batch = self.X_generator[idx].torch.to_tensor()
62-
y_batch = self.y_generator[idx].torch.to_tensor()
81+
# generate batch (or batches)
82+
if self.y_generator is not None:
83+
X_batch, y_batch = self.X_generator[idx], self.y_generator[idx]
84+
else:
85+
X_batch, y_batch = self.X_generator[idx], None
86+
87+
# load batch (or batches) with dask if possible
88+
if dask is not None:
89+
X_batch, y_batch = dask.compute(X_batch, y_batch)
90+
91+
# apply transformation(s)
92+
X_batch_tensor = self.transform(X_batch)
93+
if y_batch is not None:
94+
y_batch_tensor = self.target_transform(y_batch)
6395

64-
if self.transform:
65-
X_batch = self.transform(X_batch)
96+
assert isinstance(X_batch_tensor, torch.Tensor), self.transform
6697

67-
if self.target_transform:
68-
y_batch = self.target_transform(y_batch)
69-
return X_batch, y_batch
98+
if y_batch is None:
99+
return X_batch_tensor
100+
assert isinstance(y_batch_tensor, torch.Tensor)
101+
return X_batch_tensor, y_batch_tensor
70102

71103

72104
class IterableDataset(torch.utils.data.IterableDataset):

xbatcher/tests/test_torch_loaders.py

+79-8
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,41 @@
1+
from importlib import reload
2+
13
import numpy as np
24
import pytest
35
import xarray as xr
46

57
from xbatcher import BatchGenerator
6-
from xbatcher.loaders.torch import IterableDataset, MapDataset
8+
from xbatcher.loaders.torch import IterableDataset, MapDataset, to_tensor
79

810
torch = pytest.importorskip('torch')
911

1012

11-
@pytest.fixture(scope='module')
12-
def ds_xy():
13+
def test_import_torch_failure(monkeypatch):
14+
import sys
15+
16+
import xbatcher.loaders
17+
18+
monkeypatch.setitem(sys.modules, 'torch', None)
19+
20+
with pytest.raises(ImportError) as excinfo:
21+
reload(xbatcher.loaders.torch)
22+
23+
assert 'install PyTorch to proceed' in str(excinfo.value)
24+
25+
26+
def test_import_dask_failure(monkeypatch):
27+
import sys
28+
29+
import xbatcher.loaders
30+
31+
monkeypatch.setitem(sys.modules, 'dask', None)
32+
reload(xbatcher.loaders.torch)
33+
34+
assert xbatcher.loaders.torch.dask is None
35+
36+
37+
@pytest.fixture(scope='module', params=[True, False])
38+
def ds_xy(request):
1339
n_samples = 100
1440
n_features = 5
1541
ds = xr.Dataset(
@@ -21,17 +47,62 @@ def ds_xy():
2147
'y': (['sample'], np.random.random(n_samples)),
2248
},
2349
)
50+
51+
if request.param:
52+
ds = ds.chunk({'sample': 10})
53+
2454
return ds
2555

2656

57+
@pytest.mark.parametrize('x_var', ['x', ['x']])
58+
def test_map_dataset_without_y(ds_xy, x_var) -> None:
59+
x = ds_xy[x_var]
60+
61+
x_gen = BatchGenerator(x, {'sample': 10})
62+
63+
dataset = MapDataset(x_gen)
64+
65+
# test __getitem__
66+
x_batch = dataset[0]
67+
assert x_batch.shape == (10, 5) # type: ignore[union-attr]
68+
assert isinstance(x_batch, torch.Tensor)
69+
70+
idx = torch.tensor([0])
71+
x_batch = dataset[idx]
72+
assert x_batch.shape == (10, 5)
73+
assert isinstance(x_batch, torch.Tensor)
74+
75+
with pytest.raises(NotImplementedError):
76+
idx = torch.tensor([0, 1])
77+
x_batch = dataset[idx]
78+
79+
# test __len__
80+
assert len(dataset) == len(x_gen)
81+
82+
# test integration with torch DataLoader
83+
loader = torch.utils.data.DataLoader(dataset, batch_size=None)
84+
85+
for x_batch in loader:
86+
assert x_batch.shape == (10, 5) # type: ignore[union-attr]
87+
assert isinstance(x_batch, torch.Tensor)
88+
89+
# Check that array shape of last item in generator is same as the batch image
90+
assert tuple(x_gen[-1].sizes.values()) == x_batch.shape # type: ignore[union-attr]
91+
# Check that array values from last item in generator and batch are the same
92+
gen_array = (
93+
x_gen[-1].to_array().squeeze() if hasattr(x_gen[-1], 'to_array') else x_gen[-1]
94+
)
95+
np.testing.assert_array_equal(gen_array, x_batch) # type: ignore
96+
97+
2798
@pytest.mark.parametrize(
2899
('x_var', 'y_var'),
29100
[
30101
('x', 'y'), # xr.DataArray
31102
(['x'], ['y']), # xr.Dataset
32103
],
33104
)
34-
def test_map_dataset(ds_xy, x_var, y_var):
105+
def test_map_dataset(ds_xy, x_var, y_var) -> None:
35106
x = ds_xy[x_var]
36107
y = ds_xy[y_var]
37108

@@ -73,7 +144,7 @@ def test_map_dataset(ds_xy, x_var, y_var):
73144
gen_array = (
74145
x_gen[-1].to_array().squeeze() if hasattr(x_gen[-1], 'to_array') else x_gen[-1]
75146
)
76-
np.testing.assert_array_equal(gen_array, x_batch)
147+
np.testing.assert_array_equal(gen_array, x_batch) # type: ignore
77148

78149

79150
@pytest.mark.parametrize(
@@ -83,18 +154,18 @@ def test_map_dataset(ds_xy, x_var, y_var):
83154
(['x'], ['y']), # xr.Dataset
84155
],
85156
)
86-
def test_map_dataset_with_transform(ds_xy, x_var, y_var):
157+
def test_map_dataset_with_transform(ds_xy, x_var, y_var) -> None:
87158
x = ds_xy[x_var]
88159
y = ds_xy[y_var]
89160

90161
x_gen = BatchGenerator(x, {'sample': 10})
91162
y_gen = BatchGenerator(y, {'sample': 10})
92163

93164
def x_transform(batch):
94-
return batch * 0 + 1
165+
return to_tensor(batch * 0 + 1)
95166

96167
def y_transform(batch):
97-
return batch * 0 - 1
168+
return to_tensor(batch * 0 - 1)
98169

99170
dataset = MapDataset(
100171
x_gen, y_gen, transform=x_transform, target_transform=y_transform

0 commit comments

Comments
 (0)