import logging
import multiprocessing as mp
import warnings
import numpy as np
from astropy.modeling.mappings import Mapping
from scipy import sparse
from jwst.lib.winclip import get_clipped_pixels
from jwst.wfss_contam.sens1d import create_1d_sens
log = logging.getLogger(__name__)
__all__ = ["disperse"]
def _determine_native_wl_spacing(
x0_sky,
y0_sky,
sky_to_imgxy,
imgxy_to_grismxy,
order,
wmin,
wmax,
oversample_factor=2,
):
"""
Determine the wavelength spacing necessary to adequately sample the dispersed frame.
Parameters
----------
x0_sky : float or ndarray
RA of the input pixel position in direct image and segmentation map
y0_sky : float or ndarray
Dec of the input pixel position in direct image and segmentation map
sky_to_imgxy : astropy model
Transform from sky to image coordinates
imgxy_to_grismxy : astropy model
Transform from image to grism coordinates
order : int
Spectral order number
wmin : float
Minimum wavelength for dispersed spectra
wmax : float
Maximum wavelength for dispersed spectra
oversample_factor : int, optional
Factor by which to oversample the wavelength grid
Returns
-------
lambdas : ndarray
Wavelengths at which to compute dispersed pixel values
Notes
-----
It was found that the native wavelength spacing varies by a few percent or less
across the detector for both NIRCam and NIRISS. This function has the capability to
take in many x0, y0 at once and take the median to get the wavelengths,
but typically it's okay to just put in any x0, y0 pair.
"""
# Get x/y positions in the grism image corresponding to wmin and wmax:
# Convert to x/y in the direct image frame
x0_xy, y0_xy, _, _ = sky_to_imgxy(x0_sky, y0_sky, 1, order)
# then convert to x/y in the grism image frame.
xwmin, ywmin = imgxy_to_grismxy(x0_xy, y0_xy, wmin, order)
xwmax, ywmax = imgxy_to_grismxy(x0_xy, y0_xy, wmax, order)
dxw = xwmax - xwmin
dyw = ywmax - ywmin
# Create list of wavelengths on which to compute dispersed pixels
dw = np.abs((wmax - wmin) / (dyw - dxw))
dlam = np.median(dw / oversample_factor)
# need at least three points because often the sensitivity curve
# is not well-defined at the edges. This is typically hit only for Order 0,
# since dlam can be large or poorly defined in that case.
npts = max(int(np.ceil((wmax - wmin) / dlam)), 3)
lambdas = np.linspace(wmin, wmax, npts)
return lambdas
def _disperse_onto_grism(x0_sky, y0_sky, sky_to_imgxy, imgxy_to_grismxy, lambdas, order):
"""
Compute x/y positions in the grism image for the set of desired wavelengths.
Parameters
----------
x0_sky : ndarray
RA of the input pixel position in direct image and segmentation map
y0_sky : ndarray
Dec of the input pixel position in direct image and segmentation map
sky_to_imgxy : astropy model
Transform from sky to image coordinates
imgxy_to_grismxy : astropy model
Transform from image to grism coordinates
lambdas : ndarray
Wavelengths at which to compute dispersed pixel values
order : int
Spectral order number
Returns
-------
x0s : ndarray
X coordinates of dispersed pixels in the grism image
y0s : ndarray
Y coordinates of dispersed pixels in the grism image
lambdas : ndarray
Wavelengths corresponding to each dispersed pixel
"""
# x/y in image frame of grism image is the same for all wavelengths
x0_sky = np.repeat(x0_sky[np.newaxis, :], len(lambdas), axis=0)
y0_sky = np.repeat(y0_sky[np.newaxis, :], len(lambdas), axis=0)
x0_xy, y0_xy, _, _ = sky_to_imgxy(x0_sky, y0_sky, lambdas, order)
del x0_sky, y0_sky
# Convert to x/y in grism frame.
lambdas = np.repeat(lambdas[:, np.newaxis], x0_xy.shape[1], axis=1)
x0s, y0s = imgxy_to_grismxy(x0_xy, y0_xy, lambdas, order)
# x0s, y0s now have shape (n_lam, n_pixels)
return x0s, y0s, lambdas
def _collect_outputs_by_source(xs, ys, counts, source_ids_per_pixel, model_counts=None):
"""
Collect the dispersed pixel values into separate images for each source.
Parameters
----------
xs : ndarray
X coordinates of dispersed pixels
ys : ndarray
Y coordinates of dispersed pixels
counts : ndarray
Count rates of dispersed pixels
source_ids_per_pixel : int array
Source IDs of the dispersed pixels
model_counts : list of ndarray, optional
List of count rate arrays corresponding to input ``basis_models``
Returns
-------
outputs_by_source : dict
Dictionary containing dispersed images and bounds for each source ID
"""
# First sort by source ID. xs, ys input here cannot be assumed sorted after get_clipped_pixels
sort_idx = np.argsort(source_ids_per_pixel)
sorted_ids = source_ids_per_pixel[sort_idx]
sorted_xs = xs[sort_idx]
sorted_ys = ys[sort_idx]
sorted_counts = counts[sort_idx]
if model_counts is not None and len(model_counts) > 0:
sorted_model_counts = [mc[sort_idx] for mc in model_counts]
# Compute per-source bounds in a vectorized way
unique_ids, split_points = np.unique(sorted_ids, return_index=True)
minxs = np.minimum.reduceat(sorted_xs, split_points)
maxxs = np.maximum.reduceat(sorted_xs, split_points)
minys = np.minimum.reduceat(sorted_ys, split_points)
maxys = np.maximum.reduceat(sorted_ys, split_points)
# Now loop through sources, build the output images, and store bounds
# to reconstruct the full dispersed image later
outputs_by_source = {}
for i, this_sid in enumerate(unique_ids):
start = split_points[i]
end = split_points[i + 1] if i + 1 < len(split_points) else len(sorted_xs)
this_xs = sorted_xs[start:end]
this_ys = sorted_ys[start:end]
this_flxs = sorted_counts[start:end]
bounds = [int(minxs[i]), int(maxxs[i]), int(minys[i]), int(maxys[i])]
img = _build_dispersed_image_of_source(this_xs, this_ys, this_flxs, bounds)
outputs_by_source[this_sid] = {
"bounds": bounds,
"image": img,
}
if model_counts is not None and len(model_counts) > 0:
outputs_by_source[this_sid]["model_counts"] = [
_build_dispersed_image_of_source(this_xs, this_ys, mc[start:end], bounds)
for mc in sorted_model_counts
]
return outputs_by_source
def _build_dispersed_image_of_source(x, y, flux, bounds):
"""
Convert a flattened list of pixels to a 2-D grism image of that source.
Parameters
----------
x : ndarray
X coordinates of pixels in the grism image
y : ndarray
Y coordinates of pixels in the grism image
flux : ndarray
Fluxes of pixels in the grism image
bounds : list
Pre-computed [minx, maxx, miny, maxy] bounds for the source.
Returns
-------
a : ndarray
2-D dispersed image of the source
"""
minx, maxx, miny, maxy = bounds
return sparse.coo_matrix(
(flux, (y - miny, x - minx)), shape=(maxy - miny + 1, maxx - minx + 1)
).toarray()
[docs]
def disperse(
xs,
ys,
fluxes,
source_ids_per_pixel,
order,
wmin,
wmax,
sens_waves,
sens_resp,
direct_image_wcs,
grism_wcs,
naxis,
oversample_factor=2,
basis_models=None,
):
"""
Compute the dispersed image pixel values from the direct image.
Parameters
----------
xs : ndarray
Flat array of X coordinates of pixels in the direct image
ys : ndarray
Flat array of Y coordinates of pixels in the direct image
fluxes : ndarray of shape (N, n_pixels)
Fluxes of the pixels in the direct image corresponding to xs, ys,
in units of MJy/sr. N is the number of photometric bands; use N=1
for a flat (wavelength-independent) SED. Note in that case the array must still be 2-D.
source_ids_per_pixel : int array
Source IDs of the input pixels in the segmentation map
order : int
Spectral order number
wmin : float
Minimum wavelength for dispersed spectra
wmax : float
Maximum wavelength for dispersed spectra
sens_waves : float array
Wavelength array from photom reference file. Expected unit is micron.
sens_resp : float array
Response (flux calibration) array from photom reference file.
Expected units are (micron) * (MJy / sr) / (ADU/s).
direct_image_wcs : WCS object
WCS object for the direct image and segmentation map
grism_wcs : WCS object
WCS object for the grism image
naxis : tuple
Dimensions of the grism image (naxis[0], naxis[1])
oversample_factor : int, optional
Factor by which to oversample the wavelength grid
basis_models : list[Callable], optional
Flux distributions to evaluate at each wavelength. Typically these will be single
polynomial orders, e.g. [lambda x: x, lambda x: x^2], ...] the coefficients of which
are linearly fit later.
Returns
-------
outputs_by_source : dict
Dictionary containing dispersed images and bounds for each source ID
in the specified spectral order.
"""
n_input_sources = np.unique(source_ids_per_pixel).size
log.debug(
f"{mp.current_process()} dispersing {n_input_sources} "
f"sources in order {order} with total number of pixels: {len(xs)}"
)
width = 1.0
height = 1.0
x0 = xs + 0.5 * width
y0 = ys + 0.5 * height
del xs, ys
# Set up the transforms we need from the input WCS objects
sky_to_imgxy = grism_wcs.get_transform("world", "detector")
imgxy_to_grismxy = grism_wcs.get_transform("detector", "grism_detector")
# We only need the x,y outputs of imgxy_to_grismxy
# Making the number of outputs dynamic handles legacy WCS objects that did not pass
# the x0, y0, and order through the transform unmodified like the current version does.
n_outputs = len(imgxy_to_grismxy.outputs)
imgxy_to_grismxy = imgxy_to_grismxy | Mapping((0, 1), n_inputs=n_outputs)
# Find RA/Dec of the input pixel position in direct image
x0_sky, y0_sky = direct_image_wcs(x0, y0, with_bounding_box=False)
del x0, y0
# native spacing does not change much over the detector, so just put in one x0, y0
lambdas = _determine_native_wl_spacing(
x0_sky[0],
y0_sky[0],
sky_to_imgxy,
imgxy_to_grismxy,
order,
wmin,
wmax,
oversample_factor=oversample_factor,
)
dlam = lambdas[1] - lambdas[0]
n_pix = len(fluxes)
x0s, y0s, lambdas = _disperse_onto_grism(
x0_sky,
y0_sky,
sky_to_imgxy,
imgxy_to_grismxy,
lambdas,
order,
)
del x0_sky, y0_sky
# If none of the dispersed pixel indexes are within the image frame,
# return a null result without wasting time doing other computations
if x0s.min() >= naxis[0] or x0s.max() < 0 or y0s.min() >= naxis[1] or y0s.max() < 0:
return
# Discretize x and y coordinates to integer pixel values, keeping track of the fractional area
# that each pixel contributes to the final grism image.
# The resulting x, y coordinate pairs are non-unique: there are multiple wavelengths
# that contribute to each pixel.
padding = 1
xs, ys, areas, index = get_clipped_pixels(x0s, y0s, padding, naxis[0], naxis[1], width, height)
del x0s, y0s
# Only lambdas varies along the wavelength axis
# fluxes and source_ids are wavelength-independent, so index % n_pix
# recovers the correct source pixel column without needing np.take
# and is a bit faster.
lambdas = np.take(lambdas, index)
fluxes = fluxes[index % n_pix]
source_ids_per_pixel = source_ids_per_pixel[index % n_pix]
# Evaluate basis models on the 1-D lambda array.
# even after np.take this is element-wise so this is still full resolution
model_f = []
if basis_models is not None:
for flam in basis_models:
model_f.append(flam(lambdas))
# compute 1D sensitivity array corresponding to list of wavelengths
sens, no_cal = create_1d_sens(lambdas, sens_waves, sens_resp)
# Compute countrates for dispersed pixels.
# The input direct image data is already photometrically calibrated,
# so we need to basically apply a reverse flux calibration here.
# Divide out the response values to convert from Mjy/sr to DN/s.
# Note that the photom reference files are constructed with per-wavelength units,
# so oversampling is accounted for by the spacing of dlam.
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", category=RuntimeWarning, message="divide by zero|invalid value"
)
counts = fluxes * areas * dlam / sens
counts[no_cal] = 0.0 # set to zero where no flux cal info available
# Also convert basis models to counts.
model_counts = []
for f in model_f:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", category=RuntimeWarning, message="divide by zero|invalid value"
)
model_counts_i = fluxes * f * areas * dlam / sens
model_counts_i[no_cal] = 0.0
model_counts.append(model_counts_i)
del fluxes, areas, sens, dlam, no_cal, lambdas, index
outputs_by_source = _collect_outputs_by_source(
xs, ys, counts, source_ids_per_pixel, model_counts
)
del xs, ys, counts, source_ids_per_pixel
n_out = len(outputs_by_source)
log.debug(
f"{mp.current_process()} finished order {order} with {n_out} "
f"sources that overlap with the output frame "
f"(out of {n_input_sources} input sources)"
)
return outputs_by_source