Chris@87
|
1 """Miscellaneous functions for testing masked arrays and subclasses
|
Chris@87
|
2
|
Chris@87
|
3 :author: Pierre Gerard-Marchant
|
Chris@87
|
4 :contact: pierregm_at_uga_dot_edu
|
Chris@87
|
5 :version: $Id: testutils.py 3529 2007-11-13 08:01:14Z jarrod.millman $
|
Chris@87
|
6
|
Chris@87
|
7 """
|
Chris@87
|
8 from __future__ import division, absolute_import, print_function
|
Chris@87
|
9
|
Chris@87
|
10 __author__ = "Pierre GF Gerard-Marchant ($Author: jarrod.millman $)"
|
Chris@87
|
11 __version__ = "1.0"
|
Chris@87
|
12 __revision__ = "$Revision: 3529 $"
|
Chris@87
|
13 __date__ = "$Date: 2007-11-13 10:01:14 +0200 (Tue, 13 Nov 2007) $"
|
Chris@87
|
14
|
Chris@87
|
15
|
Chris@87
|
16 import operator
|
Chris@87
|
17
|
Chris@87
|
18 import numpy as np
|
Chris@87
|
19 from numpy import ndarray, float_
|
Chris@87
|
20 import numpy.core.umath as umath
|
Chris@87
|
21 from numpy.testing import *
|
Chris@87
|
22 import numpy.testing.utils as utils
|
Chris@87
|
23
|
Chris@87
|
24 from .core import mask_or, getmask, masked_array, nomask, masked, filled, \
|
Chris@87
|
25 equal, less
|
Chris@87
|
26
|
Chris@87
|
27 #------------------------------------------------------------------------------
|
Chris@87
|
28 def approx (a, b, fill_value=True, rtol=1e-5, atol=1e-8):
|
Chris@87
|
29 """Returns true if all components of a and b are equal subject to given tolerances.
|
Chris@87
|
30
|
Chris@87
|
31 If fill_value is True, masked values considered equal. Otherwise, masked values
|
Chris@87
|
32 are considered unequal.
|
Chris@87
|
33 The relative error rtol should be positive and << 1.0
|
Chris@87
|
34 The absolute error atol comes into play for those elements of b that are very
|
Chris@87
|
35 small or zero; it says how small a must be also.
|
Chris@87
|
36 """
|
Chris@87
|
37 m = mask_or(getmask(a), getmask(b))
|
Chris@87
|
38 d1 = filled(a)
|
Chris@87
|
39 d2 = filled(b)
|
Chris@87
|
40 if d1.dtype.char == "O" or d2.dtype.char == "O":
|
Chris@87
|
41 return np.equal(d1, d2).ravel()
|
Chris@87
|
42 x = filled(masked_array(d1, copy=False, mask=m), fill_value).astype(float_)
|
Chris@87
|
43 y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_)
|
Chris@87
|
44 d = np.less_equal(umath.absolute(x - y), atol + rtol * umath.absolute(y))
|
Chris@87
|
45 return d.ravel()
|
Chris@87
|
46
|
Chris@87
|
47
|
Chris@87
|
48 def almost(a, b, decimal=6, fill_value=True):
|
Chris@87
|
49 """Returns True if a and b are equal up to decimal places.
|
Chris@87
|
50 If fill_value is True, masked values considered equal. Otherwise, masked values
|
Chris@87
|
51 are considered unequal.
|
Chris@87
|
52 """
|
Chris@87
|
53 m = mask_or(getmask(a), getmask(b))
|
Chris@87
|
54 d1 = filled(a)
|
Chris@87
|
55 d2 = filled(b)
|
Chris@87
|
56 if d1.dtype.char == "O" or d2.dtype.char == "O":
|
Chris@87
|
57 return np.equal(d1, d2).ravel()
|
Chris@87
|
58 x = filled(masked_array(d1, copy=False, mask=m), fill_value).astype(float_)
|
Chris@87
|
59 y = filled(masked_array(d2, copy=False, mask=m), 1).astype(float_)
|
Chris@87
|
60 d = np.around(np.abs(x - y), decimal) <= 10.0 ** (-decimal)
|
Chris@87
|
61 return d.ravel()
|
Chris@87
|
62
|
Chris@87
|
63
|
Chris@87
|
64 #................................................
|
Chris@87
|
65 def _assert_equal_on_sequences(actual, desired, err_msg=''):
|
Chris@87
|
66 "Asserts the equality of two non-array sequences."
|
Chris@87
|
67 assert_equal(len(actual), len(desired), err_msg)
|
Chris@87
|
68 for k in range(len(desired)):
|
Chris@87
|
69 assert_equal(actual[k], desired[k], 'item=%r\n%s' % (k, err_msg))
|
Chris@87
|
70 return
|
Chris@87
|
71
|
Chris@87
|
72 def assert_equal_records(a, b):
|
Chris@87
|
73 """Asserts that two records are equal. Pretty crude for now."""
|
Chris@87
|
74 assert_equal(a.dtype, b.dtype)
|
Chris@87
|
75 for f in a.dtype.names:
|
Chris@87
|
76 (af, bf) = (operator.getitem(a, f), operator.getitem(b, f))
|
Chris@87
|
77 if not (af is masked) and not (bf is masked):
|
Chris@87
|
78 assert_equal(operator.getitem(a, f), operator.getitem(b, f))
|
Chris@87
|
79 return
|
Chris@87
|
80
|
Chris@87
|
81
|
Chris@87
|
82 def assert_equal(actual, desired, err_msg=''):
|
Chris@87
|
83 "Asserts that two items are equal."
|
Chris@87
|
84 # Case #1: dictionary .....
|
Chris@87
|
85 if isinstance(desired, dict):
|
Chris@87
|
86 if not isinstance(actual, dict):
|
Chris@87
|
87 raise AssertionError(repr(type(actual)))
|
Chris@87
|
88 assert_equal(len(actual), len(desired), err_msg)
|
Chris@87
|
89 for k, i in desired.items():
|
Chris@87
|
90 if not k in actual:
|
Chris@87
|
91 raise AssertionError("%s not in %s" % (k, actual))
|
Chris@87
|
92 assert_equal(actual[k], desired[k], 'key=%r\n%s' % (k, err_msg))
|
Chris@87
|
93 return
|
Chris@87
|
94 # Case #2: lists .....
|
Chris@87
|
95 if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)):
|
Chris@87
|
96 return _assert_equal_on_sequences(actual, desired, err_msg='')
|
Chris@87
|
97 if not (isinstance(actual, ndarray) or isinstance(desired, ndarray)):
|
Chris@87
|
98 msg = build_err_msg([actual, desired], err_msg,)
|
Chris@87
|
99 if not desired == actual:
|
Chris@87
|
100 raise AssertionError(msg)
|
Chris@87
|
101 return
|
Chris@87
|
102 # Case #4. arrays or equivalent
|
Chris@87
|
103 if ((actual is masked) and not (desired is masked)) or \
|
Chris@87
|
104 ((desired is masked) and not (actual is masked)):
|
Chris@87
|
105 msg = build_err_msg([actual, desired],
|
Chris@87
|
106 err_msg, header='', names=('x', 'y'))
|
Chris@87
|
107 raise ValueError(msg)
|
Chris@87
|
108 actual = np.array(actual, copy=False, subok=True)
|
Chris@87
|
109 desired = np.array(desired, copy=False, subok=True)
|
Chris@87
|
110 (actual_dtype, desired_dtype) = (actual.dtype, desired.dtype)
|
Chris@87
|
111 if actual_dtype.char == "S" and desired_dtype.char == "S":
|
Chris@87
|
112 return _assert_equal_on_sequences(actual.tolist(),
|
Chris@87
|
113 desired.tolist(),
|
Chris@87
|
114 err_msg='')
|
Chris@87
|
115 # elif actual_dtype.char in "OV" and desired_dtype.char in "OV":
|
Chris@87
|
116 # if (actual_dtype != desired_dtype) and actual_dtype:
|
Chris@87
|
117 # msg = build_err_msg([actual_dtype, desired_dtype],
|
Chris@87
|
118 # err_msg, header='', names=('actual', 'desired'))
|
Chris@87
|
119 # raise ValueError(msg)
|
Chris@87
|
120 # return _assert_equal_on_sequences(actual.tolist(),
|
Chris@87
|
121 # desired.tolist(),
|
Chris@87
|
122 # err_msg='')
|
Chris@87
|
123 return assert_array_equal(actual, desired, err_msg)
|
Chris@87
|
124
|
Chris@87
|
125
|
Chris@87
|
126 def fail_if_equal(actual, desired, err_msg='',):
|
Chris@87
|
127 """Raises an assertion error if two items are equal.
|
Chris@87
|
128 """
|
Chris@87
|
129 if isinstance(desired, dict):
|
Chris@87
|
130 if not isinstance(actual, dict):
|
Chris@87
|
131 raise AssertionError(repr(type(actual)))
|
Chris@87
|
132 fail_if_equal(len(actual), len(desired), err_msg)
|
Chris@87
|
133 for k, i in desired.items():
|
Chris@87
|
134 if not k in actual:
|
Chris@87
|
135 raise AssertionError(repr(k))
|
Chris@87
|
136 fail_if_equal(actual[k], desired[k], 'key=%r\n%s' % (k, err_msg))
|
Chris@87
|
137 return
|
Chris@87
|
138 if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)):
|
Chris@87
|
139 fail_if_equal(len(actual), len(desired), err_msg)
|
Chris@87
|
140 for k in range(len(desired)):
|
Chris@87
|
141 fail_if_equal(actual[k], desired[k], 'item=%r\n%s' % (k, err_msg))
|
Chris@87
|
142 return
|
Chris@87
|
143 if isinstance(actual, np.ndarray) or isinstance(desired, np.ndarray):
|
Chris@87
|
144 return fail_if_array_equal(actual, desired, err_msg)
|
Chris@87
|
145 msg = build_err_msg([actual, desired], err_msg)
|
Chris@87
|
146 if not desired != actual:
|
Chris@87
|
147 raise AssertionError(msg)
|
Chris@87
|
148
|
Chris@87
|
149 assert_not_equal = fail_if_equal
|
Chris@87
|
150
|
Chris@87
|
151
|
Chris@87
|
152 def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True):
|
Chris@87
|
153 """Asserts that two items are almost equal.
|
Chris@87
|
154 The test is equivalent to abs(desired-actual) < 0.5 * 10**(-decimal)
|
Chris@87
|
155 """
|
Chris@87
|
156 if isinstance(actual, np.ndarray) or isinstance(desired, np.ndarray):
|
Chris@87
|
157 return assert_array_almost_equal(actual, desired, decimal=decimal,
|
Chris@87
|
158 err_msg=err_msg, verbose=verbose)
|
Chris@87
|
159 msg = build_err_msg([actual, desired],
|
Chris@87
|
160 err_msg=err_msg, verbose=verbose)
|
Chris@87
|
161 if not round(abs(desired - actual), decimal) == 0:
|
Chris@87
|
162 raise AssertionError(msg)
|
Chris@87
|
163
|
Chris@87
|
164
|
Chris@87
|
165 assert_close = assert_almost_equal
|
Chris@87
|
166
|
Chris@87
|
167
|
Chris@87
|
168 def assert_array_compare(comparison, x, y, err_msg='', verbose=True, header='',
|
Chris@87
|
169 fill_value=True):
|
Chris@87
|
170 """Asserts that a comparison relation between two masked arrays is satisfied
|
Chris@87
|
171 elementwise."""
|
Chris@87
|
172 # Fill the data first
|
Chris@87
|
173 # xf = filled(x)
|
Chris@87
|
174 # yf = filled(y)
|
Chris@87
|
175 # Allocate a common mask and refill
|
Chris@87
|
176 m = mask_or(getmask(x), getmask(y))
|
Chris@87
|
177 x = masked_array(x, copy=False, mask=m, keep_mask=False, subok=False)
|
Chris@87
|
178 y = masked_array(y, copy=False, mask=m, keep_mask=False, subok=False)
|
Chris@87
|
179 if ((x is masked) and not (y is masked)) or \
|
Chris@87
|
180 ((y is masked) and not (x is masked)):
|
Chris@87
|
181 msg = build_err_msg([x, y], err_msg=err_msg, verbose=verbose,
|
Chris@87
|
182 header=header, names=('x', 'y'))
|
Chris@87
|
183 raise ValueError(msg)
|
Chris@87
|
184 # OK, now run the basic tests on filled versions
|
Chris@87
|
185 return utils.assert_array_compare(comparison,
|
Chris@87
|
186 x.filled(fill_value),
|
Chris@87
|
187 y.filled(fill_value),
|
Chris@87
|
188 err_msg=err_msg,
|
Chris@87
|
189 verbose=verbose, header=header)
|
Chris@87
|
190
|
Chris@87
|
191
|
Chris@87
|
192 def assert_array_equal(x, y, err_msg='', verbose=True):
|
Chris@87
|
193 """Checks the elementwise equality of two masked arrays."""
|
Chris@87
|
194 assert_array_compare(operator.__eq__, x, y,
|
Chris@87
|
195 err_msg=err_msg, verbose=verbose,
|
Chris@87
|
196 header='Arrays are not equal')
|
Chris@87
|
197
|
Chris@87
|
198
|
Chris@87
|
199 def fail_if_array_equal(x, y, err_msg='', verbose=True):
|
Chris@87
|
200 "Raises an assertion error if two masked arrays are not equal (elementwise)."
|
Chris@87
|
201 def compare(x, y):
|
Chris@87
|
202 return (not np.alltrue(approx(x, y)))
|
Chris@87
|
203 assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
|
Chris@87
|
204 header='Arrays are not equal')
|
Chris@87
|
205
|
Chris@87
|
206
|
Chris@87
|
207 def assert_array_approx_equal(x, y, decimal=6, err_msg='', verbose=True):
|
Chris@87
|
208 """Checks the elementwise equality of two masked arrays, up to a given
|
Chris@87
|
209 number of decimals."""
|
Chris@87
|
210 def compare(x, y):
|
Chris@87
|
211 "Returns the result of the loose comparison between x and y)."
|
Chris@87
|
212 return approx(x, y, rtol=10. ** -decimal)
|
Chris@87
|
213 assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
|
Chris@87
|
214 header='Arrays are not almost equal')
|
Chris@87
|
215
|
Chris@87
|
216
|
Chris@87
|
217 def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
|
Chris@87
|
218 """Checks the elementwise equality of two masked arrays, up to a given
|
Chris@87
|
219 number of decimals."""
|
Chris@87
|
220 def compare(x, y):
|
Chris@87
|
221 "Returns the result of the loose comparison between x and y)."
|
Chris@87
|
222 return almost(x, y, decimal)
|
Chris@87
|
223 assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
|
Chris@87
|
224 header='Arrays are not almost equal')
|
Chris@87
|
225
|
Chris@87
|
226
|
Chris@87
|
227 def assert_array_less(x, y, err_msg='', verbose=True):
|
Chris@87
|
228 "Checks that x is smaller than y elementwise."
|
Chris@87
|
229 assert_array_compare(operator.__lt__, x, y,
|
Chris@87
|
230 err_msg=err_msg, verbose=verbose,
|
Chris@87
|
231 header='Arrays are not less-ordered')
|
Chris@87
|
232
|
Chris@87
|
233
|
Chris@87
|
234 def assert_mask_equal(m1, m2, err_msg=''):
|
Chris@87
|
235 """Asserts the equality of two masks."""
|
Chris@87
|
236 if m1 is nomask:
|
Chris@87
|
237 assert_(m2 is nomask)
|
Chris@87
|
238 if m2 is nomask:
|
Chris@87
|
239 assert_(m1 is nomask)
|
Chris@87
|
240 assert_array_equal(m1, m2, err_msg=err_msg)
|