Skip to content

Commit 3db6470

Browse files
committed
Added comments and added a test for level validation
1 parent cb85ed0 commit 3db6470

File tree

3 files changed

+84
-35
lines changed

3 files changed

+84
-35
lines changed

reproject/hips/_dask_array.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,19 @@
1+
import functools
12
import os
2-
import struct
33
import urllib
44
import uuid
5-
import functools
65

76
import numpy as np
8-
from astropy import units as u
97
from astropy.io import fits
8+
from astropy.utils.data import download_file
109
from astropy.wcs import WCS
1110
from astropy_healpix import HEALPix, level_to_nside
1211
from dask import array as da
13-
from astropy.utils.data import download_file
14-
from astropy.wcs.utils import celestial_frame_to_wcs
1512

16-
from .utils import is_url, load_properties, tile_filename, tile_header, map_header
1713
from .high_level import VALID_COORD_SYSTEM
14+
from .utils import is_url, load_properties, map_header, tile_filename
1815

19-
__all__ = ['hips_as_dask_and_wcs']
16+
__all__ = ["hips_as_dask_and_wcs"]
2017

2118

2219
class HiPSArray:
@@ -31,6 +28,17 @@ def __init__(self, directory_or_url, level=None):
3128

3229
self._tile_width = int(self._properties["hips_tile_width"])
3330
self._order = int(self._properties["hips_order"])
31+
if level is None:
32+
self._level = self._order
33+
else:
34+
if level > self._order:
35+
raise ValueError(
36+
f"HiPS dataset at {directory_or_url} does not contain level {level} data"
37+
)
38+
elif level < 0:
39+
raise ValueError("level should be positive")
40+
else:
41+
self._level = int(level)
3442
self._level = self._order if level is None else level
3543
self._tile_format = self._properties["hips_tile_format"]
3644
self._frame_str = self._properties["hips_frame"]
@@ -81,8 +89,12 @@ def __getitem__(self, item):
8189
else:
8290
raise NotImplementedError()
8391

84-
if np.all(np.isnan(lon) | np.isnan(lat)):
92+
invalid = np.isnan(lon) | np.isnan(lat)
93+
94+
if np.all(invalid):
8595
return self._nan
96+
elif np.any(invalid):
97+
coord = coord[~invalid]
8698

8799
index = self._hp.skycoord_to_healpix(coord)
88100

@@ -125,9 +137,12 @@ def _get_tile(self, *, level, index):
125137

126138
def hips_as_dask_and_wcs(directory_or_url, *, level=None):
127139
array_wrapper = HiPSArray(directory_or_url, level=level)
128-
return da.from_array(
129-
array_wrapper,
130-
chunks=array_wrapper.chunksize,
131-
name=str(uuid.uuid4()),
132-
meta=np.array([], dtype=float)
133-
), array_wrapper.wcs
140+
return (
141+
da.from_array(
142+
array_wrapper,
143+
chunks=array_wrapper.chunksize,
144+
name=str(uuid.uuid4()),
145+
meta=np.array([], dtype=float),
146+
),
147+
array_wrapper.wcs,
148+
)

reproject/hips/high_level.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -253,10 +253,11 @@ def process(index):
253253
array_out2, footprint2 = reproject_function(
254254
(array_in, wcs_in_copy), header[1], **kwargs
255255
)
256-
array_out = (
257-
np.nan_to_num(array_out1) * footprint1 + np.nan_to_num(array_out2) * footprint2
258-
) / (footprint1 + footprint2)
259-
footprint = (footprint1 + footprint2) / 2
256+
with np.errstate(invalid="ignore"):
257+
array_out = (
258+
np.nan_to_num(array_out1) * footprint1 + np.nan_to_num(array_out2) * footprint2
259+
) / (footprint1 + footprint2)
260+
footprint = (footprint1 + footprint2) / 2
260261
header = header[0]
261262
else:
262263
array_out, footprint = reproject_function((array_in, wcs_in_copy), header, **kwargs)
Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,50 @@
1-
import pytest
21
import numpy as np
3-
4-
from astropy.wcs import WCS
2+
import pytest
53
from astropy.io import fits
4+
from astropy.utils.data import get_pkg_data_filename
5+
from astropy.wcs import WCS
66

