# Licensed under a 3-clause BSD style license - see LICENSE.rst """ This module is to contain an improved bounding box. """ from __future__ import annotations import abc import copy import warnings from typing import TYPE_CHECKING, NamedTuple import numpy as np from astropy.units import Quantity from astropy.utils import isiterable from astropy.utils.compat import COPY_IF_NEEDED if TYPE_CHECKING: from collections.abc import Callable from typing import Any, Self from astropy.units import UnitBase __all__ = ["ModelBoundingBox", "CompoundBoundingBox"] class _BaseInterval(NamedTuple): lower: float upper: float class _Interval(_BaseInterval): """ A single input's bounding box interval. Parameters ---------- lower : float The lower bound of the interval upper : float The upper bound of the interval Methods ------- validate : Constructs a valid interval outside : Determine which parts of an input array are outside the interval. domain : Constructs a discretization of the points inside the interval. """ def __repr__(self): return f"Interval(lower={self.lower}, upper={self.upper})" def copy(self): return copy.deepcopy(self) @staticmethod def _validate_shape(interval): """Validate the shape of an interval representation.""" MESSAGE = """An interval must be some sort of sequence of length 2""" try: shape = np.shape(interval) except TypeError: try: # np.shape does not work with lists of Quantities if len(interval) == 1: interval = interval[0] shape = np.shape([b.to_value() for b in interval]) except (ValueError, TypeError, AttributeError): raise ValueError(MESSAGE) valid_shape = shape in ((2,), (1, 2), (2, 0)) if not valid_shape: valid_shape = ( len(shape) > 0 and shape[0] == 2 and all(isinstance(b, np.ndarray) for b in interval) ) if not isiterable(interval) or not valid_shape: raise ValueError(MESSAGE) @classmethod def _validate_bounds(cls, lower, upper): """Validate the bounds are reasonable and construct an interval from them.""" if (np.asanyarray(lower) > np.asanyarray(upper)).all(): warnings.warn( f"Invalid interval: upper bound {upper} " f"is strictly less than lower bound {lower}.", RuntimeWarning, ) return cls(lower, upper) @classmethod def validate(cls, interval): """ Construct and validate an interval. Parameters ---------- interval : iterable A representation of the interval. Returns ------- A validated interval. """ cls._validate_shape(interval) if len(interval) == 1: interval = tuple(interval[0]) else: interval = tuple(interval) return cls._validate_bounds(interval[0], interval[1]) def outside(self, _input: np.ndarray): """ Parameters ---------- _input : np.ndarray The evaluation input in the form of an array. Returns ------- Boolean array indicating which parts of _input are outside the interval: True -> position outside interval False -> position inside interval """ return np.logical_or(_input < self.lower, _input > self.upper) def domain(self, resolution): return np.arange(self.lower, self.upper + resolution, resolution) # The interval where all ignored inputs can be found. _ignored_interval = _Interval.validate((-np.inf, np.inf)) def get_index(model, key) -> int: """ Get the input index corresponding to the given key. Can pass in either: the string name of the input or the input index itself. """ if isinstance(key, str): if key in model.inputs: index = model.inputs.index(key) else: raise ValueError(f"'{key}' is not one of the inputs: {model.inputs}.") elif np.issubdtype(type(key), np.integer): if 0 <= key < len(model.inputs): index = key else: raise IndexError( f"Integer key: {key} must be non-negative and < {len(model.inputs)}." ) else: raise ValueError(f"Key value: {key} must be string or integer.") return index def get_name(model, index: int): """Get the input name corresponding to the input index.""" return model.inputs[index] class _BoundingDomain(abc.ABC): """ Base class for ModelBoundingBox and CompoundBoundingBox. This is where all the `~astropy.modeling.core.Model` evaluation code for evaluating with a bounding box is because it is common to both types of bounding box. Parameters ---------- model : `~astropy.modeling.Model` The Model this bounding domain is for. prepare_inputs : Generates the necessary input information so that model can be evaluated only for input points entirely inside bounding_box. This needs to be implemented by a subclass. Note that most of the implementation is in ModelBoundingBox. prepare_outputs : Fills the output values in for any input points outside the bounding_box. evaluate : Performs a complete model evaluation while enforcing the bounds on the inputs and returns a complete output. """ def __init__(self, model, ignored: list[int] | None = None, order: str = "C"): self._model = model self._ignored = self._validate_ignored(ignored) self._order = self._get_order(order) @property def model(self): return self._model @property def order(self) -> str: return self._order @property def ignored(self) -> list[int]: return self._ignored def _get_order(self, order: str | None = None) -> str: """ Get if bounding_box is C/python ordered or Fortran/mathematically ordered. """ if order is None: order = self._order if order not in ("C", "F"): raise ValueError( "order must be either 'C' (C/python order) or " f"'F' (Fortran/mathematical order), got: {order}." ) return order def _get_index(self, key) -> int: """ Get the input index corresponding to the given key. Can pass in either: the string name of the input or the input index itself. """ return get_index(self._model, key) def _get_name(self, index: int): """Get the input name corresponding to the input index.""" return get_name(self._model, index) @property def ignored_inputs(self) -> list[str]: return [self._get_name(index) for index in self._ignored] def _validate_ignored(self, ignored: list) -> list[int]: if ignored is None: return [] else: return [self._get_index(key) for key in ignored] def __call__(self, *args, **kwargs): raise NotImplementedError( "This bounding box is fixed by the model and does not have " "adjustable parameters." ) @abc.abstractmethod def fix_inputs(self, model, fixed_inputs: dict): """ Fix the bounding_box for a `fix_inputs` compound model. Parameters ---------- model : `~astropy.modeling.Model` The new model for which this will be a bounding_box fixed_inputs : dict Dictionary of inputs which have been fixed by this bounding box. """ raise NotImplementedError("This should be implemented by a child class.") @abc.abstractmethod def prepare_inputs(self, input_shape, inputs) -> tuple[Any, Any, Any]: """ Get prepare the inputs with respect to the bounding box. Parameters ---------- input_shape : tuple The shape that all inputs have be reshaped/broadcasted into inputs : list List of all the model inputs Returns ------- valid_inputs : list The inputs reduced to just those inputs which are all inside their respective bounding box intervals valid_index : array_like array of all indices inside the bounding box all_out: bool if all of the inputs are outside the bounding_box """ raise NotImplementedError("This has not been implemented for BoundingDomain.") @staticmethod def _base_output(input_shape, fill_value): """ Create a baseline output, assuming that the entire input is outside the bounding box. Parameters ---------- input_shape : tuple The shape that all inputs have be reshaped/broadcasted into fill_value : float The value which will be assigned to inputs which are outside the bounding box Returns ------- An array of the correct shape containing all fill_value """ return np.zeros(input_shape) + fill_value def _all_out_output(self, input_shape, fill_value): """ Create output if all inputs are outside the domain. Parameters ---------- input_shape : tuple The shape that all inputs have be reshaped/broadcasted into fill_value : float The value which will be assigned to inputs which are outside the bounding box Returns ------- A full set of outputs for case that all inputs are outside domain. """ return [ self._base_output(input_shape, fill_value) for _ in range(self._model.n_outputs) ], None def _modify_output(self, valid_output, valid_index, input_shape, fill_value): """ For a single output fill in all the parts corresponding to inputs outside the bounding box. Parameters ---------- valid_output : numpy array The output from the model corresponding to inputs inside the bounding box valid_index : numpy array array of all indices of inputs inside the bounding box input_shape : tuple The shape that all inputs have be reshaped/broadcasted into fill_value : float The value which will be assigned to inputs which are outside the bounding box Returns ------- An output array with all the indices corresponding to inputs outside the bounding box filled in by fill_value """ output = self._base_output(input_shape, fill_value) if not output.shape: output = np.array(valid_output) else: output[valid_index] = valid_output if np.isscalar(valid_output): output = output.item(0) return output def _prepare_outputs(self, valid_outputs, valid_index, input_shape, fill_value): """ Fill in all the outputs of the model corresponding to inputs outside the bounding_box. Parameters ---------- valid_outputs : list of numpy array The list of outputs from the model corresponding to inputs inside the bounding box valid_index : numpy array array of all indices of inputs inside the bounding box input_shape : tuple The shape that all inputs have be reshaped/broadcasted into fill_value : float The value which will be assigned to inputs which are outside the bounding box Returns ------- List of filled in output arrays. """ outputs = [] for valid_output in valid_outputs: outputs.append( self._modify_output(valid_output, valid_index, input_shape, fill_value) ) return outputs def prepare_outputs(self, valid_outputs, valid_index, input_shape, fill_value): """ Fill in all the outputs of the model corresponding to inputs outside the bounding_box, adjusting any single output model so that its output becomes a list of containing that output. Parameters ---------- valid_outputs : list The list of outputs from the model corresponding to inputs inside the bounding box valid_index : array_like array of all indices of inputs inside the bounding box input_shape : tuple The shape that all inputs have be reshaped/broadcasted into fill_value : float The value which will be assigned to inputs which are outside the bounding box """ if self._model.n_outputs == 1: valid_outputs = [valid_outputs] return self._prepare_outputs( valid_outputs, valid_index, input_shape, fill_value ) @staticmethod def _get_valid_outputs_unit(valid_outputs, with_units: bool) -> UnitBase | None: """ Get the unit for outputs if one is required. Parameters ---------- valid_outputs : list of numpy array The list of outputs from the model corresponding to inputs inside the bounding box with_units : bool whether or not a unit is required """ if with_units: return getattr(valid_outputs, "unit", None) def _evaluate_model( self, evaluate: Callable, valid_inputs, valid_index, input_shape, fill_value, with_units: bool, ): """ Evaluate the model using the given evaluate routine. Parameters ---------- evaluate : Callable callable which takes in the valid inputs to evaluate model valid_inputs : list of numpy arrays The inputs reduced to just those inputs which are all inside their respective bounding box intervals valid_index : numpy array array of all indices inside the bounding box input_shape : tuple The shape that all inputs have be reshaped/broadcasted into fill_value : float The value which will be assigned to inputs which are outside the bounding box with_units : bool whether or not a unit is required Returns ------- outputs : list containing filled in output values valid_outputs_unit : the unit that will be attached to the outputs """ valid_outputs = evaluate(valid_inputs) valid_outputs_unit = self._get_valid_outputs_unit(valid_outputs, with_units) return ( self.prepare_outputs(valid_outputs, valid_index, input_shape, fill_value), valid_outputs_unit, ) def _evaluate( self, evaluate: Callable, inputs, input_shape, fill_value, with_units: bool ): """Evaluate model with steps: prepare_inputs -> evaluate -> prepare_outputs. Parameters ---------- evaluate : Callable callable which takes in the valid inputs to evaluate model valid_inputs : list of numpy arrays The inputs reduced to just those inputs which are all inside their respective bounding box intervals valid_index : numpy array array of all indices inside the bounding box input_shape : tuple The shape that all inputs have be reshaped/broadcasted into fill_value : float The value which will be assigned to inputs which are outside the bounding box with_units : bool whether or not a unit is required Returns ------- outputs : list containing filled in output values valid_outputs_unit : the unit that will be attached to the outputs """ valid_inputs, valid_index, all_out = self.prepare_inputs(input_shape, inputs) if all_out: return self._all_out_output(input_shape, fill_value) else: return self._evaluate_model( evaluate, valid_inputs, valid_index, input_shape, fill_value, with_units ) @staticmethod def _set_outputs_unit(outputs, valid_outputs_unit): """ Set the units on the outputs prepare_inputs -> evaluate -> prepare_outputs -> set output units. Parameters ---------- outputs : list containing filled in output values valid_outputs_unit : the unit that will be attached to the outputs Returns ------- List containing filled in output values and units """ if valid_outputs_unit is not None: return Quantity( outputs, valid_outputs_unit, copy=COPY_IF_NEEDED, subok=True ) return outputs def evaluate(self, evaluate: Callable, inputs, fill_value): """ Perform full model evaluation steps: prepare_inputs -> evaluate -> prepare_outputs -> set output units. Parameters ---------- evaluate : callable callable which takes in the valid inputs to evaluate model valid_inputs : list The inputs reduced to just those inputs which are all inside their respective bounding box intervals valid_index : array_like array of all indices inside the bounding box fill_value : float The value which will be assigned to inputs which are outside the bounding box """ input_shape = self._model.input_shape(inputs) # NOTE: CompoundModel does not currently support units during # evaluation for bounding_box so this feature is turned off # for CompoundModel(s). outputs, valid_outputs_unit = self._evaluate( evaluate, inputs, input_shape, fill_value, self._model.bbox_with_units ) return tuple(self._set_outputs_unit(outputs, valid_outputs_unit)) class ModelBoundingBox(_BoundingDomain): """ A model's bounding box. Parameters ---------- intervals : dict A dictionary containing all the intervals for each model input keys -> input index values -> interval for that index model : `~astropy.modeling.Model` The Model this bounding_box is for. ignored : list A list containing all the inputs (index) which will not be checked for whether or not their elements are in/out of an interval. order : optional, str The ordering that is assumed for the tuple representation of this bounding_box. Options: 'C': C/Python order, e.g. z, y, x. (default), 'F': Fortran/mathematical notation order, e.g. x, y, z. """ def __init__( self, intervals: dict[int, _Interval], model, ignored: list[int] | None = None, order: str = "C", ): super().__init__(model, ignored, order) self._intervals = {} if intervals != () and intervals != {}: self._validate(intervals, order=order) def copy(self, ignored=None): intervals = { index: interval.copy() for index, interval in self._intervals.items() } if ignored is None: ignored = self._ignored.copy() return ModelBoundingBox( intervals, self._model, ignored=ignored, order=self._order ) @property def intervals(self) -> dict[int, _Interval]: """Return bounding_box labeled using input positions.""" return self._intervals @property def named_intervals(self) -> dict[str, _Interval]: """Return bounding_box labeled using input names.""" return {self._get_name(index): bbox for index, bbox in self._intervals.items()} def __repr__(self): parts = ["ModelBoundingBox(", " intervals={"] for name, interval in self.named_intervals.items(): parts.append(f" {name}: {interval}") parts.append(" }") if len(self._ignored) > 0: parts.append(f" ignored={self.ignored_inputs}") parts.append( f" model={self._model.__class__.__name__}(inputs={self._model.inputs})" ) parts.append(f" order='{self._order}'") parts.append(")") return "\n".join(parts) def __len__(self): return len(self._intervals) def __contains__(self, key): try: return self._get_index(key) in self._intervals or self._ignored except (IndexError, ValueError): return False def has_interval(self, key): return self._get_index(key) in self._intervals def __getitem__(self, key): """Get bounding_box entries by either input name or input index.""" index = self._get_index(key) if index in self._ignored: return _ignored_interval else: return self._intervals[self._get_index(key)] def bounding_box( self, order: str | None = None ) -> tuple[float, float] | tuple[tuple[float, float], ...]: """ Return the old tuple of tuples representation of the bounding_box order='C' corresponds to the old bounding_box ordering order='F' corresponds to the gwcs bounding_box ordering. """ if len(self._intervals) == 1: return tuple(next(iter(self._intervals.values()))) else: order = self._get_order(order) inputs = self._model.inputs if order == "C": inputs = inputs[::-1] bbox = tuple(tuple(self[input_name]) for input_name in inputs) if len(bbox) == 1: bbox = bbox[0] return bbox def __eq__(self, value): """Note equality can be either with old representation or new one.""" if isinstance(value, tuple): return self.bounding_box() == value elif isinstance(value, ModelBoundingBox): return (self.intervals == value.intervals) and ( self.ignored == value.ignored ) else: return False def __setitem__(self, key, value): """Validate and store interval under key (input index or input name).""" index = self._get_index(key) if index in self._ignored: self._ignored.remove(index) self._intervals[index] = _Interval.validate(value) def __delitem__(self, key): """Delete stored interval.""" index = self._get_index(key) if index in self._ignored: raise RuntimeError(f"Cannot delete ignored input: {key}!") del self._intervals[index] self._ignored.append(index) def _validate_dict(self, bounding_box: dict): """Validate passing dictionary of intervals and setting them.""" for key, value in bounding_box.items(): self[key] = value @property def _available_input_index(self): model_input_index = [self._get_index(_input) for _input in self._model.inputs] return [_input for _input in model_input_index if _input not in self._ignored] def _validate_sequence(self, bounding_box, order: str | None = None): """ Validate passing tuple of tuples representation (or related) and setting them. """ order = self._get_order(order) if order == "C": # If bounding_box is C/python ordered, it needs to be reversed # to be in Fortran/mathematical/input order. bounding_box = bounding_box[::-1] for index, value in enumerate(bounding_box): self[self._available_input_index[index]] = value @property def _n_inputs(self) -> int: n_inputs = self._model.n_inputs - len(self._ignored) if n_inputs > 0: return n_inputs else: return 0 def _validate_iterable(self, bounding_box, order: str | None = None): """Validate and set any iterable representation.""" if len(bounding_box) != self._n_inputs: raise ValueError( f"Found {len(bounding_box)} intervals, " f"but must have exactly {self._n_inputs}." ) if isinstance(bounding_box, dict): self._validate_dict(bounding_box) else: self._validate_sequence(bounding_box, order) def _validate(self, bounding_box, order: str | None = None): """Validate and set any representation.""" if self._n_inputs == 1 and not isinstance(bounding_box, dict): self[self._available_input_index[0]] = bounding_box else: self._validate_iterable(bounding_box, order) @classmethod def validate( cls, model, bounding_box, ignored: list | None = None, order: str = "C", _preserve_ignore: bool = False, **kwargs, ) -> Self: """ Construct a valid bounding box for a model. Parameters ---------- model : `~astropy.modeling.Model` The model for which this will be a bounding_box bounding_box : dict, tuple A possible representation of the bounding box order : optional, str The order that a tuple representation will be assumed to be Default: 'C' """ if isinstance(bounding_box, ModelBoundingBox): order = bounding_box.order if _preserve_ignore: ignored = bounding_box.ignored bounding_box = bounding_box.named_intervals new = cls({}, model, ignored=ignored, order=order) new._validate(bounding_box) return new def fix_inputs(self, model, fixed_inputs: dict, _keep_ignored=False) -> Self: """ Fix the bounding_box for a `fix_inputs` compound model. Parameters ---------- model : `~astropy.modeling.Model` The new model for which this will be a bounding_box fixed_inputs : dict Dictionary of inputs which have been fixed by this bounding box. keep_ignored : bool Keep the ignored inputs of the bounding box (internal argument only) """ new = self.copy() for _input in fixed_inputs.keys(): del new[_input] if _keep_ignored: ignored = new.ignored else: ignored = None return ModelBoundingBox.validate( model, new.named_intervals, ignored=ignored, order=new._order ) @property def dimension(self): return len(self) def domain(self, resolution, order: str | None = None) -> list[np.ndarray]: inputs = self._model.inputs order = self._get_order(order) if order == "C": inputs = inputs[::-1] return [self[input_name].domain(resolution) for input_name in inputs] def _outside(self, input_shape, inputs): """ Get all the input positions which are outside the bounding_box, so that the corresponding outputs can be filled with the fill value (default NaN). Parameters ---------- input_shape : tuple The shape that all inputs have be reshaped/broadcasted into inputs : list List of all the model inputs Returns ------- outside_index : bool-numpy array True -> position outside bounding_box False -> position inside bounding_box all_out : bool if all of the inputs are outside the bounding_box """ all_out = False outside_index = np.zeros(input_shape, dtype=bool) for index, _input in enumerate(inputs): _input = np.asanyarray(_input) outside = np.broadcast_to(self[index].outside(_input), input_shape) outside_index[outside] = True if outside_index.all(): all_out = True break return outside_index, all_out def _valid_index(self, input_shape, inputs): """ Get the indices of all the inputs inside the bounding_box. Parameters ---------- input_shape : tuple The shape that all inputs have be reshaped/broadcasted into inputs : list List of all the model inputs Returns ------- valid_index : numpy array array of all indices inside the bounding box all_out : bool if all of the inputs are outside the bounding_box """ outside_index, all_out = self._outside(input_shape, inputs) valid_index = np.atleast_1d(np.logical_not(outside_index)).nonzero() if len(valid_index[0]) == 0: all_out = True return valid_index, all_out def prepare_inputs(self, input_shape, inputs) -> tuple[Any, Any, Any]: """ Get prepare the inputs with respect to the bounding box. Parameters ---------- input_shape : tuple The shape that all inputs have be reshaped/broadcasted into inputs : list List of all the model inputs Returns ------- valid_inputs : list The inputs reduced to just those inputs which are all inside their respective bounding box intervals valid_index : array_like array of all indices inside the bounding box all_out: bool if all of the inputs are outside the bounding_box """ valid_index, all_out = self._valid_index(input_shape, inputs) valid_inputs = [] if not all_out: for _input in inputs: if input_shape: valid_input = np.broadcast_to(np.atleast_1d(_input), input_shape)[ valid_index ] if np.isscalar(_input): valid_input = valid_input.item(0) valid_inputs.append(valid_input) else: valid_inputs.append(_input) return tuple(valid_inputs), valid_index, all_out class _BaseSelectorArgument(NamedTuple): index: int ignore: bool class _SelectorArgument(_BaseSelectorArgument): """ Contains a single CompoundBoundingBox slicing input. Parameters ---------- index : int The index of the input in the input list ignore : bool Whether or not this input will be ignored by the bounding box. Methods ------- validate : Returns a valid SelectorArgument for a given model. get_selector : Returns the value of the input for use in finding the correct bounding_box. get_fixed_value : Gets the slicing value from a fix_inputs set of values. """ def __new__(cls, index, ignore): self = super().__new__(cls, index, ignore) return self @classmethod def validate(cls, model, argument, ignored: bool = True) -> Self: """ Construct a valid selector argument for a CompoundBoundingBox. Parameters ---------- model : `~astropy.modeling.Model` The model for which this will be an argument for. argument : int or str A representation of which evaluation input to use ignored : optional, bool Whether or not to ignore this argument in the ModelBoundingBox. Returns ------- Validated selector_argument """ return cls(get_index(model, argument), ignored) def get_selector(self, *inputs): """ Get the selector value corresponding to this argument. Parameters ---------- *inputs : All the processed model evaluation inputs. """ _selector = inputs[self.index] if isiterable(_selector): if len(_selector) == 1: return _selector[0] else: return tuple(_selector) return _selector def name(self, model) -> str: """ Get the name of the input described by this selector argument. Parameters ---------- model : `~astropy.modeling.Model` The Model this selector argument is for. """ return get_name(model, self.index) def pretty_repr(self, model): """ Get a pretty-print representation of this object. Parameters ---------- model : `~astropy.modeling.Model` The Model this selector argument is for. """ return f"Argument(name='{self.name(model)}', ignore={self.ignore})" def get_fixed_value(self, model, values: dict): """ Gets the value fixed input corresponding to this argument. Parameters ---------- model : `~astropy.modeling.Model` The Model this selector argument is for. values : dict Dictionary of fixed inputs. """ if self.index in values: return values[self.index] else: if self.name(model) in values: return values[self.name(model)] else: raise RuntimeError( f"{self.pretty_repr(model)} was not found in {values}" ) def is_argument(self, model, argument) -> bool: """ Determine if passed argument is described by this selector argument. Parameters ---------- model : `~astropy.modeling.Model` The Model this selector argument is for. argument : int or str A representation of which evaluation input is being used """ return self.index == get_index(model, argument) def named_tuple(self, model): """ Get a tuple representation of this argument using the input name from the model. Parameters ---------- model : `~astropy.modeling.Model` The Model this selector argument is for. """ return (self.name(model), self.ignore) class _SelectorArguments(tuple): """ Contains the CompoundBoundingBox slicing description. Parameters ---------- input_ : The SelectorArgument values Methods ------- validate : Returns a valid SelectorArguments for its model. get_selector : Returns the selector a set of inputs corresponds to. is_selector : Determines if a selector is correctly formatted for this CompoundBoundingBox. get_fixed_value : Gets the selector from a fix_inputs set of values. """ _kept_ignore = None def __new__( cls, input_: tuple[_SelectorArgument], kept_ignore: list | None = None ) -> Self: self = super().__new__(cls, input_) if kept_ignore is None: self._kept_ignore = [] else: self._kept_ignore = kept_ignore return self def pretty_repr(self, model): """ Get a pretty-print representation of this object. Parameters ---------- model : `~astropy.modeling.Model` The Model these selector arguments are for. """ parts = ["SelectorArguments("] for argument in self: parts.append(f" {argument.pretty_repr(model)}") parts.append(")") return "\n".join(parts) @property def ignore(self): """Get the list of ignored inputs.""" ignore = [argument.index for argument in self if argument.ignore] ignore.extend(self._kept_ignore) return ignore @property def kept_ignore(self): """The arguments to persist in ignoring.""" return self._kept_ignore @classmethod def validate(cls, model, arguments, kept_ignore: list | None = None) -> Self: """ Construct a valid Selector description for a CompoundBoundingBox. Parameters ---------- model : `~astropy.modeling.Model` The Model these selector arguments are for. arguments : The individual argument information kept_ignore : Arguments to persist as ignored """ inputs = [] for argument in arguments: _input = _SelectorArgument.validate(model, *argument) if _input.index in [this.index for this in inputs]: raise ValueError( f"Input: '{get_name(model, _input.index)}' has been repeated." ) inputs.append(_input) if len(inputs) == 0: raise ValueError("There must be at least one selector argument.") if isinstance(arguments, _SelectorArguments): if kept_ignore is None: kept_ignore = [] kept_ignore.extend(arguments.kept_ignore) return cls(tuple(inputs), kept_ignore) def get_selector(self, *inputs): """ Get the selector corresponding to these inputs. Parameters ---------- *inputs : All the processed model evaluation inputs. """ return tuple(argument.get_selector(*inputs) for argument in self) def is_selector(self, _selector): """ Determine if this is a reasonable selector. Parameters ---------- _selector : tuple The selector to check """ return isinstance(_selector, tuple) and len(_selector) == len(self) def get_fixed_values(self, model, values: dict): """ Gets the value fixed input corresponding to this argument. Parameters ---------- model : `~astropy.modeling.Model` The Model these selector arguments are for. values : dict Dictionary of fixed inputs. """ return tuple(argument.get_fixed_value(model, values) for argument in self) def is_argument(self, model, argument) -> bool: """ Determine if passed argument is one of the selector arguments. Parameters ---------- model : `~astropy.modeling.Model` The Model these selector arguments are for. argument : int or str A representation of which evaluation input is being used """ return any(selector_arg.is_argument(model, argument) for selector_arg in self) def selector_index(self, model, argument): """ Get the index of the argument passed in the selector tuples. Parameters ---------- model : `~astropy.modeling.Model` The Model these selector arguments are for. argument : int or str A representation of which argument is being used """ for index, selector_arg in enumerate(self): if selector_arg.is_argument(model, argument): return index raise ValueError(f"{argument} does not correspond to any selector argument.") def reduce(self, model, argument): """ Reduce the selector arguments by the argument given. Parameters ---------- model : `~astropy.modeling.Model` The Model these selector arguments are for. argument : int or str A representation of which argument is being used """ arguments = list(self) kept_ignore = [arguments.pop(self.selector_index(model, argument)).index] kept_ignore.extend(self._kept_ignore) return _SelectorArguments.validate(model, tuple(arguments), kept_ignore) def add_ignore(self, model, argument): """ Add argument to the kept_ignore list. Parameters ---------- model : `~astropy.modeling.Model` The Model these selector arguments are for. argument : int or str A representation of which argument is being used """ if self.is_argument(model, argument): raise ValueError( f"{argument}: is a selector argument and cannot be ignored." ) kept_ignore = [get_index(model, argument)] return _SelectorArguments.validate(model, self, kept_ignore) def named_tuple(self, model): """ Get a tuple of selector argument tuples using input names. Parameters ---------- model : `~astropy.modeling.Model` The Model these selector arguments are for. """ return tuple(selector_arg.named_tuple(model) for selector_arg in self) class CompoundBoundingBox(_BoundingDomain): """ A model's compound bounding box. Parameters ---------- bounding_boxes : dict A dictionary containing all the ModelBoundingBoxes that are possible keys -> _selector (extracted from model inputs) values -> ModelBoundingBox model : `~astropy.modeling.Model` The Model this compound bounding_box is for. selector_args : _SelectorArguments A description of how to extract the selectors from model inputs. create_selector : optional A method which takes in the selector and the model to return a valid bounding corresponding to that selector. This can be used to construct new bounding_boxes for previously undefined selectors. These new boxes are then stored for future lookups. order : optional, str The ordering that is assumed for the tuple representation of the bounding_boxes. """ def __init__( self, bounding_boxes: dict[Any, ModelBoundingBox], model, selector_args: _SelectorArguments, create_selector: Callable | None = None, ignored: list[int] | None = None, order: str = "C", ): super().__init__(model, ignored, order) self._create_selector = create_selector self._selector_args = _SelectorArguments.validate(model, selector_args) self._bounding_boxes = {} self._validate(bounding_boxes) def copy(self): bounding_boxes = { selector: bbox.copy(self.selector_args.ignore) for selector, bbox in self._bounding_boxes.items() } return CompoundBoundingBox( bounding_boxes, self._model, selector_args=self._selector_args, create_selector=copy.deepcopy(self._create_selector), order=self._order, ) def __repr__(self): parts = ["CompoundBoundingBox(", " bounding_boxes={"] # bounding_boxes for _selector, bbox in self._bounding_boxes.items(): bbox_repr = bbox.__repr__().split("\n") parts.append(f" {_selector} = {bbox_repr.pop(0)}") for part in bbox_repr: parts.append(f" {part}") parts.append(" }") # selector_args selector_args_repr = self.selector_args.pretty_repr(self._model).split("\n") parts.append(f" selector_args = {selector_args_repr.pop(0)}") for part in selector_args_repr: parts.append(f" {part}") parts.append(")") return "\n".join(parts) @property def bounding_boxes(self) -> dict[Any, ModelBoundingBox]: return self._bounding_boxes @property def selector_args(self) -> _SelectorArguments: return self._selector_args @selector_args.setter def selector_args(self, value): self._selector_args = _SelectorArguments.validate(self._model, value) warnings.warn( "Overriding selector_args may cause problems you should re-validate " "the compound bounding box before use!", RuntimeWarning, ) @property def named_selector_tuple(self) -> tuple: return self._selector_args.named_tuple(self._model) @property def create_selector(self): return self._create_selector @staticmethod def _get_selector_key(key): if isiterable(key): return tuple(key) else: return (key,) def __setitem__(self, key, value): _selector = self._get_selector_key(key) if not self.selector_args.is_selector(_selector): raise ValueError(f"{_selector} is not a selector!") ignored = self.selector_args.ignore + self.ignored self._bounding_boxes[_selector] = ModelBoundingBox.validate( self._model, value, ignored, order=self._order ) def _validate(self, bounding_boxes: dict): for _selector, bounding_box in bounding_boxes.items(): self[_selector] = bounding_box def __eq__(self, value): if isinstance(value, CompoundBoundingBox): return ( self.bounding_boxes == value.bounding_boxes and self.selector_args == value.selector_args and self.create_selector == value.create_selector ) else: return False @classmethod def validate( cls, model, bounding_box: dict, selector_args=None, create_selector=None, ignored: list | None = None, order: str = "C", _preserve_ignore: bool = False, **kwarg, ) -> Self: """ Construct a valid compound bounding box for a model. Parameters ---------- model : `~astropy.modeling.Model` The model for which this will be a bounding_box bounding_box : dict Dictionary of possible bounding_box representations selector_args : optional Description of the selector arguments create_selector : optional, callable Method for generating new selectors order : optional, str The order that a tuple representation will be assumed to be Default: 'C' """ if isinstance(bounding_box, CompoundBoundingBox): if selector_args is None: selector_args = bounding_box.selector_args if create_selector is None: create_selector = bounding_box.create_selector order = bounding_box.order if _preserve_ignore: ignored = bounding_box.ignored bounding_box = bounding_box.bounding_boxes if selector_args is None: raise ValueError( "Selector arguments must be provided " "(can be passed as part of bounding_box argument)" ) return cls( bounding_box, model, selector_args, create_selector=create_selector, ignored=ignored, order=order, ) def __contains__(self, key): return key in self._bounding_boxes def _create_bounding_box(self, _selector): self[_selector] = self._create_selector(_selector, model=self._model) return self[_selector] def __getitem__(self, key): _selector = self._get_selector_key(key) if _selector in self: return self._bounding_boxes[_selector] elif self._create_selector is not None: return self._create_bounding_box(_selector) else: raise RuntimeError(f"No bounding box is defined for selector: {_selector}.") def _select_bounding_box(self, inputs) -> ModelBoundingBox: _selector = self.selector_args.get_selector(*inputs) return self[_selector] def prepare_inputs(self, input_shape, inputs) -> tuple[Any, Any, Any]: """ Get prepare the inputs with respect to the bounding box. Parameters ---------- input_shape : tuple The shape that all inputs have be reshaped/broadcasted into inputs : list List of all the model inputs Returns ------- valid_inputs : list The inputs reduced to just those inputs which are all inside their respective bounding box intervals valid_index : array_like array of all indices inside the bounding box all_out: bool if all of the inputs are outside the bounding_box """ bounding_box = self._select_bounding_box(inputs) return bounding_box.prepare_inputs(input_shape, inputs) def _matching_bounding_boxes(self, argument, value) -> dict[Any, ModelBoundingBox]: selector_index = self.selector_args.selector_index(self._model, argument) matching = {} for selector_key, bbox in self._bounding_boxes.items(): if selector_key[selector_index] == value: new_selector_key = list(selector_key) new_selector_key.pop(selector_index) if bbox.has_interval(argument): new_bbox = bbox.fix_inputs( self._model, {argument: value}, _keep_ignored=True ) else: new_bbox = bbox.copy() matching[tuple(new_selector_key)] = new_bbox if len(matching) == 0: raise ValueError( f"Attempting to fix input {argument}, but there are no " f"bounding boxes for argument value {value}." ) return matching def _fix_input_selector_arg(self, argument, value): matching_bounding_boxes = self._matching_bounding_boxes(argument, value) if len(self.selector_args) == 1: return matching_bounding_boxes[()] else: return CompoundBoundingBox( matching_bounding_boxes, self._model, self.selector_args.reduce(self._model, argument), ) def _fix_input_bbox_arg(self, argument, value): bounding_boxes = {} for selector_key, bbox in self._bounding_boxes.items(): bounding_boxes[selector_key] = bbox.fix_inputs( self._model, {argument: value}, _keep_ignored=True ) return CompoundBoundingBox( bounding_boxes, self._model, self.selector_args.add_ignore(self._model, argument), ) def fix_inputs(self, model, fixed_inputs: dict) -> Self: """ Fix the bounding_box for a `fix_inputs` compound model. Parameters ---------- model : `~astropy.modeling.Model` The new model for which this will be a bounding_box fixed_inputs : dict Dictionary of inputs which have been fixed by this bounding box. """ fixed_input_keys = list(fixed_inputs.keys()) argument = fixed_input_keys.pop() value = fixed_inputs[argument] if self.selector_args.is_argument(self._model, argument): bbox = self._fix_input_selector_arg(argument, value) else: bbox = self._fix_input_bbox_arg(argument, value) if len(fixed_input_keys) > 0: new_fixed_inputs = fixed_inputs.copy() del new_fixed_inputs[argument] bbox = bbox.fix_inputs(model, new_fixed_inputs) if isinstance(bbox, CompoundBoundingBox): selector_args = bbox.named_selector_tuple bbox_dict = bbox elif isinstance(bbox, ModelBoundingBox): selector_args = None bbox_dict = bbox.named_intervals return bbox.__class__.validate( model, bbox_dict, order=bbox.order, selector_args=selector_args )