Source code for jwst.wfss_contam.wfss_contam

"""Top-level module for WFSS contamination correction."""

import copy
import logging
import multiprocessing

import numpy as np
from stcal.multiprocessing import compute_num_cores
from stdatamodels.jwst import datamodels
from stdatamodels.jwst.transforms.models import (
    NIRCAMBackwardGrismDispersion,
    NIRISSBackwardGrismDispersion,
)

from jwst.lib.catalog_utils import read_source_catalog
from jwst.wfss_contam.observations import Observation
from jwst.wfss_contam.sens1d import get_photom_data
from jwst.wfss_contam.wavefit import SlitFitError, apply_basis_coeffs, fit_slit_by_basis_images

log = logging.getLogger(__name__)

__all__ = ["contam_corr"]


class UnmatchedSlitIDError(Exception):
    """Exception raised when a slit ID is not found in the list of simulated slits."""

    pass


class _LegendreFluxModel:
    """
    Picklable callable that evaluates the k-th Legendre polynomial.

    Callables need to be picklable so that they can be used in multiprocessing contexts.
    The wavelength argument is mapped from ``[wmin, wmax]`` to ``[-1, 1]`` before
    evaluation.
    """

    def __init__(self, k, wmin, wmax):
        self.k = k
        self.wmin = wmin
        self.wmax = wmax
        # Coefficient vector with 1 at position k and 0 elsewhere.
        self._coeffs = np.zeros(k + 1)
        self._coeffs[k] = 1.0

    def __call__(self, x):
        x_norm = 2.0 * (x - self.wmin) / (self.wmax - self.wmin) - 1.0
        return np.polynomial.legendre.legval(x_norm, self._coeffs)


def _find_matching_simul_slit(slit, simul_slit_sids, simul_slit_orders):
    """
    Find the index of the matching simulated slit in the list of simulated slits.

    Parameters
    ----------
    slit : `~stdatamodels.jwst.datamodels.SlitModel`
        Source slit model
    simul_slit_sids : list
        List of source IDs for simulated slits
    simul_slit_orders : list
        List of spectral orders for simulated slits

    Returns
    -------
    good_idx : int
        Index of the matching simulated slit in the list of simulated slits
    """
    sid = slit.source_id
    order = slit.meta.wcsinfo.spectral_order
    good = (np.array(simul_slit_sids) == sid) * (np.array(simul_slit_orders) == order)
    if not any(good):
        raise UnmatchedSlitIDError(
            f"Source ID {sid} order {order} requested by input slit model "
            "but not found in simulated slits. "
            "Setting contamination correction to zero for that slit."
        )
    return np.where(good)[0][0]


def _cut_frame_to_match_slit(contam, slit):
    """
    Cut out the contamination image to match the extent of the source slit.

    Parameters
    ----------
    contam : 2D array
        Contamination image for the full grism exposure
    slit : `~stdatamodels.jwst.datamodels.SlitModel`
        Source slit model

    Returns
    -------
    cutout : 2D array
        Contamination image cutout that matches the extent of the source slit
    """
    x1 = slit.xstart - 1  # convert FITS 1-indexed to 0-indexed for array access
    y1 = slit.ystart - 1  # convert FITS 1-indexed to 0-indexed for array access
    xf = x1 + slit.xsize
    yf = y1 + slit.ysize

    # zero-pad the contamination image if the slit extends beyond the contamination image
    # fixes an off-by-one bug when sources extend to the edge of the contamination image
    if xf > contam.shape[1]:
        contam = np.pad(contam, ((0, 0), (0, xf - contam.shape[1])), mode="constant")
    if yf > contam.shape[0]:
        contam = np.pad(contam, ((0, yf - contam.shape[0]), (0, 0)), mode="constant")

    return contam[y1 : y1 + slit.ysize, x1 : x1 + slit.xsize]


class SlitOverlapError(Exception):
    """Exception raised when there is no overlap between data and model for a slit."""

    pass


