Source code for reproject.mosaicking.wcs_helpers

# Licensed under a 3-clause BSD style license - see LICENSE.rst

import numpy as np
from astropy import units as u
from astropy.coordinates import SkyCoord, frame_transform_graph
from astropy.io.fits import Header
from astropy.utils import isiterable
from astropy.wcs import WCS
from astropy.wcs.utils import (
    celestial_frame_to_wcs,
    pixel_to_skycoord,
    proj_plane_pixel_scales,
    skycoord_to_pixel,
    wcs_to_celestial_frame,
)
from astropy.wcs.wcsapi import BaseHighLevelWCS, BaseLowLevelWCS

from ..utils import parse_input_shape

__all__ = ["find_optimal_celestial_wcs"]


[docs] def find_optimal_celestial_wcs( input_data, hdu_in=None, frame=None, auto_rotate=False, projection="TAN", resolution=None, reference=None, ): """ Given one or more images, return an optimal WCS projection object and shape. This currently only works with 2-d images with celestial WCS. Parameters ---------- input_data : iterable One or more input data specifications to include in the calculation of the final WCS. This should be an iterable containing one entry for each specification, where a single data specification is one of: * The name of a FITS file as a `str` or a `pathlib.Path` object * An `~astropy.io.fits.HDUList` object * An image HDU object such as a `~astropy.io.fits.PrimaryHDU`, `~astropy.io.fits.ImageHDU`, or `~astropy.io.fits.CompImageHDU` instance * A tuple where the first element is an Numpy array shape tuple and the second element is either a `~astropy.wcs.wcsapi.BaseLowLevelWCS`, `~astropy.wcs.wcsapi.BaseHighLevelWCS`, or a `~astropy.io.fits.Header` object * A tuple where the first element is a `~numpy.ndarray` and the second element is either a `~astropy.wcs.wcsapi.BaseLowLevelWCS`, `~astropy.wcs.wcsapi.BaseHighLevelWCS`, or a `~astropy.io.fits.Header` object * An `~astropy.nddata.NDData` object from which the ``.data`` and ``.wcs`` attributes will be used as the input data. * A `~astropy.wcs.wcsapi.BaseLowLevelWCS` object with ``array_shape`` set or a `~astropy.wcs.wcsapi.BaseHighLevelWCS` object whose underlying low level WCS object has ``array_shape`` set. If only one input data needs to be provided, it is also possible to pass it in without including it in an iterable. hdu_in : int or str, optional If ``input_data`` is a FITS file or an `~astropy.io.fits.HDUList` instance, specifies the HDU to use. frame : str or `~astropy.coordinates.BaseCoordinateFrame` The coordinate system for the final image (defaults to the frame of the first image specified). auto_rotate : bool Whether to rotate the header to minimize the final image area (if `True`, requires shapely>=1.6 to be installed). projection : str Three-letter code for the WCS projection. resolution : `~astropy.units.Quantity` The resolution of the final image. If not specified, this is the smallest resolution of the input images. reference : `~astropy.coordinates.SkyCoord` The reference coordinate for the final header. If not specified, this is determined automatically from the input images. Returns ------- wcs : :class:`~astropy.wcs.WCS` The optimal WCS determined from the input images. shape : tuple The optimal shape required to cover all the output. """ # TODO: support higher-dimensional datasets in future # TODO: take into account NaN values when determining the extent of the # final WCS if isinstance(frame, str): frame = frame_transform_graph.lookup_name(frame)() # Determine whether an iterable of input values was given or a single # input data. if isinstance(input_data, str): # Handle this explicitly as isiterable(str) is True iterable = False elif isiterable(input_data): if len(input_data) == 2 and isinstance( input_data[1], BaseLowLevelWCS | BaseHighLevelWCS | Header ): # Since 2-element tuples are valid single inputs we need to check for this iterable = False else: iterable = True else: iterable = False if iterable: input_shapes = [parse_input_shape(shape, hdu_in=hdu_in) for shape in input_data] else: input_shapes = [parse_input_shape(input_data, hdu_in=hdu_in)] # We start off by looping over images, checking that they are indeed # celestial images, and building up a list of all corners and all reference # coordinates in the frame of reference of the first image. corners = [] references = [] resolutions = [] for shape, wcs in input_shapes: if len(shape) != 2: raise ValueError(f"Input data is not 2-dimensional (got shape {shape!r})") if wcs.pixel_n_dim != 2 or wcs.world_n_dim != 2: raise ValueError("Input WCS is not 2-dimensional") if isinstance(wcs, WCS): if not wcs.has_celestial: raise TypeError("WCS does not have celestial components") # Determine frame if it wasn't specified if frame is None: frame = wcs_to_celestial_frame(wcs) else: # Convert a single position to determine type of output and make # sure there is only a single SkyCoord returned. coord = wcs.pixel_to_world(0, 0) if not isinstance(coord, SkyCoord): raise TypeError("WCS does not have celestial components") if frame is None: frame = coord.frame.replicate_without_data() # Find pixel coordinates of corners. In future if we are worried about # significant distortions of the edges in the reprojection process we # could simply add arbitrary numbers of midpoints to this list. ny, nx = shape xc = np.array([-0.5, nx - 0.5, nx - 0.5, -0.5]) yc = np.array([-0.5, -0.5, ny - 0.5, ny - 0.5]) # We have to do .frame here to make sure that we get a frame object # without any 'hidden' attributes, otherwise the stacking below won't # work. corners.append(wcs.pixel_to_world(xc, yc).transform_to(frame).frame) if isinstance(wcs, WCS): # We now figure out the reference coordinate for the image in the # frame of the first image. The easiest way to do this is actually # to use pixel_to_skycoord with the reference position in pixel # coordinates. We have to set origin=1 because crpix values are # 1-based. xp, yp = wcs.wcs.crpix references.append(pixel_to_skycoord(xp, yp, wcs, origin=1).transform_to(frame).frame) # Find the pixel scale at the reference position - we take the minimum # since we are going to set up a header with 'square' pixels with the # smallest resolution specified. scales = proj_plane_pixel_scales(wcs) resolutions.append(np.min(np.abs(scales))) else: xp, yp = (nx - 1) / 2, (ny - 1) / 2 references.append(wcs.pixel_to_world(xp, yp).transform_to(frame).frame) xs = np.array([xp, xp, xp + 1]) ys = np.array([yp, yp + 1, yp]) cs = wcs.pixel_to_world(xs, ys) dx = abs(cs[0].separation(cs[2]).deg) dy = abs(cs[0].separation(cs[1]).deg) resolutions.append(min(dx, dy)) # We now stack the coordinates - however the frame classes can't do this # so we have to use the high-level SkyCoord class. corners = SkyCoord(corners) references = SkyCoord(references) # If no reference coordinate has been passed in for the final header, we # determine the reference coordinate as the mean of all the reference # positions. This choice is as good as any and if the user really cares, # they can set it manually. if reference is None: reference = SkyCoord(references.data.mean(), frame=references.frame) # In any case, we need to convert the reference coordinate (either # specified or automatically determined) to the requested final frame. reference = reference.transform_to(frame) # Determine resolution if not specified if resolution is None: resolution = np.min(resolutions) * u.deg # Construct WCS object centered on position wcs_final = celestial_frame_to_wcs(frame, projection=projection) if wcs_final.wcs.cunit[0] == "": wcs_final.wcs.cunit[0] = "deg" if wcs_final.wcs.cunit[1] == "": wcs_final.wcs.cunit[1] = "deg" rep = reference.represent_as("unitspherical") wcs_final.wcs.crval = ( rep.lon.to_value(wcs_final.wcs.cunit[0]), rep.lat.to_value(wcs_final.wcs.cunit[1]), ) wcs_final.wcs.cdelt = ( -resolution.to_value(wcs_final.wcs.cunit[0]), resolution.to_value(wcs_final.wcs.cunit[1]), ) # For now, set crpix to (1, 1) and we'll then figure out where all the # images fall in this projection, then we'll adjust crpix. wcs_final.wcs.crpix = (1, 1) # Find pixel coordinates of all corners in the final WCS projection. We use # origin=1 since we are trying to determine crpix values. xp, yp = skycoord_to_pixel(corners, wcs_final, origin=1) if auto_rotate: # Use shapely to represent the points and find the minimum rotated # rectangle from shapely.geometry import MultiPoint mp = MultiPoint(list(zip(xp, yp, strict=True))) # The following returns a list of rectangle vertices - in fact there # are 5 coordinates because shapely represents it as a closed polygon # with the same first/last vertex. xr, yr = mp.minimum_rotated_rectangle.exterior.coords.xy xr, yr = xr[:4], yr[:4] # The order of the vertices is not guaranteed to be constant so we # take the vertices with the two smallest y values (which, for a # rectangle, guarantees that the vertices are neighboring) order = np.argsort(yr) x1, y1, x2, y2 = xr[order[0]], yr[order[0]], xr[order[1]], yr[order[1]] # Determine angle between two of the vertices. It doesn't matter which # ones they are, we just want to know how far from being straight the # rectangle is. angle = np.arctan2(y2 - y1, x2 - x1) # Determine the smallest angle that would cause the rectangle to be # lined up with the axes. angle = angle % (np.pi / 2) if angle > np.pi / 4: angle -= np.pi / 2 # Set rotation matrix (use PC instead of CROTA2 since PC is the # recommended approach) pc = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]]) wcs_final.wcs.pc = pc # Recompute pixel coordinates (more accurate than simply rotating xp, yp) xp, yp = skycoord_to_pixel(corners, wcs_final, origin=1) # Find the full range of values xmin = xp.min() xmax = xp.max() ymin = yp.min() ymax = yp.max() # Update crpix so that the lower range falls on the bottom and left. We add # 0.5 because in the final image the bottom left corner should be at (0.5, # 0.5) not (1, 1). wcs_final.wcs.crpix = (1 - xmin) + 0.5, (1 - ymin) + 0.5 # Return the final image shape too naxis1 = int(round(xmax - xmin)) naxis2 = int(round(ymax - ymin)) return wcs_final, (naxis2, naxis1)