Skip to content

Commit e711ff7

Browse files
authored
Merge pull request #554 from astrofrog/hips2d-dask-array
Added 2-d HiPS dask array class and fix a number of bugs with HiPS generation
2 parents eedfc6c + fdc25ea commit e711ff7

File tree

6 files changed

+397
-30
lines changed

6 files changed

+397
-30
lines changed

reproject/hips/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .high_level import * # noqa
2+
from ._dask_array import hips_as_dask_array # noqa

reproject/hips/_dask_array.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import functools
2+
import os
3+
import urllib
4+
import uuid
5+
6+
import numpy as np
7+
from astropy.io import fits
8+
from astropy.utils.data import download_file
9+
from astropy.wcs import WCS
10+
from astropy_healpix import HEALPix, level_to_nside
11+
from dask import array as da
12+
13+
from .high_level import VALID_COORD_SYSTEM
14+
from .utils import is_url, load_properties, map_header, tile_filename
15+
16+
__all__ = ["hips_as_dask_array"]
17+
18+
19+
class HiPSArray:
20+
21+
def __init__(self, directory_or_url, level=None):
22+
23+
self._directory_or_url = directory_or_url
24+
25+
self._is_url = is_url(directory_or_url)
26+
27+
self._properties = load_properties(directory_or_url)
28+
29+
self._tile_width = int(self._properties["hips_tile_width"])
30+
self._order = int(self._properties["hips_order"])
31+
if level is None:
32+
self._level = self._order
33+
else:
34+
if level > self._order:
35+
raise ValueError(
36+
f"HiPS dataset at {directory_or_url} does not contain level {level} data"
37+
)
38+
elif level < 0:
39+
raise ValueError("level should be positive")
40+
else:
41+
self._level = int(level)
42+
self._level = self._order if level is None else level
43+
self._tile_format = self._properties["hips_tile_format"]
44+
self._frame_str = self._properties["hips_frame"]
45+
self._frame = VALID_COORD_SYSTEM[self._frame_str]
46+
47+
self._hp = HEALPix(nside=level_to_nside(self._level), frame=self._frame, order="nested")
48+
49+
self._header = map_header(level=self._level, frame=self._frame, tile_size=self._tile_width)
50+
51+
self.wcs = WCS(self._header)
52+
self.shape = self.wcs.array_shape
53+
54+
self.dtype = float
55+
self.ndim = 2
56+
57+
self.chunksize = (self._tile_width, self._tile_width)
58+
59+
self._nan = np.nan * np.ones(self.chunksize, dtype=self.dtype)
60+
61+
self._blank = np.broadcast_to(np.nan, self.shape)
62+
63+
def __getitem__(self, item):
64+
65+
if item[0].start == item[0].stop or item[1].start == item[1].stop:
66+
return self._blank[item]
67+
68+
# We use two points in different parts of the image because in some
69+
# cases using the exact center or corners can cause issues.
70+
71+
istart = item[0].start
72+
irange = item[0].stop - item[0].start
73+
imid = np.array([istart + 0.25 * irange, istart + 0.75 * irange])
74+
75+
jstart = item[1].start
76+
jrange = item[1].stop - item[1].start
77+
jmid = np.array([jstart + 0.25 * jrange, jstart + 0.75 * jrange])
78+
79+
# Convert pixel coordinates to HEALPix indices
80+
81+
coord = self.wcs.pixel_to_world(jmid, imid)
82+
83+
if self._frame_str == "equatorial":
84+
lon, lat = coord.ra.deg, coord.dec.deg
85+
elif self._frame_str == "galactic":
86+
lon, lat = coord.l.deg, coord.b.deg
87+
else:
88+
raise NotImplementedError()
89+
90+
invalid = np.isnan(lon) | np.isnan(lat)
91+
92+
if np.all(invalid):
93+
return self._nan
94+
elif np.any(invalid):
95+
coord = coord[~invalid]
96+
97+
index = self._hp.skycoord_to_healpix(coord)
98+
99+
if np.all(index == -1):
100+
return self._nan
101+
102+
index = np.max(index)
103+
104+
return self._get_tile(level=self._level, index=index)
105+
106+
@functools.lru_cache(maxsize=128) # noqa: B019
107+
def _get_tile(self, *, level, index):
108+
109+
filename_or_url = tile_filename(
110+
level=self._level,
111+
index=index,
112+
output_directory=self._directory_or_url,
113+
extension="fits",
114+
)
115+
116+
if self._is_url:
117+
try:
118+
filename = download_file(filename_or_url, cache=True)
119+
except urllib.error.HTTPError:
120+
return self._nan
121+
elif not os.path.exists(filename_or_url):
122+
return self._nan
123+
else:
124+
filename = filename_or_url
125+
126+
with fits.open(filename) as hdulist:
127+
hdu = hdulist[0]
128+
data = hdu.data
129+
130+
return data
131+
132+
133+
def hips_as_dask_array(directory_or_url, *, level=None):
134+
"""
135+
Return a dask array and WCS that represent a HiPS dataset at a particular level.
136+
"""
137+
array_wrapper = HiPSArray(directory_or_url, level=level)
138+
return (
139+
da.from_array(
140+
array_wrapper,
141+
chunks=array_wrapper.chunksize,
142+
name=str(uuid.uuid4()),
143+
meta=np.array([], dtype=float),
144+
),
145+
array_wrapper.wcs,
146+
)

