Skip to content

Commit 05a57f6

Browse files
committed
WIP Implementation of get/set viewport in dummy_viewer
1 parent 428d510 commit 05a57f6

File tree

1 file changed

+186
-91
lines changed

1 file changed

+186
-91
lines changed

src/astro_image_display_api/dummy_viewer.py

Lines changed: 186 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,31 @@
1+
import numbers
12
import os
3+
from collections import defaultdict
4+
from copy import copy
25
from dataclasses import dataclass, field
36
from pathlib import Path
47
from typing import Any
58

9+
from astropy import units as u
610
from astropy.coordinates import SkyCoord
711
from astropy.nddata import CCDData, NDData
812
from astropy.table import Table, vstack
913
from astropy.units import Quantity, get_physical_type
1014
from astropy.wcs import WCS
15+
from astropy.wcs.utils import proj_plane_pixel_scales
1116
from astropy.visualization import AsymmetricPercentileInterval, BaseInterval, BaseStretch, LinearStretch, ManualInterval
1217
from numpy.typing import ArrayLike
1318

1419
from .interface_definition import ImageViewerInterface
1520

21+
@dataclass
22+
class ViewportInfo:
23+
"""
24+
Class to hold image and viewport information.
25+
"""
26+
center: SkyCoord | tuple[numbers.Real, numbers.Real] | None = None
27+
fov: float | Quantity | None = None
28+
wcs: WCS | None = None
1629

1730
@dataclass
1831
class ImageViewer:
@@ -28,7 +41,7 @@ class ImageViewer:
2841
zoom_level: float = 1
2942
_cursor: str = ImageViewerInterface.ALLOWED_CURSOR_LOCATIONS[0]
3043
marker: Any = "marker"
31-
_cuts: BaseInterval | tuple[float, float] = AsymmetricPercentileInterval(upper_percentile=95)
44+
_cuts: BaseInterval | tuple[numbers.Real, numbers.Real] = AsymmetricPercentileInterval(upper_percentile=95)
3245
_stretch: BaseStretch = LinearStretch
3346
# viewer: Any
3447

@@ -46,7 +59,15 @@ class ImageViewer:
4659
_previous_marker: Any = ""
4760
_markers: dict[str, Table] = field(default_factory=dict)
4861
_wcs: WCS | None = None
49-
_center: tuple[float, float] = (0.0, 0.0)
62+
_center: tuple[numbers.Real, numbers.Real] = (0.0, 0.0)
63+
64+
65+
def __post_init__(self):
66+
# Set up the initial state of the viewer
67+
self._images = defaultdict(ViewportInfo)
68+
self._images[None].center = None
69+
self._images[None].fov = None
70+
self._images[None].wcs = None
5071

5172
def get_stretch(self) -> BaseStretch:
5273
return self._stretch
@@ -59,7 +80,7 @@ def set_stretch(self, value: BaseStretch) -> None:
5980
def get_cuts(self) -> tuple:
6081
return self._cuts
6182

62-
def set_cuts(self, value: tuple[float, float] | BaseInterval) -> None:
83+
def set_cuts(self, value: tuple[numbers.Real, numbers.Real] | BaseInterval) -> None:
6384
if isinstance(value, tuple) and len(value) == 2:
6485
self._cuts = ManualInterval(value[0], value[1])
6586
elif isinstance(value, BaseInterval):
@@ -80,7 +101,42 @@ def cursor(self, value: str) -> None:
80101
# The methods, grouped loosely by purpose
81102

