# Licensed under a 3-clause BSD style license - see LICENSE.rst import pickle import textwrap from itertools import chain, permutations import numpy as np import pytest from numpy.testing import assert_array_equal from astropy import units as u from astropy.nddata import NDDataArray from astropy.nddata import _testing as nd_testing from astropy.nddata.nddata import NDData from astropy.nddata.nduncertainty import StdDevUncertainty from astropy.utils import NumpyRNGContext from astropy.utils.compat.optional_deps import HAS_DASK from astropy.utils.masked import Masked from astropy.utils.metadata.tests.test_metadata import MetaBaseTest from astropy.wcs import WCS from astropy.wcs.wcsapi import BaseHighLevelWCS, HighLevelWCSWrapper, SlicedLowLevelWCS from .test_nduncertainty import FakeUncertainty class FakeNumpyArray: """ Class that has a few of the attributes of a numpy array. These attributes are checked for by NDData. """ def __init__(self): super().__init__() def shape(self): pass def __getitem__(self, key): pass def __array__(self, dtype=None, copy=None): pass @property def dtype(self): return "fake" class MinimalUncertainty: """ Define the minimum attributes acceptable as an uncertainty object. """ def __init__(self, value): self._uncertainty = value @property def uncertainty_type(self): return "totally and completely fake" class BadNDDataSubclass(NDData): def __init__( self, data, uncertainty=None, mask=None, wcs=None, meta=None, unit=None, psf=None, ): self._data = data self._uncertainty = uncertainty self._mask = mask self._wcs = wcs self._psf = psf self._unit = unit self._meta = meta # Setter tests def test_uncertainty_setter(): nd = NDData([1, 2, 3]) good_uncertainty = MinimalUncertainty(5) nd.uncertainty = good_uncertainty assert nd.uncertainty is good_uncertainty # Check the fake uncertainty (minimal does not work since it has no # parent_nddata attribute from NDUncertainty) nd.uncertainty = FakeUncertainty(5) assert nd.uncertainty.parent_nddata is nd # Check that it works if the uncertainty was set during init nd = NDData(nd) assert isinstance(nd.uncertainty, FakeUncertainty) nd.uncertainty = 10 assert not isinstance(nd.uncertainty, FakeUncertainty) assert nd.uncertainty.array == 10 def test_mask_setter(): # Since it just changes the _mask attribute everything should work nd = NDData([1, 2, 3]) nd.mask = True assert nd.mask nd.mask = False assert not nd.mask # Check that it replaces a mask from init nd = NDData(nd, mask=True) assert nd.mask nd.mask = False assert not nd.mask # Init tests def test_nddata_empty(): with pytest.raises(TypeError): NDData() # empty initializer should fail def test_nddata_init_data_nonarray(): inp = [1, 2, 3] nd = NDData(inp) assert (np.array(inp) == nd.data).all() def test_nddata_init_data_ndarray(): # random floats with NumpyRNGContext(123): nd = NDData(np.random.random((10, 10))) assert nd.data.shape == (10, 10) assert nd.data.size == 100 assert nd.data.dtype == np.dtype(float) # specific integers nd = NDData(np.array([[1, 2, 3], [4, 5, 6]])) assert nd.data.size == 6 assert nd.data.dtype == np.dtype(int) # Tests to ensure that creating a new NDData object copies by *reference*. a = np.ones((10, 10)) nd_ref = NDData(a) a[0, 0] = 0 assert nd_ref.data[0, 0] == 0 # Except we choose copy=True a = np.ones((10, 10)) nd_ref = NDData(a, copy=True) a[0, 0] = 0 assert nd_ref.data[0, 0] != 0 def test_nddata_init_data_maskedarray(): with NumpyRNGContext(456): NDData(np.random.random((10, 10)), mask=np.random.random((10, 10)) > 0.5) # Another test (just copied here) with NumpyRNGContext(12345): a = np.random.randn(100) marr = np.ma.masked_where(a > 0, a) nd = NDData(marr) # check that masks and data match assert_array_equal(nd.mask, marr.mask) assert_array_equal(nd.data, marr.data) # check that they are both by reference marr.mask[10] = ~marr.mask[10] marr.data[11] = 123456789 assert_array_equal(nd.mask, marr.mask) assert_array_equal(nd.data, marr.data) # or not if we choose copy=True nd = NDData(marr, copy=True) marr.mask[10] = ~marr.mask[10] marr.data[11] = 0 assert nd.mask[10] != marr.mask[10] assert nd.data[11] != marr.data[11] @pytest.mark.parametrize("data", [np.array([1, 2, 3]), 5]) def test_nddata_init_data_quantity(data): # Test an array and a scalar because a scalar Quantity does not always # behave the same way as an array. quantity = data * u.adu ndd = NDData(quantity) assert ndd.unit == quantity.unit assert_array_equal(ndd.data, np.array(quantity)) if ndd.data.size > 1: # check that if it is an array it is not copied quantity.value[1] = 100 assert ndd.data[1] == quantity.value[1] # or is copied if we choose copy=True ndd = NDData(quantity, copy=True) quantity.value[1] = 5 assert ndd.data[1] != quantity.value[1] # provide a quantity and override the unit ndd_unit = NDData(data * u.erg, unit=u.J) assert ndd_unit.unit == u.J np.testing.assert_allclose((ndd_unit.data * ndd_unit.unit).to_value(u.erg), data) def test_nddata_init_data_masked_quantity(): a = np.array([2, 3]) q = a * u.m m = False mq = Masked(q, mask=m) nd = NDData(mq) assert_array_equal(nd.data, a) # This test failed before the change in nddata init because the masked # arrays data (which in fact was a quantity was directly saved) assert nd.unit == u.m assert not isinstance(nd.data, u.Quantity) np.testing.assert_array_equal(nd.mask, np.array(m)) def test_nddata_init_data_nddata(): nd1 = NDData(np.array([1])) nd2 = NDData(nd1) assert nd2.wcs == nd1.wcs assert nd2.uncertainty == nd1.uncertainty assert nd2.mask == nd1.mask assert nd2.unit == nd1.unit assert nd2.meta == nd1.meta assert nd2.psf == nd1.psf # Check that it is copied by reference nd1 = NDData(np.ones((5, 5))) nd2 = NDData(nd1) assert nd1.data is nd2.data # Check that it is really copied if copy=True nd2 = NDData(nd1, copy=True) nd1.data[2, 3] = 10 assert nd1.data[2, 3] != nd2.data[2, 3] # Now let's see what happens if we have all explicitly set nd1 = NDData( np.array([1]), mask=False, uncertainty=StdDevUncertainty(10), unit=u.s, meta={"dest": "mordor"}, wcs=WCS(naxis=1), psf=np.array([10]), ) nd2 = NDData(nd1) assert nd2.data is nd1.data assert nd2.wcs is nd1.wcs assert nd2.uncertainty.array == nd1.uncertainty.array assert nd2.mask == nd1.mask assert nd2.unit == nd1.unit assert nd2.meta == nd1.meta assert nd2.psf == nd1.psf # now what happens if we overwrite them all too nd3 = NDData( nd1, mask=True, uncertainty=StdDevUncertainty(200), unit=u.km, meta={"observer": "ME"}, wcs=WCS(naxis=1), psf=np.array([20]), ) assert nd3.data is nd1.data assert nd3.wcs is not nd1.wcs assert nd3.uncertainty.array != nd1.uncertainty.array assert nd3.mask != nd1.mask assert nd3.unit != nd1.unit assert nd3.meta != nd1.meta assert nd3.psf != nd1.psf def test_nddata_init_data_nddata_subclass(): uncert = StdDevUncertainty(3) # There might be some incompatible subclasses of NDData around. bnd = BadNDDataSubclass(False, True, 3, 2, "gollum", 100, 12) # Before changing the NDData init this would not have raised an error but # would have lead to a compromised nddata instance with pytest.raises(TypeError): NDData(bnd) # but if it has no actual incompatible attributes it passes bnd_good = BadNDDataSubclass( np.array([1, 2]), uncert, 3, HighLevelWCSWrapper(WCS(naxis=1)), {"enemy": "black knight"}, u.km, ) nd = NDData(bnd_good) assert nd.unit == bnd_good.unit assert nd.meta == bnd_good.meta assert nd.uncertainty == bnd_good.uncertainty assert nd.mask == bnd_good.mask assert nd.wcs is bnd_good.wcs assert nd.data is bnd_good.data def test_nddata_init_data_fail(): # First one is sliceable but has no shape, so should fail. with pytest.raises(TypeError): NDData({"a": "dict"}) # This has a shape but is not sliceable class Shape: def __init__(self): self.shape = 5 def __repr__(self): return "7" with pytest.raises(TypeError): NDData(Shape()) def test_nddata_init_data_fakes(): ndd1 = NDData(FakeNumpyArray()) # First make sure that NDData isn't converting its data to a numpy array. assert isinstance(ndd1.data, FakeNumpyArray) # Make a new NDData initialized from an NDData ndd2 = NDData(ndd1) # Check that the data wasn't converted to numpy assert isinstance(ndd2.data, FakeNumpyArray) # Specific parameters def test_param_uncertainty(): u = StdDevUncertainty(array=np.ones((5, 5))) d = NDData(np.ones((5, 5)), uncertainty=u) # Test that the parent_nddata is set. assert d.uncertainty.parent_nddata is d # Test conflicting uncertainties (other NDData) u2 = StdDevUncertainty(array=np.ones((5, 5)) * 2) d2 = NDData(d, uncertainty=u2) assert d2.uncertainty is u2 assert d2.uncertainty.parent_nddata is d2 def test_param_wcs(): # Since everything is allowed we only need to test something nd = NDData([1], wcs=WCS(naxis=1)) assert nd.wcs is not None # Test conflicting wcs (other NDData) nd2 = NDData(nd, wcs=WCS(naxis=1)) assert nd2.wcs is not None and nd2.wcs is not nd.wcs def test_param_meta(): # everything dict-like is allowed with pytest.raises(TypeError): NDData([1], meta=3) nd = NDData([1, 2, 3], meta={}) assert len(nd.meta) == 0 nd = NDData([1, 2, 3]) assert isinstance(nd.meta, dict) assert len(nd.meta) == 0 # Test conflicting meta (other NDData) nd2 = NDData(nd, meta={"image": "sun"}) assert len(nd2.meta) == 1 nd3 = NDData(nd2, meta={"image": "moon"}) assert len(nd3.meta) == 1 assert nd3.meta["image"] == "moon" def test_param_mask(): # Since everything is allowed we only need to test something nd = NDData([1], mask=False) assert not nd.mask # Test conflicting mask (other NDData) nd2 = NDData(nd, mask=True) assert nd2.mask # (masked array) nd3 = NDData(np.ma.array([1], mask=False), mask=True) assert nd3.mask # (masked quantity) mq = np.ma.array(np.array([2, 3]) * u.m, mask=False) nd4 = NDData(mq, mask=True) assert nd4.mask def test_param_unit(): with pytest.raises(ValueError): NDData(np.ones((5, 5)), unit="NotAValidUnit") NDData([1, 2, 3], unit="meter") # Test conflicting units (quantity as data) q = np.array([1, 2, 3]) * u.m nd = NDData(q, unit="cm") assert nd.unit != q.unit assert nd.unit == u.cm # (masked quantity) mq = np.ma.array(np.array([2, 3]) * u.m, mask=False) nd2 = NDData(mq, unit=u.pc) assert nd2.unit == u.pc # (another NDData as data) nd3 = NDData(nd, unit="km") assert nd3.unit == u.km # (MaskedQuantity given to NDData) mq_astropy = Masked.from_unmasked(q, False) nd4 = NDData(mq_astropy, unit="km") assert nd4.unit == u.km def test_pickle_nddata_with_uncertainty(): ndd = NDData( np.ones(3), uncertainty=StdDevUncertainty(np.ones(5), unit=u.m), unit=u.m ) ndd_dumped = pickle.dumps(ndd) ndd_restored = pickle.loads(ndd_dumped) assert type(ndd_restored.uncertainty) is StdDevUncertainty assert ndd_restored.uncertainty.parent_nddata is ndd_restored assert ndd_restored.uncertainty.unit == u.m def test_pickle_uncertainty_only(): ndd = NDData( np.ones(3), uncertainty=StdDevUncertainty(np.ones(5), unit=u.m), unit=u.m ) uncertainty_dumped = pickle.dumps(ndd.uncertainty) uncertainty_restored = pickle.loads(uncertainty_dumped) np.testing.assert_array_equal(ndd.uncertainty.array, uncertainty_restored.array) assert ndd.uncertainty.unit == uncertainty_restored.unit # Even though it has a parent there is no one that references the parent # after unpickling so the weakref "dies" immediately after unpickling # finishes. assert uncertainty_restored.parent_nddata is None def test_pickle_nddata_without_uncertainty(): ndd = NDData(np.ones(3), unit=u.m) dumped = pickle.dumps(ndd) ndd_restored = pickle.loads(dumped) np.testing.assert_array_equal(ndd.data, ndd_restored.data) # Check that the meta descriptor is working as expected. The MetaBaseTest class # takes care of defining all the tests, and we simply have to define the class # and any minimal set of args to pass. class TestMetaNDData(MetaBaseTest): test_class = NDData args = np.array([[1.0]]) # Representation tests def test_nddata_str(): arr1d = NDData(np.array([1, 2, 3])) assert str(arr1d) == "[1 2 3]" arr2d = NDData(np.array([[1, 2], [3, 4]])) assert str(arr2d) == textwrap.dedent( """ [[1 2] [3 4]]"""[1:] ) arr3d = NDData(np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])) assert str(arr3d) == textwrap.dedent( """ [[[1 2] [3 4]] [[5 6] [7 8]]]"""[1:] ) # let's add units! arr = NDData(np.array([1, 2, 3]), unit="km") assert str(arr) == "[1 2 3] km" # what if it had these units? arr = NDData(np.array([1, 2, 3]), unit="erg cm^-2 s^-1 A^-1") assert str(arr) == "[1 2 3] erg / (A s cm2)" def test_nddata_repr(): # The big test is eval(repr()) should be equal to the original! # but this must be modified slightly since adopting the # repr machinery from astropy.utils.masked arr1d = NDData(np.array([1, 2, 3])) s = repr(arr1d) assert s == "NDData([1, 2, 3])" got = eval(s) assert np.all(got.data == arr1d.data) assert got.unit == arr1d.unit arr2d = NDData(np.array([[1, 2], [3, 4]])) s = repr(arr2d) assert s == ("NDData([[1, 2],\n [3, 4]])") got = eval(s) assert np.all(got.data == arr2d.data) assert got.unit == arr2d.unit arr3d = NDData(np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])) s = repr(arr3d) assert s == ( "NDData([[[1, 2],\n [3, 4]],\n\n [[5, 6],\n [7, 8]]])" ) got = eval(s) assert np.all(got.data == arr3d.data) assert got.unit == arr3d.unit # let's add units! arr = NDData(np.array([1, 2, 3]), unit="km") s = repr(arr) assert s == "NDData([1, 2, 3], unit='km')" got = eval(s) assert np.all(got.data == arr.data) assert got.unit == arr.unit @pytest.mark.skipif(not HAS_DASK, reason="requires dask to be available") def test_nddata_repr_dask(): import dask.array as da arr = NDData(da.arange(3), unit="km") s = repr(arr) # just check repr equality for dask arrays, not round-tripping: assert s in ( 'NDData(\n data=dask.array,\n unit=Unit("km")\n)', 'NDData(\n data=dask.array,\n unit=Unit("km")\n)', ) # Not supported features def test_slicing_not_supported(): ndd = NDData(np.ones((5, 5))) with pytest.raises(TypeError): ndd[0] def test_arithmetic_not_supported(): ndd = NDData(np.ones((5, 5))) with pytest.raises(TypeError): ndd + ndd def test_nddata_wcs_setter_error_cases(): ndd = NDData(np.ones((5, 5))) # Setting with a non-WCS should raise an error with pytest.raises(TypeError): ndd.wcs = "I am not a WCS" naxis = 2 # This should succeed since the WCS is currently None ndd.wcs = nd_testing._create_wcs_simple( naxis=naxis, ctype=["deg"] * naxis, crpix=[0] * naxis, crval=[10] * naxis, cdelt=[1] * naxis, ) with pytest.raises(ValueError): # This should fail since the WCS is not None ndd.wcs = nd_testing._create_wcs_simple( naxis=naxis, ctype=["deg"] * naxis, crpix=[0] * naxis, crval=[10] * naxis, cdelt=[1] * naxis, ) def test_nddata_wcs_setter_with_low_level_wcs(): ndd = NDData(np.ones((5, 5))) wcs = WCS() # If the wcs property is set with a low level WCS it should get # wrapped to high level. low_level = SlicedLowLevelWCS(wcs, 5) assert not isinstance(low_level, BaseHighLevelWCS) ndd.wcs = low_level assert isinstance(ndd.wcs, BaseHighLevelWCS) def test_nddata_init_with_low_level_wcs(): wcs = WCS() low_level = SlicedLowLevelWCS(wcs, 5) ndd = NDData(np.ones((5, 5)), wcs=low_level) assert isinstance(ndd.wcs, BaseHighLevelWCS) class NDDataCustomWCS(NDData): @property def wcs(self): return WCS() def test_overriden_wcs(): # Check that a sub-class that overrides `.wcs` without providing a setter # works NDDataCustomWCS(np.ones((5, 5))) # set up parameters for test_collapse: np.random.seed(42) collapse_units = [None, u.Jy] collapse_propagate = [True, False] collapse_data_shapes = [ # 3D example: (4, 3, 2), # 5D example (6, 5, 4, 3, 2), ] collapse_ignore_masked = [True, False] collapse_masks = list( chain.from_iterable( [ # try the operations without a mask (all False): np.zeros(collapse_data_shape).astype(bool) ] + [ # assemble a bunch of random masks: np.random.randint(0, 2, size=collapse_data_shape).astype(bool) for _ in range(10) ] for collapse_data_shape in collapse_data_shapes ) ) # the following provides pytest.mark.parametrize with every # permutation of (1) the units, (2) propagating/not propagating # uncertainties, and (3) the data shapes of different ndim. permute = ( len(collapse_masks) * len(collapse_propagate) * len(collapse_units) * len(collapse_ignore_masked) ) collapse_units = permute // len(collapse_units) * collapse_units collapse_propagate = permute // len(collapse_propagate) * collapse_propagate collapse_masks = permute // len(collapse_masks) * collapse_masks collapse_ignore_masked = permute // len(collapse_ignore_masked) * collapse_ignore_masked @pytest.mark.parametrize( "mask, unit, propagate_uncertainties, operation_ignores_mask", zip(collapse_masks, collapse_units, collapse_propagate, collapse_ignore_masked), ) def test_collapse(mask, unit, propagate_uncertainties, operation_ignores_mask): # unique set of combinations of each of the N-1 axes for an N-D cube: axes_permutations = {tuple(axes[:2]) for axes in permutations(range(mask.ndim))} # each of the single axis slices: axes_permutations.update(set(range(mask.ndim))) axes_permutations.update({None}) cube = np.arange(np.prod(mask.shape)).reshape(mask.shape) numpy_cube = np.ma.masked_array(cube, mask=mask) ma_cube = Masked(cube, mask=mask) ndarr = NDDataArray(cube, uncertainty=StdDevUncertainty(cube), unit=unit, mask=mask) # By construction, the minimum value along each axis is always the zeroth index and # the maximum is always the last along that axis. We verify that here, so we can # test that the correct uncertainties are extracted during the # `NDDataArray.min` and `NDDataArray.max` methods later: for axis in range(cube.ndim): assert np.all(np.equal(cube.argmin(axis=axis), 0)) assert np.all(np.equal(cube.argmax(axis=axis), cube.shape[axis] - 1)) # confirm that supported nddata methods agree with corresponding numpy methods # for the masked data array: sum_methods = ["sum", "mean"] ext_methods = ["min", "max"] all_methods = sum_methods + ext_methods # for all supported methods, ensure the masking is propagated: for method in all_methods: for axes in axes_permutations: astropy_method = getattr(ma_cube, method)(axis=axes) numpy_method = getattr(numpy_cube, method)(axis=axes) nddata_method = getattr(ndarr, method)( axis=axes, propagate_uncertainties=propagate_uncertainties, operation_ignores_mask=operation_ignores_mask, ) astropy_unmasked = astropy_method.base[~astropy_method.mask] nddata_unmasked = nddata_method.data[~nddata_method.mask] # check if the units are passed through correctly: assert unit == nddata_method.unit # check if the numpy and astropy.utils.masked results agree when # the result is not fully masked: if len(astropy_unmasked) > 0: if not operation_ignores_mask: # compare with astropy assert np.all(np.equal(astropy_unmasked, nddata_unmasked)) assert np.all(np.equal(astropy_method.mask, nddata_method.mask)) else: # compare with numpy assert np.ma.all( np.ma.equal(numpy_method, np.asanyarray(nddata_method)) ) # For extremum methods, ensure the uncertainty returned corresponds to the # min/max data value. We've created the uncertainties to have the same value # as the data array, so we can just check for equality: if method in ext_methods and propagate_uncertainties: assert np.ma.all(np.ma.equal(astropy_method, nddata_method))