def match_backplane_prefer_first(slit0, slit1):
    """
    Reshape slit1 to the backplane of slit0.

    Parameters
    ----------
    slit0 : `~stdatamodels.jwst.datamodels.SlitModel`
        Slit model for the first slit, which is used as reference.
    slit1 : `~stdatamodels.jwst.datamodels.SlitModel`
        Slit model for the second slit, which is reshaped to match slit0.

    Returns
    -------
    slit1 : `~stdatamodels.jwst.datamodels.SlitModel`
        Reshaped slit model.
    """
    data0 = slit0.data
    data1 = slit1.data

    x1 = slit1.xstart - slit0.xstart
    y1 = slit1.ystart - slit0.ystart
    backplane1 = np.zeros_like(data0)

    i0 = max([y1, 0])
    i1 = min([y1 + data1.shape[0], data0.shape[0], data1.shape[0]])
    j0 = max([x1, 0])
    j1 = min([x1 + data1.shape[1], data0.shape[1], data1.shape[1]])
    if i0 >= i1 or j0 >= j1:
        raise SlitOverlapError(
            f"No overlap region between data and model for slit {slit0.source_id}, "
            f"order {slit0.meta.wcsinfo.spectral_order}. "
            "setting contamination correction to zero for that slit."
        )
    di = i0 - y1  # offset into data1's own row axis
    dj = j0 - x1  # offset into data1's own col axis
    backplane1[i0:i1, j0:j1] = data1[di : di + (i1 - i0), dj : dj + (j1 - j0)]
    slit1.data = backplane1

    # also update fluxmodel attributes if present
    k = 1
    while True:
        attr = f"fluxmodel_{k}"
        mc = getattr(slit1, attr, None)
        if mc is None:
            break
        mc_bp = np.zeros_like(data0)
        mc_bp[i0:i1, j0:j1] = np.asarray(mc)[di : di + (i1 - i0), dj : dj + (j1 - j0)]
        setattr(slit1, attr, mc_bp)
        k += 1

    slit1.xstart = slit0.xstart
    slit1.ystart = slit0.ystart
    slit1.xsize = slit0.xsize
    slit1.ysize = slit0.ysize

    return slit1


def _validate_orders_against_reference(orders, spec_orders):
    """
    Compare user-requested spectral orders with the orders defined in the reference file.

    Parameters
    ----------
    orders : list[int]
        List of user-requested spectral orders.
    spec_orders : list[int]
        List of spectral orders defined in the reference file.

    Returns
    -------
    np.ndarray[int]
        List of spectral orders constrained to the user-specified ones
        that are also defined in the reference file.
    """
    spec_orders = np.array(spec_orders, dtype=int)
    if orders is None:
        return spec_orders
    orders = np.array(orders, dtype=int)
    good_orders = np.isin(orders, spec_orders, assume_unique=True)
    if (len(good_orders) == 0) or (not np.any(good_orders)):
        log.error(
            f"None of the requested spectral orders {orders} are defined "
            "in the wavelength range reference file. "
            f"Expected orders are: {spec_orders}. "
        )
        return []
    if not np.all(good_orders):
        log.warning(
            f"Not all requested spectral orders {orders} are defined in the "
            f"wavelength range reference file. Defined orders are: {spec_orders}. "
            "Skipping undefined orders."
        )
    return orders[good_orders]


def _validate_orders_against_transform(wcs, spec_orders):
    """
    Ensure the requested spectral orders are defined in the WCS transforms.

    Parameters
    ----------
    wcs : gwcs.wcs.WCS
        The input MultiSlitModel's WCS object.
    spec_orders : list[int]
        The list of requested spectral orders.

    Returns
    -------
    list
        List of spectral orders that are defined in the WCS transform.
    """
    sky_to_grism = wcs.backward_transform
    good_orders = spec_orders.copy()
    for model in sky_to_grism:
        if isinstance(model, (NIRCAMBackwardGrismDispersion, NIRISSBackwardGrismDispersion)):
            # Get the orders defined in the transform
            orders = np.sort(model.orders)
            is_good_order = [order in orders for order in spec_orders]
            if not any(is_good_order):
                log.error(
                    f"None of the requested spectral orders {spec_orders} are defined "
                    "in the WCS transform. "
                    f"Defined orders are: {orders}. "
                )
                return []
            if not all(is_good_order):
                log.warning(
                    f"Not all requested spectral orders {spec_orders} are "
                    f"defined in the WCS transform. Defined orders are: {orders}. "
                    "Skipping undefined orders."
                )
            good_orders = [order for order in spec_orders if order in orders]
            # There will be only one transform of this type in the wcs
            break
    return np.sort(good_orders)