reproject/hips/high_level.py

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pathlib import Path
99

1010
import numpy as np
11+
from astropy import units as u
1112
from astropy.coordinates import ICRS, BarycentricTrueEcliptic, Galactic
1213
from astropy.io import fits
1314
from astropy.nddata import block_reduce
@@ -22,7 +23,9 @@
2223
from ..utils import as_transparent_rgb, is_jpeg, is_png, parse_input_data
2324
from ..wcs_utils import has_celestial, pixel_scale
2425
from .utils import (
26+
load_properties,
2527
make_tile_folders,
28+
save_properties,
2629
tile_filename,
2730
tile_header,
2831
)
@@ -202,6 +205,8 @@ def reproject_to_hips(
202205
# Determine center of image and radius to furthest corner, to determine
203206
# which HiPS tiles need to be generated
204207

208+
# TODO: this will fail for e.g. allsky maps
209+
205210
ny, nx = array_in.shape[-2:]
206211

207212
cen_x, cen_y = (nx - 1) / 2, (ny - 1) / 2
@@ -212,7 +217,30 @@ def reproject_to_hips(
212217
cen_world = wcs_in.pixel_to_world(cen_x, cen_y)
213218
cor_world = wcs_in.pixel_to_world(cor_x, cor_y)
214219

215-
radius = cor_world.separation(cen_world).max()
220+
separations = cor_world.separation(cen_world)
221+
222+
if np.any(np.isnan(separations)):
223+
224+
# At least one of the corners is outside of the region of validity of
225+
# the WCS, so we use a different approach where we randomly sample a
226+
# number of positions in the image and then check the maximum
227+
# separation between any pair of points.
228+
229+
n_ran = 1000
230+
ran_x = np.random.uniform(-0.5, nx - 0.5, n_ran)
231+
ran_y = np.random.uniform(-0.5, nx - 0.5, n_ran)
232+
233+
ran_world = wcs_in.pixel_to_world(ran_x, ran_y)
234+
235+
separations = ran_world[:, None].separation(ran_world[None, :])
236+
237+
max_separation = np.nanmax(separations)
238+
239+
else:
240+
241+
max_separation = separations.max()
242+
243+
radius = 1.5 * max_separation
216244

217245
# TODO: in future if astropy-healpix implements polygon searches, we could
218246
# use that instead
@@ -222,7 +250,10 @@ def reproject_to_hips(
222250
nside = level_to_nside(level)
223251
hp = HEALPix(nside=nside, order="nested", frame=frame)
224252

225-
indices = hp.cone_search_skycoord(cen_world, radius=radius)
253+
if radius > 120 * u.deg:
254+
indices = np.arange(hp.npix)
255+
else:
256+
indices = hp.cone_search_skycoord(cen_world, radius=radius)
226257

227258
logger.info(f"Found {len(indices)} tiles (at most) to generate at level {level}")
228259

@@ -234,12 +265,29 @@ def reproject_to_hips(
234265

235266
# Iterate over the tiles and generate them
236267
def process(index):
237-
header = tile_header(level=level, index=index, frame=frame, tile_size=tile_size)
238268
if hasattr(wcs_in, "deepcopy"):
239269
wcs_in_copy = wcs_in.deepcopy()
240270
else:
241271
wcs_in_copy = deepcopy(wcs_in)
242-
array_out, footprint = reproject_function((array_in, wcs_in_copy), header, **kwargs)
272+
273+
header = tile_header(level=level, index=index, frame=frame, tile_size=tile_size)
274+
275+
if isinstance(header, tuple):
276+
array_out1, footprint1 = reproject_function(
277+
(array_in, wcs_in_copy), header[0], **kwargs
278+
)
279+
array_out2, footprint2 = reproject_function(
280+
(array_in, wcs_in_copy), header[1], **kwargs
281+
)
282+
with np.errstate(invalid="ignore"):
283+
array_out = (
284+
np.nan_to_num(array_out1) * footprint1 + np.nan_to_num(array_out2) * footprint2
285+
) / (footprint1 + footprint2)
286+
footprint = (footprint1 + footprint2) / 2
287+
header = header[0]
288+
else:
289+
array_out, footprint = reproject_function((array_in, wcs_in_copy), header, **kwargs)
290+
243291
if tile_format != "png":
244292
array_out[np.isnan(array_out)] = 0.0
245293
if np.all(footprint == 0):
@@ -253,6 +301,7 @@ def process(index):
253301
extension=EXTENSION[tile_format],
254302
),
255303
array_out,
304+
header,
256305
)
257306
else:
258307
if tile_format == "png":
@@ -288,6 +337,9 @@ def process(index):
288337
indices = np.array(generated_indices)
289338

290339
# Iterate over higher levels and compute lower resolution tiles
340+
341+
half_tile_size = tile_size // 2
342+
291343
for ilevel in range(level - 1, -1, -1):
292344

293345
# Find index of tiles to produce at lower-resolution levels
@@ -299,6 +351,9 @@ def process(index):
299351

300352
header = tile_header(level=ilevel, index=index, frame=frame, tile_size=tile_size)
301353

354+
if isinstance(header, tuple):
355+
header = header[0]
356+
302357
if tile_format == "fits":
303358
array = np.zeros((tile_size, tile_size))
304359
elif tile_format == "png":
@@ -326,13 +381,13 @@ def process(index):
326381
)
327382

328383
if subindex == 0:
329-
array[256:, :256] = data
384+
array[half_tile_size:, :half_tile_size] = data
330385
elif subindex == 2:
331-
array[256:, 256:] = data
386+
array[half_tile_size:, half_tile_size:] = data
332387
elif subindex == 1:
333-
array[:256, :256] = data
388+
array[:half_tile_size, :half_tile_size] = data
334389
elif subindex == 3:
335-
array[:256, 256:] = data
390+
array[:half_tile_size, half_tile_size:] = data
336391

337392
if tile_format == "fits":
338393
fits.writeto(
@@ -403,21 +458,6 @@ def save_index(directory):
403458
f.write(INDEX_HTML)
404459

405460

406-
def save_properties(directory, properties):
407-
with open(os.path.join(directory, "properties"), "w") as f:
408-
for key, value in properties.items():
409-
f.write(f"{key:20s} = {value}\n")
410-
411-
412-
def load_properties(directory):
413-
properties = {}
414-
with open(os.path.join(directory, "properties")) as f:
415-
for line in f:
416-
key, value = line.split("=")
417-
properties[key.strip()] = value.strip()
418-
return properties
419-
420-
421461
def coadd_hips(input_directories, output_directory):
422462
"""
423463
Given multiple HiPS directories, combine these into a single HiPS.

0 commit comments

Comments
 (0)