# Licensed under a 3-clause BSD style license - see PYFITS.rst import sys import numpy as np from astropy.io import fits from .conftest import FitsTestCase def compare_arrays(arr1in, arr2in, verbose=False): """ Compare the values field-by-field in two sets of numpy arrays or recarrays. """ arr1 = arr1in.view(np.ndarray) arr2 = arr2in.view(np.ndarray) nfail = 0 for n2 in arr2.dtype.names: n1 = n2 if n1 not in arr1.dtype.names: n1 = n1.lower() if n1 not in arr1.dtype.names: n1 = n1.upper() if n1 not in arr1.dtype.names: raise ValueError(f"field name {n2} not found in array 1") if verbose: sys.stdout.write(f" testing field: '{n2}'\n") sys.stdout.write(" shape...........") if arr2[n2].shape != arr1[n1].shape: nfail += 1 if verbose: sys.stdout.write("shapes differ\n") else: if verbose: sys.stdout.write("OK\n") sys.stdout.write(" elements........") (w,) = np.where(arr1[n1].ravel() != arr2[n2].ravel()) if w.size > 0: nfail += 1 if verbose: sys.stdout.write( f"\n {w.size} elements in field {n2} differ\n" ) else: if verbose: sys.stdout.write("OK\n") if nfail == 0: if verbose: sys.stdout.write("All tests passed\n") return True else: if verbose: sys.stdout.write(f"{nfail} differences found\n") return False def get_test_data(verbose=False): st = np.zeros(3, [("f1", "i4"), ("f2", "S6"), ("f3", ">2f8")]) np.random.seed(35) st["f1"] = [1, 3, 5] st["f2"] = ["hello", "world", "byebye"] st["f3"] = np.random.random(st["f3"].shape) return st class TestStructured(FitsTestCase): def test_structured(self): fname = self.data("stddata.fits") data1, h1 = fits.getdata(fname, ext=1, header=True) data2, h2 = fits.getdata(fname, ext=2, header=True) st = get_test_data() outfile = self.temp("test.fits") fits.writeto(outfile, data1, overwrite=True) fits.append(outfile, data2) fits.append(outfile, st) assert st.dtype.isnative assert np.all(st["f1"] == [1, 3, 5]) data1check, h1check = fits.getdata(outfile, ext=1, header=True) data2check, h2check = fits.getdata(outfile, ext=2, header=True) stcheck, sthcheck = fits.getdata(outfile, ext=3, header=True) assert compare_arrays(data1, data1check, verbose=True) assert compare_arrays(data2, data2check, verbose=True) assert compare_arrays(st, stcheck, verbose=True) # try reading with view dataviewcheck, hviewcheck = fits.getdata( outfile, ext=2, header=True, view=np.ndarray ) assert compare_arrays(data2, dataviewcheck, verbose=True)