Chris@87: """Miscellaneous functions for testing masked arrays and subclasses Chris@87: Chris@87: :author: Pierre Gerard-Marchant Chris@87: :contact: pierregm_at_uga_dot_edu Chris@87: :version: $Id: testutils.py 3529 2007-11-13 08:01:14Z jarrod.millman $ Chris@87: Chris@87: """ Chris@87: from __future__ import division, absolute_import, print_function Chris@87: Chris@87: __author__ = "Pierre GF Gerard-Marchant ($Author: jarrod.millman $)" Chris@87: __version__ = "1.0" Chris@87: __revision__ = "$Revision: 3529 $" Chris@87: __date__ = "$Date: 2007-11-13 10:01:14 +0200 (Tue, 13 Nov 2007) $" Chris@87: Chris@87: Chris@87: import operator Chris@87: Chris@87: import numpy as np Chris@87: from numpy import ndarray, float_ Chris@87: import numpy.core.umath as umath Chris@87: from numpy.testing import * Chris@87: import numpy.testing.utils as utils Chris@87: Chris@87: from .core import mask_or, getmask, masked_array, nomask, masked, filled, \ Chris@87: equal, less Chris@87: Chris@87: #------------------------------------------------------------------------------ Chris@87: def approx (a, b, fill_value=True, rtol=1e-5, atol=1e-8): Chris@87: """Returns true if all components of a and b are equal subject to given tolerances. Chris@87: Chris@87: If fill_value is True, masked values considered equal. Otherwise, masked values Chris@87: are considered unequal. Chris@87: The relative error rtol should be positive and << 1.0 Chris@87: The absolute error atol comes into play for those elements of b that are very Chris@87: small or zero; it says how small a must be also. Chris@87: """ Chris@87: m = mask_or(getmask(a), getmask(b)) Chris@87: d1 = filled(a) Chris@87: d2 = filled(b) Chris@87: if d1.dtype.char == "O" or d2.dtype.char == "O": Chris@87: return np.equal(d1, d2).ravel() Chris@87: x = filled(masked_array(d1, copy=False, mask=m), fill_value).astype(float_) Chris@87: y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_) Chris@87: d = np.less_equal(umath.absolute(x - y), atol + rtol * umath.absolute(y)) Chris@87: return d.ravel() Chris@87: Chris@87: Chris@87: def almost(a, b, decimal=6, fill_value=True): Chris@87: """Returns True if a and b are equal up to decimal places. Chris@87: If fill_value is True, masked values considered equal. Otherwise, masked values Chris@87: are considered unequal. Chris@87: """ Chris@87: m = mask_or(getmask(a), getmask(b)) Chris@87: d1 = filled(a) Chris@87: d2 = filled(b) Chris@87: if d1.dtype.char == "O" or d2.dtype.char == "O": Chris@87: return np.equal(d1, d2).ravel() Chris@87: x = filled(masked_array(d1, copy=False, mask=m), fill_value).astype(float_) Chris@87: y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_) Chris@87: d = np.around(np.abs(x - y), decimal) <= 10.0 ** (-decimal) Chris@87: return d.ravel() Chris@87: Chris@87: Chris@87: #................................................ Chris@87: def _assert_equal_on_sequences(actual, desired, err_msg=''): Chris@87: "Asserts the equality of two non-array sequences." Chris@87: assert_equal(len(actual), len(desired), err_msg) Chris@87: for k in range(len(desired)): Chris@87: assert_equal(actual[k], desired[k], 'item=%r\n%s' % (k, err_msg)) Chris@87: return Chris@87: Chris@87: def assert_equal_records(a, b): Chris@87: """Asserts that two records are equal. Pretty crude for now.""" Chris@87: assert_equal(a.dtype, b.dtype) Chris@87: for f in a.dtype.names: Chris@87: (af, bf) = (operator.getitem(a, f), operator.getitem(b, f)) Chris@87: if not (af is masked) and not (bf is masked): Chris@87: assert_equal(operator.getitem(a, f), operator.getitem(b, f)) Chris@87: return Chris@87: Chris@87: Chris@87: def assert_equal(actual, desired, err_msg=''): Chris@87: "Asserts that two items are equal." Chris@87: # Case #1: dictionary ..... Chris@87: if isinstance(desired, dict): Chris@87: if not isinstance(actual, dict): Chris@87: raise AssertionError(repr(type(actual))) Chris@87: assert_equal(len(actual), len(desired), err_msg) Chris@87: for k, i in desired.items(): Chris@87: if not k in actual: Chris@87: raise AssertionError("%s not in %s" % (k, actual)) Chris@87: assert_equal(actual[k], desired[k], 'key=%r\n%s' % (k, err_msg)) Chris@87: return Chris@87: # Case #2: lists ..... Chris@87: if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)): Chris@87: return _assert_equal_on_sequences(actual, desired, err_msg='') Chris@87: if not (isinstance(actual, ndarray) or isinstance(desired, ndarray)): Chris@87: msg = build_err_msg([actual, desired], err_msg,) Chris@87: if not desired == actual: Chris@87: raise AssertionError(msg) Chris@87: return Chris@87: # Case #4. arrays or equivalent Chris@87: if ((actual is masked) and not (desired is masked)) or \ Chris@87: ((desired is masked) and not (actual is masked)): Chris@87: msg = build_err_msg([actual, desired], Chris@87: err_msg, header='', names=('x', 'y')) Chris@87: raise ValueError(msg) Chris@87: actual = np.array(actual, copy=False, subok=True) Chris@87: desired = np.array(desired, copy=False, subok=True) Chris@87: (actual_dtype, desired_dtype) = (actual.dtype, desired.dtype) Chris@87: if actual_dtype.char == "S" and desired_dtype.char == "S": Chris@87: return _assert_equal_on_sequences(actual.tolist(), Chris@87: desired.tolist(), Chris@87: err_msg='') Chris@87: # elif actual_dtype.char in "OV" and desired_dtype.char in "OV": Chris@87: # if (actual_dtype != desired_dtype) and actual_dtype: Chris@87: # msg = build_err_msg([actual_dtype, desired_dtype], Chris@87: # err_msg, header='', names=('actual', 'desired')) Chris@87: # raise ValueError(msg) Chris@87: # return _assert_equal_on_sequences(actual.tolist(), Chris@87: # desired.tolist(), Chris@87: # err_msg='') Chris@87: return assert_array_equal(actual, desired, err_msg) Chris@87: Chris@87: Chris@87: def fail_if_equal(actual, desired, err_msg='',): Chris@87: """Raises an assertion error if two items are equal. Chris@87: """ Chris@87: if isinstance(desired, dict): Chris@87: if not isinstance(actual, dict): Chris@87: raise AssertionError(repr(type(actual))) Chris@87: fail_if_equal(len(actual), len(desired), err_msg) Chris@87: for k, i in desired.items(): Chris@87: if not k in actual: Chris@87: raise AssertionError(repr(k)) Chris@87: fail_if_equal(actual[k], desired[k], 'key=%r\n%s' % (k, err_msg)) Chris@87: return Chris@87: if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)): Chris@87: fail_if_equal(len(actual), len(desired), err_msg) Chris@87: for k in range(len(desired)): Chris@87: fail_if_equal(actual[k], desired[k], 'item=%r\n%s' % (k, err_msg)) Chris@87: return Chris@87: if isinstance(actual, np.ndarray) or isinstance(desired, np.ndarray): Chris@87: return fail_if_array_equal(actual, desired, err_msg) Chris@87: msg = build_err_msg([actual, desired], err_msg) Chris@87: if not desired != actual: Chris@87: raise AssertionError(msg) Chris@87: Chris@87: assert_not_equal = fail_if_equal Chris@87: Chris@87: Chris@87: def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True): Chris@87: """Asserts that two items are almost equal. Chris@87: The test is equivalent to abs(desired-actual) < 0.5 * 10**(-decimal) Chris@87: """ Chris@87: if isinstance(actual, np.ndarray) or isinstance(desired, np.ndarray): Chris@87: return assert_array_almost_equal(actual, desired, decimal=decimal, Chris@87: err_msg=err_msg, verbose=verbose) Chris@87: msg = build_err_msg([actual, desired], Chris@87: err_msg=err_msg, verbose=verbose) Chris@87: if not round(abs(desired - actual), decimal) == 0: Chris@87: raise AssertionError(msg) Chris@87: Chris@87: Chris@87: assert_close = assert_almost_equal Chris@87: Chris@87: Chris@87: def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='', Chris@87: fill_value=True): Chris@87: """Asserts that a comparison relation between two masked arrays is satisfied Chris@87: elementwise.""" Chris@87: # Fill the data first Chris@87: # xf = filled(x) Chris@87: # yf = filled(y) Chris@87: # Allocate a common mask and refill Chris@87: m = mask_or(getmask(x), getmask(y)) Chris@87: x = masked_array(x, copy=False, mask=m, keep_mask=False, subok=False) Chris@87: y = masked_array(y, copy=False, mask=m, keep_mask=False, subok=False) Chris@87: if ((x is masked) and not (y is masked)) or \ Chris@87: ((y is masked) and not (x is masked)): Chris@87: msg = build_err_msg([x, y], err_msg=err_msg, verbose=verbose, Chris@87: header=header, names=('x', 'y')) Chris@87: raise ValueError(msg) Chris@87: # OK, now run the basic tests on filled versions Chris@87: return utils.assert_array_compare(comparison, Chris@87: x.filled(fill_value), Chris@87: y.filled(fill_value), Chris@87: err_msg=err_msg, Chris@87: verbose=verbose, header=header) Chris@87: Chris@87: Chris@87: def assert_array_equal(x, y, err_msg='', verbose=True): Chris@87: """Checks the elementwise equality of two masked arrays.""" Chris@87: assert_array_compare(operator.__eq__, x, y, Chris@87: err_msg=err_msg, verbose=verbose, Chris@87: header='Arrays are not equal') Chris@87: Chris@87: Chris@87: def fail_if_array_equal(x, y, err_msg='', verbose=True): Chris@87: "Raises an assertion error if two masked arrays are not equal (elementwise)." Chris@87: def compare(x, y): Chris@87: return (not np.alltrue(approx(x, y))) Chris@87: assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose, Chris@87: header='Arrays are not equal') Chris@87: Chris@87: Chris@87: def assert_array_approx_equal(x, y, decimal=6, err_msg='', verbose=True): Chris@87: """Checks the elementwise equality of two masked arrays, up to a given Chris@87: number of decimals.""" Chris@87: def compare(x, y): Chris@87: "Returns the result of the loose comparison between x and y)." Chris@87: return approx(x, y, rtol=10. ** -decimal) Chris@87: assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose, Chris@87: header='Arrays are not almost equal') Chris@87: Chris@87: Chris@87: def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): Chris@87: """Checks the elementwise equality of two masked arrays, up to a given Chris@87: number of decimals.""" Chris@87: def compare(x, y): Chris@87: "Returns the result of the loose comparison between x and y)." Chris@87: return almost(x, y, decimal) Chris@87: assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose, Chris@87: header='Arrays are not almost equal') Chris@87: Chris@87: Chris@87: def assert_array_less(x, y, err_msg='', verbose=True): Chris@87: "Checks that x is smaller than y elementwise." Chris@87: assert_array_compare(operator.__lt__, x, y, Chris@87: err_msg=err_msg, verbose=verbose, Chris@87: header='Arrays are not less-ordered') Chris@87: Chris@87: Chris@87: def assert_mask_equal(m1, m2, err_msg=''): Chris@87: """Asserts the equality of two masks.""" Chris@87: if m1 is nomask: Chris@87: assert_(m2 is nomask) Chris@87: if m2 is nomask: Chris@87: assert_(m1 is nomask) Chris@87: assert_array_equal(m1, m2, err_msg=err_msg)