77
from reproject import reproject_interp
88
from reproject.hips import reproject_to_hips
99
from reproject.hips._dask_array import hips_as_dask_and_wcs
10-
from astropy.utils.data import get_pkg_data_filename
10+
1111

1212
class TestHIPSDaskArray:
1313

1414
def setup_method(self):
15-
16-
hdu = fits.open(get_pkg_data_filename('allsky/allsky_rosat.fits'))[0]
17-
self.original_header = hdu.header
15+
# We use an all-sky WCS image as input since this will test all parts
16+
# of the HiPS projection (some issues happen around boundaries for instance)
17+
hdu = fits.open(get_pkg_data_filename("allsky/allsky_rosat.fits"))[0]
1818
self.original_wcs = WCS(hdu.header)
1919
self.original_array = hdu.data.size + np.arange(hdu.data.size).reshape(hdu.data.shape)
2020

21-
@pytest.mark.parametrize('frame', ('galactic', 'equatorial'))
22-
@pytest.mark.parametrize('level', (0, 1))
21+
@pytest.mark.parametrize("frame", ("galactic", "equatorial"))
22+
@pytest.mark.parametrize("level", (0, 1))
2323
def test_roundtrip(self, tmp_path, frame, level):
2424

25-
self.output_directory = tmp_path / 'roundtrip'
25+
output_directory = tmp_path / "roundtrip"
2626

27+
# Note that we always use level=1 to generate, but use a variable level
28+
# to construct the dask array - this is deliberate and ensure that the
29+
# dask array has a proper separation of maximum and current level.
2730
reproject_to_hips(
2831
(self.original_array, self.original_wcs),
2932
coord_system_out=frame,
30-
level=level,
33+
level=1,
3134
reproject_function=reproject_interp,
32-
output_directory=self.output_directory,
35+
output_directory=output_directory,
36+
tile_size=256,
3337
)
3438

35-
dask_array, wcs = hips_as_dask_and_wcs(self.output_directory, level=level)
39+
# Represent the HiPS as a dask array
40+
dask_array, wcs = hips_as_dask_and_wcs(output_directory, level=level)
3641

37-
final_array, footprint = reproject_interp((dask_array, wcs), self.original_wcs, shape_out=self.original_array.shape)
42+
# Reproject back to the original WCS
43+
final_array, footprint = reproject_interp(
44+
(dask_array, wcs),
45+
self.original_wcs,
46+
shape_out=self.original_array.shape,
47+
)
3848

3949
# FIXME: Due to boundary effects and the fact there are NaN values in
4050
# the whole-map dask array, there are a few NaN pixels in the image in
@@ -47,10 +57,33 @@ def test_roundtrip(self, tmp_path, frame, level):
4757
# values.
4858

4959
valid = ~np.isnan(final_array)
50-
5160
assert np.sum(valid) > 90400
52-
5361
np.testing.assert_allclose(final_array[valid], self.original_array[valid], rtol=0.01)
5462

63+
def test_level_validation(self, tmp_path):
64+
65+
output_directory = tmp_path / "levels"
66+
67+
reproject_to_hips(
68+
(self.original_array, self.original_wcs),
69+
coord_system_out="equatorial",
70+
level=1,
71+
reproject_function=reproject_interp,
72+
output_directory=output_directory,
73+
tile_size=32,
74+
)
75+
76+
dask_array, wcs = hips_as_dask_and_wcs(output_directory, level=0)
77+
assert dask_array.shape == (160, 160)
78+
79+
dask_array, wcs = hips_as_dask_and_wcs(output_directory, level=1)
80+
assert dask_array.shape == (320, 320)
81+
82+
dask_array, wcs = hips_as_dask_and_wcs(output_directory)
83+
assert dask_array.shape == (320, 320)
84+
85+
with pytest.raises(Exception, match=r"does not contain level 2 data"):
86+
hips_as_dask_and_wcs(output_directory, level=2)
5587

56-
# VALIDATE LEVEL
88+
with pytest.raises(Exception, match=r"should be positive"):
89+
hips_as_dask_and_wcs(output_directory, level=-1)

0 commit comments

Comments
 (0)