def _find_min_relresp(sens_waves, sens_response):
    """
    Find the minimum relative response in the sensitivity response.

    Helper function is necessary instead of just nanmin because reference file
    sometimes has zero-valued wavelength/response pairs.

    Parameters
    ----------
    sens_waves : np.ndarray
        Wavelengths corresponding to the sensitivity response.
    sens_response : np.ndarray
        Sensitivity response values.

    Returns
    -------
    float
        Minimum relative response value.
    """
    good = (sens_waves > 0) & (sens_response > 0) & np.isfinite(sens_waves)
    return np.nanmin(sens_response[good])


def _build_simulated_image_from_slits(simulated_slits, shape):
    """
    Reconstruct the full-frame simulated image from the simulated slits.

    This is needed instead of just using ``obs.simulated_image`` because when
    spectral fitting is requested, the simulated slits get modified, and the
    modifications need to end up in the simul file.

    Parameters
    ----------
    simulated_slits : `~stdatamodels.jwst.datamodels.MultiSlitModel`
        The simulated slits.
    shape : tuple of int
        ``(nrows, ncols)`` of the full detector frame.

    Returns
    -------
    full_image : ndarray
        Full-frame simulated image.
    """
    full_image = np.zeros(shape, dtype=float)
    nrows, ncols = shape
    for slit in simulated_slits.slits:
        x0 = slit.xstart - 1  # convert FITS 1-indexed to 0-indexed for array access
        y0 = slit.ystart - 1  # convert FITS 1-indexed to 0-indexed for array access
        # Clip to frame boundaries in case a slit overflows
        x1 = min(x0 + slit.xsize, ncols)
        y1 = min(y0 + slit.ysize, nrows)
        full_image[y0:y1, x0:x1] += slit.data[: y1 - y0, : x1 - x0]
    return full_image


def _apply_magnitude_limit(
    order, source_catalog, sens_wave, sens_response, magnitude_limit, min_relresp_order1
):
    """
    Rescale the magnitude limit based on the sensitivity response for a given spectral order.

    Parameters
    ----------
    order : int
        Spectral order for which the magnitude limit is applied.
    source_catalog : astropy.table.Table
        The source catalog containing source IDs and isophotal AB magnitudes.
    sens_wave : np.ndarray
        The wavelengths corresponding to the sensitivity response.
    sens_response : np.ndarray
        The sensitivity response for the order.
    magnitude_limit : float
        The isophotal AB magnitude limit for sources to be included in the contamination correction.
    min_relresp_order1 : float
        Minimum relative response for order 1, used to scale the magnitude limit.

    Returns
    -------
    list
        List of source IDs that meet the magnitude limit criteria.
    """
    if order in [0, 1]:
        # Magnitude limit is set according to order 1 sensitivity response
        # and order 0 is a special case because it's not dispersed
        order_mag_limit = magnitude_limit
    else:
        # Scale the magnitude limit according to the order sensitivity response
        order_sens_factor = min_relresp_order1 / _find_min_relresp(sens_wave, sens_response)
        order_mag_diff = -2.5 * np.log10(order_sens_factor)
        order_mag_limit = magnitude_limit - order_mag_diff

    # Select sources that are brighter than the magnitude limit
    good_sources = source_catalog[source_catalog["isophotal_abmag"] < order_mag_limit]
    if len(good_sources) == 0:
        return None
    log.info(
        f"Applying magnitude limit of {order_mag_limit:.1f} to order {order}. "
        f"Sources selected: {len(good_sources)}"
    )
    return good_sources["label"].tolist()


