1
+ import numbers
1
2
import os
3
+ from collections import defaultdict
4
+ from copy import copy
2
5
from dataclasses import dataclass , field
3
6
from pathlib import Path
4
7
from typing import Any
5
8
9
+ from astropy import units as u
6
10
from astropy .coordinates import SkyCoord
7
11
from astropy .nddata import CCDData , NDData
8
12
from astropy .table import Table , vstack
9
13
from astropy .units import Quantity , get_physical_type
10
14
from astropy .wcs import WCS
15
+ from astropy .wcs .utils import proj_plane_pixel_scales
11
16
from astropy .visualization import AsymmetricPercentileInterval , BaseInterval , BaseStretch , LinearStretch , ManualInterval
12
17
from numpy .typing import ArrayLike
13
18
14
19
from .interface_definition import ImageViewerInterface
15
20
21
+ @dataclass
22
+ class ViewportInfo :
23
+ """
24
+ Class to hold image and viewport information.
25
+ """
26
+ center : SkyCoord | tuple [numbers .Real , numbers .Real ] | None = None
27
+ fov : float | Quantity | None = None
28
+ wcs : WCS | None = None
16
29
17
30
@dataclass
18
31
class ImageViewer :
@@ -28,7 +41,7 @@ class ImageViewer:
28
41
zoom_level : float = 1
29
42
_cursor : str = ImageViewerInterface .ALLOWED_CURSOR_LOCATIONS [0 ]
30
43
marker : Any = "marker"
31
- _cuts : BaseInterval | tuple [float , float ] = AsymmetricPercentileInterval (upper_percentile = 95 )
44
+ _cuts : BaseInterval | tuple [numbers . Real , numbers . Real ] = AsymmetricPercentileInterval (upper_percentile = 95 )
32
45
_stretch : BaseStretch = LinearStretch
33
46
# viewer: Any
34
47
@@ -46,7 +59,15 @@ class ImageViewer:
46
59
_previous_marker : Any = ""
47
60
_markers : dict [str , Table ] = field (default_factory = dict )
48
61
_wcs : WCS | None = None
49
- _center : tuple [float , float ] = (0.0 , 0.0 )
62
+ _center : tuple [numbers .Real , numbers .Real ] = (0.0 , 0.0 )
63
+
64
+
65
+ def __post_init__ (self ):
66
+ # Set up the initial state of the viewer
67
+ self ._images = defaultdict (ViewportInfo )
68
+ self ._images [None ].center = None
69
+ self ._images [None ].fov = None
70
+ self ._images [None ].wcs = None
50
71
51
72
def get_stretch (self ) -> BaseStretch :
52
73
return self ._stretch
@@ -59,7 +80,7 @@ def set_stretch(self, value: BaseStretch) -> None:
59
80
def get_cuts (self ) -> tuple :
60
81
return self ._cuts
61
82
62
- def set_cuts (self , value : tuple [float , float ] | BaseInterval ) -> None :
83
+ def set_cuts (self , value : tuple [numbers . Real , numbers . Real ] | BaseInterval ) -> None :
63
84
if isinstance (value , tuple ) and len (value ) == 2 :
64
85
self ._cuts = ManualInterval (value [0 ], value [1 ])
65
86
elif isinstance (value , BaseInterval ):
@@ -80,7 +101,42 @@ def cursor(self, value: str) -> None:
80
101
# The methods, grouped loosely by purpose
81
102
82
103
# Methods for loading data
83
- def load_image (self , file : str | os .PathLike | ArrayLike | NDData ) -> None :
104
+ def _user_image_labels (self ) -> list [str ]:
105
+ """
106
+ Get the list of user-defined image labels.
107
+
108
+ Returns
109
+ -------
110
+ list of str
111
+ The list of user-defined image labels.
112
+ """
113
+ return [label for label in self ._images if label is not None ]
114
+
115
+ def _resolve_image_label (self , image_label : str | None ) -> str :
116
+ """
117
+ Figure out the catalog label if the user did not specify one. This
118
+ is needed so that the user gets what they expect in the simple case
119
+ where there is only one catalog loaded. In that case the user may
120
+ or may not have actually specified a catalog label.
121
+ """
122
+ user_keys = self ._user_image_labels ()
123
+ if image_label is None :
124
+ match len (user_keys ):
125
+ case 0 :
126
+ # No user-defined catalog labels, so return the default label.
127
+ image_label = None
128
+ case 1 :
129
+ # The user must have loaded a catalog, so return that instead of
130
+ # the default label, which live in the key None.
131
+ image_label = user_keys [0 ]
132
+ case _:
133
+ raise ValueError (
134
+ "Multiple catalog styles defined. Please specify a image_label to get the style."
135
+ )
136
+
137
+ return image_label
138
+
139
+ def load_image (self , file : str | os .PathLike | ArrayLike | NDData , image_label : str | None = None ) -> None :
84
140
"""
85
141
Load a FITS file into the viewer.
86
142
@@ -89,32 +145,42 @@ def load_image(self, file: str | os.PathLike | ArrayLike | NDData) -> None:
89
145
file : str or `astropy.io.fits.HDU`
90
146
The FITS file to load. If a string, it can be a URL or a
91
147
file path.
148
+
149
+ image_label : str, optional
150
+ A label for the image.
92
151
"""
152
+ image_label = self ._resolve_image_label (image_label )
153
+
154
+ # Delete the current viewport if it exists
155
+ if image_label in self ._images :
156
+ del self ._images [image_label ]
157
+
93
158
if isinstance (file , (str , os .PathLike )):
94
159
if isinstance (file , str ):
95
160
is_adsf = file .endswith (".asdf" )
96
161
else :
97
162
is_asdf = file .suffix == ".asdf"
98
163
if is_asdf :
99
- self ._load_asdf (file )
164
+ self ._load_asdf (file , image_label )
100
165
else :
101
- self ._load_fits (file )
166
+ self ._load_fits (file , image_label )
102
167
elif isinstance (file , NDData ):
103
- self ._load_nddata (file )
168
+ self ._load_nddata (file , image_label )
104
169
else :
105
170
# Assume it is a 2D array
106
- self ._load_array (file )
171
+ self ._load_array (file , image_label )
107
172
108
- def _load_fits (self , file : str | os .PathLike ) -> None :
173
+ def _load_fits (self , file : str | os .PathLike , image_label : str | None ) -> None :
109
174
ccd = CCDData .read (file )
110
- self ._wcs = ccd .wcs
111
- self .image_height , self .image_width = ccd .shape
112
- # Totally made up number...as currently defined, zoom_level means, esentially, ratio
113
- # of image size to viewer size.
114
- self .zoom_level = 1.0
115
- self .center_on ((self .image_width / 2 , self .image_height / 2 ))
116
-
117
- def _load_array (self , array : ArrayLike ) -> None :
175
+ height , width = ccd .shape
176
+ self ._images [image_label ].wcs = ccd .wcs
177
+ self .set_viewport (
178
+ center = (width / 2 , height / 2 ),
179
+ fov = max (ccd .shape ),
180
+ image_label = image_label
181
+ )
182
+
183
+ def _load_array (self , array : ArrayLike , image_label : str | None ) -> None :
118
184
"""
119
185
Load a 2D array into the viewer.
120
186
@@ -123,14 +189,15 @@ def _load_array(self, array: ArrayLike) -> None:
123
189
array : array-like
124
190
The array to load.
125
191
"""
126
- self .image_height , self .image_width = array .shape
127
- # Totally made up number...as currently defined, zoom_level means, esentially, ratio
128
- # of image size to viewer size.
129
- self .zoom_level = 1.0
130
- self .center_on ((self .image_width / 2 , self .image_height / 2 ))
131
-
192
+ height , width = array .shape
193
+ self ._images [image_label ].wcs = None # No WCS for raw arrays
194
+ self .set_viewport (
195
+ center = (width / 2 , height / 2 ),
196
+ fov = max (array .shape ),
197
+ image_label = image_label
198
+ )
132
199
133
- def _load_nddata (self , data : NDData ) -> None :
200
+ def _load_nddata (self , data : NDData , image_label : str | None ) -> None :
134
201
"""
135
202
Load an `astropy.nddata.NDData` object into the viewer.
136
203
@@ -139,15 +206,16 @@ def _load_nddata(self, data: NDData) -> None:
139
206
data : `astropy.nddata.NDData`
140
207
The NDData object to load.
141
208
"""
142
- self ._wcs = data .wcs
209
+ self ._images [ image_label ]. wcs = data .wcs
143
210
# Not all NDDData objects have a shape, apparently
144
- self .image_height , self .image_width = data .data .shape
145
- # Totally made up number...as currently defined, zoom_level means, esentially, ratio
146
- # of image size to viewer size.
147
- self .zoom_level = 1.0
148
- self .center_on ((self .image_width / 2 , self .image_height / 2 ))
211
+ height , width = data .data .shape
212
+ self .set_viewport (
213
+ center = (width / 2 , height / 2 ),
214
+ fov = max (data .data .shape ),
215
+ image_label = image_label
216
+ )
149
217
150
- def _load_asdf (self , asdf_file : str | os .PathLike ) -> None :
218
+ def _load_asdf (self , asdf_file : str | os .PathLike , image_label : str | None ) -> None :
151
219
"""
152
220
Not implementing some load types is fine.
153
221
"""
@@ -313,67 +381,94 @@ def get_markers(self, x_colname: str = 'x', y_colname: str = 'y',
313
381
314
382
315
383
# Methods that modify the view
316
- def center_on (self , point : tuple | SkyCoord ):
317
- """
318
- Center the view on the point.
319
-
320
- Parameters
321
- ----------
322
- tuple or `~astropy.coordinates.SkyCoord`
323
- If tuple of ``(X, Y)`` is given, it is assumed
324
- to be in data coordinates.
325
- """
326
- # currently there is no way to get the position of the center, but we may as well make
327
- # note of it
328
- if isinstance (point , SkyCoord ):
329
- if self ._wcs is not None :
330
- point = self ._wcs .world_to_pixel (point )
384
+ def set_viewport (
385
+ self , center : SkyCoord | tuple [numbers .Real , numbers .Real ] | None = None ,
386
+ fov : Quantity | numbers .Real | None = None ,
387
+ image_label : str | None = None
388
+ ) -> None :
389
+ image_label = self ._resolve_image_label (image_label )
390
+
391
+ # Get current center/fov, if any, so that the user may input only one of them
392
+ # after the initial setup if they wish.
393
+ current_viewport = copy (self ._images [image_label ])
394
+ if center is None :
395
+ center = current_viewport .center
396
+ if fov is None :
397
+ fov = current_viewport .fov
398
+
399
+ # If either center or fov is None these checks will raise an appropriate error
400
+ if not isinstance (center , (SkyCoord , tuple )):
401
+ raise TypeError ("Invalid value for center. Center must be a SkyCoord or tuple of (X, Y)." )
402
+ if not isinstance (fov , (Quantity , numbers .Real )):
403
+ raise TypeError ("Invalid value for fov. FOV must be a Quantity or float." )
404
+
405
+ # Check that the center and fov are compatible with the current image
406
+ if self ._images [image_label ].wcs is None :
407
+ if current_viewport .center is not None :
408
+ # If there is a WCS either input is fine. If there is no WCS then we only
409
+ # check wther the new center is the same type as the current center.
410
+ if isinstance (center , SkyCoord ) and not isinstance (current_viewport .center , SkyCoord ):
411
+ raise ValueError ("Center must be a SkyCoord for this image when WCS is not set." )
412
+ elif isinstance (center , tuple ) and not isinstance (current_viewport .center , tuple ):
413
+ raise ValueError ("Center must be a tuple of (X, Y) for this image when WCS is not set." )
414
+ if current_viewport .fov is not None :
415
+ if isinstance (fov , Quantity ) and not isinstance (current_viewport .fov , Quantity ):
416
+ raise ValueError ("FOV must be a angular Quantity for this image when WCS is not set." )
417
+ elif isinstance (fov , numbers .Real ) and not isinstance (current_viewport .fov , numbers .Real ):
418
+ raise ValueError ("FOV must be a float for this image when WCS is set." )
419
+
420
+ # 😅 if we made it this far we should be able to handle the actual setting
421
+ self ._images [image_label ].center = center
422
+ self ._images [image_label ].fov = fov
423
+
424
+
425
+ set_viewport .__doc__ = ImageViewerInterface .set_viewport .__doc__
426
+
427
+ def get_viewport (
428
+ self , sky_or_pixel : str | None = None , image_label : str | None = None
429
+ ) -> dict [str , Any ]:
430
+ if sky_or_pixel not in (None , "sky" , "pixel" ):
431
+ raise ValueError ("sky_or_pixel must be 'sky', 'pixel', or None." )
432
+ image_label = self ._resolve_image_label (image_label )
433
+
434
+ viewport = self ._images [image_label ]
435
+ if sky_or_pixel == "sky" :
436
+ if isinstance (viewport .center , SkyCoord ):
437
+ center = viewport .center
438
+ elif isinstance (viewport .center , tuple ):
439
+ # If the center is a tuple, we need to convert it to SkyCoord
440
+ if viewport .wcs is None :
441
+ raise ValueError ("WCS is not set. Cannot convert pixel coordinates to sky coordinates." )
442
+ center = viewport .wcs .pixel_to_world (viewport .center [0 ], viewport .center [1 ])
443
+ if isinstance (viewport .fov , Quantity ):
444
+ fov = viewport .fov
445
+ elif isinstance (viewport .fov , numbers .Real ):
446
+ if viewport .wcs is None :
447
+ raise ValueError ("WCS is not set. Cannot convert FOV to sky coordinates." )
448
+ pixel_scale = proj_plane_pixel_scales (viewport .wcs )
449
+ fov = pixel_scale * viewport .fov * u .degree
450
+ else :
451
+ # Pixel coordinates
452
+ if isinstance (viewport .center , SkyCoord ):
453
+ if viewport .wcs is None :
454
+ raise ValueError ("WCS is not set. Cannot convert sky coordinates to pixel coordinates." )
455
+ center = viewport .wcs .world_to_pixel (viewport .center )
331
456
else :
332
- raise ValueError ("WCS is not set. Cannot convert to pixel coordinates." )
333
-
334
- self ._center = point
335
-
336
- def offset_by (self , dx : float | Quantity , dy : float | Quantity ) -> None :
337
- """
338
- Move the center to a point that is given offset
339
- away from the current center.
340
-
341
- Parameters
342
- ----------
343
- dx, dy : float or `~astropy.units.Quantity`
344
- Offset value. Without a unit, assumed to be pixel offsets.
345
- If a unit is attached, offset by pixel or sky is assumed from
346
- the unit.
347
- """
348
- # Convert to quantity to make the rest of the processing uniform
349
- dx = Quantity (dx )
350
- dy = Quantity (dy )
351
-
352
- # This raises a UnitConversionError if the units are not compatible
353
- dx .to (dy .unit )
354
-
355
- # Do we have an angle or pixel offset?
356
- if get_physical_type (dx ) == "angle" :
357
- # This is a sky offset
358
- if self ._wcs is not None :
359
- old_center_coord = self ._wcs .pixel_to_world (self ._center [0 ], self ._center [1 ])
360
- new_center = old_center_coord .spherical_offsets_by (dx , dy )
361
- self .center_on (new_center )
457
+ center = viewport .center
458
+ if isinstance (viewport .fov , Quantity ):
459
+ if viewport .wcs is None :
460
+ raise ValueError ("WCS is not set. Cannot convert FOV to pixel coordinates." )
461
+ pixel_scale = proj_plane_pixel_scales (viewport .wcs )
462
+ fov = viewport .fov / pixel_scale
362
463
else :
363
- raise ValueError ("WCS is not set. Cannot convert to pixel coordinates." )
364
- else :
365
- # This is a pixel offset
366
- new_center = (self ._center [0 ] + dx .value , self ._center [1 ] + dy .value )
367
- self .center_on (new_center )
464
+ fov = viewport .fov
368
465
369
- def zoom (self , val ) -> None :
370
- """
371
- Zoom in or out by the given factor.
466
+ return dict (
467
+ center = center ,
468
+ fov = fov ,
469
+ wcs = viewport .wcs ,
470
+ image_label = image_label
471
+ )
372
472
373
- Parameters
374
- ----------
375
- val : int
376
- The zoom level to zoom the image.
377
- See `zoom_level`.
378
- """
379
- self .zoom_level *= val
473
+
474
+ get_viewport .__doc__ = ImageViewerInterface .get_viewport .__doc__
0 commit comments