# Licensed under a 3-clause BSD style license - see LICENSE.rst """ Tests that relate to evaluating models with quantity parameters """ import numpy as np import pytest from numpy.testing import assert_allclose from astropy import units as u from astropy.modeling.core import Model from astropy.modeling.models import Gaussian1D, Pix2Sky_TAN, Scale, Shift from astropy.tests.helper import assert_quantity_allclose from astropy.units import UnitsError # We start off by taking some simple cases where the units are defined by # whatever the model is initialized with, and we check that the model evaluation # returns quantities. MESSAGE = ( "{}: Units of input 'x', {}.*, could not be converted to required input units of" " {}.*" ) def test_evaluate_with_quantities(): """ Test evaluation of a single model with Quantity parameters that do not explicitly require units. """ # We create two models here - one with quantities, and one without. The one # without is used to create the reference values for comparison. g = Gaussian1D(1, 1, 0.1) gq = Gaussian1D(1 * u.J, 1 * u.m, 0.1 * u.m) # We first check that calling the Gaussian with quantities returns the # expected result assert_quantity_allclose(gq(1 * u.m), g(1) * u.J) # Units have to be specified for the Gaussian with quantities - if not, an # error is raised with pytest.raises(UnitsError, match=MESSAGE.format("Gaussian1D", "", "m ")): gq(1) # However, zero is a special case assert_quantity_allclose(gq(0), g(0) * u.J) # We can also evaluate models with equivalent units assert_allclose(gq(0.0005 * u.km).value, g(0.5)) # But not with incompatible units with pytest.raises(UnitsError, match=MESSAGE.format("Gaussian1D", "s", "m")): gq(3 * u.s) # We also can't evaluate the model without quantities with a quantity with pytest.raises( UnitsError, match=r"Can only apply 'subtract' function to dimensionless quantities .*", ): g(3 * u.m) # TODO: determine what error message should be here # assert exc.value.args[0] == ("Units of input 'x', m (length), could not be " # "converted to required dimensionless input") def test_evaluate_with_quantities_and_equivalencies(): """ We now make sure that equivalencies are correctly taken into account """ g = Gaussian1D(1 * u.Jy, 10 * u.nm, 2 * u.nm) # We aren't setting the equivalencies, so this won't work with pytest.raises(UnitsError, match=MESSAGE.format("Gaussian1D", "PHz", "nm")): g(30 * u.PHz) # But it should now work if we pass equivalencies when evaluating assert_quantity_allclose( g(30 * u.PHz, equivalencies={"x": u.spectral()}), g(9.993081933333332 * u.nm) ) class MyTestModel(Model): n_inputs = 2 n_outputs = 1 def evaluate(self, a, b): print("a", a) print("b", b) return a * b class TestInputUnits: def setup_method(self, method): self.model = MyTestModel() def test_evaluate(self): # We should be able to evaluate with anything assert_quantity_allclose(self.model(3, 5), 15) assert_quantity_allclose(self.model(4 * u.m, 5), 20 * u.m) assert_quantity_allclose(self.model(3 * u.deg, 5), 15 * u.deg) def test_input_units(self): self.model._input_units = {"x": u.deg} assert_quantity_allclose(self.model(3 * u.deg, 4), 12 * u.deg) assert_quantity_allclose(self.model(4 * u.rad, 2), 8 * u.rad) assert_quantity_allclose(self.model(4 * u.rad, 2 * u.s), 8 * u.rad * u.s) with pytest.raises(UnitsError, match=MESSAGE.format("MyTestModel", "s", "deg")): self.model(4 * u.s, 3) with pytest.raises(UnitsError, match=MESSAGE.format("MyTestModel", "", "deg")): self.model(3, 3) def test_input_units_allow_dimensionless(self): self.model._input_units = {"x": u.deg} self.model._input_units_allow_dimensionless = True assert_quantity_allclose(self.model(3 * u.deg, 4), 12 * u.deg) assert_quantity_allclose(self.model(4 * u.rad, 2), 8 * u.rad) with pytest.raises(UnitsError, match=MESSAGE.format("MyTestModel", "s", "deg")): self.model(4 * u.s, 3) assert_quantity_allclose(self.model(3, 3), 9) def test_input_units_strict(self): self.model._input_units = {"x": u.deg} self.model._input_units_strict = True assert_quantity_allclose(self.model(3 * u.deg, 4), 12 * u.deg) result = self.model(np.pi * u.rad, 2) assert_quantity_allclose(result, 360 * u.deg) assert result.unit is u.deg def test_input_units_equivalencies(self): self.model._input_units = {"x": u.micron} with pytest.raises( UnitsError, match=MESSAGE.format("MyTestModel", "PHz", "micron") ): self.model(3 * u.PHz, 3) self.model.input_units_equivalencies = {"x": u.spectral()} assert_quantity_allclose( self.model(3 * u.PHz, 3), 3 * (3 * u.PHz).to(u.micron, equivalencies=u.spectral()), ) def test_return_units(self): self.model._input_units = {"z": u.deg} self.model._return_units = {"z": u.rad} result = self.model(3 * u.deg, 4) assert_quantity_allclose(result, 12 * u.deg) assert result.unit is u.rad def test_return_units_scalar(self): # Check that return_units also works when giving a single unit since # there is only one output, so is unambiguous. self.model._input_units = {"x": u.deg} self.model._return_units = u.rad result = self.model(3 * u.deg, 4) assert_quantity_allclose(result, 12 * u.deg) assert result.unit is u.rad def test_and_input_units(): """ Test units to first model in chain. """ s1 = Shift(10 * u.deg) s2 = Shift(10 * u.deg) cs = s1 & s2 out = cs(10 * u.arcsecond, 20 * u.arcsecond) assert_quantity_allclose(out[0], 10 * u.deg + 10 * u.arcsec) assert_quantity_allclose(out[1], 10 * u.deg + 20 * u.arcsec) def test_plus_input_units(): """ Test units to first model in chain. """ s1 = Shift(10 * u.deg) s2 = Shift(10 * u.deg) cs = s1 + s2 out = cs(10 * u.arcsecond) assert_quantity_allclose(out, 20 * u.deg + 20 * u.arcsec) def test_compound_input_units(): """ Test units to first model in chain. """ s1 = Shift(10 * u.deg) s2 = Shift(10 * u.deg) cs = s1 | s2 out = cs(10 * u.arcsecond) assert_quantity_allclose(out, 20 * u.deg + 10 * u.arcsec) def test_compound_input_units_fail(): """ Test incompatible units to first model in chain. """ s1 = Shift(10 * u.deg) s2 = Shift(10 * u.deg) cs = s1 | s2 with pytest.raises(UnitsError, match=MESSAGE.format("Shift", "pix", "deg")): cs(10 * u.pix) def test_compound_incompatible_units_fail(): """ Test incompatible model units in chain. """ s1 = Shift(10 * u.pix) s2 = Shift(10 * u.deg) cs = s1 | s2 with pytest.raises(UnitsError, match=MESSAGE.format("Shift", "pix", "deg")): cs(10 * u.pix) def test_compound_pipe_equiv_call(): """ Check that equivalencies work when passed to evaluate, for a chained model (which has one input). """ s1 = Shift(10 * u.deg) s2 = Shift(10 * u.deg) cs = s1 | s2 out = cs(10 * u.pix, equivalencies={"x": u.pixel_scale(0.5 * u.deg / u.pix)}) assert_quantity_allclose(out, 25 * u.deg) def test_compound_and_equiv_call(): """ Check that equivalencies work when passed to evaluate, for a composite model with two inputs. """ s1 = Shift(10 * u.deg) s2 = Shift(10 * u.deg) cs = s1 & s2 out = cs( 10 * u.pix, 10 * u.pix, equivalencies={ "x0": u.pixel_scale(0.5 * u.deg / u.pix), "x1": u.pixel_scale(0.5 * u.deg / u.pix), }, ) assert_quantity_allclose(out[0], 15 * u.deg) assert_quantity_allclose(out[1], 15 * u.deg) def test_compound_input_units_equivalencies(): """ Test setting input_units_equivalencies on one of the models. """ s1 = Shift(10 * u.deg) s1.input_units_equivalencies = {"x": u.pixel_scale(0.5 * u.deg / u.pix)} s2 = Shift(10 * u.deg) sp = Shift(10 * u.pix) cs = s1 | s2 assert cs.input_units_equivalencies == {"x": u.pixel_scale(0.5 * u.deg / u.pix)} out = cs(10 * u.pix) assert_quantity_allclose(out, 25 * u.deg) cs = sp | s1 assert cs.input_units_equivalencies is None out = cs(10 * u.pix) assert_quantity_allclose(out, 20 * u.deg) cs = s1 & s2 assert cs.input_units_equivalencies == {"x0": u.pixel_scale(0.5 * u.deg / u.pix)} cs = cs.rename("TestModel") out = cs(20 * u.pix, 10 * u.deg) assert_quantity_allclose(out, 20 * u.deg) with pytest.raises(UnitsError, match=MESSAGE.format("Shift", "pix", "deg")): out = cs(20 * u.pix, 10 * u.pix) def test_compound_input_units_strict(): """ Test setting input_units_strict on one of the models. """ class ScaleDegrees(Scale): input_units = {"x": u.deg} s1 = ScaleDegrees(2) s2 = Scale(2) cs = s1 | s2 out = cs(10 * u.arcsec) assert_quantity_allclose(out, 40 * u.arcsec) assert out.unit is u.deg # important since this tests input_units_strict cs = s2 | s1 out = cs(10 * u.arcsec) assert_quantity_allclose(out, 40 * u.arcsec) assert out.unit is u.deg # important since this tests input_units_strict cs = s1 & s2 out = cs(10 * u.arcsec, 10 * u.arcsec) assert_quantity_allclose(out, 20 * u.arcsec) assert out[0].unit is u.deg assert out[1].unit is u.arcsec def test_compound_input_units_allow_dimensionless(): """ Test setting input_units_allow_dimensionless on one of the models. """ class ScaleDegrees(Scale): input_units = {"x": u.deg} s1 = ScaleDegrees(2) s1._input_units_allow_dimensionless = True s2 = Scale(2) cs = s1 | s2 cs = cs.rename("TestModel") out = cs(10) assert_quantity_allclose(out, 40 * u.one) out = cs(10 * u.arcsec) assert_quantity_allclose(out, 40 * u.arcsec) with pytest.raises(UnitsError, match=MESSAGE.format("ScaleDegrees", "m", "deg")): out = cs(10 * u.m) s1._input_units_allow_dimensionless = False cs = s1 | s2 cs = cs.rename("TestModel") with pytest.raises(UnitsError, match=MESSAGE.format("ScaleDegrees", "", "deg")): out = cs(10) s1._input_units_allow_dimensionless = True cs = s2 | s1 cs = cs.rename("TestModel") out = cs(10) assert_quantity_allclose(out, 40 * u.one) out = cs(10 * u.arcsec) assert_quantity_allclose(out, 40 * u.arcsec) with pytest.raises(UnitsError, match=MESSAGE.format("ScaleDegrees", "m", "deg")): out = cs(10 * u.m) s1._input_units_allow_dimensionless = False cs = s2 | s1 with pytest.raises(UnitsError, match=MESSAGE.format("ScaleDegrees", "", "deg")): out = cs(10) s1._input_units_allow_dimensionless = True s1 = ScaleDegrees(2) s1._input_units_allow_dimensionless = True s2 = ScaleDegrees(2) s2._input_units_allow_dimensionless = False cs = s1 & s2 cs = cs.rename("TestModel") out = cs(10, 10 * u.arcsec) assert_quantity_allclose(out[0], 20 * u.one) assert_quantity_allclose(out[1], 20 * u.arcsec) with pytest.raises(UnitsError, match=MESSAGE.format("ScaleDegrees", "", "deg")): out = cs(10, 10) def test_compound_return_units(): """ Test that return_units on the first model in the chain is respected for the input to the second. """ class PassModel(Model): n_inputs = 2 n_outputs = 2 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @property def input_units(self): """Input units.""" return {"x0": u.deg, "x1": u.deg} @property def return_units(self): """Output units.""" return {"x0": u.deg, "x1": u.deg} def evaluate(self, x, y): return x.value, y.value cs = Pix2Sky_TAN() | PassModel() assert_quantity_allclose(cs(0 * u.deg, 0 * u.deg), (0, 90) * u.deg)