def _match_simulated_slits(output_model, obs):
    """
    Match each observed slit to its simulated counterpart.

    Reprojects the simulated slit onto the same backplane as the observed slit.
    If the slit ID is not found in the simulated slits, or if there is no spatial overlap between
    the observed slit and the simulated slit, the slit is skipped and represented as ``None``
    in the returned list.

    Parameters
    ----------
    output_model : `~stdatamodels.jwst.datamodels.MultiSlitModel`
        The observed slits to match against.
    obs : `~jwst.wfss_contam.observations.Observation`
        The observation object containing the dispersed simulated slits.

    Returns
    -------
    matched_flat_simuls : list of `~stdatamodels.jwst.datamodels.SlitModel`
        Simulated slit reprojected onto each observed slit's backplane,
        or ``None`` where no match was found.
    good_idxs : list of int or None
        Index into ``obs.simulated_slits.slits`` for each observed slit,
        or ``None`` where no match was found.
    """
    simul_slit_sids = [slit.source_id for slit in obs.simulated_slits.slits]
    simul_slit_orders = [slit.meta.wcsinfo.spectral_order for slit in obs.simulated_slits.slits]

    matched_flat_simuls = []
    good_idxs = []
    for slit in output_model.slits:
        try:
            good_idx = _find_matching_simul_slit(slit, simul_slit_sids, simul_slit_orders)
            matched_flat = copy.deepcopy(obs.simulated_slits.slits[good_idx])
            matched_flat = match_backplane_prefer_first(slit, matched_flat)
            matched_flat_simuls.append(matched_flat)
            good_idxs.append(good_idx)
        except (UnmatchedSlitIDError, SlitOverlapError) as e:
            log.warning(e)
            matched_flat_simuls.append(None)
            good_idxs.append(None)

    return matched_flat_simuls, good_idxs


def _fit_spectral_shape(
    observed_slit,
    simul_slit,
    simul_slit_backplane_unmatched,
    polyfit_degree,
    l2_alpha=0.0,
    rejection_threshold=0.1,
):
    """
    Fit a polynomial spectral shape to one slit and apply the result in-place.

    Parameters
    ----------
    observed_slit : `~stdatamodels.jwst.datamodels.SlitModel`
        Observed slit whose ``.data`` is used as the target for the fit.
        If ``n_iterations > 1``, this holds the contamination-corrected data
        from the previous iteration.
    simul_slit : `~stdatamodels.jwst.datamodels.SlitModel`
        Simulated slit, already backplane-matched to ``observed_slit``.
        Its ``.data`` attribute is updated in-place with the spectrally fitted result.
    simul_slit_backplane_unmatched : `~stdatamodels.jwst.datamodels.SlitModel`
        The simulation in ``obs.simulated_slits`` without backplane matching applied.
        This is tracked independently from ``simul_slit`` because the extraction
        of the observed slit from extract_2d often misses flux from the extended PSF wings,
        whereas the simulation in this step disperses the whole segment, including those wings.
        Its ``.data`` is updated in-place so the full-frame reconstruction
        reflects the fitted spectral shape.
    polyfit_degree : int or None
        Degree of the polynomial spectral model.  ``None`` means no fitting.
    l2_alpha : float, optional
        L2 regularisation strength passed to `~jwst.wfss_contam.wavefit.fit_slit_by_basis_images`.
    rejection_threshold : float, optional
        Threshold for rejecting fits based on the fitted constant term coefficient, passed to
        `~jwst.wfss_contam.wavefit.fit_slit_by_basis_images`.

    Returns
    -------
    status : bool
        ``True`` if the fit was successful, ``False`` if the fit failed and the
        original flat-spectrum simulation is used.
    """
    if polyfit_degree is None or getattr(simul_slit, "fluxmodel_1", None) is None:
        return False

    log.debug(
        f"Fitting polynomial of degree {polyfit_degree} to the simulated slit "
        f"for source ID {observed_slit.source_id}, "
        f"order {observed_slit.meta.wcsinfo.spectral_order}"
    )
    try:
        coeffs = fit_slit_by_basis_images(
            observed_slit, simul_slit, l2_alpha=l2_alpha, rejection_threshold=rejection_threshold
        )
        if coeffs is None:
            return False
        simul_slit.data = apply_basis_coeffs(simul_slit, coeffs)
        simul_slit_backplane_unmatched.data = apply_basis_coeffs(
            simul_slit_backplane_unmatched, coeffs
        )
    except SlitFitError as e:
        log.debug(
            f"Polynomial fitting failed for slit with source ID {observed_slit.source_id}, "
            f"order {observed_slit.meta.wcsinfo.spectral_order}: {e}. "
            "Using the original simulated slit without fitting."
        )
        return False
    else:
        return True


