1
+ from importlib import reload
2
+
1
3
import numpy as np
2
4
import pytest
3
5
import xarray as xr
4
6
5
7
from xbatcher import BatchGenerator
6
- from xbatcher .loaders .torch import IterableDataset , MapDataset
8
+ from xbatcher .loaders .torch import IterableDataset , MapDataset , to_tensor
7
9
8
10
torch = pytest .importorskip ('torch' )
9
11
10
12
11
- @pytest .fixture (scope = 'module' )
12
- def ds_xy ():
13
+ def test_import_torch_failure (monkeypatch ):
14
+ import sys
15
+
16
+ import xbatcher .loaders
17
+
18
+ monkeypatch .setitem (sys .modules , 'torch' , None )
19
+
20
+ with pytest .raises (ImportError ) as excinfo :
21
+ reload (xbatcher .loaders .torch )
22
+
23
+ assert 'install PyTorch to proceed' in str (excinfo .value )
24
+
25
+
26
+ def test_import_dask_failure (monkeypatch ):
27
+ import sys
28
+
29
+ import xbatcher .loaders
30
+
31
+ monkeypatch .setitem (sys .modules , 'dask' , None )
32
+ reload (xbatcher .loaders .torch )
33
+
34
+ assert xbatcher .loaders .torch .dask is None
35
+
36
+
37
+ @pytest .fixture (scope = 'module' , params = [True , False ])
38
+ def ds_xy (request ):
13
39
n_samples = 100
14
40
n_features = 5
15
41
ds = xr .Dataset (
@@ -21,17 +47,62 @@ def ds_xy():
21
47
'y' : (['sample' ], np .random .random (n_samples )),
22
48
},
23
49
)
50
+
51
+ if request .param :
52
+ ds = ds .chunk ({'sample' : 10 })
53
+
24
54
return ds
25
55
26
56
57
+ @pytest .mark .parametrize ('x_var' , ['x' , ['x' ]])
58
+ def test_map_dataset_without_y (ds_xy , x_var ) -> None :
59
+ x = ds_xy [x_var ]
60
+
61
+ x_gen = BatchGenerator (x , {'sample' : 10 })
62
+
63
+ dataset = MapDataset (x_gen )
64
+
65
+ # test __getitem__
66
+ x_batch = dataset [0 ]
67
+ assert x_batch .shape == (10 , 5 ) # type: ignore[union-attr]
68
+ assert isinstance (x_batch , torch .Tensor )
69
+
70
+ idx = torch .tensor ([0 ])
71
+ x_batch = dataset [idx ]
72
+ assert x_batch .shape == (10 , 5 )
73
+ assert isinstance (x_batch , torch .Tensor )
74
+
75
+ with pytest .raises (NotImplementedError ):
76
+ idx = torch .tensor ([0 , 1 ])
77
+ x_batch = dataset [idx ]
78
+
79
+ # test __len__
80
+ assert len (dataset ) == len (x_gen )
81
+
82
+ # test integration with torch DataLoader
83
+ loader = torch .utils .data .DataLoader (dataset , batch_size = None )
84
+
85
+ for x_batch in loader :
86
+ assert x_batch .shape == (10 , 5 ) # type: ignore[union-attr]
87
+ assert isinstance (x_batch , torch .Tensor )
88
+
89
+ # Check that array shape of last item in generator is same as the batch image
90
+ assert tuple (x_gen [- 1 ].sizes .values ()) == x_batch .shape # type: ignore[union-attr]
91
+ # Check that array values from last item in generator and batch are the same
92
+ gen_array = (
93
+ x_gen [- 1 ].to_array ().squeeze () if hasattr (x_gen [- 1 ], 'to_array' ) else x_gen [- 1 ]
94
+ )
95
+ np .testing .assert_array_equal (gen_array , x_batch ) # type: ignore
96
+
97
+
27
98
@pytest .mark .parametrize (
28
99
('x_var' , 'y_var' ),
29
100
[
30
101
('x' , 'y' ), # xr.DataArray
31
102
(['x' ], ['y' ]), # xr.Dataset
32
103
],
33
104
)
34
- def test_map_dataset (ds_xy , x_var , y_var ):
105
+ def test_map_dataset (ds_xy , x_var , y_var ) -> None :
35
106
x = ds_xy [x_var ]
36
107
y = ds_xy [y_var ]
37
108
@@ -73,7 +144,7 @@ def test_map_dataset(ds_xy, x_var, y_var):
73
144
gen_array = (
74
145
x_gen [- 1 ].to_array ().squeeze () if hasattr (x_gen [- 1 ], 'to_array' ) else x_gen [- 1 ]
75
146
)
76
- np .testing .assert_array_equal (gen_array , x_batch )
147
+ np .testing .assert_array_equal (gen_array , x_batch ) # type: ignore
77
148
78
149
79
150
@pytest .mark .parametrize (
@@ -83,18 +154,18 @@ def test_map_dataset(ds_xy, x_var, y_var):
83
154
(['x' ], ['y' ]), # xr.Dataset
84
155
],
85
156
)
86
- def test_map_dataset_with_transform (ds_xy , x_var , y_var ):
157
+ def test_map_dataset_with_transform (ds_xy , x_var , y_var ) -> None :
87
158
x = ds_xy [x_var ]
88
159
y = ds_xy [y_var ]
89
160
90
161
x_gen = BatchGenerator (x , {'sample' : 10 })
91
162
y_gen = BatchGenerator (y , {'sample' : 10 })
92
163
93
164
def x_transform (batch ):
94
- return batch * 0 + 1
165
+ return to_tensor ( batch * 0 + 1 )
95
166
96
167
def y_transform (batch ):
97
- return batch * 0 - 1
168
+ return to_tensor ( batch * 0 - 1 )
98
169
99
170
dataset = MapDataset (
100
171
x_gen , y_gen , transform = x_transform , target_transform = y_transform
0 commit comments