82103
# Methods for loading data
83-
def load_image(self, file: str | os.PathLike | ArrayLike | NDData) -> None:
104+
def _user_image_labels(self) -> list[str]:
105+
"""
106+
Get the list of user-defined image labels.
107+
108+
Returns
109+
-------
110+
list of str
111+
The list of user-defined image labels.
112+
"""
113+
return [label for label in self._images if label is not None]
114+
115+
def _resolve_image_label(self, image_label: str | None) -> str:
116+
"""
117+
Figure out the catalog label if the user did not specify one. This
118+
is needed so that the user gets what they expect in the simple case
119+
where there is only one catalog loaded. In that case the user may
120+
or may not have actually specified a catalog label.
121+
"""
122+
user_keys = self._user_image_labels()
123+
if image_label is None:
124+
match len(user_keys):
125+
case 0:
126+
# No user-defined catalog labels, so return the default label.
127+
image_label = None
128+
case 1:
129+
# The user must have loaded a catalog, so return that instead of
130+
# the default label, which live in the key None.
131+
image_label = user_keys[0]
132+
case _:
133+
raise ValueError(
134+
"Multiple catalog styles defined. Please specify a image_label to get the style."
135+
)
136+
137+
return image_label
138+
139+
def load_image(self, file: str | os.PathLike | ArrayLike | NDData, image_label: str | None = None) -> None:
84140
"""
85141
Load a FITS file into the viewer.
86142
@@ -89,32 +145,42 @@ def load_image(self, file: str | os.PathLike | ArrayLike | NDData) -> None:
89145
file : str or `astropy.io.fits.HDU`
90146
The FITS file to load. If a string, it can be a URL or a
91147
file path.
148+
149+
image_label : str, optional
150+
A label for the image.
92151
"""
152+
image_label = self._resolve_image_label(image_label)
153+
154+
# Delete the current viewport if it exists
155+
if image_label in self._images:
156+
del self._images[image_label]
157+
93158
if isinstance(file, (str, os.PathLike)):
94159
if isinstance(file, str):
95160
is_adsf = file.endswith(".asdf")
96161
else:
97162
is_asdf = file.suffix == ".asdf"
98163
if is_asdf:
99-
self._load_asdf(file)
164+
self._load_asdf(file, image_label)
100165
else:
101-
self._load_fits(file)
166+
self._load_fits(file, image_label)
102167
elif isinstance(file, NDData):
103-
self._load_nddata(file)
168+
self._load_nddata(file, image_label)
104169
else:
105170
# Assume it is a 2D array
106-
self._load_array(file)
171+
self._load_array(file, image_label)
107172

108-
def _load_fits(self, file: str | os.PathLike) -> None:
173+
def _load_fits(self, file: str | os.PathLike, image_label: str | None) -> None:
109174
ccd = CCDData.read(file)
110-
self._wcs = ccd.wcs
111-
self.image_height, self.image_width = ccd.shape
112-
# Totally made up number...as currently defined, zoom_level means, esentially, ratio
113-
# of image size to viewer size.
114-
self.zoom_level = 1.0
115-
self.center_on((self.image_width / 2, self.image_height / 2))
116-
117-
def _load_array(self, array: ArrayLike) -> None:
175+
height, width = ccd.shape
176+
self._images[image_label].wcs = ccd.wcs
177+
self.set_viewport(
178+
center=(width / 2, height / 2),
179+
fov=max(ccd.shape),
180+
image_label=image_label
181+
)
182+
183+
def _load_array(self, array: ArrayLike, image_label: str | None) -> None:
118184
"""
119185
Load a 2D array into the viewer.
120186
@@ -123,14 +189,15 @@ def _load_array(self, array: ArrayLike) -> None:
123189
array : array-like
124190
The array to load.
125191
"""
126-
self.image_height, self.image_width = array.shape
127-
# Totally made up number...as currently defined, zoom_level means, esentially, ratio
128-
# of image size to viewer size.
129-
self.zoom_level = 1.0
130-
self.center_on((self.image_width / 2, self.image_height / 2))
131-
192+
height, width = array.shape
193+
self._images[image_label].wcs = None # No WCS for raw arrays
194+
self.set_viewport(
195+
center=(width / 2, height / 2),
196+
fov=max(array.shape),
197+
image_label=image_label
198+
)
132199

