Source code for jwst.wfss_contam.wavefit

"""Fit a spectral shape to a WFSS dispersed source."""

import logging

import numpy as np

log = logging.getLogger(__name__)

__all__ = [
    "apply_basis_coeffs",
    "fit_slit_by_basis_images",
]


class SlitFitError(Exception):
    """Raise when spectral fitting fails."""

    pass


def _get_basis_images(simul_slit):
    """
    Collect the flat simulated image and all ``fluxmodel_N`` polynomial basis images.

    Parameters
    ----------
    simul_slit : `~stdatamodels.jwst.datamodels.SlitModel`
        Simulated slit with attributes ``data`` (the constant/flat term) and optional
        ``fluxmodel_1``, ``fluxmodel_2``, ... attributes for higher-degree terms.

    Returns
    -------
    basis : list of ndarray
        ``[data, fluxmodel_1, fluxmodel_2, ...]`` as numpy arrays.
    """
    basis = [np.asarray(simul_slit.data)]
    k = 1
    while True:
        mc = getattr(simul_slit, f"fluxmodel_{k}", None)
        if mc is None:
            break
        basis.append(np.asarray(mc))
        k += 1
    return basis


[docs] def fit_slit_by_basis_images(observed_slit, simul_slit, l2_alpha=0.0, rejection_threshold=0.1): """ Fit a linear combination of dispersed basis images to the observed slit. The constant (degree-0) term is ``simul_slit.data`` (the flat-spectrum simulation). Higher-degree terms are taken from the ``fluxmodel_1``, ``fluxmodel_2``, ... attributes of ``simul_slit``. These are the grism-frame images produced by passing polynomial flux models through ``disperse()``. The fit solves:: observed ≈ c_0 * data + c_1 * fluxmodel_1 + c_2 * fluxmodel_2 + ... via inverse-variance-weighted least squares on valid pixels, using the ``err`` array of the observed slit as pixel uncertainties. When ``l2_alpha > 0``, L2 regularisation is applied to the weighted normal equations. Parameters ---------- observed_slit : `~stdatamodels.jwst.datamodels.SlitModel` Observed 2-D spectral cutout. simul_slit : `~stdatamodels.jwst.datamodels.SlitModel` Simulated slit with ``data`` and optional ``fluxmodel_N`` attributes. l2_alpha : float, optional L2 regularisation strength. Added to the diagonal of the weighted normal-equation matrix as ``alpha * I`` before solving, which penalizes large coefficients. A value of ``0`` (the default) turns off regularization. Typical useful values are in the range ``1e-3`` - ``1e1``. rejection_threshold : float, optional If the fitted constant term coefficient ``c_0`` deviates from 1 by more than this amount, the fit is rejected and `None` is returned. This fit rejection is necessary to avoid fits "blowing up" when a source is located in nonzero (pseudo-)background, either from a nearby bright source or because the background subtraction was imperfect. If None, no fits will be rejected. Returns ------- coeffs : ndarray Best-fit coefficients ``[c_0, c_1, ...]``. Raises ------ SlitFitError If there are fewer valid pixels than basis terms. """ basis = _get_basis_images(simul_slit) obs_data = np.asarray(observed_slit.data) mask = np.isfinite(obs_data) & np.isfinite(basis[0]) & (basis[0] != 0) if getattr(observed_slit, "dq", None) is not None: mask &= (np.asarray(observed_slit.dq) & 1) == 0 n_valid = int(mask.sum()) n_terms = len(basis) if n_valid < n_terms: raise SlitFitError( f"Only {n_valid} valid pixel(s) available for a {n_terms}-term linear fit " f"(need at least {n_terms})." ) # Build inverse-variance weights from the error array. err_arr = getattr(observed_slit, "err", None) if err_arr is not None: is_finite_err = np.isfinite(err_arr) & (err_arr > 0) if not np.any(is_finite_err & mask): raise SlitFitError( "No valid pixels have finite positive error values; cannot compute fit weights." ) with np.errstate(divide="ignore", invalid="ignore"): inv_var = np.where( is_finite_err, 1.0 / err_arr**2, 0.0, ) w_sqrt = np.sqrt(inv_var)[mask] else: w_sqrt = np.ones(n_valid) design_matrix = np.column_stack([b[mask] for b in basis]) # Weighted normal equations: (A^T W A) c = A^T W b aw = design_matrix * w_sqrt[:, np.newaxis] bw = obs_data[mask] * w_sqrt if l2_alpha == 0.0: coeffs, *_ = np.linalg.lstsq(aw, bw, rcond=None) else: ata = aw.T @ aw atb = aw.T @ bw coeffs = np.linalg.solve(ata + l2_alpha * np.eye(n_terms), atb) # log some fit diagnostics for the source n_total = obs_data.size if rejection_threshold is not None and np.abs(coeffs[0] - 1) > rejection_threshold: log.debug(f"Fitted constant term c_0={coeffs[0]:.3g} is far from 1; rejecting fit.") return None log.debug( f"source_id={observed_slit.source_id} " f"order={observed_slit.meta.wcsinfo.spectral_order} " f"valid_pixels/total={n_valid}/{n_total} " # f"cond={cond:.3g} " # condition number of weighted design matrix f"coeffs={np.array2string(coeffs, precision=4, suppress_small=True)}" ) return coeffs
[docs] def apply_basis_coeffs(simul_slit, coeffs): """ Reconstruct a fitted slit as a linear combination of dispersed basis images. Parameters ---------- simul_slit : `~stdatamodels.jwst.datamodels.SlitModel` Simulated slit with ``data`` and optional ``fluxmodel_N`` attributes. coeffs : ndarray Coefficients from `fit_slit_by_basis_images`. Returns ------- ndarray Fitted slit image. """ basis = _get_basis_images(simul_slit) return sum(c * b for c, b in zip(coeffs, basis, strict=True))