diff --git a/src/astro_image_display_api/dummy_viewer.py b/src/astro_image_display_api/dummy_viewer.py index 39cf53b..e79b913 100644 --- a/src/astro_image_display_api/dummy_viewer.py +++ b/src/astro_image_display_api/dummy_viewer.py @@ -1,15 +1,18 @@ +import numbers import os from collections import defaultdict -from copy import deepcopy +from copy import copy, deepcopy from dataclasses import dataclass, field from pathlib import Path from typing import Any +from astropy import units as u from astropy.coordinates import SkyCoord from astropy.nddata import CCDData, NDData from astropy.table import Table, vstack from astropy.units import Quantity, get_physical_type from astropy.wcs import WCS +from astropy.wcs.utils import proj_plane_pixel_scales from astropy.visualization import AsymmetricPercentileInterval, BaseInterval, BaseStretch, LinearStretch, ManualInterval from numpy.typing import ArrayLike @@ -23,6 +26,18 @@ class CatalogInfo: style: dict[str, Any] = field(default_factory=dict) data: Table | None = None +@dataclass +class ViewportInfo: + """ + Class to hold image and viewport information. + """ + center: SkyCoord | tuple[numbers.Real, numbers.Real] | None = None + fov: float | Quantity | None = None + wcs: WCS | None = None + largest_dimension: int | None = None + stretch: BaseStretch | None = None + cuts: BaseInterval | tuple[numbers.Real, numbers.Real] | None = None + @dataclass class ImageViewer: """ @@ -45,7 +60,7 @@ class ImageViewer: # some internal variable for keeping track of viewer state _wcs: WCS | None = None - _center: tuple[float, float] = (0.0, 0.0) + _center: tuple[numbers.Real, numbers.Real] = (0.0, 0.0) def __post_init__(self): # This is a dictionary of marker sets. The keys are the names of the @@ -54,6 +69,12 @@ def __post_init__(self): self._catalogs[None].data = None self._catalogs[None].style = self._default_catalog_style.copy() + self._images = defaultdict(ViewportInfo) + self._images[None].center = None + self._images[None].fov = None + self._images[None].wcs = None + + def _user_catalog_labels(self) -> list[str]: """ Get the user-defined catalog labels. @@ -95,24 +116,39 @@ def _default_catalog_style(self) -> dict[str, Any]: "size": 5, } - def get_stretch(self) -> BaseStretch: - return self._stretch - def set_stretch(self, value: BaseStretch) -> None: - if not isinstance(value, BaseStretch): - raise ValueError(f"Stretch option {value} is not valid. Must be an Astropy.visualization Stretch object.") - self._stretch = value - def get_cuts(self) -> tuple: - return self._cuts + def get_stretch(self, image_label: str | None = None) -> BaseStretch: + image_label = self._resolve_image_label(image_label) + if image_label not in self._images: + raise ValueError(f"Image label '{image_label}' not found. Please load an image first.") + return self._images[image_label].stretch - def set_cuts(self, value: tuple[float, float] | BaseInterval) -> None: + def set_stretch(self, value: BaseStretch, image_label: str | None = None) -> None: + if not isinstance(value, BaseStretch): + raise ValueError(f"Stretch option {value} is not valid. Must be an Astropy.visualization Stretch object.") + image_label = self._resolve_image_label(image_label) + if image_label not in self._images: + raise ValueError(f"Image label '{image_label}' not found. Please load an image first.") + self._images[image_label].stretch = value + + def get_cuts(self, image_label: str | None = None) -> tuple: + image_label = self._resolve_image_label(image_label) + if image_label not in self._images: + raise ValueError(f"Image label '{image_label}' not found. Please load an image first.") + return self._images[image_label].cuts + + def set_cuts(self, value: tuple[numbers.Real, numbers.Real] | BaseInterval, image_label: str | None = None) -> None: if isinstance(value, tuple) and len(value) == 2: self._cuts = ManualInterval(value[0], value[1]) elif isinstance(value, BaseInterval): self._cuts = value else: raise ValueError("Cuts must be an Astropy.visualization Interval object or a tuple of two values.") + image_label = self._resolve_image_label(image_label) + if image_label not in self._images: + raise ValueError(f"Image label '{image_label}' not found. Please load an image first.") + self._images[image_label].cuts = self._cuts @property def cursor(self) -> str: @@ -183,7 +219,42 @@ def set_catalog_style( ) # Methods for loading data - def load_image(self, file: str | os.PathLike | ArrayLike | NDData) -> None: + def _user_image_labels(self) -> list[str]: + """ + Get the list of user-defined image labels. + + Returns + ------- + list of str + The list of user-defined image labels. + """ + return [label for label in self._images if label is not None] + + def _resolve_image_label(self, image_label: str | None) -> str: + """ + Figure out the image label if the user did not specify one. This + is needed so that the user gets what they expect in the simple case + where there is only one image loaded. In that case the user may + or may not have actually specified a image label. + """ + user_keys = self._user_image_labels() + if image_label is None: + match len(user_keys): + case 0: + # No user-defined image labels, so return the default label. + image_label = None + case 1: + # The user must have loaded a image, so return that instead of + # the default label, which live in the key None. + image_label = user_keys[0] + case _: + raise ValueError( + "Multiple image labels defined. Please specify a image_label to get the style." + ) + + return image_label + + def load_image(self, file: str | os.PathLike | ArrayLike | NDData, image_label: str | None = None) -> None: """ Load a FITS file into the viewer. @@ -192,32 +263,103 @@ def load_image(self, file: str | os.PathLike | ArrayLike | NDData) -> None: file : str or `astropy.io.fits.HDU` The FITS file to load. If a string, it can be a URL or a file path. + + image_label : str, optional + A label for the image. """ + image_label = self._resolve_image_label(image_label) + + # Delete the current viewport if it exists + if image_label in self._images: + del self._images[image_label] + if isinstance(file, (str, os.PathLike)): if isinstance(file, str): is_adsf = file.endswith(".asdf") else: is_asdf = file.suffix == ".asdf" if is_asdf: - self._load_asdf(file) + self._load_asdf(file, image_label) else: - self._load_fits(file) + self._load_fits(file, image_label) elif isinstance(file, NDData): - self._load_nddata(file) + self._load_nddata(file, image_label) else: # Assume it is a 2D array - self._load_array(file) + self._load_array(file, image_label) + + # This may eventually get pulled, but for now is needed to keep markers + # working with the new image. + self._wcs = self._images[image_label].wcs + + def _determine_largest_dimension(self, shape: tuple[int, int]) -> int: + """ + Determine which index is the largest dimension. + + Parameters + ---------- + shape : tuple of int + The shape of the image. - def _load_fits(self, file: str | os.PathLike) -> None: + Returns + ------- + int + The index of the largest dimension of the image, or 0 if square. + """ + return int(shape[1] > shape[0]) + + def _initialize_image_viewport_stretch_cuts( + self, + image_data: ArrayLike | NDData | CCDData, + image_label: str | None, + ) -> None: + """ + Initialize the viewport, stretch and cuts for an image. + + Parameters + ---------- + image_data : ArrayLike + The image data to initialize the viewport for. + image_label : str or None + The label for the image. If None, the default label will be used. + + Note + ---- + This method is called internally to set up the initial viewport, + stretch, and cuts for the image. It should be called AFTER setting + the WCS. + """ + + # Deal with the viewport first + height, width = image_data.shape + # Center the image in the viewport and show the whole image. + center = (width / 2, height / 2) + fov = max(image_data.shape) + self._images[image_label].largest_dimension = self._determine_largest_dimension(image_data.shape) + + wcs = self._images[image_label].wcs + # Is there a WCS set? If yes, make center a SkyCoord and fov a Quantity, + # otherwise leave them as pixels. + if wcs is not None: + center = wcs.pixel_to_world(center[0], center[1]) + fov = fov * u.degree / proj_plane_pixel_scales(wcs)[self._images[image_label].largest_dimension] + + self.set_viewport( + center=center, + fov=fov, + image_label=image_label + ) + + # Now set the stretch and cuts + self.set_cuts(AsymmetricPercentileInterval(1, 95), image_label=image_label) + self.set_stretch(LinearStretch(), image_label=image_label) + + def _load_fits(self, file: str | os.PathLike, image_label: str | None) -> None: ccd = CCDData.read(file) - self._wcs = ccd.wcs - self.image_height, self.image_width = ccd.shape - # Totally made up number...as currently defined, zoom_level means, esentially, ratio - # of image size to viewer size. - self.zoom_level = 1.0 - self.center_on((self.image_width / 2, self.image_height / 2)) - - def _load_array(self, array: ArrayLike) -> None: + self._images[image_label].wcs = ccd.wcs + self._initialize_image_viewport_stretch_cuts(ccd.data, image_label) + + def _load_array(self, array: ArrayLike, image_label: str | None) -> None: """ Load a 2D array into the viewer. @@ -226,14 +368,11 @@ def _load_array(self, array: ArrayLike) -> None: array : array-like The array to load. """ - self.image_height, self.image_width = array.shape - # Totally made up number...as currently defined, zoom_level means, esentially, ratio - # of image size to viewer size. - self.zoom_level = 1.0 - self.center_on((self.image_width / 2, self.image_height / 2)) + self._images[image_label].wcs = None # No WCS for raw arrays + self._images[image_label].largest_dimension = self._determine_largest_dimension(array.shape) + self._initialize_image_viewport_stretch_cuts(array, image_label) - - def _load_nddata(self, data: NDData) -> None: + def _load_nddata(self, data: NDData, image_label: str | None) -> None: """ Load an `astropy.nddata.NDData` object into the viewer. @@ -242,15 +381,12 @@ def _load_nddata(self, data: NDData) -> None: data : `astropy.nddata.NDData` The NDData object to load. """ - self._wcs = data.wcs + self._images[image_label].wcs = data.wcs + self._images[image_label].largest_dimension = self._determine_largest_dimension(data.data.shape) # Not all NDDData objects have a shape, apparently - self.image_height, self.image_width = data.data.shape - # Totally made up number...as currently defined, zoom_level means, esentially, ratio - # of image size to viewer size. - self.zoom_level = 1.0 - self.center_on((self.image_width / 2, self.image_height / 2)) + self._initialize_image_viewport_stretch_cuts(data.data, image_label) - def _load_asdf(self, asdf_file: str | os.PathLike) -> None: + def _load_asdf(self, asdf_file: str | os.PathLike, image_label: str | None) -> None: """ Not implementing some load types is fine. """ @@ -382,67 +518,115 @@ def get_catalog_names(self) -> list[str]: get_catalog_names.__doc__ = ImageViewerInterface.get_catalog_names.__doc__ # Methods that modify the view - def center_on(self, point: tuple | SkyCoord): - """ - Center the view on the point. - - Parameters - ---------- - tuple or `~astropy.coordinates.SkyCoord` - If tuple of ``(X, Y)`` is given, it is assumed - to be in data coordinates. - """ - # currently there is no way to get the position of the center, but we may as well make - # note of it - if isinstance(point, SkyCoord): - if self._wcs is not None: - point = self._wcs.world_to_pixel(point) + def set_viewport( + self, center: SkyCoord | tuple[numbers.Real, numbers.Real] | None = None, + fov: Quantity | numbers.Real | None = None, + image_label: str | None = None + ) -> None: + image_label = self._resolve_image_label(image_label) + + if image_label not in self._images: + raise ValueError(f"Image label '{image_label}' not found. Please load an image first.") + + # Get current center/fov, if any, so that the user may input only one of them + # after the initial setup if they wish. + current_viewport = copy(self._images[image_label]) + if center is None: + center = current_viewport.center + if fov is None: + fov = current_viewport.fov + + # If either center or fov is None these checks will raise an appropriate error + if not isinstance(center, (SkyCoord, tuple)): + raise TypeError("Invalid value for center. Center must be a SkyCoord or tuple of (X, Y).") + if not isinstance(fov, (Quantity, numbers.Real)): + raise TypeError("Invalid value for fov. fov must be an angular Quantity or float.") + + if isinstance(fov, Quantity) and not fov.unit.is_equivalent(u.deg): + raise u.UnitTypeError("Incorrect unit for fov. fov must be an angular Quantity or float.") + + # Check that the center and fov are compatible with the current image + if self._images[image_label].wcs is None: + if current_viewport.center is not None: + # If there is a WCS either input is fine. If there is no WCS then we only + # check wther the new center is the same type as the current center. + if isinstance(center, SkyCoord) and not isinstance(current_viewport.center, SkyCoord): + raise TypeError("Center must be a tuple for this image when WCS is not set.") + elif isinstance(center, tuple) and not isinstance(current_viewport.center, tuple): + raise TypeError("Center must be a SkyCoord for this image when WCS is not set.") + if current_viewport.fov is not None: + if isinstance(fov, Quantity) and not isinstance(current_viewport.fov, Quantity): + raise TypeError("FOV must be a float for this image when WCS is not set.") + elif isinstance(fov, numbers.Real) and not isinstance(current_viewport.fov, numbers.Real): + raise TypeError("FOV must be a float for this image when WCS is not set.") + + # 😅 if we made it this far we should be able to handle the actual setting + self._images[image_label].center = center + self._images[image_label].fov = fov + + + set_viewport.__doc__ = ImageViewerInterface.set_viewport.__doc__ + + def get_viewport( + self, sky_or_pixel: str | None = None, image_label: str | None = None + ) -> dict[str, Any]: + if sky_or_pixel not in (None, "sky", "pixel"): + raise ValueError("sky_or_pixel must be 'sky', 'pixel', or None.") + image_label = self._resolve_image_label(image_label) + + if image_label not in self._images: + raise ValueError(f"Image label '{image_label}' not found. Please load an image first.") + + viewport = self._images[image_label] + + # Figure out what to return if the user did not specify sky_or_pixel + if sky_or_pixel is None: + if isinstance(viewport.center, SkyCoord): + # Somebody set this to sky coordinates, so return sky coordinates + sky_or_pixel = "sky" + elif isinstance(viewport.center, tuple): + # Somebody set this to pixel coordinates, so return pixel coordinates + sky_or_pixel = "pixel" + + center = None + fov = None + if sky_or_pixel == "sky": + if isinstance(viewport.center, SkyCoord): + center = viewport.center + + if isinstance(viewport.fov, Quantity): + fov = viewport.fov + + if center is None or fov is None: + # At least one of center or fov is not set, which means at least one + # was not already sky, so we need to convert them or fail + if viewport.wcs is None: + raise ValueError("WCS is not set. Cannot convert pixel coordinates to sky coordinates.") + else: + center = viewport.wcs.pixel_to_world(viewport.center[0], viewport.center[1]) + pixel_scale = proj_plane_pixel_scales(viewport.wcs)[viewport.largest_dimension] + fov = pixel_scale * viewport.fov * u.degree + else: + # Pixel coordinates + if isinstance(viewport.center, SkyCoord): + if viewport.wcs is None: + raise ValueError("WCS is not set. Cannot convert sky coordinates to pixel coordinates.") + center = viewport.wcs.world_to_pixel(viewport.center) else: - raise ValueError("WCS is not set. Cannot convert to pixel coordinates.") - - self._center = point - - def offset_by(self, dx: float | Quantity, dy: float | Quantity) -> None: - """ - Move the center to a point that is given offset - away from the current center. - - Parameters - ---------- - dx, dy : float or `~astropy.units.Quantity` - Offset value. Without a unit, assumed to be pixel offsets. - If a unit is attached, offset by pixel or sky is assumed from - the unit. - """ - # Convert to quantity to make the rest of the processing uniform - dx = Quantity(dx) - dy = Quantity(dy) - - # This raises a UnitConversionError if the units are not compatible - dx.to(dy.unit) - - # Do we have an angle or pixel offset? - if get_physical_type(dx) == "angle": - # This is a sky offset - if self._wcs is not None: - old_center_coord = self._wcs.pixel_to_world(self._center[0], self._center[1]) - new_center = old_center_coord.spherical_offsets_by(dx, dy) - self.center_on(new_center) + center = viewport.center + if isinstance(viewport.fov, Quantity): + if viewport.wcs is None: + raise ValueError("WCS is not set. Cannot convert FOV to pixel coordinates.") + pixel_scale = proj_plane_pixel_scales(viewport.wcs)[viewport.largest_dimension] + fov = viewport.fov.value / pixel_scale else: - raise ValueError("WCS is not set. Cannot convert to pixel coordinates.") - else: - # This is a pixel offset - new_center = (self._center[0] + dx.value, self._center[1] + dy.value) - self.center_on(new_center) + fov = viewport.fov - def zoom(self, val) -> None: - """ - Zoom in or out by the given factor. + return dict( + center=center, + fov=fov, + image_label=image_label + ) - Parameters - ---------- - val : int - The zoom level to zoom the image. - See `zoom_level`. - """ - self.zoom_level *= val + + get_viewport.__doc__ = ImageViewerInterface.get_viewport.__doc__ diff --git a/src/astro_image_display_api/interface_definition.py b/src/astro_image_display_api/interface_definition.py index e4b2461..5077423 100644 --- a/src/astro_image_display_api/interface_definition.py +++ b/src/astro_image_display_api/interface_definition.py @@ -26,7 +26,6 @@ class ImageViewerInterface(Protocol): # do any checking at all of these types. image_width: int image_height: int - zoom_level: float cursor: str # Allowed locations for cursor display @@ -36,7 +35,7 @@ class ImageViewerInterface(Protocol): # Method for loading image data @abstractmethod - def load_image(self, data: Any) -> None: + def load_image(self, data: Any, image_label: str | None = None) -> None: """ Load data into the viewer. At a minimum, this should allow a FITS file to be loaded. Viewers may allow additional data types to be loaded, such as @@ -47,12 +46,19 @@ def load_image(self, data: Any) -> None: data : Any The data to load. This can be a FITS file, a 2D array, or an `astropy.nddata.NDData` object. + + image_label : str, optional + The label for the image. + + Notes + ----- + Loading an image should also set the viewport for that image. """ raise NotImplementedError # Setting and getting image properties @abstractmethod - def set_cuts(self, cuts: tuple | BaseInterval) -> None: + def set_cuts(self, cuts: tuple | BaseInterval, image_label: str | None = None) -> None: """ Set the cuts for the image. @@ -62,14 +68,25 @@ def set_cuts(self, cuts: tuple | BaseInterval) -> None: The cuts to set. If a tuple, it should be of the form ``(min, max)`` and will be interpreted as a `~astropy.visualization.ManualInterval`. + + image_label : str, optional + The label of the image to set the cuts for. If not given and there is + only one image loaded, the cuts for that image are set. """ raise NotImplementedError @abstractmethod - def get_cuts(self) -> BaseInterval: + def get_cuts(self, image_label: str | None = None) -> BaseInterval: """ Get the current cuts for the image. + Parameters + ---------- + image_label : str, optional + The label of the image to get the cuts for. If not given and there is + only one image loaded, the cuts for that image are returned. If there are + multiple images and no label is provided, an error is raised. + Returns ------- cuts : `~astropy.visualization.BaseInterval` @@ -78,7 +95,7 @@ def get_cuts(self) -> BaseInterval: raise NotImplementedError @abstractmethod - def set_stretch(self, stretch: BaseStretch) -> None: + def set_stretch(self, stretch: BaseStretch, image_label: str | None = None) -> None: """ Set the stretch for the image. @@ -87,14 +104,25 @@ def set_stretch(self, stretch: BaseStretch) -> None: stretch : Any stretch from `~astropy.visualization` The stretch to set. This can be any subclass of `~astropy.visualization.BaseStretch`. + + image_label : str, optional + The label of the image to set the cuts for. If not given and there is + only one image loaded, the cuts for that image are set. """ raise NotImplementedError @abstractmethod - def get_stretch(self) -> BaseStretch: + def get_stretch(self, image_label: str | None = None) -> BaseStretch: """ Get the current stretch for the image. + Parameters + ---------- + image_label : str, optional + The label of the image to get the cuts for. If not given and there is + only one image loaded, the cuts for that image are returned. If there are + multiple images and no label is provided, an error is raised. + Returns ------- stretch : `~astropy.visualization.BaseStretch` @@ -264,42 +292,73 @@ def get_catalog_names(self) -> list[str]: # Methods that modify the view @abstractmethod - def center_on(self, point: tuple | SkyCoord): + def set_viewport( + self, center: SkyCoord | tuple[float, float] | None = None, + fov: Quantity | float | None = None, + image_label: str | None = None + ) -> None: """ - Center the view on the point. + Set the viewport of the image, which defines the center and field of view. Parameters ---------- - tuple or `~astropy.coordinates.SkyCoord` - If tuple of ``(X, Y)`` is given, it is assumed - to be in data coordinates. - """ - raise NotImplementedError + center : `astropy.coordinates.SkyCoord` or tuple of float, optional + The center of the viewport. If not given, the current center is used. + fov : `astropy.units.Quantity` or float, optional + The field of view (FOV) of the viewport. If not given, the current FOV + is used. If a float is given, it is interpreted a size in pixels. For images that are + not square, the FOV is interpreted as the size of the longer side of the image. + image_label : str, optional + The label of the image to set the viewport for. If not given and there is + only one image loaded, the viewport for that image is set. If there are + multiple images and no label is provided, an error is raised. - @abstractmethod - def offset_by(self, dx: float | Quantity, dy: float | Quantity) -> None: - """ - Move the center to a point that is given offset - away from the current center. + Raises + ------ + TypeError + If the `center` is not a `SkyCoord` object or a tuple of floats, or if + the `fov` is not a angular `Quantity` or a float, or if there is no WCS + and the center or field of view require a WCS to be applied. - Parameters - ---------- - dx, dy : float or `~astropy.units.Quantity` - Offset value. Without a unit, assumed to be pixel offsets. - If a unit is attached, offset by pixel or sky is assumed from - the unit. + ValueError + If `image_label` is not provided when there are multiple images loaded. + + `astropy.units.UnitTypeError` + If the `fov` is a `Quantity` but does not have an angular unit. """ raise NotImplementedError @abstractmethod - def zoom(self, val: float) -> None: + def get_viewport(self, sky_or_pixel: str | None = None, image_label: str | None = None) -> dict[str, Any]: """ - Zoom in or out by the given factor. + Get the current viewport of the image. Parameters ---------- - val : float - The zoom level to zoom the image. - See `zoom_level`. + sky_or_pixel : str, optional + If 'sky', the center will be returned as a `SkyCoord` object. + If 'pixel', the center will be returned as a tuple of pixel coordinates. + If `None`, the default behavior is to return the center as a `SkyCoord` if + possible, or as a tuple of floats if the image is in pixel coordinates and has + no WCS information. + image_label : str, optional + The label of the image to get the viewport for. If not given and there is only one + image loaded, the viewport for that image is returned. If there are multiple images + and no label is provided, an error is raised. + + Returns + ------- + dict + A dictionary containing the current viewport settings. + The keys are 'center', 'fov', and 'image_label'. + - 'center' is an `astropy.coordinates.SkyCoord` object or a tuple of floats. + - 'fov' is an `astropy.units.Quantity` object or a float. + - 'image_label' is a string representing the label of the image. + + Raises + ------- + ValueError + If the `sky_or_pixel` parameter is not one of 'sky', 'pixel', or `None`, or if + the `image_label` is not provided when there are multiple images loaded. """ raise NotImplementedError diff --git a/src/astro_image_display_api/widget_api_test.py b/src/astro_image_display_api/widget_api_test.py index 9387b0c..9f2685a 100644 --- a/src/astro_image_display_api/widget_api_test.py +++ b/src/astro_image_display_api/widget_api_test.py @@ -1,14 +1,16 @@ +import numbers import pytest import numpy as np +from astropy.coordinates import SkyCoord from astropy.io import fits -from astropy.nddata import NDData +from astropy.nddata import CCDData, NDData from astropy.table import Table, vstack from astropy import units as u from astropy.wcs import WCS -from astropy.visualization import AsymmetricPercentileInterval, LogStretch, ManualInterval +from astropy.visualization import AsymmetricPercentileInterval, BaseInterval, BaseStretch, LogStretch, ManualInterval __all__ = ['ImageWidgetAPITest'] @@ -19,7 +21,7 @@ class ImageWidgetAPITest: @pytest.fixture def data(self): rng = np.random.default_rng(1234) - return rng.random((100, 100)) + return rng.random((100, 150)) @pytest.fixture def wcs(self): @@ -30,6 +32,7 @@ def wcs(self): w = WCS(naxis=2) # Set up an "Airy's zenithal" projection + # Note: WCS is 1-based, not 0-based w.wcs.crpix = [-234.75, 8.3393] w.wcs.cdelt = np.array([-0.066667, 0.066667]) w.wcs.crval = [0, -90] @@ -103,38 +106,253 @@ def test_load(self, data, tmp_path, load_type): self.image.load_image(load_arg) - def test_center_on(self): - self.image.center_on((10, 10)) # X, Y + def test_set_get_center_xy(self, data): + self.image.load_image(data, image_label='test') + self.image.set_viewport(center=(10, 10), image_label='test') # X, Y + vport = self.image.get_viewport(image_label='test') + assert vport['center'] == (10, 10) + assert vport['image_label'] == 'test' - def test_offset_by(self, data, wcs): - self.image.offset_by(10, 10) # dX, dY + def test_set_get_center_world(self, data, wcs): + self.image.load_image(NDData(data=data, wcs=wcs), image_label='test') + self.image.set_viewport(center=SkyCoord(*wcs.wcs.crval, unit='deg'), image_label='test') - # Testing offset by WCS requires a WCS. The viewer will (or ought to - # have) taken care of setting up the WCS internally if initialized with - # an NDData that has a WCS. - ndd = NDData(data=data, wcs=wcs) - self.image.load_image(ndd) + vport = self.image.get_viewport(image_label='test') + assert isinstance(vport['center'], SkyCoord) + assert vport['center'].ra.deg == pytest.approx(wcs.wcs.crval[0]) + assert vport['center'].dec.deg == pytest.approx(wcs.wcs.crval[1]) - self.image.offset_by(10 * u.arcmin, 10 * u.arcmin) - - # A mix of pixel and sky should produce an error - with pytest.raises(u.UnitConversionError, match='are not convertible'): - self.image.offset_by(10 * u.arcmin, 10) + def test_set_get_fov_pixel(self, data): + # Set data first, since that is needed to determine zoom level + self.image.load_image(data, image_label='test') - # A mix of inconsistent units should produce an error - with pytest.raises(u.UnitConversionError, match='are not convertible'): - self.image.offset_by(1 * u.arcsec, 1 * u.AA) + self.image.set_viewport(fov=100, image_label='test') + vport = self.image.get_viewport(image_label='test') + assert vport['fov'] == 100 + assert vport['image_label'] == 'test' - def test_zoom_level(self, data): + def test_set_get_fov_world(self, data, wcs): # Set data first, since that is needed to determine zoom level + self.image.load_image(NDData(data=data, wcs=wcs), image_label='test') + + # Set the FOV in world coordinates + self.image.set_viewport(fov=0.1 * u.deg, image_label='test') + vport = self.image.get_viewport(image_label='test') + assert isinstance(vport['fov'], u.Quantity) + assert len(np.atleast_1d(vport['fov'])) == 1 + assert vport['fov'].unit.physical_type == 'angle' + fov_degree = vport['fov'].to(u.degree).value + assert fov_degree == pytest.approx(0.1) + + def test_set_get_viewport_errors(self, data, wcs): + # Test several of the expected errors that can be raised + self.image.load_image(NDData(data=data, wcs=wcs), image_label='test') + + # fov can be float or an angular Qunatity + with pytest.raises(u.UnitTypeError, match='[Ii]ncorrect unit for fov'): + self.image.set_viewport(fov=100 * u.meter, image_label='test') + + # try an fov that is completely the wrong type + with pytest.raises(TypeError, match='[Ii]nvalid value for fov'): + self.image.set_viewport(fov='not a valid value', image_label='test') + + # center can be a SkyCoord or a tuple of floats. Try a value that is neither + with pytest.raises(TypeError, match='[Ii]nvalid value for center'): + self.image.set_viewport(center='not a valid value', image_label='test') + + # Check that an error is raised if a label is provided that does not + # match an image that is loaded. + with pytest.raises(ValueError, match='[Ii]mage label.*not found'): + self.image.set_viewport(center=(10, 10), fov=100, image_label='not a valid label') + + # Getting a viewport for an image_label that does not exist should raise an error + with pytest.raises(ValueError, match='[Ii]mage label.*not found'): + self.image.get_viewport(image_label='not a valid label') + + # If there are multiple images loaded, the image_label must be provided + self.image.load_image(data, image_label='another test') + + with pytest.raises(ValueError, match='Multiple image labels defined'): + self.image.get_viewport() + + # setting sky_or_pixel to something other than 'sky' or 'pixel' or None + # should raise an error + with pytest.raises(ValueError, match='[Ss]ky_or_pixel must be'): + self.image.get_viewport(sky_or_pixel='not a valid value') + + def test_set_get_viewport_errors_because_no_wcs(self, data): + # Check that errors are raised when they should be when calling + # get_viewport when no WCS is present. + + # Load the data without a WCS + self.image.load_image(data, image_label='test') + + # Set the viewport with a SkyCoord center + with pytest.raises(TypeError, match='Center must be a tuple'): + self.image.set_viewport(center=SkyCoord(ra=10, dec=20, unit='deg'), image_label='test') + + # Set the viewport with a Quantity fov + with pytest.raises(TypeError, match='FOV must be a float'): + self.image.set_viewport(fov=100 * u.arcmin, image_label='test') + + # Try getting the viewport as sky + with pytest.raises(ValueError, match='WCS is not set'): + self.image.get_viewport(image_label='test', sky_or_pixel='sky') + + @pytest.mark.parametrize("world", [True, False]) + def test_viewport_is_defined_after_loading_image(self, tmp_path, data, wcs, world): + # Check that the viewport is set to a default value when an image + # is loaded, even if no viewport is explicitly set. + + # Load the image from FITS to ensure that at least one image with WCS + # has been loaded from FITS. + wcs = wcs if world else None + ccd = CCDData(data=data, unit="adu", wcs=wcs) + + ccd_path = tmp_path / 'test.fits' + ccd.write(ccd_path) + self.image.load_image(ccd_path) + + # Getting the viewport should not fail... + vport = self.image.get_viewport() + + assert 'center' in vport + + assert 'fov' in vport + assert 'image_label' in vport + assert vport['image_label'] is None + if world: + assert isinstance(vport['center'], SkyCoord) + # fov should be a Quantity since WCS is present + assert isinstance(vport['fov'], u.Quantity) + else: + # No world, so center should be a tuple + assert isinstance(vport['center'], tuple) + # fov should be a float since no WCS + assert isinstance(vport['fov'], numbers.Real) + + def test_set_get_viewport_no_image_label(self, data): + # If there is only one image, the viewport should be able to be set + # and retrieved without an image label. + + # Add an image without an image label self.image.load_image(data) - self.image.zoom_level = 5 - assert self.image.zoom_level == 5 - def test_zoom(self): - self.image.zoom_level = 3 - self.image.zoom(2) - assert self.image.zoom_level == 6 # 3 x 2 + # Set the viewport without an image label + self.image.set_viewport(center=(10, 10), fov=100) + + # Getting the viewport again should return the same values + vport = self.image.get_viewport() + assert vport['center'] == (10, 10) + assert vport['fov'] == 100 + assert vport['image_label'] is None + + def test_set_get_viewport_single_label(self, data): + # If there is only one image, the viewport should be able to be set + # and retrieved without an image label as long as the image + # has an image label. + + # Add an image with an image label + self.image.load_image(data, image_label='test') + + # Getting the viewport should not fail... + vport = self.image.get_viewport() + assert 'center' in vport + assert 'fov' in vport + assert 'image_label' in vport + assert vport['image_label'] == 'test' + + # Set the viewport with an image label + self.image.set_viewport(center=(10, 10), fov=100) + + # Getting the viewport again should return the same values + vport = self.image.get_viewport() + assert vport['center'] == (10, 10) + assert vport['fov'] == 100 + assert vport['image_label'] == 'test' + + def test_get_viewport_sky_or_pixel(self, data, wcs): + # Check that the viewport can be retrieved in both pixel and world + # coordinates, depending on the WCS of the image. + + # Load the data with a WCS + self.image.load_image(NDData(data=data, wcs=wcs), image_label='test') + + input_center = SkyCoord(*wcs.wcs.crval, unit='deg') + input_fov = 2 * u.arcmin + self.image.set_viewport(center=input_center, fov=input_fov, image_label='test') + + # Get the viewport in pixel coordinates + vport_pixel = self.image.get_viewport(image_label='test', sky_or_pixel='pixel') + # The WCS set up for the tests is 1-based, rather than the usual 0-based, + # so we need to subtract 1 from the pixel coordinates. + assert all(vport_pixel['center'] == (wcs.wcs.crpix - 1)) + # tbh, not at all sure what the fov should be in pixel coordinates, + # so just check that it is a float. + assert isinstance(vport_pixel['fov'], numbers.Real) + + # Get the viewport in world coordinates + vport_world = self.image.get_viewport(image_label='test', sky_or_pixel='sky') + assert vport_world['center'] == input_center + assert vport_world['fov'] == input_fov + + @pytest.mark.parametrize("sky_or_pixel", ['sky', 'pixel']) + def test_get_viewport_no_sky_or_pixel(self, data, wcs, sky_or_pixel): + # Check that get_viewport returns the correct "default" sky_or_pixel + # value when the result ought to be unambiguous. + if sky_or_pixel == 'sky': + use_wcs = wcs + else: + use_wcs = None + + self.image.load_image(NDData(data=data, wcs=use_wcs), image_label='test') + + vport = self.image.get_viewport(image_label='test') + match sky_or_pixel: + case 'sky': + assert isinstance(vport['center'], SkyCoord) + assert vport['fov'].unit.physical_type == "angle" + case 'pixel': + assert isinstance(vport['center'], tuple) + assert isinstance(vport['fov'], numbers.Real) + + def test_get_viewport_with_wcs_set_pixel_or_world(self, data, wcs): + # Check that the viewport can be retrieved in both pixel and world + # after setting with the opposite if the WCS is set. + # Load the data with a WCS + self.image.load_image(NDData(data=data, wcs=wcs), image_label='test') + + # Set the viewport in world coordinates + input_center = SkyCoord(*wcs.wcs.crval, unit='deg') + input_fov = 2 * u.arcmin + self.image.set_viewport(center=input_center, fov=input_fov, image_label='test') + + # Get the viewport in pixel coordinates + vport_pixel = self.image.get_viewport(image_label='test', sky_or_pixel='pixel') + assert all(vport_pixel['center'] == (wcs.wcs.crpix - 1)) + assert isinstance(vport_pixel['fov'], numbers.Real) + + # Set the viewport in pixel coordinates + input_center_pixel = (wcs.wcs.crpix[0], wcs.wcs.crpix[1]) + input_fov_pixel = 100 # in pixels + self.image.set_viewport(center=input_center_pixel, fov=input_fov_pixel, image_label='test') + + # Get the viewport in world coordinates + vport_world = self.image.get_viewport(image_label='test', sky_or_pixel='sky') + assert vport_world['center'] == wcs.pixel_to_world(*input_center_pixel) + assert isinstance(vport_world['fov'], u.Quantity) + + def test_viewport_round_trips(self, data, wcs): + # Check that the viewport retrieved with get can be used to set + # the viewport again, and that the values are the same. + self.image.load_image(NDData(data=data, wcs=wcs), image_label='test') + self.image.set_viewport(center=(10, 10), fov=100, image_label='test') + vport = self.image.get_viewport(image_label='test') + # Set the viewport again using the values from the get_viewport + self.image.set_viewport(**vport) + # Get the viewport again and check that the values are the same + vport2 = self.image.get_viewport(image_label='test') + assert vport2 == vport def test_set_catalog_style_before_catalog_data_raises_error(self): # Make sure that adding a catalog style before adding any catalog @@ -453,6 +671,51 @@ def test_cuts(self, data): assert isinstance(self.image.get_cuts(), ManualInterval) assert self.image.get_cuts().get_limits(data) == (10, 100) + def test_stretch_cuts_labels(self, data): + # Check that stretch and cuts can be set with labels + self.image.load_image(data, image_label='test') + + # Set stretch and cuts with labels + self.image.set_stretch(LogStretch(), image_label='test') + self.image.set_cuts((10, 100), image_label='test') + + # Get stretch and cuts with labels + stretch = self.image.get_stretch(image_label='test') + cuts = self.image.get_cuts(image_label='test') + + assert isinstance(stretch, LogStretch) + assert isinstance(cuts, ManualInterval) + assert cuts.get_limits(data) == (10, 100) + + def test_stretch_cuts_are_set_after_loading_image(self, data): + # Check that stretch and cuts are set to default values after loading an image + self.image.load_image(data, image_label='test') + + stretch = self.image.get_stretch(image_label='test') + cuts = self.image.get_cuts(image_label='test') + + # Backends can set whatever stretch and cuts they want, so + # we just check that they are instances of the expected classes. + assert isinstance(stretch, BaseStretch) + assert isinstance(cuts, BaseInterval) + + def test_stretch_cuts_errors(self, data): + # Check that errors are raised when trying to get or set stretch or cuts + # for an image label that does not exist. + self.image.load_image(data, image_label='test') + + with pytest.raises(ValueError, match='[Ii]mage label.*not found'): + self.image.get_stretch(image_label='not a valid label') + + with pytest.raises(ValueError, match='[Ii]mage label.*not found'): + self.image.get_cuts(image_label='not a valid label') + + with pytest.raises(ValueError, match='[Ii]mage label.*not found'): + self.image.set_stretch(LogStretch(), image_label='not a valid label') + + with pytest.raises(ValueError, match='[Ii]mage label.*not found'): + self.image.set_cuts((10, 100), image_label='not a valid label') + @pytest.mark.skip(reason="Not clear whether colormap is part of the API") def test_colormap(self): cmap_desired = 'gray'