133-
def _load_nddata(self, data: NDData) -> None:
200+
def _load_nddata(self, data: NDData, image_label: str | None) -> None:
134201
"""
135202
Load an `astropy.nddata.NDData` object into the viewer.
136203
@@ -139,15 +206,16 @@ def _load_nddata(self, data: NDData) -> None:
139206
data : `astropy.nddata.NDData`
140207
The NDData object to load.
141208
"""
142-
self._wcs = data.wcs
209+
self._images[image_label].wcs = data.wcs
143210
# Not all NDDData objects have a shape, apparently
144-
self.image_height, self.image_width = data.data.shape
145-
# Totally made up number...as currently defined, zoom_level means, esentially, ratio
146-
# of image size to viewer size.
147-
self.zoom_level = 1.0
148-
self.center_on((self.image_width / 2, self.image_height / 2))
211+
height, width = data.data.shape
212+
self.set_viewport(
213+
center=(width / 2, height / 2),
214+
fov=max(data.data.shape),
215+
image_label=image_label
216+
)
149217

150-
def _load_asdf(self, asdf_file: str | os.PathLike) -> None:
218+
def _load_asdf(self, asdf_file: str | os.PathLike, image_label: str | None) -> None:
151219
"""
152220
Not implementing some load types is fine.
153221
"""
@@ -313,67 +381,94 @@ def get_markers(self, x_colname: str = 'x', y_colname: str = 'y',
313381

314382

315383
# Methods that modify the view
316-
def center_on(self, point: tuple | SkyCoord):
317-
"""
318-
Center the view on the point.
319-
320-
Parameters
321-
----------
322-
tuple or `~astropy.coordinates.SkyCoord`
323-
If tuple of ``(X, Y)`` is given, it is assumed
324-
to be in data coordinates.
325-
"""
326-
# currently there is no way to get the position of the center, but we may as well make
327-
# note of it
328-
if isinstance(point, SkyCoord):
329-
if self._wcs is not None:
330-
point = self._wcs.world_to_pixel(point)
384+
def set_viewport(
385+
self, center: SkyCoord | tuple[numbers.Real, numbers.Real] | None = None,
386+
fov: Quantity | numbers.Real | None = None,
387+
image_label: str | None = None
388+
) -> None:
389+
image_label = self._resolve_image_label(image_label)
390+
391+
# Get current center/fov, if any, so that the user may input only one of them
392+
# after the initial setup if they wish.
393+
current_viewport = copy(self._images[image_label])
394+
if center is None:
395+
center = current_viewport.center
396+
if fov is None:
397+
fov = current_viewport.fov
398+
399+
# If either center or fov is None these checks will raise an appropriate error
400+
if not isinstance(center, (SkyCoord, tuple)):
401+
raise TypeError("Invalid value for center. Center must be a SkyCoord or tuple of (X, Y).")
402+
if not isinstance(fov, (Quantity, numbers.Real)):
403+
raise TypeError("Invalid value for fov. FOV must be a Quantity or float.")
404+
405+
# Check that the center and fov are compatible with the current image
406+
if self._images[image_label].wcs is None:
407+
if current_viewport.center is not None:
408+
# If there is a WCS either input is fine. If there is no WCS then we only
409+
# check wther the new center is the same type as the current center.
410+
if isinstance(center, SkyCoord) and not isinstance(current_viewport.center, SkyCoord):
411+
raise ValueError("Center must be a SkyCoord for this image when WCS is not set.")
412+
elif isinstance(center, tuple) and not isinstance(current_viewport.center, tuple):
413+
raise ValueError("Center must be a tuple of (X, Y) for this image when WCS is not set.")
414+
if current_viewport.fov is not None:
415+
if isinstance(fov, Quantity) and not isinstance(current_viewport.fov, Quantity):
416+
raise ValueError("FOV must be a angular Quantity for this image when WCS is not set.")
417+
elif isinstance(fov, numbers.Real) and not isinstance(current_viewport.fov, numbers.Real):
418+
raise ValueError("FOV must be a float for this image when WCS is set.")
419+
420+
# 😅 if we made it this far we should be able to handle the actual setting
421+
self._images[image_label].center = center
422+
self._images[image_label].fov = fov
423+
424+
425+
set_viewport.__doc__ = ImageViewerInterface.set_viewport.__doc__
426+
427+
def get_viewport(
428+
self, sky_or_pixel: str | None = None, image_label: str | None = None
429+
) -> dict[str, Any]:
430+
if sky_or_pixel not in (None, "sky", "pixel"):
431+
raise ValueError("sky_or_pixel must be 'sky', 'pixel', or None.")
432+
image_label = self._resolve_image_label(image_label)
433+
434+
viewport = self._images[image_label]
435+
if sky_or_pixel == "sky":
436+
if isinstance(viewport.center, SkyCoord):
437+
center = viewport.center
438+
elif isinstance(viewport.center, tuple):
439+
# If the center is a tuple, we need to convert it to SkyCoord
440+
if viewport.wcs is None:
441+
raise ValueError("WCS is not set. Cannot convert pixel coordinates to sky coordinates.")
442+
center = viewport.wcs.pixel_to_world(viewport.center[0], viewport.center[1])
443+
if isinstance(viewport.fov, Quantity):
444+
fov = viewport.fov
445+
elif isinstance(viewport.fov, numbers.Real):
446+
if viewport.wcs is None:
447+
raise ValueError("WCS is not set. Cannot convert FOV to sky coordinates.")
448+
pixel_scale = proj_plane_pixel_scales(viewport.wcs)
449+
fov = pixel_scale * viewport.fov * u.degree
450+
else:
451+
# Pixel coordinates
452+
if isinstance(viewport.center, SkyCoord):
453+
if viewport.wcs is None:
454+
raise ValueError("WCS is not set. Cannot convert sky coordinates to pixel coordinates.")
455+
center = viewport.wcs.world_to_pixel(viewport.center)
331456
else:
332-
raise ValueError("WCS is not set. Cannot convert to pixel coordinates.")
333-
334-
self._center = point
335-
336-
def offset_by(self, dx: float | Quantity, dy: float | Quantity) -> None:
337-
"""
338-
Move the center to a point that is given offset
339-
away from the current center.
340-
341-
Parameters
342-
----------
343-
dx, dy : float or `~astropy.units.Quantity`
344-
Offset value. Without a unit, assumed to be pixel offsets.
345-
If a unit is attached, offset by pixel or sky is assumed from
346-
the unit.
347-
"""
348-
# Convert to quantity to make the rest of the processing uniform
349-
dx = Quantity(dx)
350-
dy = Quantity(dy)
351-
352-
# This raises a UnitConversionError if the units are not compatible
353-
dx.to(dy.unit)
354-
355-
# Do we have an angle or pixel offset?
356-
if get_physical_type(dx) == "angle":
357-
# This is a sky offset
358-
if self._wcs is not None:
359-
old_center_coord = self._wcs.pixel_to_world(self._center[0], self._center[1])
360-
new_center = old_center_coord.spherical_offsets_by(dx, dy)
361-
self.center_on(new_center)
457+
center = viewport.center
458+
if isinstance(viewport.fov, Quantity):
459+
if viewport.wcs is None:
460+
raise ValueError("WCS is not set. Cannot convert FOV to pixel coordinates.")
461+
pixel_scale = proj_plane_pixel_scales(viewport.wcs)
462+
fov = viewport.fov / pixel_scale
362463
else:
363-
raise ValueError("WCS is not set. Cannot convert to pixel coordinates.")
364-
else:
365-
# This is a pixel offset
366-
new_center = (self._center[0] + dx.value, self._center[1] + dy.value)
367-
self.center_on(new_center)
464+
fov = viewport.fov
368465

369-
def zoom(self, val) -> None:
370-
"""
371-
Zoom in or out by the given factor.
466+
return dict(
467+
center=center,
468+
fov=fov,
469+
wcs=viewport.wcs,
470+
image_label=image_label
471+
)
372472

373-
Parameters
374-
----------
375-
val : int
376-
The zoom level to zoom the image.
377-
See `zoom_level`.
378-
"""
379-
self.zoom_level *= val
473+
474+
get_viewport.__doc__ = ImageViewerInterface.get_viewport.__doc__

0 commit comments

Comments
 (0)