[docs] def contam_corr( input_model, waverange, photom, max_cores, orders=None, magnitude_limit=None, max_pixels_per_chunk=5e4, oversample_factor=2, polyfit_degree=None, n_iterations=1, l2_alpha=0.1, rejection_threshold=0.1, ): """ Correct contamination in WFSS spectral cutouts. Parameters ---------- input_model : `~stdatamodels.jwst.datamodels.MultiSlitModel` Input data model containing 2D spectral cutouts. May be modified by processing: make a copy before calling this function, if needed. waverange : `~stdatamodels.jwst.datamodels.WavelengthrangeModel` Wavelength range reference file model photom : `~stdatamodels.jwst.datamodels.NrcWfssPhotomModel` or \ `~stdatamodels.jwst.datamodels.NisWfssPhotomModel` Photom (flux cal) reference file model max_cores : str or int Number of cores to use for multiprocessing. If set to 'none' (the default), then no multiprocessing will be done. The other allowable string values are 'quarter', 'half', and 'all', which indicate the fraction of cores to use for multi-proc. The total number of cores includes the SMT cores (Hyper Threading for Intel). If an integer is provided, it will be the exact number of cores used. orders : list, optional List of spectral orders to process. If None, all orders defined in the wavelengthrange file will be processed. magnitude_limit : float, optional Isophotal AB magnitude limit for sources to be included in the contamination correction. The magnitude limit is applied per spectral order, where the orders are scaled relative to order 0 based on their photometric response as read from the photom reference file. This means that generally fewer sources will be dispersed in higher orders. If None, no magnitude limit is applied and all sources are included. max_pixels_per_chunk : int, optional Maximum number of pixels to disperse simultaneously. oversample_factor : int, optional Wavelength oversampling factor. polyfit_degree : int, optional Degree of polynomial fit to spectral shape. If None (the default), do not attempt polynomial fitting and just use the flat-spectrum simulated slit. n_iterations : int, optional Number of times to iterate the contamination correction. On each iteration the polynomial fit is re-run using the contamination-corrected spectrum from the previous iteration, yielding a progressively better estimate of each source's true spectral shape (and therefore a better contamination estimate for its neighbors). Requires ``polyfit_degree`` to be set; if ``polyfit_degree`` is None this parameter is ignored and a single iteration is performed. l2_alpha : float, optional L2 regularization strength for the polynomial spectral fit, passed to `~jwst.wfss_contam.wavefit.fit_slit_by_basis_images`. rejection_threshold : float, optional Threshold for rejecting fits based on the fitted constant term coefficient, passed to `~jwst.wfss_contam.wavefit.fit_slit_by_basis_images`. Returns ------- output_model : `~stdatamodels.jwst.datamodels.MultiSlitModel` A copy of the input_model that has been corrected simul_model : `~stdatamodels.jwst.datamodels.ImageModel` Full-frame simulated image of the grism exposure contam_model : `~stdatamodels.jwst.datamodels.MultiSlitModel` Contamination estimate images for each source slit """ max_available_cores = multiprocessing.cpu_count() # don't worry about case where nchunks < ncpus; just set nchunks large for now ncpus = compute_num_cores(max_cores, 1e10, max_available_cores) # Get the segmentation map and direct image for this grism exposure seg_model = datamodels.open(input_model.meta.segmentation_map) direct_file = input_model.meta.direct_image log.debug(f"Direct image ={direct_file}") with datamodels.open(direct_file) as direct_model: direct_image = direct_model.data direct_image_wcs = direct_model.meta.wcs # Get the grism WCS object from the first cutout in the input model. # This WCS is used to transform from direct image to grism frame for all sources # in the segmentation map. # The "detector" to "grism_detector" and "world" to "detector" transforms are identical # for all slits, so just use the first one. The "grism_detector" to "grism_slit" # transform is not used by the step. grism_wcs = input_model.slits[0].meta.wcs # Find out how many spectral orders are defined based on the # array of order values in the Wavelengthrange ref file, # then constrain the orders to the user-specified ones spec_orders = np.asarray(waverange.order) spec_orders = _validate_orders_against_reference(orders, spec_orders) spec_orders = _validate_orders_against_transform(grism_wcs, spec_orders) if len(spec_orders) == 0: log.error("No valid spectral orders found. Step will be SKIPPED.") return input_model, None, None, None log.info(f"Spectral orders requested = {[int(x) for x in spec_orders]}") # Get the FILTER and PUPIL wheel positions, for use later filter_kwd = input_model.meta.instrument.filter pupil_kwd = input_model.meta.instrument.pupil # NOTE: The NIRCam WFSS mode uses filters that are in the FILTER wheel # with gratings in the PUPIL wheel. NIRISS WFSS mode, however, is just # the opposite. It has gratings in the FILTER wheel and filters in the # PUPIL wheel. So when processing NIRISS grism exposures the name of # filter needs to come from the PUPIL keyword value. if input_model.meta.instrument.name == "NIRISS": filter_name = pupil_kwd else: filter_name = filter_kwd # Read the source catalog to perform magnitude-based source selection later # mag limit will be scaled according to order 1 sensitivity if magnitude_limit is not None: source_catalog = read_source_catalog(input_model.meta.source_catalog) order1_wave_response, order1_sens_response = get_photom_data( photom, filter_kwd, pupil_kwd, order=1 ) min_relresp_order1 = _find_min_relresp(order1_wave_response, order1_sens_response) # set up observation object to disperse obs = Observation( direct_image, seg_model.data, grism_wcs, direct_image_wcs, boundaries=[0, 2047, 0, 2047], max_cpu=ncpus, max_pixels_per_chunk=max_pixels_per_chunk, oversample_factor=oversample_factor, ) seg_model.close() no_sources = True for order in spec_orders: # Load lists of wavelength ranges and flux cal info wavelength_range = waverange.get_wfss_wavelength_range(filter_name, [order]) wmin = wavelength_range[order][0] wmax = wavelength_range[order][1] log.debug(f"wmin={wmin}, wmax={wmax} for order {order}") sens_waves, sens_response = get_photom_data(photom, filter_kwd, pupil_kwd, order) # Build Legendre basis flux models for disperse() if polynomial fitting is requested. # wmin/wmax are order-specific, so construction must happen inside the order loop. # The constant term (k=0, i.e. slit.data) is always included; start at degree 1. basis_models = None if polyfit_degree is not None: basis_models = [_LegendreFluxModel(k, wmin, wmax) for k in range(1, polyfit_degree + 1)] # Constrain the source IDs to those that are below the magnitude limit selected_ids = None if magnitude_limit is not None: good_ids = _apply_magnitude_limit( order, source_catalog, sens_waves, sens_response, magnitude_limit, min_relresp_order1, ) if good_ids is None: log.info( f"No sources meet the magnitude limit of {magnitude_limit} for order {order}. " "Skipping contamination correction for this order." ) continue selected_ids = good_ids no_sources = False # Compute the dispersion for all sources in this order log.info(f"Creating full simulated grism image for order {order}") obs.disperse_order( order, wmin, wmax, sens_waves, sens_response, selected_ids, basis_models=basis_models, ) if no_sources: log.error( f"No sources found that met the magnitude limit {magnitude_limit}. Step will be SKIPPED" ) return input_model, None, None, None # Initialize output multislitmodel output_model = datamodels.MultiSlitModel() # Copy over matching slits. # Note that this makes a reference to input slits, not a deep copy, # so the input data may be modified by this function. The input data is # copied in the calling step, as needed. good_slits = [slit for slit in input_model.slits if slit.source_id in obs.source_ids] output_model.slits.extend(good_slits) contam_model = datamodels.MultiSlitModel() contam_model.update(input_model, only="PRIMARY") simul_slits = datamodels.MultiSlitModel() simul_slits.update(input_model, only="PRIMARY") # Hold onto original input data so iterative corrections always start from the same baseline. original_data = [np.array(slit.data) for slit in output_model.slits] if n_iterations > 1 and polyfit_degree is None: log.warning( "n_iterations > 1 has no effect when polyfit_degree is None " "(there is no spectral fit to iterate). Only one iteration will be performed." ) n_iterations = 1 # Match simulated slits to observed slits matched_flat_simuls, good_idxs = _match_simulated_slits(output_model, obs) # Hold onto the flat-spectrum full-resolution data for obs.simulated_slits so we can # reset it at the start of each iteration before applying new spectral fits. flat_full_data = [np.array(s.data) for s in obs.simulated_slits.slits] # Save the flat-spectrum .data for each matched slit so we can reset it each iteration # without deepcopying the whole SlitModel. _fit_spectral_shape only reassigns .data # (never modifies fluxmodel_N), so restoring just the data array is sufficient. flat_matched_data = [np.array(s.data) if s is not None else None for s in matched_flat_simuls] # Iterate: each pass re-fits spectral shapes using the contamination-corrected # data from the previous pass, giving progressively better contamination estimates. if polyfit_degree is not None: log.info( f"Using polyfit_degree={polyfit_degree} " f"for spectral fitting over {n_iterations} iterations" ) # check that background subtraction did not fail if input_model.meta.cal_step.bkg_subtract != "COMPLETE": log.warning( f"Background subtraction step status is {input_model.meta.cal_step.bkg_subtract}. " "A good background subtraction is necessary for models to match observed data " "well enough for spectral fitting to succeed. Fitting will be attempted, " "but failures may be expected." ) per_slit_simuls = [] contam_cuts = [] for iteration in range(n_iterations): if n_iterations > 1: log.info(f"Contamination correction iteration {iteration + 1} of {n_iterations}") # Reset obs.simulated_slits to flat-spectrum before applying new spectral fits. for s, d in zip(obs.simulated_slits.slits, flat_full_data, strict=True): s.data = d.copy() per_slit_simuls = [] success = 0 for i, slit in enumerate(output_model.slits): matched_flat = matched_flat_simuls[i] if matched_flat is None: per_slit_simuls.append(None) continue # Reset to flat-spectrum data; avoids a deepcopy each iteration because # _fit_spectral_shape only reassigns .data and never modifies fluxmodel_N. matched_flat.data = flat_matched_data[i].copy() # slit.data holds the contamination-corrected data from the previous iteration # (or the original input on the first pass). success += _fit_spectral_shape( slit, matched_flat, obs.simulated_slits.slits[good_idxs[i]], polyfit_degree, l2_alpha=l2_alpha, rejection_threshold=rejection_threshold, ) per_slit_simuls.append(matched_flat) log.info( f"Spectral fitting successful for {success} out of {len(output_model.slits)} slits " f"in iteration {iteration + 1}. Turn on debug logging for details of failures." ) # Rebuild full-frame simulation from the simulated slits. # This accounts for the fitted flux if a matched observation was found and the fit # was successful. Otherwise this is just the flat simulation. # It always includes all simulated sources even if unmatched to observed slits, e.g. # because extract_2d was run with a small value of wfss_nbright. simul_data = _build_simulated_image_from_slits( obs.simulated_slits, obs.simulated_image.shape ) # Compute per-slit contamination and update corrected data for the next iteration. contam_cuts = [] for i, (slit, this_simul) in enumerate( zip(output_model.slits, per_slit_simuls, strict=True) ): if this_simul is None: contam_cut = np.zeros_like(slit.data) else: simul_all_cut = _cut_frame_to_match_slit(simul_data, slit) contam_cut = simul_all_cut - this_simul.data # Always subtract from the original input, not the previous iteration's output, # so errors do not accumulate across iterations. slit.data = original_data[i] - contam_cut contam_cuts.append(contam_cut) if success == 0: log.warning( f"No successful spectral fits in iteration {iteration + 1}. " "Will not continue iterating. Ensure that the background is well subtracted, and " "consider reducing polyfit_degree or increasing l2_alpha to improve fit stability." ) break # Build output contam_model and simul_slits from the final iteration's results. log.info("Creating contamination image for each individual source") for i in range(len(per_slit_simuls)): this_obs = output_model.slits[i] this_simul = per_slit_simuls[i] contam_cut = contam_cuts[i] if this_simul is not None: # turn it into a SlitModel so we can use model.update this_simul = datamodels.SlitModel(this_simul.instance) this_simul.update(this_obs) simul_slits.slits.append(this_simul) contam_slit = datamodels.SlitModel() contam_slit.update(this_obs) contam_slit.data = contam_cut contam_model.slits.append(contam_slit) simul_model = datamodels.ImageModel(data=simul_data) simul_model.update(input_model, only="PRIMARY") output_model.update(input_model, only="PRIMARY") output_model.meta.cal_step.wfss_contam = "COMPLETE" return output_model, simul_model, contam_model, simul_slits