import functools
import logging
import multiprocessing
import warnings
from multiprocessing import cpu_count
import gwcs
import numpy as np
from astropy.modeling.models import Identity, Scale, Shift
from astropy.stats import sigma_clipped_stats as scs
from astropy.utils.exceptions import AstropyUserWarning
from scipy.signal import find_peaks
from stcal.multiprocessing import compute_num_cores
from stcal.resample.utils import is_flux_density
from stdatamodels.jwst import datamodels
from stdatamodels.jwst.datamodels import dqflags
from jwst.adaptive_trace_model.bspline import bspline_fit
from jwst.assign_wcs.nirspec import nrs_ifu_wcs
from jwst.lib.pipe_utils import match_nans_and_flags
__all__ = [
"fit_2d_spline_trace",
"linear_oversample",
"fit_all_regions",
"oversample_flux",
"fit_and_oversample",
]
log = logging.getLogger(__name__)
def _get_weights_for_fit_scale(ratio, model_fit):
"""
Get weights for scaling the spline model to the fit data.
Sets weight to zero for outliers and invalid data points.
Parameters
----------
ratio : ndarray
Ratio between the fit data and the model evaluated at the data points.
model_fit : ndarray
The spline model evaluated at the data points.
Returns
-------
weights : ndarray
Array matching the ``ratio`` shape, containing weights for each
ratio data point.
"""
# Weights start off proportional to flux of the model
weights = model_fit.copy()
# Weights are zero for any pixel where the data was NaN
weights[~np.isfinite(ratio)] = 0
# Weights are zero for any pixel where the data or model was negative
weights[(ratio < 0) | (model_fit < 0)] = 0
# Identify the 5 largest weight points
order = np.argsort(weights)
largest_5 = (weights >= weights[order[-5]]) & (np.isfinite(ratio))
# Sigma-clipped mean and rms of these 5 ratios
mean, _, rms = scs(ratio[largest_5])
# Bad if over 2 sigma away
bad = np.abs(mean - ratio) > (2 * rms)
weights[bad] = 0
# Normalize weights
weights /= np.nansum(weights)
return weights
[docs]
def fit_2d_spline_trace(
flux,
alpha,
fit_scale=None,
lrange=50,
col_index=None,
space_ratio=1.2,
sigma_low=2.5,
sigma_high=2.5,
fit_iter=3,
require_ngood=None,
spline_bkpt=None,
auto_ngood_factor=0.5,
auto_bkpt_factor=2.0,
):
"""
Create a trace model from spline fits to a single slit/slice image.
Image must be oriented so that wavelengths are along x-axis. Each
column is fit separately, with a window to include nearby data.
Parameters
----------
flux : ndarray
Input 2D flux image to fit.
alpha : ndarray
Alpha coordinates for input flux.
fit_scale : ndarray, optional
Array of scale values to apply to the input flux before fitting.
lrange : int, optional
Local column range for data to include in the fit, to the
left and right of each input column.
col_index : iterable or None, optional
Iterable or generator that produces column index values to fit.
If provided, columns will be fit in the order specified.
If not provided, columns will be fit left to right.
space_ratio : float, optional
Maximum spacing ratio to allow fitting to continue. If
the tenth-largest spacing in the input ``xvec`` is larger
than the knot spacing by this ratio, then return None instead
of attempting to fit the data.
sigma_low : float, optional
Low sigma threshold for iterative spline fit.
sigma_high : float, optional
High sigma threshold for iterative spline fit.
fit_iter : int, optional
Maximum number of iterations for spline fit.
require_ngood : int or None, optional
Minimum number of data points required to attempt a fit in a column.
spline_bkpt : int or None, optional
Number of spline breakpoints (knots).
auto_ngood_factor : float, optional
If ``require_ngood`` is not provided, set it to this factor times
the native spacing range in the input slit/slice.
auto_bkpt_factor : float, optional
If ``spline_bkpt`` is not provided, set it to this factor times
the native spacing range in the input slit/slice.
Returns
-------
splines : dict
Keys are column index numbers, values are dicts. Each dict contains a
normalized ``model`` with `~scipy.interpolate.BSpline` value, a ``scale``
with a float value to scale to the data, and ``bounds`` with a 2-tuple of
floating point values, corresponding to the lower and upper bounds of the
alpha coordinates used in the fit. If a spline model could not be fit,
the column index number is not present.
"""
# Set reasonable spline breakpoints and ngood if not provided
if spline_bkpt is None or require_ngood is None:
native_n_alpha = (np.nanmax(alpha) - np.nanmin(alpha)) / _native_dalpha(alpha)
if spline_bkpt is None:
spline_bkpt = int(auto_bkpt_factor * native_n_alpha)
log.debug(f"Set spline_bkpt to {spline_bkpt}")
if require_ngood is None:
require_ngood = int(auto_ngood_factor * native_n_alpha)
log.debug(f"Set require_ngood to {require_ngood}")
# Define a fallback spline model, initialize to None
spline_model_save = None
# Similarly define the bounds of the fallback model
spline_lobound_save = None
spline_hibound_save = None
# Set up the column fitting order if not provided
xsize = flux.shape[-1]
if col_index is None:
col_index = range(0, xsize, 1)
# Scale the flux for fitting
if fit_scale is not None:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning)
scaled_flux = flux / fit_scale
else:
scaled_flux = flux
# Loop over columns in the slit/slice
splines = {}
for i in col_index:
col_flux = flux[:, i]
col_alpha = alpha[:, i]
ngood = np.sum(np.isfinite(col_flux) & np.isfinite(col_alpha))
if ngood <= require_ngood:
continue
# Get local alpha and flux values for fitting
lstart = np.max([i - lrange, 0])
lstop = np.min([i + lrange, xsize])
local_alpha = alpha[:, lstart:lstop]
local_data = scaled_flux[:, lstart:lstop]
# Trim to finite values
finite_values = np.isfinite(local_alpha) & np.isfinite(local_data)
local_alpha = local_alpha[finite_values]
local_data = local_data[finite_values]
# Sort by alpha
idx = np.argsort(local_alpha)
local_alpha = local_alpha[idx]
local_data = local_data[idx]
# Fit a bspline to the local data
try:
bspline = bspline_fit(
local_alpha,
local_data,
nbkpts=spline_bkpt,
wrapsig_low=sigma_low,
wrapsig_high=sigma_high,
wrapiter=fit_iter,
space_ratio=space_ratio,
verbose=False,
)
"""
# Make a useful plot every 100 columns.
# Code is retained here for debugging purposes, commented out since
# matplotlib is not a dependency of this package.
if (i > 0) and (i % 100 == 0):
from matplotlib import pyplot as plt
temp_dalpha = np.abs(np.nanmedian(np.diff(alpha, axis=0)))
temp_alpha = np.arange(
np.nanmin(local_alpha), np.nanmax(local_alpha), temp_dalpha / 50
)
plt.plot(local_alpha, local_data, "x", label="Nearby Cols")
plt.plot(alpha[:, i], scaled_flux[:, i], "s", zorder=50, label=f"Column {i}")
plt.plot(temp_alpha, bspline(temp_alpha), label="Spline Fit")
plt.legend()
plt.ylim(np.nanmin(bspline(temp_alpha)) - 0.1, np.nanmax(bspline(temp_alpha)) + 0.1)
plt.show()
"""
# If the fit failed (returned None) and no saved fit is available,
# try the fitting routine again with slightly fewer
# breakpoints to resolve occasional numerical issues.
if bspline is None and spline_model_save is None and spline_bkpt > 3:
bspline = bspline_fit(
local_alpha,
local_data,
nbkpts=spline_bkpt - 3,
wrapsig_low=sigma_low,
wrapsig_high=sigma_high,
wrapiter=fit_iter,
space_ratio=space_ratio,
verbose=False,
)
# If bspline is still None, use the saved fit if available
if bspline is None and spline_model_save is not None:
spline_model = spline_model_save
spline_lobound = spline_lobound_save
spline_hibound = spline_hibound_save
else:
spline_model = bspline # may be valid fit or None
spline_lobound = np.nanmin(local_alpha)
spline_hibound = np.nanmax(local_alpha)
except (ValueError, RuntimeError) as err:
log.warning(f"Spline fit failed at column {i}: {str(err)}")
spline_model = spline_model_save # may be valid fit or None
spline_lobound = spline_lobound_save
spline_hibound = spline_hibound_save
# Check for a good model
if spline_model is None:
continue
# Store the spline model and bounds for the column
spline_model_save = spline_model
spline_lobound_save = spline_lobound
spline_hibound_save = spline_hibound
# Evaluate the bspline at the valid input locations to determine
# a scale factor for the fit
idx = np.where(
(np.isfinite(col_alpha)) & (col_alpha >= spline_lobound) & (col_alpha <= spline_hibound)
)
col_alpha = col_alpha[idx]
col_flux = col_flux[idx]
col_fit = spline_model(col_alpha)
# Determine the normalization by the weighted mean ratio between model and data
# Weights are based on the model so that we can reject outliers
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning)
ratio = col_flux / col_fit
# Weights start off proportional to flux
weights = _get_weights_for_fit_scale(ratio, col_fit)
wmeanratio = np.nansum(ratio * weights)
# Store the model, scale, and bounds to return
splines[i] = {
"model": spline_model,
"scale": wmeanratio,
"bounds": (spline_lobound, spline_hibound),
}
return splines
def _reindex(xmin, xmax, scale=2.0):
"""
Convert pixel positions on the old grid to oversampled positions.
For example, with oversample scale = 2, [0, 1, 2] goes to
old_x = [-0.25, 0.25, 0.75, 1.25, 1.75, 2.25], for
new_x = [0, 1, 2, 3, 4, 5].
With oversample scale = 3, [0, 1, 2] goes to
old_x = [-0.33, 0, 0.33, 0.67, 1, 1.33, 1.67, 2, 2.33], for
new_x = [0, 1, 2, 3, 4, 5, 6, 7, 8].
Parameters
----------
xmin : int
Minimum index.
xmax : int
Maximum index.
scale : float, optional
Oversample scaling factor.
Returns
-------
new_x : ndarray of int
Array of indices in the new grid.
old_x : ndarray of float
Array of coordinates in the old grid, corresponding to the new indices.
"""
# Indices in the new array
new_x = np.arange(xmin * scale, (xmax + 1) * scale, dtype=np.int32)
# Indices in the old array, scaled for new pixel spacing
# Also offset to center new coordinates on the old
old_x = new_x / scale - (scale - 1) / (scale * 2)
return new_x, old_x
def _is_compact_source(
alpha_slice, alpha_ptsource, native_dalpha, spline_bkpt, pad=3, require_npt=50
):
"""
Determine which pixels within a slice contain a compact source.
Parameters
----------
alpha_slice : ndarray
Alpha coordinates for the output slice. If oversampling is performed,
these should be the oversampled coordinates.
alpha_ptsource : ndarray
Array of alpha values for modeled flux that met the slope limit threshold.
native_dalpha : float
Approximate native pixel size in alpha, along the columns.
spline_bkpt : int
The number of breakpoints used in the spline modeling.
pad : int, optional
The number of pixels near peak data to include the spline fit for in
the output array.
require_npt : int, optional
The minimum required number of high-slope data points to consider any
pixels to be compact.
Returns
-------
is_compact : ndarray
Boolean array matching the shape of ``alpha_slice``, where True
indicates a pixel containing a compact source.
"""
is_compact = np.full(alpha_slice.shape, False)
# If there is not enough high slope data or no spline models were found,
# just return False for all data
if len(alpha_ptsource) < require_npt or spline_bkpt is None:
return is_compact
# Bin the alpha coordinates for the high slope locations
avec = np.arange(spline_bkpt) * native_dalpha / 2 - (native_dalpha * spline_bkpt / 4)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning)
hist, edges = np.histogram(
alpha_ptsource,
bins=spline_bkpt,
range=(-native_dalpha * spline_bkpt / 4, native_dalpha * spline_bkpt / 4),
density=True,
)
hist = hist / np.nanmax(hist)
# Require peaks above some threshold
peak_indices, _ = find_peaks(hist, height=0.2)
amask = avec[peak_indices]
# Flag regions near the compact source, with some padding
for value in amask:
indx = (alpha_slice > value - pad * native_dalpha) & (
alpha_slice <= value + pad * native_dalpha
)
is_compact[indx] = True
return is_compact
def _native_dalpha(alpha):
"""
Compute the native spatial pixel spacing for a spectral region.
Parameters
----------
alpha : ndarray
2D array containing spatial coordinates. Horizontal dispersion
is assumed.
Returns
-------
float
The median pixel spacing.
"""
return np.abs(np.nanmedian(np.diff(alpha, axis=0)))
def _trace_image(shape, spline_models, region_map, alpha, slope_limit=0.1, pad=3):
"""
Evaluate spline models at all pixels to generate a trace image.
The trace image will be NaN wherever a spline model was not fit and
wherever the source is not compact enough for the spline model
to be appropriate. The ``slope_limit`` parameter controls the decision
for compact source regions.
Parameters
----------
shape : tuple of int
Data shape for the output image.
spline_models : dict
Spline models to evaluate.
region_map : ndarray
2D image matching shape, mapping valid region numbers.
alpha : ndarray
Alpha coordinates for all pixels marked as valid regions.
slope_limit : float, optional
The slope limit in the normalized model fits above which the spline
model is considered appropriate. Lower values will use spline fits
for fainter sources. If less than or equal to zero, the spline fits
will always be used.
pad : int, optional
The number of pixels near peak data to include the spline fit for in
the output array.
Returns
-------
trace_used : ndarray
2D image containing the scaled spline data fit evaluated at
the input alpha coordinates for compact source regions. Values are
NaN where no spline model was available and where the source
was below the slope limit.
full_trace: ndarray
2D image containing the scaled spline data fit evaluated at
all pixels. Values are NaN where no spline model was available.
"""
trace_used = np.full(shape, np.nan, dtype=np.float32)
full_trace = np.full(shape, np.nan, dtype=np.float32)
alpha_slice = np.full(shape, np.nan, dtype=np.float32)
trace_slice = np.full(shape, np.nan, dtype=np.float32)
spline_bkpt = None
for reg_num in spline_models:
splines = spline_models[reg_num]
alpha_slice[:] = np.nan
trace_slice[:] = np.nan
indx = region_map == reg_num
alpha_slice[indx] = alpha[indx]
# Define a list that will hold all alpha values for this
# slice where the slope is high
alpha_ptsource = []
# loop over columns
for i in range(shape[-1]):
if i not in splines:
continue
spline = splines[i]["model"]
scale = splines[i]["scale"]
lobound, hibound = splines[i]["bounds"]
# Evaluate the spline model for relevant data
col_alpha = alpha_slice[:, i]
valid_alpha = np.isfinite(col_alpha) & (col_alpha >= lobound) & (col_alpha <= hibound)
col_fit = spline(col_alpha[valid_alpha])
# Set the edges to NaN to avoid edge effects
col_fit[0] = np.nan
col_fit[-1] = np.nan
scaled_fit = scale * col_fit
trace_slice[:, i][valid_alpha] = scaled_fit
# Get the slope of the model fit prior to scaling
model_slope = np.abs(np.diff(col_fit, prepend=0))
# Ensure boundaries don't look weird
model_slope[0] = 0
model_slope[-1] = 0
highslope = (np.where(model_slope > slope_limit))[0]
alpha_ptsource.append(col_alpha[valid_alpha][highslope])
# Get the number of spline breakpoints used from the first real model
if spline_bkpt is None:
spline_bkpt = len(np.unique(spline.t)) - 1
full_trace[indx] = trace_slice[indx]
if slope_limit <= 0:
# Always use the spline fit in this case
trace_used[indx] = trace_slice[indx]
else:
if len(alpha_ptsource) > 0:
alpha_ptsource = np.concatenate(alpha_ptsource)
compact_locations = _is_compact_source(
alpha_slice, alpha_ptsource, _native_dalpha(alpha_slice), spline_bkpt, pad
)
trace_used[compact_locations] = trace_slice[compact_locations]
total_used = np.sum(compact_locations)
log.debug(
f"Using {total_used}/{np.sum(indx)} pixels from "
f"the spline model for region {reg_num}"
)
return trace_used, full_trace
def _linear_interp(col_y, col_flux, y_interp, edge_limit=0, preserve_nan=True):
"""
Perform a linear interpolation at one column.
Parameters
----------
col_y : ndarray
Y values in the original data for the column.
col_flux : ndarray
Flux values in the original data for the column.
y_interp : ndarray
Y values to interpolate to.
edge_limit : int, optional
If greater than zero, this many pixels at the edges of
the interpolated values will be set to NaN.
preserve_nan : bool, optional
If True, NaNs in the input will be preserved in the output.
Returns
-------
interpolated_flux : ndarray
Interpolated flux array.
"""
valid_data = np.isfinite(col_flux)
valid_y = np.isfinite(col_y)
valid_interp = valid_y & valid_data
interpolated_flux = np.interp(y_interp, col_y[valid_interp], col_flux[valid_interp])
if edge_limit >= 1:
interpolated_flux[0:edge_limit] = np.nan
interpolated_flux[-edge_limit:] = np.nan
# Check for NaNs in the input: they should be preserved in the output
if preserve_nan:
closest_pix = np.round(y_interp).astype(int)
is_nan = ~np.isfinite(col_flux[closest_pix])
interpolated_flux[is_nan] = np.nan
return interpolated_flux
[docs]
def linear_oversample(
data, region_map, oversample_factor, require_ngood, edge_limit=0, preserve_nan=True
):
"""
Oversample the input data with a linear interpolation.
Linear interpolation is performed for each column in each region in
the provided region map.
Parameters
----------
data : ndarray
Original data to oversample.
region_map : ndarray of int
Map containing the slice or slit number for valid regions.
Values are >0 for pixels in valid regions, 0 otherwise.
oversample_factor : float
Scaling factor to oversample by.
require_ngood : int
Minimum number of pixels required in a column to perform an interpolation.
edge_limit : int, optional
If greater than zero, this many pixels at the edges of
the interpolated values will be set to NaN.
preserve_nan : bool, optional
If True, NaNs in the input will be preserved in the output.
Returns
-------
os_data : ndarray
The oversampled data array.
"""
ysize, xsize = data.shape
_, basey = np.meshgrid(np.arange(xsize), np.arange(ysize))
os_shape = (int(np.ceil(ysize * oversample_factor)), xsize)
os_data = np.full(os_shape, np.nan, dtype=np.float32)
data_slice = np.full_like(data, np.nan)
y_slice = np.full_like(data, np.nan)
region_numbers = np.unique(region_map[region_map > 0])
for reg_num in region_numbers:
data_slice[:] = np.nan
y_slice[:] = np.nan
# Copy the relevant data for this slice into the holding arrays
indx = region_map == reg_num
data_slice[indx] = data[indx]
y_slice[indx] = basey[indx]
for ii in range(xsize):
valid_data = np.isfinite(data_slice[:, ii])
ngood = np.sum(valid_data)
if ngood <= require_ngood:
continue
col_y = y_slice[:, ii]
col_flux = data_slice[:, ii]
valid_y = np.isfinite(col_y)
newy, oldy = _reindex(
int(col_y[valid_y].min()), int(col_y[valid_y].max()), scale=oversample_factor
)
os_data[newy, ii] = _linear_interp(
col_y, col_flux, oldy, edge_limit=edge_limit, preserve_nan=preserve_nan
)
return os_data
def _crossdisp_profile(data_slice, err_slice, alpha_slice):
"""
Collapse a spectral region along wavelengths to make a cross-dispersion profile.
Parameters
----------
data_slice : ndarray
2D flux array for the full spectral region.
err_slice : ndarray
2D error array for the full spectral region.
alpha_slice : ndarray
2D spatial coordinate array for the full spectral region.
Returns
-------
alpha_xdisp : ndarray
1D spatial coordinates.
flux_xdisp : ndarray
1D flux values, median-combined across wavelengths.
err_xdisp : ndarray
1D error values, median-combined across wavelengths.
snr_xdisp : ndarray
1D signal-to-noise ratio values, median-combined across wavelengths.
"""
valid = (
np.isfinite(data_slice)
& np.isfinite(err_slice)
& np.isfinite(alpha_slice)
& (err_slice > 0)
)
if not np.any(valid):
return None, None, None, None
# Compute some SNR and noise statistics collapsed along wavelength
snr_slice = np.full_like(data_slice, np.nan)
snr_slice[valid] = data_slice[valid] / err_slice[valid]
step = _native_dalpha(alpha_slice) / 2.0
# Bin errors and SNR by alpha values
alpha_xdisp = np.arange(np.nanmin(alpha_slice), np.nanmax(alpha_slice), step)
flux_xdisp = np.full_like(alpha_xdisp, np.nan)
err_xdisp = np.full_like(alpha_xdisp, np.nan)
snr_xdisp = np.full_like(alpha_xdisp, np.nan)
for kk in range(len(alpha_xdisp)):
indx = (
(alpha_slice >= alpha_xdisp[kk] - step / 2.0)
& (alpha_slice < alpha_xdisp[kk] + step / 2.0)
& np.isfinite(snr_slice)
)
if not np.any(indx):
continue
flux_xdisp[kk] = np.nanmedian(data_slice[indx])
err_xdisp[kk] = np.nanmedian(err_slice[indx])
snr_xdisp[kk] = np.nanmedian(snr_slice[indx])
return alpha_xdisp, flux_xdisp, err_xdisp, snr_xdisp
def _trim_edges(data_slice, alpha_slice, alpha_xdisp, err_xdisp, snr_xdisp):
"""
Set bad edge values to NaN.
Bad pixel values are identified by spatial coordinates (alpha),
for which the collapsed signal-to-noise ratio (SNR) is low
and the collapsed error value is high.
Parameters
----------
data_slice : ndarray
2D flux array for the full spectral region; updated in place.
alpha_slice : ndarray
2D spatial coordinate array for the full spectral region.
alpha_xdisp : ndarray
1D spatial coordinates, collapsed across wavelengths.
err_xdisp : ndarray
1D error values, median-combined across wavelengths.
snr_xdisp : ndarray
1D SNR values, median-combined across wavelengths.
"""
# Bad edges are where SNR is low but ERR is high: set them to NaN
valid = np.isfinite(snr_xdisp)
if not np.any(valid):
return
err_mean, _, err_rms = scs(err_xdisp[valid])
bad = (np.abs(snr_xdisp) < 5) & (err_xdisp > err_mean + 5 * err_rms)
if not np.any(bad) or np.all(bad):
return
# Drop data below the largest negative bad alpha
bad_alpha = alpha_xdisp[bad]
test_bad = bad_alpha < 0
if np.any(test_bad):
indx = alpha_slice <= np.max(bad_alpha[test_bad])
data_slice[indx] = np.nan
# Drop data above the smallest positive bad alpha
test_bad = bad_alpha > 0
if np.any(test_bad):
indx = alpha_slice >= np.min(bad_alpha[test_bad])
data_slice[indx] = np.nan
def _threshold_test(flux_xdisp, snr_xdisp, region_number, peak_threshold, snr_threshold):
"""
Determine if the peak flux or SNR is higher than a given threshold.
If both ``peak_threshold`` and ``snr_threshold`` are None, the
return value is always True. Otherwise, the maximum ``flux_xdisp``
is compared to the ``peak_threshold`` if provided and the
the maximum ``snr_xdisp`` is compared to the ``snr_threshold`` if provided.
True is returned if either condition is met.
Parameters
----------
flux_xdisp : ndarray
1D flux values, median-combined across wavelengths.
snr_xdisp : ndarray
1D signal-to-noise ratio values, median-combined across wavelengths.
region_number : int
Region number, used to select the correct peak threshold value.
peak_threshold : dict or None
Dictionary of peak flux threshold values, by region number.
snr_threshold : float or None
SNR threshold value.
Returns
-------
bool
True if either peak flux or SNR are greater than the provided
threshold.
"""
if peak_threshold is None and snr_threshold is None:
return True
peak_over_threshold = (
peak_threshold is not None
and flux_xdisp is not None
and np.nanmax(flux_xdisp) > peak_threshold[region_number]
)
snr_over_threshold = (
snr_threshold is not None and snr_xdisp is not None and np.nanmax(snr_xdisp) > snr_threshold
)
return peak_over_threshold or snr_over_threshold
def _fit_one_region(
flux,
error,
alpha,
region_map,
region_number,
peak_threshold=None,
snr_threshold=None,
**fit_kwargs,
):
"""
Fit a trace model to a single region in the flux image.
Called from fit_all_regions, optionally parallelized via multiprocessing.
Parameters
----------
flux : ndarray
The flux image to fit.
error : ndarray
The error image associated with the flux.
alpha : ndarray
Alpha coordinates for all flux values.
region_map : ndarray of int
Map containing the slice or slit number for valid regions.
Values are >0 for pixels in valid regions, 0 otherwise.
region_number : int
Index number for the single region to be fit in this invocation.
peak_threshold : dict or None, optional
Flux threshold values for each valid region in the region map. If
the median peak value across columns in the region is below this
threshold, a fit will not be attempted for that region.
snr_threshold : float or None, optional
Signal-to-noise ratio (SNR) threshold value. If the median SNR value
across columns in the region is below this threshold, a fit will not
be attempted for that region.
**fit_kwargs
Keyword arguments to pass to the fitting routine (see `fit_2d_spline_trace`).
Returns
-------
splines : dict
Dict containing a spline model, scale, and bounds for each column index in the region.
If a spline model could not be fit, the column index number is not present.
"""
# Arrays to reset with NaNs for each slice
data_slice = np.full_like(flux, np.nan)
err_slice = np.full_like(flux, np.nan)
alpha_slice = np.full_like(flux, np.nan)
# Copy the relevant data for this slice into the holding arrays
indx = region_map == region_number
data_slice[indx] = flux[indx]
err_slice[indx] = error[indx]
alpha_slice[indx] = alpha[indx]
# Collapse the slit or slice along wavelength, to estimate peak flux and SNR
alpha_xdisp, flux_xdisp, err_xdisp, snr_xdisp = _crossdisp_profile(
data_slice, err_slice, alpha_slice
)
# Is either peak flux or SNR over threshold? If not, stop processing
no_data_msg = "No data over threshold; not fitting splines."
if not _threshold_test(flux_xdisp, snr_xdisp, region_number, peak_threshold, snr_threshold):
log.debug(no_data_msg)
return {}
# Use the collapsed SNR and error estimates to trim slit or slice edges
if alpha_xdisp is not None:
_trim_edges(data_slice, alpha_slice, alpha_xdisp, err_xdisp, snr_xdisp)
# Redo the collapsed profile after trimming
alpha_xdisp, flux_xdisp, err_xdisp, snr_xdisp = _crossdisp_profile(
data_slice, err_slice, alpha_slice
)
# Check again for signal over threshold after trimming
if not _threshold_test(flux_xdisp, snr_xdisp, region_number, peak_threshold, snr_threshold):
log.debug(no_data_msg)
return {}
# Get a running sum in a given detector column (used for normalization)
negative_nod_threshold = -5.0
if snr_xdisp is not None and np.nanmin(snr_xdisp) < negative_nod_threshold:
# If significant negative nods present, just sum positive data
log.debug("Found significant negative data; summing positive only for normalization.")
runsum = np.sum(data_slice, where=(data_slice > 0), axis=0)
else:
runsum = np.nansum(data_slice, axis=0)
# Fit the splines
splines = fit_2d_spline_trace(data_slice, alpha_slice, fit_scale=runsum, **fit_kwargs)
return splines
[docs]
def fit_all_regions(flux, error, alpha, region_map, maximum_cores="none", **fit_kwargs):
"""
Fit a trace model to all regions in the flux image.
Parameters
----------
flux : ndarray
The flux image to fit.
error : ndarray
The error image associated with the flux.
alpha : ndarray
Alpha coordinates for all flux values.
region_map : ndarray of int
Map containing the slice or slit number for valid regions.
Values are >0 for pixels in valid regions, 0 otherwise.
maximum_cores : str
Number of cores to use for multiprocessing. If set to 'none' (the default),
then no multiprocessing will be done. The other allowable values are 'quarter',
'half', 'all', and string integers. This is the fraction of available or
the explicit number of cores to use for multiprocessing.
**fit_kwargs
Keyword arguments to pass to the fitting routine (see `fit_2d_spline_trace`).
Returns
-------
spline_models : dict
Keys are region numbers, values are dicts containing a spline model,
scale, and bounds for each column index in the region. If a spline model
could not be fit, the column index number is not present.
"""
spline_models = {}
region_numbers = np.unique(region_map[region_map > 0])
# Determine number of slices to use for multi-processor computations
num_available_cores = cpu_count()
number_slices = compute_num_cores(maximum_cores, len(region_numbers), num_available_cores)
# Call adaptive trace model for the single processor (1 data slice) case
if number_slices == 1:
# Single threaded computation
log.debug("Running single-process calculation")
for reg_num in region_numbers:
if len(region_numbers) > 1:
log.info("Fitting slice %s", reg_num)
spline_models[reg_num] = _fit_one_region(
flux, error, alpha, region_map, reg_num, **fit_kwargs
)
else:
# Parallelized computation
log.info(f"Multiprocessing on {number_slices} cores")
# Use functools.partial to supply all other inputs to _fit_one_region except slice number
# This is needed since pool.starmap doesn't support passing **fit_kwargs
fit_one_region_with_args = functools.partial(
_fit_one_region, flux, error, alpha, region_map, **fit_kwargs
)
# Run the parallelized calc and collect results
ctx = multiprocessing.get_context("spawn")
pool = ctx.Pool(processes=number_slices)
try:
pool_results = pool.starmap(fit_one_region_with_args, [(n,) for n in region_numbers])
finally:
pool.close()
pool.join()
for reg_num, result in zip(region_numbers, pool_results, strict=True):
spline_models[reg_num] = result
return spline_models
[docs]
def oversample_flux(
flux,
alpha,
region_map,
spline_models,
oversample_factor,
alpha_os,
require_ngood=10,
slope_limit=0.1,
psf_optimal=False,
trim_ends=False,
pad=3,
):
"""
Oversample a flux image from spline models fit to the data.
For each column in each slice or slit in the region map:
1. Check if there are enough valid data points to proceed.
2. Compute oversampled coordinates corresponding to the input column.
3. Linearly interpolate flux values onto the oversampled column.
4. If a spline fit is available, evaluate it for the original column
coordinates.
5. Construct a residual between the spline fit and the original column.
data, then linearly interpolate the residual onto the oversampled
column.
6. Compute the slope of each column pixel as the difference between the
normalized spline model at that pixel and its immediate neighbor.
7. Evaluate the spline model at the oversampled coordinates.
The oversampled flux for each slice or slit is set from the spline flux
plus the interpolated residual, for pixels where the slope exceeds the
``slope_limit``. Otherwise, the flux is set to the linearly interpolated
value.
Parameters
----------
flux : ndarray
The flux image to fit.
alpha : ndarray
Alpha coordinates for all flux values.
region_map : ndarray of int
Map containing the slice or slit number for valid regions.
Values are >0 for pixels in valid regions, 0 otherwise.
spline_models : dict
Keys are region numbers, values are dicts containing a spline model,
scale, and bounds for each column index in the region. If a spline model
could not be fit, the column index number is not present.
oversample_factor : float
Scaling factor to oversample by.
alpha_os : ndarray
Alpha coordinates for the oversampled array, used to evaluate spline models
at every pixel.
require_ngood : int, optional
Minimum number of pixels required in a column to perform an interpolation.
slope_limit : float, optional
The slope limit in the normalized model fits above which the spline
model is considered appropriate. Lower values will use spline fits
for fainter sources. If less than or equal to zero, the spline fits
will always be used.
psf_optimal : bool, optional
If True, residual corrections to the spline model are not included
in the oversampled flux.
trim_ends : bool, optional
If True, the edges of the evaluated spline fit will be set to NaN.
pad : int, optional
The number of pixels near peak data to include the spline fit for in
the output array.
Returns
-------
flux_os : ndarray
The oversampled flux array, containing contributions from the evaluated
spline models, linear interpolations, and residual corrections.
trace_used : ndarray
A trace model, generated from the spline models evaluated at
pixels containing a compact source.
full_trace : ndarray
A trace model, generated from the spline models evaluated at
every pixel.
linear_flux : ndarray
The flux linearly interpolated onto the oversampled grid.
residual_flux : ndarray
Residuals between the spline modeled data and the original flux,
linearly interpolated onto the oversampled grid.
"""
ysize, xsize = flux.shape
_, basey = np.meshgrid(np.arange(xsize), np.arange(ysize))
# Oversampled flux array (linear and bspline to compare)
os_shape = (int(np.ceil(ysize * oversample_factor)), xsize)
flux_os_linear = np.full(os_shape, np.nan) # Linear interpolation
flux_os_bspline_full = np.full(os_shape, np.nan) # All bspline models
flux_os_bspline_use = np.full(os_shape, np.nan) # Actual bspline array applied
flux_os_residual = np.full(os_shape, np.nan) # Residual corrections
# Arrays to reset with NaNs for each slice
data_slice = np.full_like(flux, np.nan)
alpha_slice = np.full_like(flux, np.nan)
basey_slice = np.full_like(flux, np.nan)
alpha_os_slice = np.full(os_shape, np.nan)
reset_arrays = [data_slice, basey_slice, alpha_slice, alpha_os_slice]
# Edge limit for trimming ends
edge_limit = int(oversample_factor)
region_numbers = np.unique(region_map[region_map > 0])
spline_bkpt = None
for reg_num in region_numbers:
# Reset holding arrays to NaN
for reset_array in reset_arrays:
reset_array[:] = np.nan
# Copy the relevant data for this slice into the holding arrays
indx = region_map == reg_num
data_slice[indx] = flux[indx]
alpha_slice[indx] = alpha[indx]
basey_slice[indx] = basey[indx]
# Define a list that will hold all alpha values for this slice
# where the slope is high
alpha_ptsource = []
for ii in range(xsize):
# Are there sufficient values in this column to do anything?
valid_data = np.isfinite(data_slice[:, ii])
ngood = np.sum(valid_data)
if ngood <= require_ngood:
continue
# Get the relevant data for this column
col_y = basey_slice[:, ii]
col_alpha = alpha_slice[:, ii]
col_flux = data_slice[:, ii]
# newy is the resampled Y pixel indices in the expanded detector frame
# oldy is the resampled Y pixel indices in the original detector frame
valid_y = np.isfinite(col_y)
newy, oldy = _reindex(
int(col_y[valid_y].min()), int(col_y[valid_y].max()), scale=oversample_factor
)
# Default approach is to do linear interpolation
flux_os_linear[newy, ii] = _linear_interp(col_y, col_flux, oldy, edge_limit=edge_limit)
# Check for a spline fit for this column
if reg_num not in spline_models or ii not in spline_models[reg_num]:
continue
spline_model = spline_models[reg_num][ii]["model"]
spline_scale = spline_models[reg_num][ii]["scale"]
spline_lobound = spline_models[reg_num][ii]["bounds"][0]
spline_hibound = spline_models[reg_num][ii]["bounds"][1]
# Get the number of spline breakpoints used from the first real model
if spline_bkpt is None:
spline_bkpt = len(np.unique(spline_model.t)) - 1
# Get valid input locations and evaluate the spline
valid_alpha = (
np.isfinite(col_alpha)
& (col_alpha >= spline_lobound)
& (col_alpha <= spline_hibound)
)
col_fit = spline_model(col_alpha[valid_alpha])
scaled_fit = col_fit * spline_scale
# Construct the residual between spline fit and original data
# then oversample it to output frame by linear interpolation
residual = (col_flux[valid_alpha] - scaled_fit).astype(np.float32)
y_interp = col_y[valid_alpha]
valid_interp = np.isfinite(y_interp) & np.isfinite(residual)
interpval = np.interp(oldy, y_interp[valid_interp], residual[valid_interp])
if edge_limit >= 1:
interpval[0:edge_limit] = np.nan
interpval[-edge_limit:] = np.nan
flux_os_residual[newy, ii] = interpval
# What was the slope of the model fit prior to scaling?
model_slope = np.abs(np.diff(col_fit, prepend=0))
# Ensure boundaries don't look weird
if edge_limit >= 1:
model_slope[0:edge_limit] = 0
model_slope[-edge_limit:] = 0
# Add to our list of alpha values where the slope can be high for this slice
highslope = (np.where(model_slope > slope_limit))[0]
alpha_ptsource.append(col_alpha[valid_alpha][highslope])
# Store the oversampled alpha values to check against later
inbounds = np.where(
(alpha_os[newy, ii] >= spline_lobound) & (alpha_os[newy, ii] <= spline_hibound)
)
alpha_os_slice[newy[inbounds], ii] = alpha_os[newy[inbounds], ii]
# Evaluate the bspline at the oversampled alpha for this column
oversampled_fit = spline_model(alpha_os[newy[inbounds], ii]) * spline_scale
if trim_ends and edge_limit >= 1:
oversampled_fit[0:edge_limit] = np.nan
oversampled_fit[-edge_limit:] = np.nan
flux_os_bspline_full[newy[inbounds], ii] = oversampled_fit
# Now that our initial loop along the slice is done, we have a spline model everywhere
# Now look at our list of alpha values where model slopes were high to figure out
# where traces are and we actually want to use the spline model
if slope_limit <= 0:
# Always use the spline fit in this case
flux_os_bspline_use = flux_os_bspline_full
else:
if len(alpha_ptsource) > 0:
alpha_ptsource = np.concatenate(alpha_ptsource)
compact_locations = _is_compact_source(
alpha_os_slice, alpha_ptsource, _native_dalpha(alpha_slice), spline_bkpt, pad
)
flux_os_bspline_use[compact_locations] = flux_os_bspline_full[compact_locations]
total_used = np.sum(compact_locations)
log.debug(
f"Using {total_used}/{np.sum(indx)} pixels "
f"from the spline model for region {reg_num}"
)
# Insert the bspline interpolated values into the final combined oversampled array,
# starting from the linearly interpolated array
flux_os = flux_os_linear.copy()
indx = np.where(np.isfinite(flux_os_bspline_use))
flux_os[indx] = flux_os_bspline_use[indx]
# Unless we're doing a specific psf optimal extraction, add in the residual fit
if not psf_optimal:
log.info("Applying complex scene corrections.")
# DRL- conflicted about this indx array
# Using only where flux_os_bspline_use is finite will trim the slice edges a bit
# because the spline can extend slightly beyond the linear interpolation which can
# be bad for sources on the edge. But requiring the residual to also be finite
# can result in bad performance when the residual correction was really NEEDED
# on the edge.
indx = np.where(np.isfinite(flux_os_bspline_use) & np.isfinite(flux_os_residual))
flux_os[indx] += flux_os_residual[indx]
return flux_os, flux_os_bspline_use, flux_os_bspline_full, flux_os_linear, flux_os_residual
def _set_fit_kwargs(mode, detector, grating, xsize):
"""
Set optional parameters for spline fits by detector.
Parameters
----------
mode : str
Fitting mode ("NRS_IFU", "NRS_SLIT", "NRS_MOS", "MIR_MRS",
"MIR_LRS_SLIT", or "MIR_LRS_SLITLESS").
detector : str
Detector name.
grating : str or None
Grating name.
xsize : int
Input size for the data, along the dispersion axis. Used
to determine the column index order for spline fits.
Returns
-------
fit_kwargs : dict
Optional parameter settings to pass to the ``fit_all_regions``
function.
Raises
------
ValueError
If the input detector is not supported.
"""
# Empirical parameters for this mode
sigma_low = 2.5
sigma_high = 2.5
fit_iter = 3
spline_bkpt = None
require_ngood = None
auto_bkpt_factor = None
auto_ngood_factor = None
if detector.startswith("NRS"):
# Start with some defaults for all modes
lrange = 50
# This factor of 1.6 was dialed based on inspection of the results
# as sampling gets progressively worse for NIRSpec detectors
space_ratio = 1.6
# Set the spline breakpoints and minimum good pixels automatically
# from the native pixel spacing
auto_bkpt_factor = 2.0
auto_ngood_factor = 0.5
# Set some overrides for PRISM, which changes fast with wavelength
# and has a lot of PSF structure
if str(grating).upper() == "PRISM":
lrange = 10
auto_bkpt_factor = 1.0
auto_ngood_factor = 0.25
# Set up the column fitting order by detector
if detector == "NRS1":
# For NRS1, start on the left of detector since the tilt wrt pixels is greatest here
col_index = range(0, xsize, 1)
else:
# For NRS2, start on the right of detector since the tilt wrt pixels is greatest here
col_index = range(xsize - 1, -1, -1)
elif detector.startswith("MIR"):
require_ngood = 8
space_ratio = 1.2
if mode == "MIR_MRS":
lrange = 50
spline_bkpt = 36
# For MRS fitting order, we need to start on the left and run to the middle,
# and then on the right to the middle in order to have the middle
# section not go too far beyond last good fit
col_index = np.concatenate(
[np.arange(0, xsize // 2 + 1), np.arange(xsize - 1, xsize // 2, -1)]
)
else:
lrange = 5
sigma_low = 3.0
sigma_high = 3.0
if mode == "MIR_LRS_SLITLESS":
fit_iter = 2
spline_bkpt = 60
else:
spline_bkpt = 40
# For LRS, start on the right and move to the left
col_index = range(xsize - 1, -1, -1)
else:
raise ValueError("Unknown detector")
fit_kwargs = {
"lrange": lrange,
"col_index": col_index,
"space_ratio": space_ratio,
"sigma_low": sigma_low,
"sigma_high": sigma_high,
"fit_iter": fit_iter,
"spline_bkpt": spline_bkpt,
"require_ngood": require_ngood,
"auto_bkpt_factor": auto_bkpt_factor,
"auto_ngood_factor": auto_ngood_factor,
}
# Log the determined parameters
msg = f"Spline fit parameters for {mode}, detector={detector} xsize={xsize} grating={grating}:"
log.debug(msg)
for key, val in fit_kwargs.items():
log.debug(f" {key}: {val}")
return fit_kwargs
def _set_oversample_kwargs(mode, detector):
"""
Set optional parameters for oversampling by detector.
Parameters
----------
mode : str
Fitting mode (e.g. "NRS_IFU", "NRS_SLIT").
detector : str
Detector name.
Returns
-------
oversample_kwargs : dict
Optional parameter settings to pass to the ``oversample_flux``
function.
Raises
------
ValueError
If the input detector is not supported.
"""
require_ngood = 3
if detector.startswith("NRS"):
# Padding to add near point sources
if mode == "NRS_IFU":
pad = 2
else:
pad = 1
# Trimming ends of the interpolation can help with bad extrapolations
trim_ends = True
elif detector.startswith("MIR"):
# Padding to add near point sources
pad = 3
# Trimming ends is bad for MIRI, where dithers place point sources near the ends
trim_ends = False
else:
raise ValueError("Unknown detector")
oversample_kwargs = {"pad": pad, "trim_ends": trim_ends, "require_ngood": require_ngood}
return oversample_kwargs
def _get_alpha_nrs_ifu(ifu_wcs, xsize, ysize):
"""
Get alpha coordinates for NIRSpec IFU corresponding to the original data array.
Parameters
----------
ifu_wcs : list of `~gwcs.WCS`
List of WCS objects, one per slice.
xsize : int
X-size for the data array.
ysize : int
Y-size for the data array.
Returns
-------
alpha_orig : ndarray
Alpha coordinates for the data array, with shape (ysize, xsize).
"""
alpha_orig = np.full((ysize, xsize), np.nan)
for slice_wcs in ifu_wcs:
x, y = gwcs.wcstools.grid_from_bounding_box(slice_wcs.bounding_box)
_, alpha, _ = slice_wcs.transform("detector", "slicer", x, y)
idx = y.astype(int), x.astype(int)
# Flip alpha so in same direction as increasing Y
alpha_orig[*idx] = -alpha
return alpha_orig
def _get_alpha_nrs_slit(wcs, xsize, ysize):
"""
Get alpha coordinates for NIRSpec slits corresponding to the original data array.
Parameters
----------
wcs : `~gwcs.WCS`
WCS object for the slit.
xsize : int
X-size for the data array.
ysize : int
Y-size for the data array.
Returns
-------
alpha_orig : ndarray
Alpha coordinates for the data array, with shape (ysize, xsize).
"""
x, y = gwcs.wcstools.grid_from_bounding_box(wcs.bounding_box)
idx = y.astype(int), x.astype(int)
_, alpha, _ = wcs.transform("detector", "slit_frame", x, y)
# Flip alpha so in same direction as increasing Y
alpha_orig = np.full((ysize, xsize), np.nan)
alpha_orig[*idx] = -alpha
return alpha_orig
def _get_alpha_mir_mrs(wcs, xsize, ysize):
"""
Get alpha coordinates for MIRI MRS corresponding to the original data array.
Parameters
----------
wcs : `~gwcs.WCS`
WCS object.
xsize : int
X-size for the data array.
ysize : int
Y-size for the data array.
Returns
-------
alpha_orig : ndarray
Alpha coordinates for the data array, with shape (ysize, xsize).
"""
x, y = np.meshgrid(np.arange(xsize), np.arange(ysize))
det2ab = wcs.get_transform("detector", "alpha_beta")
alpha_orig, _, _ = det2ab(x, y)
return alpha_orig
def _get_alpha_mir_lrs(wcs, xsize, ysize):
"""
Get alpha coordinates for MIRI LRS corresponding to the original data array.
Parameters
----------
wcs : `~gwcs.WCS`
WCS object.
xsize : int
X-size for the data array.
ysize : int
Y-size for the data array.
Returns
-------
alpha_orig : ndarray
Alpha coordinates for the data array, with shape (ysize, xsize).
"""
x, y = gwcs.wcstools.grid_from_bounding_box(wcs.bounding_box)
idx = y.astype(int), x.astype(int)
alpha, _, _ = wcs.transform("detector", "alpha_beta", x, y)
alpha_orig = np.full((ysize, xsize), np.nan)
alpha_orig[*idx] = alpha
return alpha_orig
def _get_oversampled_coords_nrs_ifu(ifu_wcs, x_os, y_os):
"""
Get alpha coordinates for NIRSpec IFU corresponding to the oversampled data array.
Parameters
----------
ifu_wcs : list of `~gwcs.WCS`
List of WCS objects, one per slice.
x_os : int
X-size for the oversampled data array.
y_os : int
Y-size for the oversampled data array.
Returns
-------
alpha_os : ndarray
Alpha coordinates for the data array, with shape (y_os, x_os).
wave_os : ndarray
Wavelength coordinates for the data array, with shape (y_os, x_os),
in um.
"""
os_shape = x_os.shape
alpha_os = np.full(os_shape, np.nan)
wave_os = np.full(os_shape, np.nan)
for slice_wcs in ifu_wcs:
bbox = slice_wcs.bounding_box
x_in_bounds = (x_os >= bbox[0][0]) & (x_os <= bbox[0][1])
y_in_bounds = (y_os >= bbox[1][0]) & (y_os <= bbox[1][1])
_, alpha, lam = slice_wcs.transform(
"detector",
"slicer",
x_os[x_in_bounds & y_in_bounds],
y_os[x_in_bounds & y_in_bounds],
)
alpha_os[x_in_bounds & y_in_bounds] = -alpha
# Store wavelength, convert to um
wave_os[x_in_bounds & y_in_bounds] = lam * 1e6
return alpha_os, wave_os
def _inflate_error(error_array, extname, oversample_factor):
"""
Inflate error or variance arrays to account for oversampling.
Errors are increased by a factor dependent on the oversampling ratio
in order to account for the covariance introduced by the oversampling.
The inflation factor was determined empirically for IFU data
by comparing the reported error of single-spaxel spectra and aperture-summed
spectra, following ``cube_build`` on an oversampled image.
Empirically, based on the RMS of the aperture-summed spectrum in a line-free
region of a stellar spectrum, the true SNR does not change much (< 4%) between
N=1 and N=2/3/4. In contrast the reported SNR increases by an amount well fit
by X = 0.23N + 0.77. I.e., X=1 for N=1, and X=1.46 for N=3. This does not account
for variations in individual pixels, but to first order, inflating by this X
factor when the oversampling is performed will produce data cubes in which the
SNR is mostly preserved accurately. Per-pixel errors in the oversampled
product are not accurately reported by the inflated errors, but the oversampled
product should be considered primarily an intermediate data product; the
errors in the resampled cube are more important.
Parameters
----------
error_array : ndarray
Error or variance image to inflate. Updated in place.
extname : {"err", "var_rnoise", "var_poisson", "var_flat"}
Extension name.
oversample_factor : float
The oversampling factor used.
"""
inflation_factor = 0.23 * oversample_factor + 0.77
if str(extname).lower().startswith("var"):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "overflow encountered", category=RuntimeWarning)
error_array *= inflation_factor**2
else:
error_array *= inflation_factor
def _update_wcs_nrs_ifu(wcs, map_pixels):
"""
Update a NIRSpec IFU WCS to include the oversampling transform.
Parameters
----------
wcs : `~gwcs.WCS`
The WCS object, including transforms for all slices.
May be either coordinate-based or slice-based.
map_pixels : `~astropy.modeling.models.Model`
Model that transforms from oversampled pixels to original detector
pixels, to be prepended to the WCS pipeline.
Returns
-------
wcs : `~gwcs.WCS`
The updated WCS. If the input WCS was coordinate-based,
then the new transform is prepended to the existing "coordinates"
transform. If it was slice-based, a new WCS pipeline is created
with "coordinates" as the input frame, containing the new transform.
"""
if "coordinates" in wcs.available_frames:
# coordinate-based WCS: update the existing transform with the new mapping
first_transform = wcs.pipeline[0].transform
wcs.pipeline[0].transform = map_pixels | first_transform
wcs.pipeline[0].transform.name = first_transform.name
wcs.pipeline[0].transform.inputs = first_transform.inputs
wcs.pipeline[0].transform.outputs = first_transform.outputs
# update bounding box limits in place
det2slicer_selector = wcs.pipeline[1].transform.selector
for slnum in range(30):
bb = det2slicer_selector[slnum + 1].bounding_box
bb[0], bb[1] = map_pixels.inverse(bb[0], bb[1])
else:
# slice-based WCS
map_pixels &= Identity(1)
map_pixels.name = "coord2det"
map_pixels.inputs = ("x", "y", "name")
map_pixels.outputs = ("x", "y", "name")
bbox = wcs.bounding_box
frame = gwcs.coordinate_frames.Frame2D(name="coordinates", axes_order=(0, 1))
wcs = gwcs.WCS([(frame, map_pixels), *wcs.pipeline])
# update bounding box limits
for slnum in range(30):
bb = bbox[slnum]
bb[0], bb[1], _ = map_pixels.inverse(bb[0], bb[1], slnum)
wcs.bounding_box = bbox
return wcs
def _update_wcs(wcs, map_pixels):
"""
Update a WCS to include the oversampling transform.
Appropriate to the MIRI MRS WCS or slit-like WCS objects, following ``extract_2d``.
Parameters
----------
wcs : `~gwcs.WCS`
The WCS object, including transforms for all slits or slices.
map_pixels : `~astropy.modeling.models.Model`
Model that transforms from oversampled pixels to original detector
pixels, to be prepended to the WCS pipeline.
Returns
-------
wcs : `~gwcs.WCS`
A new WCS pipeline, with "coordinates" as the input frame, containing the
new transform.
"""
map_pixels.name = "coord2det"
map_pixels.inputs = ("x", "y")
map_pixels.outputs = ("x", "y")
frame = gwcs.coordinate_frames.Frame2D(name="coordinates", axes_order=(0, 1))
new_wcs = gwcs.WCS([(frame, map_pixels), *wcs.pipeline])
# update bounding box limits if present
bbox = wcs.bounding_box
if wcs.bounding_box is not None:
bbox[0], bbox[1] = map_pixels.inverse(bbox[0], bbox[1])
new_wcs.bounding_box = (bbox[0], bbox[1])
return new_wcs
def _intermediate_models(model, data_arrays):
"""
Make new datamodels for intermediate data arrays.
Parameters
----------
model : `~stdatamodels.jwst.datamodels.IFUImageModel` or \
`~stdatamodels.jwst.datamodels.SlitModel`
The input datamodel. Metadata will be copied from it.
data_arrays : list of ndarray or None
Data arrays to save. If None, the model returned is also None.
Returns
-------
new_models : list of `~stdatamodels.jwst.datamodels.IFUImageModel` or \
`~stdatamodels.jwst.datamodels.SlitModel`, or None
A list of datamodels containing the input data arrays.
Datamodel type will match the input model.
"""
if isinstance(model, datamodels.IFUImageModel):
model_type = datamodels.IFUImageModel
else:
model_type = datamodels.SlitModel
new_models = []
for data in data_arrays:
if data is None:
new_model = None
else:
new_model = model_type(data=data)
new_model.update(model)
# prevent empty error arrays
new_model.err = None
new_model.meta.bunit_err = None
new_models.append(new_model)
return new_models
[docs]
def fit_and_oversample(
model,
fit_threshold=10.0,
slope_limit=0.1,
psf_optimal=False,
oversample_factor=1.0,
return_intermediate_models=False,
maximum_cores="none",
metadata_model=None,
):
"""
Fit a trace model and optionally oversample a spectral datamodel.
Parameters
----------
model : `~stdatamodels.jwst.datamodels.IFUImageModel` or \
`~stdatamodels.jwst.datamodels.SlitModel`
The input datamodel, updated in place.
fit_threshold : float, optional
The signal threshold sigma for attempting spline fits within a spectral region.
Lower values will create spline traces for more regions. If less than or
equal to 0, all regions will be fit.
slope_limit : float, optional
The normalized slope threshold for using the spline model in oversampled
data. Lower values will use the spline model for fainter sources. If less
than or equal to 0, the spline model will always be used.
psf_optimal : bool, optional
If True, residual corrections to the spline model are not included
in the oversampled flux. This option is generally appropriate for simple
isolated point sources only. If set, ``slope_limit`` and ``fit_threshold``
values are ignored and spline fits are attempted and used for all data.
oversample_factor : float, optional
If not 1.0, then the data will be oversampled by this factor.
return_intermediate_models : bool, optional
If True, additional image models will be returned, containing the full
spline model, the spline model as used for compact sources, the residual
model, and the linearly interpolated data.
maximum_cores : str, optional
Number of cores to use for multiprocessing. If set to 'none' (the default),
then no multiprocessing will be done. The other allowable values are 'quarter',
'half', 'all', and string integers. This is the fraction of available or
the explicit number of cores to use for multiprocessing.
metadata_model : `~stdatamodels.jwst.datamodels.MultiSlitModel`, optional
If the input is one slit from a multi-slit model, the containing model
may be passed to retrieve appropriate top-level metadata (e.g. detector,
exposure type, grating).
Returns
-------
model : `~stdatamodels.jwst.datamodels.IFUImageModel` or \
`~stdatamodels.jwst.datamodels.SlitModel`
The datamodel, updated with a trace image and optionally oversampled
arrays.
full_spline_model : `~stdatamodels.jwst.datamodels.IFUImageModel` or \
`~stdatamodels.jwst.datamodels.SlitModel`, optional
The spline model evaluated at all pixels. Returned only if
``return_intermediate_models`` is True.
source_spline_model : `~stdatamodels.jwst.datamodels.IFUImageModel` or \
`~stdatamodels.jwst.datamodels.SlitModel`, optional
The spline model evaluated at compact source locations only.
Returned only if ``return_intermediate_models`` is True.
linear_model : `~stdatamodels.jwst.datamodels.IFUImageModel` or \
`~stdatamodels.jwst.datamodels.SlitModel` or None, optional
All data linearly interpolated onto the oversampled grid
Returned only if ``return_intermediate_models`` is True. Will be None if
``oversample_factor`` is 1.0.
residual_model : `~stdatamodels.jwst.datamodels.IFUImageModel` or \
`~stdatamodels.jwst.datamodels.SlitModel` or None, optional
Residuals from the spline fit, linearly interpolated onto the oversampled grid
Returned only if ``return_intermediate_models`` is True. Will be None if
``oversample_factor`` is 1.0.
"""
# Check parameters
if psf_optimal:
log.info("Ignoring fit threshold and slope limit for psf_optimal=True")
fit_threshold = 0
slope_limit = 0
if metadata_model is not None:
model_meta = metadata_model.meta
else:
model_meta = model.meta
# Get input data coordinates
detector = str(model_meta.instrument.detector).upper()
exp_type = str(model_meta.exposure.type).upper()
grating = str(getattr(model_meta.instrument, "grating", "NONE")).upper()
if model.data.ndim == 3:
nint, ysize, xsize = model.data.shape
else:
nint = 1
ysize, xsize = model.data.shape
if detector.startswith("NRS"):
rotate = False
if isinstance(model, datamodels.IFUImageModel):
mode = "NRS_IFU"
wcs = nrs_ifu_wcs(model)
alpha_orig = _get_alpha_nrs_ifu(wcs, xsize, ysize)
# the region map is already stored in the datamodel
region_map = model.regions
else:
if exp_type == "NRS_MSASPEC":
mode = "NRS_MOS"
else:
mode = "NRS_SLIT"
wcs = model.meta.wcs
alpha_orig = _get_alpha_nrs_slit(wcs, xsize, ysize)
# Set the region map for the single slit from the valid coordinates
region_map = np.zeros((ysize, xsize), dtype=np.int32)
region_map[np.isfinite(alpha_orig)] = 1
elif detector.startswith("MIR"):
rotate = True
if isinstance(model, datamodels.IFUImageModel):
mode = "MIR_MRS"
wcs = model.meta.wcs
alpha_orig = _get_alpha_mir_mrs(wcs, xsize, ysize)
# Region map is stored in the transform
det2ab_transform = wcs.get_transform("detector", "alpha_beta")
region_map = det2ab_transform.label_mapper.mapper.copy()
else:
if exp_type == "MIR_LRS-SLITLESS":
mode = "MIR_LRS_SLITLESS"
elif exp_type == "MIR_WFSS":
raise ValueError("MIRI WFSS is not supported.")
else:
mode = "MIR_LRS_SLIT"
wcs = model.meta.wcs
alpha_orig = _get_alpha_mir_lrs(wcs, xsize, ysize)
# Set the region map for the single slit from the valid coordinates
region_map = np.zeros((ysize, xsize), dtype=np.int32)
region_map[np.isfinite(alpha_orig)] = 1
else:
raise ValueError("Unknown detector")
# For multiple integrations, fit the profile to the median image
if nint > 1:
# Also check for an input oversample factor:
# oversampling is not supported for multiple integrations
if oversample_factor != 1:
raise ValueError("Oversampling is not supported for TSO data.")
log.info("Fitting the spatial profile to the median image.")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning)
flux_orig = np.nanmedian(model.data, axis=0)
err_orig = np.nanmedian(model.err, axis=0)
else:
# Otherwise, just get the data and error arrays from the model
flux_orig = model.data
err_orig = model.err
# Rotate input data if needed
if rotate:
xsize, ysize = ysize, xsize
flux_orig = np.rot90(flux_orig)
err_orig = np.rot90(err_orig)
alpha_orig = np.rot90(alpha_orig)
region_map = np.rot90(region_map)
# Set thresholding for the bspline fitting
# Do some statistics on the overall cal file
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=AstropyUserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
overall_mean, _, overall_rms = scs(flux_orig[region_map > 0])
overall_mean = 0 if ~np.isfinite(overall_mean) else overall_mean
overall_rms = 0 if ~np.isfinite(overall_rms) else overall_rms
# Need to ensure that the median pixel value isn't negative, because that causes chaos
# Subtract off that constant
restore_mean = None
if overall_mean < 0:
restore_mean = overall_mean
flux_orig = flux_orig - overall_mean
overall_mean = 0
# Define a per-slice analysis threshold for IFU
# (must be brighter than some level above background)
peak_threshold = None
region_numbers = np.unique(region_map[region_map > 0])
if fit_threshold <= 0:
# In this case for any mode, all regions should be fit,
# so set both thresholds to None
fit_threshold = None
else:
if mode == "MIR_MRS":
# For MIRI MRS we need each channel to have its own threshold, particularly
# for Ch3/Ch4 since the sky is so much brighter in Ch4
peak_threshold = dict.fromkeys(region_numbers, np.nan)
for channel in [100, 200, 300, 400]:
ch_data = (region_map >= channel) & (region_map < channel + 100)
if not np.any(ch_data):
continue
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=AstropyUserWarning)
ch_mean, _, ch_rms = scs(flux_orig[ch_data])
for reg_num in region_numbers:
if channel <= reg_num < channel + 100:
peak_threshold[reg_num] = ch_mean + fit_threshold * ch_rms
elif mode == "NRS_IFU":
# For NIRSpec IFU data, all regions have the same threshold
threshold = overall_mean + fit_threshold * overall_rms
peak_threshold = dict.fromkeys(region_numbers, threshold)
# Fit spline models to all regions
fit_kwargs = _set_fit_kwargs(mode, detector, grating, xsize)
if peak_threshold is not None:
fit_kwargs["peak_threshold"] = peak_threshold
else:
fit_kwargs["snr_threshold"] = fit_threshold
spline_models = fit_all_regions(
flux_orig,
err_orig,
alpha_orig,
region_map,
maximum_cores=maximum_cores,
**fit_kwargs,
)
# If oversampling is not needed, evaluate the spline models to create the
# trace image, store it in the model, and return.
# In the future, it might be useful to update the SCI extension here for the
# psf_optimal=True case, even when oversample=1, but for now, we will leave
# data unmodified.
oversample_kwargs = _set_oversample_kwargs(mode, detector)
if oversample_factor == 1:
trace_used, full_trace = _trace_image(
flux_orig.shape,
spline_models,
region_map,
alpha_orig,
slope_limit=slope_limit,
pad=oversample_kwargs["pad"],
)
if rotate:
trace_used = np.rot90(trace_used, k=-1)
full_trace = np.rot90(full_trace, k=-1)
# Restore the overall mean level to the trace if needed
if restore_mean is not None:
trace_used += restore_mean
full_trace += restore_mean
model.trace_model = trace_used
if return_intermediate_models:
new_models = _intermediate_models(model, [full_trace, trace_used, None, None])
return model, *new_models
else:
return model
# Oversampled array size
os_shape = (int(np.ceil(ysize * oversample_factor)), xsize)
x_os = np.full(os_shape, np.nan)
y_os = np.full(os_shape, np.nan)
# Pre-compute coordinates for the new data size
log.info("Computing oversampled coordinates")
basex, basey = np.meshgrid(np.arange(xsize), np.arange(ysize))
newy, oldy = _reindex(0, ysize - 1, scale=oversample_factor)
y_os[:, :] = oldy[:, None]
x_os[:, :] = basex[oldy.astype(int), :]
if mode == "NRS_IFU":
alpha_os, wave_os = _get_oversampled_coords_nrs_ifu(wcs, x_os, y_os)
elif mode.startswith("NRS"):
_, alpha_os, wave_os = model.meta.wcs.transform("detector", "slit_frame", x_os, y_os)
alpha_os *= -1
wave_os *= 1e6
else:
# Because MIRI was rotated the indexing in the non-rotated frame,
# the input spatial coordinates need to be adjusted slightly
det2ab = model.meta.wcs.get_transform("detector", "alpha_beta")
alpha_os, _, _ = det2ab(ysize - y_os - 1, x_os)
# Get wavelengths from the full WCS pipeline for the same coordinates
# Necessary for MIRI LRS to get appropriate NaN values outside the slit region.
_, _, wave_os = model.meta.wcs(ysize - y_os - 1, x_os)
log.info("Oversampling the flux array from the fit trace model")
flux_os, trace_used, full_trace, linear, residual = oversample_flux(
flux_orig,
alpha_orig,
region_map,
spline_models,
oversample_factor,
alpha_os,
slope_limit=slope_limit,
psf_optimal=psf_optimal,
**oversample_kwargs,
)
log.info("Oversampling error and DQ arrays")
error_extensions = ["err", "var_rnoise", "var_poisson", "var_flat"]
errors = {}
for extname in error_extensions:
if model.hasattr(extname):
errors[extname] = getattr(model, extname)
if rotate:
errors[extname] = np.rot90(errors[extname])
dq = model.dq
if rotate:
dq = np.rot90(dq)
# Nearest pixel interpolation for the dq and regions array
closest_pix = (np.round(y_os).astype(int), np.round(x_os).astype(int))
dq_os = dq[*closest_pix]
regions_os = region_map[*closest_pix]
# Update the DQ image for pixels that used to be NaN, now replaced by spline interpolation.
# Remove the DO_NOT_USE flag, add FLUX_ESTIMATED
is_estimated = ~np.isnan(flux_os) & ((dq_os & dqflags.pixel["DO_NOT_USE"]) > 0)
dq_os[is_estimated] ^= dqflags.pixel["DO_NOT_USE"]
dq_os[is_estimated] |= dqflags.pixel["FLUX_ESTIMATED"]
# Simple linear oversample for the error arrays
errors_os = {}
for extname, error_array in errors.items():
error_os = linear_oversample(
error_array,
region_map,
oversample_factor,
oversample_kwargs["require_ngood"],
edge_limit=0,
preserve_nan=False,
)
# Restore NaNs from the input, except at the estimated locations
is_nan = ~np.isfinite(error_array[closest_pix])
error_os[is_nan & ~is_estimated] = np.nan
# Inflate the errors to account for oversampling covariance
_inflate_error(error_os, extname, oversample_factor)
errors_os[extname] = error_os
# Update the wcs for new pixel scale
scale_and_shift = Scale(1 / oversample_factor) | Shift(
-(oversample_factor - 1) / (oversample_factor * 2)
)
if mode == "NRS_IFU":
map_pixels = Identity(1) & scale_and_shift
model.meta.wcs = _update_wcs_nrs_ifu(model.meta.wcs, map_pixels)
elif mode.startswith("NRS"):
map_pixels = Identity(1) & scale_and_shift
model.meta.wcs = _update_wcs(model.meta.wcs, map_pixels)
else:
# MIRI
map_pixels = scale_and_shift & Identity(1)
model.meta.wcs = _update_wcs(model.meta.wcs, map_pixels)
# If needed, undo all of our rotations before passing back the arrays
if rotate:
flux_os = np.rot90(flux_os, k=-1)
dq_os = np.rot90(dq_os, k=-1)
wave_os = np.rot90(wave_os, k=-1)
regions_os = np.rot90(regions_os, k=-1)
trace_used = np.rot90(trace_used, k=-1)
full_trace = np.rot90(full_trace, k=-1)
linear = np.rot90(linear, k=-1)
residual = np.rot90(residual, k=-1)
for extname, error_array in errors_os.items():
errors_os[extname] = np.rot90(error_array, k=-1)
# Restore the overall mean level if needed
if restore_mean is not None:
flux_os += restore_mean
trace_used += restore_mean
full_trace += restore_mean
# If the data is in flux density units rather than surface brightness,
# we also need to correct for flux conservation
if is_flux_density(model.meta.bunit_data):
flux_os /= oversample_factor
trace_used /= oversample_factor
linear /= oversample_factor
residual /= oversample_factor
for extname, error_array in errors_os.items():
if extname == "err":
error_array /= oversample_factor
else:
error_array /= oversample_factor**2
# Update the model with the oversampled arrays
model.data = flux_os
model.dq = dq_os
model.wavelength = wave_os
model.trace_model = trace_used
for extname, error_array in errors_os.items():
setattr(model, extname, error_array)
if isinstance(model, datamodels.IFUImageModel):
model.regions = regions_os
# Update additional metadata: pixel area has changed in one dimension
if model.meta.photometry.pixelarea_steradians is not None:
model.meta.photometry.pixelarea_steradians /= oversample_factor
if model.meta.photometry.pixelarea_arcsecsq is not None:
model.meta.photometry.pixelarea_arcsecsq /= oversample_factor
# Remove some extra arrays if present: no longer needed
extras = [
"area",
"barshadow",
"flatfield_point",
"flatfield_uniform",
"pathloss_point",
"pathloss_uniform",
"photom_point",
"photom_uniform",
"zeroframe",
]
for name in extras:
if model.hasattr(name):
setattr(model, name, None)
# Make sure NaNs and DO_NOT_USE flags match in all extensions
match_nans_and_flags(model)
# Return intermediate models if needed
if return_intermediate_models:
new_models = _intermediate_models(model, [full_trace, trace_used, linear, residual])
return model, *new_models
else:
return model