Chris@87
|
1 """
|
Chris@87
|
2 Utility function to facilitate testing.
|
Chris@87
|
3
|
Chris@87
|
4 """
|
Chris@87
|
5 from __future__ import division, absolute_import, print_function
|
Chris@87
|
6
|
Chris@87
|
7 import os
|
Chris@87
|
8 import sys
|
Chris@87
|
9 import re
|
Chris@87
|
10 import operator
|
Chris@87
|
11 import warnings
|
Chris@87
|
12 from functools import partial
|
Chris@87
|
13 import shutil
|
Chris@87
|
14 import contextlib
|
Chris@87
|
15 from tempfile import mkdtemp
|
Chris@87
|
16 from .nosetester import import_nose
|
Chris@87
|
17 from numpy.core import float32, empty, arange, array_repr, ndarray
|
Chris@87
|
18
|
Chris@87
|
19 if sys.version_info[0] >= 3:
|
Chris@87
|
20 from io import StringIO
|
Chris@87
|
21 else:
|
Chris@87
|
22 from StringIO import StringIO
|
Chris@87
|
23
|
Chris@87
|
24 __all__ = ['assert_equal', 'assert_almost_equal', 'assert_approx_equal',
|
Chris@87
|
25 'assert_array_equal', 'assert_array_less', 'assert_string_equal',
|
Chris@87
|
26 'assert_array_almost_equal', 'assert_raises', 'build_err_msg',
|
Chris@87
|
27 'decorate_methods', 'jiffies', 'memusage', 'print_assert_equal',
|
Chris@87
|
28 'raises', 'rand', 'rundocs', 'runstring', 'verbose', 'measure',
|
Chris@87
|
29 'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex',
|
Chris@87
|
30 'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings',
|
Chris@87
|
31 'assert_allclose', 'IgnoreException']
|
Chris@87
|
32
|
Chris@87
|
33
|
Chris@87
|
34 verbose = 0
|
Chris@87
|
35
|
Chris@87
|
36
|
Chris@87
|
37 def assert_(val, msg='') :
|
Chris@87
|
38 """
|
Chris@87
|
39 Assert that works in release mode.
|
Chris@87
|
40 Accepts callable msg to allow deferring evaluation until failure.
|
Chris@87
|
41
|
Chris@87
|
42 The Python built-in ``assert`` does not work when executing code in
|
Chris@87
|
43 optimized mode (the ``-O`` flag) - no byte-code is generated for it.
|
Chris@87
|
44
|
Chris@87
|
45 For documentation on usage, refer to the Python documentation.
|
Chris@87
|
46
|
Chris@87
|
47 """
|
Chris@87
|
48 if not val :
|
Chris@87
|
49 try:
|
Chris@87
|
50 smsg = msg()
|
Chris@87
|
51 except TypeError:
|
Chris@87
|
52 smsg = msg
|
Chris@87
|
53 raise AssertionError(smsg)
|
Chris@87
|
54
|
Chris@87
|
55 def gisnan(x):
|
Chris@87
|
56 """like isnan, but always raise an error if type not supported instead of
|
Chris@87
|
57 returning a TypeError object.
|
Chris@87
|
58
|
Chris@87
|
59 Notes
|
Chris@87
|
60 -----
|
Chris@87
|
61 isnan and other ufunc sometimes return a NotImplementedType object instead
|
Chris@87
|
62 of raising any exception. This function is a wrapper to make sure an
|
Chris@87
|
63 exception is always raised.
|
Chris@87
|
64
|
Chris@87
|
65 This should be removed once this problem is solved at the Ufunc level."""
|
Chris@87
|
66 from numpy.core import isnan
|
Chris@87
|
67 st = isnan(x)
|
Chris@87
|
68 if isinstance(st, type(NotImplemented)):
|
Chris@87
|
69 raise TypeError("isnan not supported for this type")
|
Chris@87
|
70 return st
|
Chris@87
|
71
|
Chris@87
|
72 def gisfinite(x):
|
Chris@87
|
73 """like isfinite, but always raise an error if type not supported instead of
|
Chris@87
|
74 returning a TypeError object.
|
Chris@87
|
75
|
Chris@87
|
76 Notes
|
Chris@87
|
77 -----
|
Chris@87
|
78 isfinite and other ufunc sometimes return a NotImplementedType object instead
|
Chris@87
|
79 of raising any exception. This function is a wrapper to make sure an
|
Chris@87
|
80 exception is always raised.
|
Chris@87
|
81
|
Chris@87
|
82 This should be removed once this problem is solved at the Ufunc level."""
|
Chris@87
|
83 from numpy.core import isfinite, errstate
|
Chris@87
|
84 with errstate(invalid='ignore'):
|
Chris@87
|
85 st = isfinite(x)
|
Chris@87
|
86 if isinstance(st, type(NotImplemented)):
|
Chris@87
|
87 raise TypeError("isfinite not supported for this type")
|
Chris@87
|
88 return st
|
Chris@87
|
89
|
Chris@87
|
90 def gisinf(x):
|
Chris@87
|
91 """like isinf, but always raise an error if type not supported instead of
|
Chris@87
|
92 returning a TypeError object.
|
Chris@87
|
93
|
Chris@87
|
94 Notes
|
Chris@87
|
95 -----
|
Chris@87
|
96 isinf and other ufunc sometimes return a NotImplementedType object instead
|
Chris@87
|
97 of raising any exception. This function is a wrapper to make sure an
|
Chris@87
|
98 exception is always raised.
|
Chris@87
|
99
|
Chris@87
|
100 This should be removed once this problem is solved at the Ufunc level."""
|
Chris@87
|
101 from numpy.core import isinf, errstate
|
Chris@87
|
102 with errstate(invalid='ignore'):
|
Chris@87
|
103 st = isinf(x)
|
Chris@87
|
104 if isinstance(st, type(NotImplemented)):
|
Chris@87
|
105 raise TypeError("isinf not supported for this type")
|
Chris@87
|
106 return st
|
Chris@87
|
107
|
Chris@87
|
108 def rand(*args):
|
Chris@87
|
109 """Returns an array of random numbers with the given shape.
|
Chris@87
|
110
|
Chris@87
|
111 This only uses the standard library, so it is useful for testing purposes.
|
Chris@87
|
112 """
|
Chris@87
|
113 import random
|
Chris@87
|
114 from numpy.core import zeros, float64
|
Chris@87
|
115 results = zeros(args, float64)
|
Chris@87
|
116 f = results.flat
|
Chris@87
|
117 for i in range(len(f)):
|
Chris@87
|
118 f[i] = random.random()
|
Chris@87
|
119 return results
|
Chris@87
|
120
|
Chris@87
|
121 if sys.platform[:5]=='linux':
|
Chris@87
|
122 def jiffies(_proc_pid_stat = '/proc/%s/stat'%(os.getpid()),
|
Chris@87
|
123 _load_time=[]):
|
Chris@87
|
124 """ Return number of jiffies (1/100ths of a second) that this
|
Chris@87
|
125 process has been scheduled in user mode. See man 5 proc. """
|
Chris@87
|
126 import time
|
Chris@87
|
127 if not _load_time:
|
Chris@87
|
128 _load_time.append(time.time())
|
Chris@87
|
129 try:
|
Chris@87
|
130 f=open(_proc_pid_stat, 'r')
|
Chris@87
|
131 l = f.readline().split(' ')
|
Chris@87
|
132 f.close()
|
Chris@87
|
133 return int(l[13])
|
Chris@87
|
134 except:
|
Chris@87
|
135 return int(100*(time.time()-_load_time[0]))
|
Chris@87
|
136
|
Chris@87
|
137 def memusage(_proc_pid_stat = '/proc/%s/stat'%(os.getpid())):
|
Chris@87
|
138 """ Return virtual memory size in bytes of the running python.
|
Chris@87
|
139 """
|
Chris@87
|
140 try:
|
Chris@87
|
141 f=open(_proc_pid_stat, 'r')
|
Chris@87
|
142 l = f.readline().split(' ')
|
Chris@87
|
143 f.close()
|
Chris@87
|
144 return int(l[22])
|
Chris@87
|
145 except:
|
Chris@87
|
146 return
|
Chris@87
|
147 else:
|
Chris@87
|
148 # os.getpid is not in all platforms available.
|
Chris@87
|
149 # Using time is safe but inaccurate, especially when process
|
Chris@87
|
150 # was suspended or sleeping.
|
Chris@87
|
151 def jiffies(_load_time=[]):
|
Chris@87
|
152 """ Return number of jiffies (1/100ths of a second) that this
|
Chris@87
|
153 process has been scheduled in user mode. [Emulation with time.time]. """
|
Chris@87
|
154 import time
|
Chris@87
|
155 if not _load_time:
|
Chris@87
|
156 _load_time.append(time.time())
|
Chris@87
|
157 return int(100*(time.time()-_load_time[0]))
|
Chris@87
|
158 def memusage():
|
Chris@87
|
159 """ Return memory usage of running python. [Not implemented]"""
|
Chris@87
|
160 raise NotImplementedError
|
Chris@87
|
161
|
Chris@87
|
162 if os.name=='nt' and sys.version[:3] > '2.3':
|
Chris@87
|
163 # Code "stolen" from enthought/debug/memusage.py
|
Chris@87
|
164 def GetPerformanceAttributes(object, counter, instance = None,
|
Chris@87
|
165 inum=-1, format = None, machine=None):
|
Chris@87
|
166 # NOTE: Many counters require 2 samples to give accurate results,
|
Chris@87
|
167 # including "% Processor Time" (as by definition, at any instant, a
|
Chris@87
|
168 # thread's CPU usage is either 0 or 100). To read counters like this,
|
Chris@87
|
169 # you should copy this function, but keep the counter open, and call
|
Chris@87
|
170 # CollectQueryData() each time you need to know.
|
Chris@87
|
171 # See http://msdn.microsoft.com/library/en-us/dnperfmo/html/perfmonpt2.asp
|
Chris@87
|
172 # My older explanation for this was that the "AddCounter" process forced
|
Chris@87
|
173 # the CPU to 100%, but the above makes more sense :)
|
Chris@87
|
174 import win32pdh
|
Chris@87
|
175 if format is None: format = win32pdh.PDH_FMT_LONG
|
Chris@87
|
176 path = win32pdh.MakeCounterPath( (machine, object, instance, None, inum, counter) )
|
Chris@87
|
177 hq = win32pdh.OpenQuery()
|
Chris@87
|
178 try:
|
Chris@87
|
179 hc = win32pdh.AddCounter(hq, path)
|
Chris@87
|
180 try:
|
Chris@87
|
181 win32pdh.CollectQueryData(hq)
|
Chris@87
|
182 type, val = win32pdh.GetFormattedCounterValue(hc, format)
|
Chris@87
|
183 return val
|
Chris@87
|
184 finally:
|
Chris@87
|
185 win32pdh.RemoveCounter(hc)
|
Chris@87
|
186 finally:
|
Chris@87
|
187 win32pdh.CloseQuery(hq)
|
Chris@87
|
188
|
Chris@87
|
189 def memusage(processName="python", instance=0):
|
Chris@87
|
190 # from win32pdhutil, part of the win32all package
|
Chris@87
|
191 import win32pdh
|
Chris@87
|
192 return GetPerformanceAttributes("Process", "Virtual Bytes",
|
Chris@87
|
193 processName, instance,
|
Chris@87
|
194 win32pdh.PDH_FMT_LONG, None)
|
Chris@87
|
195
|
Chris@87
|
196 def build_err_msg(arrays, err_msg, header='Items are not equal:',
|
Chris@87
|
197 verbose=True, names=('ACTUAL', 'DESIRED'), precision=8):
|
Chris@87
|
198 msg = ['\n' + header]
|
Chris@87
|
199 if err_msg:
|
Chris@87
|
200 if err_msg.find('\n') == -1 and len(err_msg) < 79-len(header):
|
Chris@87
|
201 msg = [msg[0] + ' ' + err_msg]
|
Chris@87
|
202 else:
|
Chris@87
|
203 msg.append(err_msg)
|
Chris@87
|
204 if verbose:
|
Chris@87
|
205 for i, a in enumerate(arrays):
|
Chris@87
|
206
|
Chris@87
|
207 if isinstance(a, ndarray):
|
Chris@87
|
208 # precision argument is only needed if the objects are ndarrays
|
Chris@87
|
209 r_func = partial(array_repr, precision=precision)
|
Chris@87
|
210 else:
|
Chris@87
|
211 r_func = repr
|
Chris@87
|
212
|
Chris@87
|
213 try:
|
Chris@87
|
214 r = r_func(a)
|
Chris@87
|
215 except:
|
Chris@87
|
216 r = '[repr failed]'
|
Chris@87
|
217 if r.count('\n') > 3:
|
Chris@87
|
218 r = '\n'.join(r.splitlines()[:3])
|
Chris@87
|
219 r += '...'
|
Chris@87
|
220 msg.append(' %s: %s' % (names[i], r))
|
Chris@87
|
221 return '\n'.join(msg)
|
Chris@87
|
222
|
Chris@87
|
223 def assert_equal(actual,desired,err_msg='',verbose=True):
|
Chris@87
|
224 """
|
Chris@87
|
225 Raises an AssertionError if two objects are not equal.
|
Chris@87
|
226
|
Chris@87
|
227 Given two objects (scalars, lists, tuples, dictionaries or numpy arrays),
|
Chris@87
|
228 check that all elements of these objects are equal. An exception is raised
|
Chris@87
|
229 at the first conflicting values.
|
Chris@87
|
230
|
Chris@87
|
231 Parameters
|
Chris@87
|
232 ----------
|
Chris@87
|
233 actual : array_like
|
Chris@87
|
234 The object to check.
|
Chris@87
|
235 desired : array_like
|
Chris@87
|
236 The expected object.
|
Chris@87
|
237 err_msg : str, optional
|
Chris@87
|
238 The error message to be printed in case of failure.
|
Chris@87
|
239 verbose : bool, optional
|
Chris@87
|
240 If True, the conflicting values are appended to the error message.
|
Chris@87
|
241
|
Chris@87
|
242 Raises
|
Chris@87
|
243 ------
|
Chris@87
|
244 AssertionError
|
Chris@87
|
245 If actual and desired are not equal.
|
Chris@87
|
246
|
Chris@87
|
247 Examples
|
Chris@87
|
248 --------
|
Chris@87
|
249 >>> np.testing.assert_equal([4,5], [4,6])
|
Chris@87
|
250 ...
|
Chris@87
|
251 <type 'exceptions.AssertionError'>:
|
Chris@87
|
252 Items are not equal:
|
Chris@87
|
253 item=1
|
Chris@87
|
254 ACTUAL: 5
|
Chris@87
|
255 DESIRED: 6
|
Chris@87
|
256
|
Chris@87
|
257 """
|
Chris@87
|
258 if isinstance(desired, dict):
|
Chris@87
|
259 if not isinstance(actual, dict) :
|
Chris@87
|
260 raise AssertionError(repr(type(actual)))
|
Chris@87
|
261 assert_equal(len(actual), len(desired), err_msg, verbose)
|
Chris@87
|
262 for k, i in desired.items():
|
Chris@87
|
263 if k not in actual :
|
Chris@87
|
264 raise AssertionError(repr(k))
|
Chris@87
|
265 assert_equal(actual[k], desired[k], 'key=%r\n%s' % (k, err_msg), verbose)
|
Chris@87
|
266 return
|
Chris@87
|
267 if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)):
|
Chris@87
|
268 assert_equal(len(actual), len(desired), err_msg, verbose)
|
Chris@87
|
269 for k in range(len(desired)):
|
Chris@87
|
270 assert_equal(actual[k], desired[k], 'item=%r\n%s' % (k, err_msg), verbose)
|
Chris@87
|
271 return
|
Chris@87
|
272 from numpy.core import ndarray, isscalar, signbit
|
Chris@87
|
273 from numpy.lib import iscomplexobj, real, imag
|
Chris@87
|
274 if isinstance(actual, ndarray) or isinstance(desired, ndarray):
|
Chris@87
|
275 return assert_array_equal(actual, desired, err_msg, verbose)
|
Chris@87
|
276 msg = build_err_msg([actual, desired], err_msg, verbose=verbose)
|
Chris@87
|
277
|
Chris@87
|
278 # Handle complex numbers: separate into real/imag to handle
|
Chris@87
|
279 # nan/inf/negative zero correctly
|
Chris@87
|
280 # XXX: catch ValueError for subclasses of ndarray where iscomplex fail
|
Chris@87
|
281 try:
|
Chris@87
|
282 usecomplex = iscomplexobj(actual) or iscomplexobj(desired)
|
Chris@87
|
283 except ValueError:
|
Chris@87
|
284 usecomplex = False
|
Chris@87
|
285
|
Chris@87
|
286 if usecomplex:
|
Chris@87
|
287 if iscomplexobj(actual):
|
Chris@87
|
288 actualr = real(actual)
|
Chris@87
|
289 actuali = imag(actual)
|
Chris@87
|
290 else:
|
Chris@87
|
291 actualr = actual
|
Chris@87
|
292 actuali = 0
|
Chris@87
|
293 if iscomplexobj(desired):
|
Chris@87
|
294 desiredr = real(desired)
|
Chris@87
|
295 desiredi = imag(desired)
|
Chris@87
|
296 else:
|
Chris@87
|
297 desiredr = desired
|
Chris@87
|
298 desiredi = 0
|
Chris@87
|
299 try:
|
Chris@87
|
300 assert_equal(actualr, desiredr)
|
Chris@87
|
301 assert_equal(actuali, desiredi)
|
Chris@87
|
302 except AssertionError:
|
Chris@87
|
303 raise AssertionError(msg)
|
Chris@87
|
304
|
Chris@87
|
305 # Inf/nan/negative zero handling
|
Chris@87
|
306 try:
|
Chris@87
|
307 # isscalar test to check cases such as [np.nan] != np.nan
|
Chris@87
|
308 if isscalar(desired) != isscalar(actual):
|
Chris@87
|
309 raise AssertionError(msg)
|
Chris@87
|
310
|
Chris@87
|
311 # If one of desired/actual is not finite, handle it specially here:
|
Chris@87
|
312 # check that both are nan if any is a nan, and test for equality
|
Chris@87
|
313 # otherwise
|
Chris@87
|
314 if not (gisfinite(desired) and gisfinite(actual)):
|
Chris@87
|
315 isdesnan = gisnan(desired)
|
Chris@87
|
316 isactnan = gisnan(actual)
|
Chris@87
|
317 if isdesnan or isactnan:
|
Chris@87
|
318 if not (isdesnan and isactnan):
|
Chris@87
|
319 raise AssertionError(msg)
|
Chris@87
|
320 else:
|
Chris@87
|
321 if not desired == actual:
|
Chris@87
|
322 raise AssertionError(msg)
|
Chris@87
|
323 return
|
Chris@87
|
324 elif desired == 0 and actual == 0:
|
Chris@87
|
325 if not signbit(desired) == signbit(actual):
|
Chris@87
|
326 raise AssertionError(msg)
|
Chris@87
|
327 # If TypeError or ValueError raised while using isnan and co, just handle
|
Chris@87
|
328 # as before
|
Chris@87
|
329 except (TypeError, ValueError, NotImplementedError):
|
Chris@87
|
330 pass
|
Chris@87
|
331
|
Chris@87
|
332 # Explicitly use __eq__ for comparison, ticket #2552
|
Chris@87
|
333 if not (desired == actual):
|
Chris@87
|
334 raise AssertionError(msg)
|
Chris@87
|
335
|
Chris@87
|
336 def print_assert_equal(test_string, actual, desired):
|
Chris@87
|
337 """
|
Chris@87
|
338 Test if two objects are equal, and print an error message if test fails.
|
Chris@87
|
339
|
Chris@87
|
340 The test is performed with ``actual == desired``.
|
Chris@87
|
341
|
Chris@87
|
342 Parameters
|
Chris@87
|
343 ----------
|
Chris@87
|
344 test_string : str
|
Chris@87
|
345 The message supplied to AssertionError.
|
Chris@87
|
346 actual : object
|
Chris@87
|
347 The object to test for equality against `desired`.
|
Chris@87
|
348 desired : object
|
Chris@87
|
349 The expected result.
|
Chris@87
|
350
|
Chris@87
|
351 Examples
|
Chris@87
|
352 --------
|
Chris@87
|
353 >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 1])
|
Chris@87
|
354 >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 2])
|
Chris@87
|
355 Traceback (most recent call last):
|
Chris@87
|
356 ...
|
Chris@87
|
357 AssertionError: Test XYZ of func xyz failed
|
Chris@87
|
358 ACTUAL:
|
Chris@87
|
359 [0, 1]
|
Chris@87
|
360 DESIRED:
|
Chris@87
|
361 [0, 2]
|
Chris@87
|
362
|
Chris@87
|
363 """
|
Chris@87
|
364 import pprint
|
Chris@87
|
365
|
Chris@87
|
366 if not (actual == desired):
|
Chris@87
|
367 msg = StringIO()
|
Chris@87
|
368 msg.write(test_string)
|
Chris@87
|
369 msg.write(' failed\nACTUAL: \n')
|
Chris@87
|
370 pprint.pprint(actual, msg)
|
Chris@87
|
371 msg.write('DESIRED: \n')
|
Chris@87
|
372 pprint.pprint(desired, msg)
|
Chris@87
|
373 raise AssertionError(msg.getvalue())
|
Chris@87
|
374
|
Chris@87
|
375 def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True):
|
Chris@87
|
376 """
|
Chris@87
|
377 Raises an AssertionError if two items are not equal up to desired
|
Chris@87
|
378 precision.
|
Chris@87
|
379
|
Chris@87
|
380 .. note:: It is recommended to use one of `assert_allclose`,
|
Chris@87
|
381 `assert_array_almost_equal_nulp` or `assert_array_max_ulp`
|
Chris@87
|
382 instead of this function for more consistent floating point
|
Chris@87
|
383 comparisons.
|
Chris@87
|
384
|
Chris@87
|
385 The test is equivalent to ``abs(desired-actual) < 0.5 * 10**(-decimal)``.
|
Chris@87
|
386
|
Chris@87
|
387 Given two objects (numbers or ndarrays), check that all elements of these
|
Chris@87
|
388 objects are almost equal. An exception is raised at conflicting values.
|
Chris@87
|
389 For ndarrays this delegates to assert_array_almost_equal
|
Chris@87
|
390
|
Chris@87
|
391 Parameters
|
Chris@87
|
392 ----------
|
Chris@87
|
393 actual : array_like
|
Chris@87
|
394 The object to check.
|
Chris@87
|
395 desired : array_like
|
Chris@87
|
396 The expected object.
|
Chris@87
|
397 decimal : int, optional
|
Chris@87
|
398 Desired precision, default is 7.
|
Chris@87
|
399 err_msg : str, optional
|
Chris@87
|
400 The error message to be printed in case of failure.
|
Chris@87
|
401 verbose : bool, optional
|
Chris@87
|
402 If True, the conflicting values are appended to the error message.
|
Chris@87
|
403
|
Chris@87
|
404 Raises
|
Chris@87
|
405 ------
|
Chris@87
|
406 AssertionError
|
Chris@87
|
407 If actual and desired are not equal up to specified precision.
|
Chris@87
|
408
|
Chris@87
|
409 See Also
|
Chris@87
|
410 --------
|
Chris@87
|
411 assert_allclose: Compare two array_like objects for equality with desired
|
Chris@87
|
412 relative and/or absolute precision.
|
Chris@87
|
413 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal
|
Chris@87
|
414
|
Chris@87
|
415 Examples
|
Chris@87
|
416 --------
|
Chris@87
|
417 >>> import numpy.testing as npt
|
Chris@87
|
418 >>> npt.assert_almost_equal(2.3333333333333, 2.33333334)
|
Chris@87
|
419 >>> npt.assert_almost_equal(2.3333333333333, 2.33333334, decimal=10)
|
Chris@87
|
420 ...
|
Chris@87
|
421 <type 'exceptions.AssertionError'>:
|
Chris@87
|
422 Items are not equal:
|
Chris@87
|
423 ACTUAL: 2.3333333333333002
|
Chris@87
|
424 DESIRED: 2.3333333399999998
|
Chris@87
|
425
|
Chris@87
|
426 >>> npt.assert_almost_equal(np.array([1.0,2.3333333333333]),
|
Chris@87
|
427 ... np.array([1.0,2.33333334]), decimal=9)
|
Chris@87
|
428 ...
|
Chris@87
|
429 <type 'exceptions.AssertionError'>:
|
Chris@87
|
430 Arrays are not almost equal
|
Chris@87
|
431 <BLANKLINE>
|
Chris@87
|
432 (mismatch 50.0%)
|
Chris@87
|
433 x: array([ 1. , 2.33333333])
|
Chris@87
|
434 y: array([ 1. , 2.33333334])
|
Chris@87
|
435
|
Chris@87
|
436 """
|
Chris@87
|
437 from numpy.core import ndarray
|
Chris@87
|
438 from numpy.lib import iscomplexobj, real, imag
|
Chris@87
|
439
|
Chris@87
|
440 # Handle complex numbers: separate into real/imag to handle
|
Chris@87
|
441 # nan/inf/negative zero correctly
|
Chris@87
|
442 # XXX: catch ValueError for subclasses of ndarray where iscomplex fail
|
Chris@87
|
443 try:
|
Chris@87
|
444 usecomplex = iscomplexobj(actual) or iscomplexobj(desired)
|
Chris@87
|
445 except ValueError:
|
Chris@87
|
446 usecomplex = False
|
Chris@87
|
447
|
Chris@87
|
448 def _build_err_msg():
|
Chris@87
|
449 header = ('Arrays are not almost equal to %d decimals' % decimal)
|
Chris@87
|
450 return build_err_msg([actual, desired], err_msg, verbose=verbose,
|
Chris@87
|
451 header=header)
|
Chris@87
|
452
|
Chris@87
|
453 if usecomplex:
|
Chris@87
|
454 if iscomplexobj(actual):
|
Chris@87
|
455 actualr = real(actual)
|
Chris@87
|
456 actuali = imag(actual)
|
Chris@87
|
457 else:
|
Chris@87
|
458 actualr = actual
|
Chris@87
|
459 actuali = 0
|
Chris@87
|
460 if iscomplexobj(desired):
|
Chris@87
|
461 desiredr = real(desired)
|
Chris@87
|
462 desiredi = imag(desired)
|
Chris@87
|
463 else:
|
Chris@87
|
464 desiredr = desired
|
Chris@87
|
465 desiredi = 0
|
Chris@87
|
466 try:
|
Chris@87
|
467 assert_almost_equal(actualr, desiredr, decimal=decimal)
|
Chris@87
|
468 assert_almost_equal(actuali, desiredi, decimal=decimal)
|
Chris@87
|
469 except AssertionError:
|
Chris@87
|
470 raise AssertionError(_build_err_msg())
|
Chris@87
|
471
|
Chris@87
|
472 if isinstance(actual, (ndarray, tuple, list)) \
|
Chris@87
|
473 or isinstance(desired, (ndarray, tuple, list)):
|
Chris@87
|
474 return assert_array_almost_equal(actual, desired, decimal, err_msg)
|
Chris@87
|
475 try:
|
Chris@87
|
476 # If one of desired/actual is not finite, handle it specially here:
|
Chris@87
|
477 # check that both are nan if any is a nan, and test for equality
|
Chris@87
|
478 # otherwise
|
Chris@87
|
479 if not (gisfinite(desired) and gisfinite(actual)):
|
Chris@87
|
480 if gisnan(desired) or gisnan(actual):
|
Chris@87
|
481 if not (gisnan(desired) and gisnan(actual)):
|
Chris@87
|
482 raise AssertionError(_build_err_msg())
|
Chris@87
|
483 else:
|
Chris@87
|
484 if not desired == actual:
|
Chris@87
|
485 raise AssertionError(_build_err_msg())
|
Chris@87
|
486 return
|
Chris@87
|
487 except (NotImplementedError, TypeError):
|
Chris@87
|
488 pass
|
Chris@87
|
489 if round(abs(desired - actual), decimal) != 0 :
|
Chris@87
|
490 raise AssertionError(_build_err_msg())
|
Chris@87
|
491
|
Chris@87
|
492
|
Chris@87
|
493 def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True):
|
Chris@87
|
494 """
|
Chris@87
|
495 Raises an AssertionError if two items are not equal up to significant
|
Chris@87
|
496 digits.
|
Chris@87
|
497
|
Chris@87
|
498 .. note:: It is recommended to use one of `assert_allclose`,
|
Chris@87
|
499 `assert_array_almost_equal_nulp` or `assert_array_max_ulp`
|
Chris@87
|
500 instead of this function for more consistent floating point
|
Chris@87
|
501 comparisons.
|
Chris@87
|
502
|
Chris@87
|
503 Given two numbers, check that they are approximately equal.
|
Chris@87
|
504 Approximately equal is defined as the number of significant digits
|
Chris@87
|
505 that agree.
|
Chris@87
|
506
|
Chris@87
|
507 Parameters
|
Chris@87
|
508 ----------
|
Chris@87
|
509 actual : scalar
|
Chris@87
|
510 The object to check.
|
Chris@87
|
511 desired : scalar
|
Chris@87
|
512 The expected object.
|
Chris@87
|
513 significant : int, optional
|
Chris@87
|
514 Desired precision, default is 7.
|
Chris@87
|
515 err_msg : str, optional
|
Chris@87
|
516 The error message to be printed in case of failure.
|
Chris@87
|
517 verbose : bool, optional
|
Chris@87
|
518 If True, the conflicting values are appended to the error message.
|
Chris@87
|
519
|
Chris@87
|
520 Raises
|
Chris@87
|
521 ------
|
Chris@87
|
522 AssertionError
|
Chris@87
|
523 If actual and desired are not equal up to specified precision.
|
Chris@87
|
524
|
Chris@87
|
525 See Also
|
Chris@87
|
526 --------
|
Chris@87
|
527 assert_allclose: Compare two array_like objects for equality with desired
|
Chris@87
|
528 relative and/or absolute precision.
|
Chris@87
|
529 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal
|
Chris@87
|
530
|
Chris@87
|
531 Examples
|
Chris@87
|
532 --------
|
Chris@87
|
533 >>> np.testing.assert_approx_equal(0.12345677777777e-20, 0.1234567e-20)
|
Chris@87
|
534 >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345671e-20,
|
Chris@87
|
535 significant=8)
|
Chris@87
|
536 >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345672e-20,
|
Chris@87
|
537 significant=8)
|
Chris@87
|
538 ...
|
Chris@87
|
539 <type 'exceptions.AssertionError'>:
|
Chris@87
|
540 Items are not equal to 8 significant digits:
|
Chris@87
|
541 ACTUAL: 1.234567e-021
|
Chris@87
|
542 DESIRED: 1.2345672000000001e-021
|
Chris@87
|
543
|
Chris@87
|
544 the evaluated condition that raises the exception is
|
Chris@87
|
545
|
Chris@87
|
546 >>> abs(0.12345670e-20/1e-21 - 0.12345672e-20/1e-21) >= 10**-(8-1)
|
Chris@87
|
547 True
|
Chris@87
|
548
|
Chris@87
|
549 """
|
Chris@87
|
550 import numpy as np
|
Chris@87
|
551
|
Chris@87
|
552 (actual, desired) = map(float, (actual, desired))
|
Chris@87
|
553 if desired==actual:
|
Chris@87
|
554 return
|
Chris@87
|
555 # Normalized the numbers to be in range (-10.0,10.0)
|
Chris@87
|
556 # scale = float(pow(10,math.floor(math.log10(0.5*(abs(desired)+abs(actual))))))
|
Chris@87
|
557 with np.errstate(invalid='ignore'):
|
Chris@87
|
558 scale = 0.5*(np.abs(desired) + np.abs(actual))
|
Chris@87
|
559 scale = np.power(10, np.floor(np.log10(scale)))
|
Chris@87
|
560 try:
|
Chris@87
|
561 sc_desired = desired/scale
|
Chris@87
|
562 except ZeroDivisionError:
|
Chris@87
|
563 sc_desired = 0.0
|
Chris@87
|
564 try:
|
Chris@87
|
565 sc_actual = actual/scale
|
Chris@87
|
566 except ZeroDivisionError:
|
Chris@87
|
567 sc_actual = 0.0
|
Chris@87
|
568 msg = build_err_msg([actual, desired], err_msg,
|
Chris@87
|
569 header='Items are not equal to %d significant digits:' %
|
Chris@87
|
570 significant,
|
Chris@87
|
571 verbose=verbose)
|
Chris@87
|
572 try:
|
Chris@87
|
573 # If one of desired/actual is not finite, handle it specially here:
|
Chris@87
|
574 # check that both are nan if any is a nan, and test for equality
|
Chris@87
|
575 # otherwise
|
Chris@87
|
576 if not (gisfinite(desired) and gisfinite(actual)):
|
Chris@87
|
577 if gisnan(desired) or gisnan(actual):
|
Chris@87
|
578 if not (gisnan(desired) and gisnan(actual)):
|
Chris@87
|
579 raise AssertionError(msg)
|
Chris@87
|
580 else:
|
Chris@87
|
581 if not desired == actual:
|
Chris@87
|
582 raise AssertionError(msg)
|
Chris@87
|
583 return
|
Chris@87
|
584 except (TypeError, NotImplementedError):
|
Chris@87
|
585 pass
|
Chris@87
|
586 if np.abs(sc_desired - sc_actual) >= np.power(10., -(significant-1)) :
|
Chris@87
|
587 raise AssertionError(msg)
|
Chris@87
|
588
|
Chris@87
|
589 def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
|
Chris@87
|
590 header='', precision=6):
|
Chris@87
|
591 from numpy.core import array, isnan, isinf, any, all, inf
|
Chris@87
|
592 x = array(x, copy=False, subok=True)
|
Chris@87
|
593 y = array(y, copy=False, subok=True)
|
Chris@87
|
594
|
Chris@87
|
595 def isnumber(x):
|
Chris@87
|
596 return x.dtype.char in '?bhilqpBHILQPefdgFDG'
|
Chris@87
|
597
|
Chris@87
|
598 def chk_same_position(x_id, y_id, hasval='nan'):
|
Chris@87
|
599 """Handling nan/inf: check that x and y have the nan/inf at the same
|
Chris@87
|
600 locations."""
|
Chris@87
|
601 try:
|
Chris@87
|
602 assert_array_equal(x_id, y_id)
|
Chris@87
|
603 except AssertionError:
|
Chris@87
|
604 msg = build_err_msg([x, y],
|
Chris@87
|
605 err_msg + '\nx and y %s location mismatch:' \
|
Chris@87
|
606 % (hasval), verbose=verbose, header=header,
|
Chris@87
|
607 names=('x', 'y'), precision=precision)
|
Chris@87
|
608 raise AssertionError(msg)
|
Chris@87
|
609
|
Chris@87
|
610 try:
|
Chris@87
|
611 cond = (x.shape==() or y.shape==()) or x.shape == y.shape
|
Chris@87
|
612 if not cond:
|
Chris@87
|
613 msg = build_err_msg([x, y],
|
Chris@87
|
614 err_msg
|
Chris@87
|
615 + '\n(shapes %s, %s mismatch)' % (x.shape,
|
Chris@87
|
616 y.shape),
|
Chris@87
|
617 verbose=verbose, header=header,
|
Chris@87
|
618 names=('x', 'y'), precision=precision)
|
Chris@87
|
619 if not cond :
|
Chris@87
|
620 raise AssertionError(msg)
|
Chris@87
|
621
|
Chris@87
|
622 if isnumber(x) and isnumber(y):
|
Chris@87
|
623 x_isnan, y_isnan = isnan(x), isnan(y)
|
Chris@87
|
624 x_isinf, y_isinf = isinf(x), isinf(y)
|
Chris@87
|
625
|
Chris@87
|
626 # Validate that the special values are in the same place
|
Chris@87
|
627 if any(x_isnan) or any(y_isnan):
|
Chris@87
|
628 chk_same_position(x_isnan, y_isnan, hasval='nan')
|
Chris@87
|
629 if any(x_isinf) or any(y_isinf):
|
Chris@87
|
630 # Check +inf and -inf separately, since they are different
|
Chris@87
|
631 chk_same_position(x == +inf, y == +inf, hasval='+inf')
|
Chris@87
|
632 chk_same_position(x == -inf, y == -inf, hasval='-inf')
|
Chris@87
|
633
|
Chris@87
|
634 # Combine all the special values
|
Chris@87
|
635 x_id, y_id = x_isnan, y_isnan
|
Chris@87
|
636 x_id |= x_isinf
|
Chris@87
|
637 y_id |= y_isinf
|
Chris@87
|
638
|
Chris@87
|
639 # Only do the comparison if actual values are left
|
Chris@87
|
640 if all(x_id):
|
Chris@87
|
641 return
|
Chris@87
|
642
|
Chris@87
|
643 if any(x_id):
|
Chris@87
|
644 val = comparison(x[~x_id], y[~y_id])
|
Chris@87
|
645 else:
|
Chris@87
|
646 val = comparison(x, y)
|
Chris@87
|
647 else:
|
Chris@87
|
648 val = comparison(x, y)
|
Chris@87
|
649
|
Chris@87
|
650 if isinstance(val, bool):
|
Chris@87
|
651 cond = val
|
Chris@87
|
652 reduced = [0]
|
Chris@87
|
653 else:
|
Chris@87
|
654 reduced = val.ravel()
|
Chris@87
|
655 cond = reduced.all()
|
Chris@87
|
656 reduced = reduced.tolist()
|
Chris@87
|
657 if not cond:
|
Chris@87
|
658 match = 100-100.0*reduced.count(1)/len(reduced)
|
Chris@87
|
659 msg = build_err_msg([x, y],
|
Chris@87
|
660 err_msg
|
Chris@87
|
661 + '\n(mismatch %s%%)' % (match,),
|
Chris@87
|
662 verbose=verbose, header=header,
|
Chris@87
|
663 names=('x', 'y'), precision=precision)
|
Chris@87
|
664 if not cond :
|
Chris@87
|
665 raise AssertionError(msg)
|
Chris@87
|
666 except ValueError as e:
|
Chris@87
|
667 import traceback
|
Chris@87
|
668 efmt = traceback.format_exc()
|
Chris@87
|
669 header = 'error during assertion:\n\n%s\n\n%s' % (efmt, header)
|
Chris@87
|
670
|
Chris@87
|
671 msg = build_err_msg([x, y], err_msg, verbose=verbose, header=header,
|
Chris@87
|
672 names=('x', 'y'), precision=precision)
|
Chris@87
|
673 raise ValueError(msg)
|
Chris@87
|
674
|
Chris@87
|
675 def assert_array_equal(x, y, err_msg='', verbose=True):
|
Chris@87
|
676 """
|
Chris@87
|
677 Raises an AssertionError if two array_like objects are not equal.
|
Chris@87
|
678
|
Chris@87
|
679 Given two array_like objects, check that the shape is equal and all
|
Chris@87
|
680 elements of these objects are equal. An exception is raised at
|
Chris@87
|
681 shape mismatch or conflicting values. In contrast to the standard usage
|
Chris@87
|
682 in numpy, NaNs are compared like numbers, no assertion is raised if
|
Chris@87
|
683 both objects have NaNs in the same positions.
|
Chris@87
|
684
|
Chris@87
|
685 The usual caution for verifying equality with floating point numbers is
|
Chris@87
|
686 advised.
|
Chris@87
|
687
|
Chris@87
|
688 Parameters
|
Chris@87
|
689 ----------
|
Chris@87
|
690 x : array_like
|
Chris@87
|
691 The actual object to check.
|
Chris@87
|
692 y : array_like
|
Chris@87
|
693 The desired, expected object.
|
Chris@87
|
694 err_msg : str, optional
|
Chris@87
|
695 The error message to be printed in case of failure.
|
Chris@87
|
696 verbose : bool, optional
|
Chris@87
|
697 If True, the conflicting values are appended to the error message.
|
Chris@87
|
698
|
Chris@87
|
699 Raises
|
Chris@87
|
700 ------
|
Chris@87
|
701 AssertionError
|
Chris@87
|
702 If actual and desired objects are not equal.
|
Chris@87
|
703
|
Chris@87
|
704 See Also
|
Chris@87
|
705 --------
|
Chris@87
|
706 assert_allclose: Compare two array_like objects for equality with desired
|
Chris@87
|
707 relative and/or absolute precision.
|
Chris@87
|
708 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal
|
Chris@87
|
709
|
Chris@87
|
710 Examples
|
Chris@87
|
711 --------
|
Chris@87
|
712 The first assert does not raise an exception:
|
Chris@87
|
713
|
Chris@87
|
714 >>> np.testing.assert_array_equal([1.0,2.33333,np.nan],
|
Chris@87
|
715 ... [np.exp(0),2.33333, np.nan])
|
Chris@87
|
716
|
Chris@87
|
717 Assert fails with numerical inprecision with floats:
|
Chris@87
|
718
|
Chris@87
|
719 >>> np.testing.assert_array_equal([1.0,np.pi,np.nan],
|
Chris@87
|
720 ... [1, np.sqrt(np.pi)**2, np.nan])
|
Chris@87
|
721 ...
|
Chris@87
|
722 <type 'exceptions.ValueError'>:
|
Chris@87
|
723 AssertionError:
|
Chris@87
|
724 Arrays are not equal
|
Chris@87
|
725 <BLANKLINE>
|
Chris@87
|
726 (mismatch 50.0%)
|
Chris@87
|
727 x: array([ 1. , 3.14159265, NaN])
|
Chris@87
|
728 y: array([ 1. , 3.14159265, NaN])
|
Chris@87
|
729
|
Chris@87
|
730 Use `assert_allclose` or one of the nulp (number of floating point values)
|
Chris@87
|
731 functions for these cases instead:
|
Chris@87
|
732
|
Chris@87
|
733 >>> np.testing.assert_allclose([1.0,np.pi,np.nan],
|
Chris@87
|
734 ... [1, np.sqrt(np.pi)**2, np.nan],
|
Chris@87
|
735 ... rtol=1e-10, atol=0)
|
Chris@87
|
736
|
Chris@87
|
737 """
|
Chris@87
|
738 assert_array_compare(operator.__eq__, x, y, err_msg=err_msg,
|
Chris@87
|
739 verbose=verbose, header='Arrays are not equal')
|
Chris@87
|
740
|
Chris@87
|
741 def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
|
Chris@87
|
742 """
|
Chris@87
|
743 Raises an AssertionError if two objects are not equal up to desired
|
Chris@87
|
744 precision.
|
Chris@87
|
745
|
Chris@87
|
746 .. note:: It is recommended to use one of `assert_allclose`,
|
Chris@87
|
747 `assert_array_almost_equal_nulp` or `assert_array_max_ulp`
|
Chris@87
|
748 instead of this function for more consistent floating point
|
Chris@87
|
749 comparisons.
|
Chris@87
|
750
|
Chris@87
|
751 The test verifies identical shapes and verifies values with
|
Chris@87
|
752 ``abs(desired-actual) < 0.5 * 10**(-decimal)``.
|
Chris@87
|
753
|
Chris@87
|
754 Given two array_like objects, check that the shape is equal and all
|
Chris@87
|
755 elements of these objects are almost equal. An exception is raised at
|
Chris@87
|
756 shape mismatch or conflicting values. In contrast to the standard usage
|
Chris@87
|
757 in numpy, NaNs are compared like numbers, no assertion is raised if
|
Chris@87
|
758 both objects have NaNs in the same positions.
|
Chris@87
|
759
|
Chris@87
|
760 Parameters
|
Chris@87
|
761 ----------
|
Chris@87
|
762 x : array_like
|
Chris@87
|
763 The actual object to check.
|
Chris@87
|
764 y : array_like
|
Chris@87
|
765 The desired, expected object.
|
Chris@87
|
766 decimal : int, optional
|
Chris@87
|
767 Desired precision, default is 6.
|
Chris@87
|
768 err_msg : str, optional
|
Chris@87
|
769 The error message to be printed in case of failure.
|
Chris@87
|
770 verbose : bool, optional
|
Chris@87
|
771 If True, the conflicting values are appended to the error message.
|
Chris@87
|
772
|
Chris@87
|
773 Raises
|
Chris@87
|
774 ------
|
Chris@87
|
775 AssertionError
|
Chris@87
|
776 If actual and desired are not equal up to specified precision.
|
Chris@87
|
777
|
Chris@87
|
778 See Also
|
Chris@87
|
779 --------
|
Chris@87
|
780 assert_allclose: Compare two array_like objects for equality with desired
|
Chris@87
|
781 relative and/or absolute precision.
|
Chris@87
|
782 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal
|
Chris@87
|
783
|
Chris@87
|
784 Examples
|
Chris@87
|
785 --------
|
Chris@87
|
786 the first assert does not raise an exception
|
Chris@87
|
787
|
Chris@87
|
788 >>> np.testing.assert_array_almost_equal([1.0,2.333,np.nan],
|
Chris@87
|
789 [1.0,2.333,np.nan])
|
Chris@87
|
790
|
Chris@87
|
791 >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan],
|
Chris@87
|
792 ... [1.0,2.33339,np.nan], decimal=5)
|
Chris@87
|
793 ...
|
Chris@87
|
794 <type 'exceptions.AssertionError'>:
|
Chris@87
|
795 AssertionError:
|
Chris@87
|
796 Arrays are not almost equal
|
Chris@87
|
797 <BLANKLINE>
|
Chris@87
|
798 (mismatch 50.0%)
|
Chris@87
|
799 x: array([ 1. , 2.33333, NaN])
|
Chris@87
|
800 y: array([ 1. , 2.33339, NaN])
|
Chris@87
|
801
|
Chris@87
|
802 >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan],
|
Chris@87
|
803 ... [1.0,2.33333, 5], decimal=5)
|
Chris@87
|
804 <type 'exceptions.ValueError'>:
|
Chris@87
|
805 ValueError:
|
Chris@87
|
806 Arrays are not almost equal
|
Chris@87
|
807 x: array([ 1. , 2.33333, NaN])
|
Chris@87
|
808 y: array([ 1. , 2.33333, 5. ])
|
Chris@87
|
809
|
Chris@87
|
810 """
|
Chris@87
|
811 from numpy.core import around, number, float_, result_type, array
|
Chris@87
|
812 from numpy.core.numerictypes import issubdtype
|
Chris@87
|
813 from numpy.core.fromnumeric import any as npany
|
Chris@87
|
814 def compare(x, y):
|
Chris@87
|
815 try:
|
Chris@87
|
816 if npany(gisinf(x)) or npany( gisinf(y)):
|
Chris@87
|
817 xinfid = gisinf(x)
|
Chris@87
|
818 yinfid = gisinf(y)
|
Chris@87
|
819 if not xinfid == yinfid:
|
Chris@87
|
820 return False
|
Chris@87
|
821 # if one item, x and y is +- inf
|
Chris@87
|
822 if x.size == y.size == 1:
|
Chris@87
|
823 return x == y
|
Chris@87
|
824 x = x[~xinfid]
|
Chris@87
|
825 y = y[~yinfid]
|
Chris@87
|
826 except (TypeError, NotImplementedError):
|
Chris@87
|
827 pass
|
Chris@87
|
828
|
Chris@87
|
829 # make sure y is an inexact type to avoid abs(MIN_INT); will cause
|
Chris@87
|
830 # casting of x later.
|
Chris@87
|
831 dtype = result_type(y, 1.)
|
Chris@87
|
832 y = array(y, dtype=dtype, copy=False, subok=True)
|
Chris@87
|
833 z = abs(x-y)
|
Chris@87
|
834
|
Chris@87
|
835 if not issubdtype(z.dtype, number):
|
Chris@87
|
836 z = z.astype(float_) # handle object arrays
|
Chris@87
|
837
|
Chris@87
|
838 return around(z, decimal) <= 10.0**(-decimal)
|
Chris@87
|
839
|
Chris@87
|
840 assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
|
Chris@87
|
841 header=('Arrays are not almost equal to %d decimals' % decimal),
|
Chris@87
|
842 precision=decimal)
|
Chris@87
|
843
|
Chris@87
|
844
|
Chris@87
|
845 def assert_array_less(x, y, err_msg='', verbose=True):
|
Chris@87
|
846 """
|
Chris@87
|
847 Raises an AssertionError if two array_like objects are not ordered by less
|
Chris@87
|
848 than.
|
Chris@87
|
849
|
Chris@87
|
850 Given two array_like objects, check that the shape is equal and all
|
Chris@87
|
851 elements of the first object are strictly smaller than those of the
|
Chris@87
|
852 second object. An exception is raised at shape mismatch or incorrectly
|
Chris@87
|
853 ordered values. Shape mismatch does not raise if an object has zero
|
Chris@87
|
854 dimension. In contrast to the standard usage in numpy, NaNs are
|
Chris@87
|
855 compared, no assertion is raised if both objects have NaNs in the same
|
Chris@87
|
856 positions.
|
Chris@87
|
857
|
Chris@87
|
858
|
Chris@87
|
859
|
Chris@87
|
860 Parameters
|
Chris@87
|
861 ----------
|
Chris@87
|
862 x : array_like
|
Chris@87
|
863 The smaller object to check.
|
Chris@87
|
864 y : array_like
|
Chris@87
|
865 The larger object to compare.
|
Chris@87
|
866 err_msg : string
|
Chris@87
|
867 The error message to be printed in case of failure.
|
Chris@87
|
868 verbose : bool
|
Chris@87
|
869 If True, the conflicting values are appended to the error message.
|
Chris@87
|
870
|
Chris@87
|
871 Raises
|
Chris@87
|
872 ------
|
Chris@87
|
873 AssertionError
|
Chris@87
|
874 If actual and desired objects are not equal.
|
Chris@87
|
875
|
Chris@87
|
876 See Also
|
Chris@87
|
877 --------
|
Chris@87
|
878 assert_array_equal: tests objects for equality
|
Chris@87
|
879 assert_array_almost_equal: test objects for equality up to precision
|
Chris@87
|
880
|
Chris@87
|
881
|
Chris@87
|
882
|
Chris@87
|
883 Examples
|
Chris@87
|
884 --------
|
Chris@87
|
885 >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1.1, 2.0, np.nan])
|
Chris@87
|
886 >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1, 2.0, np.nan])
|
Chris@87
|
887 ...
|
Chris@87
|
888 <type 'exceptions.ValueError'>:
|
Chris@87
|
889 Arrays are not less-ordered
|
Chris@87
|
890 (mismatch 50.0%)
|
Chris@87
|
891 x: array([ 1., 1., NaN])
|
Chris@87
|
892 y: array([ 1., 2., NaN])
|
Chris@87
|
893
|
Chris@87
|
894 >>> np.testing.assert_array_less([1.0, 4.0], 3)
|
Chris@87
|
895 ...
|
Chris@87
|
896 <type 'exceptions.ValueError'>:
|
Chris@87
|
897 Arrays are not less-ordered
|
Chris@87
|
898 (mismatch 50.0%)
|
Chris@87
|
899 x: array([ 1., 4.])
|
Chris@87
|
900 y: array(3)
|
Chris@87
|
901
|
Chris@87
|
902 >>> np.testing.assert_array_less([1.0, 2.0, 3.0], [4])
|
Chris@87
|
903 ...
|
Chris@87
|
904 <type 'exceptions.ValueError'>:
|
Chris@87
|
905 Arrays are not less-ordered
|
Chris@87
|
906 (shapes (3,), (1,) mismatch)
|
Chris@87
|
907 x: array([ 1., 2., 3.])
|
Chris@87
|
908 y: array([4])
|
Chris@87
|
909
|
Chris@87
|
910 """
|
Chris@87
|
911 assert_array_compare(operator.__lt__, x, y, err_msg=err_msg,
|
Chris@87
|
912 verbose=verbose,
|
Chris@87
|
913 header='Arrays are not less-ordered')
|
Chris@87
|
914
|
Chris@87
|
915 def runstring(astr, dict):
|
Chris@87
|
916 exec(astr, dict)
|
Chris@87
|
917
|
Chris@87
|
918 def assert_string_equal(actual, desired):
|
Chris@87
|
919 """
|
Chris@87
|
920 Test if two strings are equal.
|
Chris@87
|
921
|
Chris@87
|
922 If the given strings are equal, `assert_string_equal` does nothing.
|
Chris@87
|
923 If they are not equal, an AssertionError is raised, and the diff
|
Chris@87
|
924 between the strings is shown.
|
Chris@87
|
925
|
Chris@87
|
926 Parameters
|
Chris@87
|
927 ----------
|
Chris@87
|
928 actual : str
|
Chris@87
|
929 The string to test for equality against the expected string.
|
Chris@87
|
930 desired : str
|
Chris@87
|
931 The expected string.
|
Chris@87
|
932
|
Chris@87
|
933 Examples
|
Chris@87
|
934 --------
|
Chris@87
|
935 >>> np.testing.assert_string_equal('abc', 'abc')
|
Chris@87
|
936 >>> np.testing.assert_string_equal('abc', 'abcd')
|
Chris@87
|
937 Traceback (most recent call last):
|
Chris@87
|
938 File "<stdin>", line 1, in <module>
|
Chris@87
|
939 ...
|
Chris@87
|
940 AssertionError: Differences in strings:
|
Chris@87
|
941 - abc+ abcd? +
|
Chris@87
|
942
|
Chris@87
|
943 """
|
Chris@87
|
944 # delay import of difflib to reduce startup time
|
Chris@87
|
945 import difflib
|
Chris@87
|
946
|
Chris@87
|
947 if not isinstance(actual, str) :
|
Chris@87
|
948 raise AssertionError(repr(type(actual)))
|
Chris@87
|
949 if not isinstance(desired, str):
|
Chris@87
|
950 raise AssertionError(repr(type(desired)))
|
Chris@87
|
951 if re.match(r'\A'+desired+r'\Z', actual, re.M):
|
Chris@87
|
952 return
|
Chris@87
|
953
|
Chris@87
|
954 diff = list(difflib.Differ().compare(actual.splitlines(1), desired.splitlines(1)))
|
Chris@87
|
955 diff_list = []
|
Chris@87
|
956 while diff:
|
Chris@87
|
957 d1 = diff.pop(0)
|
Chris@87
|
958 if d1.startswith(' '):
|
Chris@87
|
959 continue
|
Chris@87
|
960 if d1.startswith('- '):
|
Chris@87
|
961 l = [d1]
|
Chris@87
|
962 d2 = diff.pop(0)
|
Chris@87
|
963 if d2.startswith('? '):
|
Chris@87
|
964 l.append(d2)
|
Chris@87
|
965 d2 = diff.pop(0)
|
Chris@87
|
966 if not d2.startswith('+ ') :
|
Chris@87
|
967 raise AssertionError(repr(d2))
|
Chris@87
|
968 l.append(d2)
|
Chris@87
|
969 d3 = diff.pop(0)
|
Chris@87
|
970 if d3.startswith('? '):
|
Chris@87
|
971 l.append(d3)
|
Chris@87
|
972 else:
|
Chris@87
|
973 diff.insert(0, d3)
|
Chris@87
|
974 if re.match(r'\A'+d2[2:]+r'\Z', d1[2:]):
|
Chris@87
|
975 continue
|
Chris@87
|
976 diff_list.extend(l)
|
Chris@87
|
977 continue
|
Chris@87
|
978 raise AssertionError(repr(d1))
|
Chris@87
|
979 if not diff_list:
|
Chris@87
|
980 return
|
Chris@87
|
981 msg = 'Differences in strings:\n%s' % (''.join(diff_list)).rstrip()
|
Chris@87
|
982 if actual != desired :
|
Chris@87
|
983 raise AssertionError(msg)
|
Chris@87
|
984
|
Chris@87
|
985
|
Chris@87
|
986 def rundocs(filename=None, raise_on_error=True):
|
Chris@87
|
987 """
|
Chris@87
|
988 Run doctests found in the given file.
|
Chris@87
|
989
|
Chris@87
|
990 By default `rundocs` raises an AssertionError on failure.
|
Chris@87
|
991
|
Chris@87
|
992 Parameters
|
Chris@87
|
993 ----------
|
Chris@87
|
994 filename : str
|
Chris@87
|
995 The path to the file for which the doctests are run.
|
Chris@87
|
996 raise_on_error : bool
|
Chris@87
|
997 Whether to raise an AssertionError when a doctest fails. Default is
|
Chris@87
|
998 True.
|
Chris@87
|
999
|
Chris@87
|
1000 Notes
|
Chris@87
|
1001 -----
|
Chris@87
|
1002 The doctests can be run by the user/developer by adding the ``doctests``
|
Chris@87
|
1003 argument to the ``test()`` call. For example, to run all tests (including
|
Chris@87
|
1004 doctests) for `numpy.lib`:
|
Chris@87
|
1005
|
Chris@87
|
1006 >>> np.lib.test(doctests=True) #doctest: +SKIP
|
Chris@87
|
1007 """
|
Chris@87
|
1008 import doctest, imp
|
Chris@87
|
1009 if filename is None:
|
Chris@87
|
1010 f = sys._getframe(1)
|
Chris@87
|
1011 filename = f.f_globals['__file__']
|
Chris@87
|
1012 name = os.path.splitext(os.path.basename(filename))[0]
|
Chris@87
|
1013 path = [os.path.dirname(filename)]
|
Chris@87
|
1014 file, pathname, description = imp.find_module(name, path)
|
Chris@87
|
1015 try:
|
Chris@87
|
1016 m = imp.load_module(name, file, pathname, description)
|
Chris@87
|
1017 finally:
|
Chris@87
|
1018 file.close()
|
Chris@87
|
1019
|
Chris@87
|
1020 tests = doctest.DocTestFinder().find(m)
|
Chris@87
|
1021 runner = doctest.DocTestRunner(verbose=False)
|
Chris@87
|
1022
|
Chris@87
|
1023 msg = []
|
Chris@87
|
1024 if raise_on_error:
|
Chris@87
|
1025 out = lambda s: msg.append(s)
|
Chris@87
|
1026 else:
|
Chris@87
|
1027 out = None
|
Chris@87
|
1028
|
Chris@87
|
1029 for test in tests:
|
Chris@87
|
1030 runner.run(test, out=out)
|
Chris@87
|
1031
|
Chris@87
|
1032 if runner.failures > 0 and raise_on_error:
|
Chris@87
|
1033 raise AssertionError("Some doctests failed:\n%s" % "\n".join(msg))
|
Chris@87
|
1034
|
Chris@87
|
1035
|
Chris@87
|
1036 def raises(*args,**kwargs):
|
Chris@87
|
1037 nose = import_nose()
|
Chris@87
|
1038 return nose.tools.raises(*args,**kwargs)
|
Chris@87
|
1039
|
Chris@87
|
1040
|
Chris@87
|
1041 def assert_raises(*args,**kwargs):
|
Chris@87
|
1042 """
|
Chris@87
|
1043 assert_raises(exception_class, callable, *args, **kwargs)
|
Chris@87
|
1044
|
Chris@87
|
1045 Fail unless an exception of class exception_class is thrown
|
Chris@87
|
1046 by callable when invoked with arguments args and keyword
|
Chris@87
|
1047 arguments kwargs. If a different type of exception is
|
Chris@87
|
1048 thrown, it will not be caught, and the test case will be
|
Chris@87
|
1049 deemed to have suffered an error, exactly as for an
|
Chris@87
|
1050 unexpected exception.
|
Chris@87
|
1051
|
Chris@87
|
1052 """
|
Chris@87
|
1053 nose = import_nose()
|
Chris@87
|
1054 return nose.tools.assert_raises(*args,**kwargs)
|
Chris@87
|
1055
|
Chris@87
|
1056
|
Chris@87
|
1057 assert_raises_regex_impl = None
|
Chris@87
|
1058
|
Chris@87
|
1059
|
Chris@87
|
1060 def assert_raises_regex(exception_class, expected_regexp,
|
Chris@87
|
1061 callable_obj=None, *args, **kwargs):
|
Chris@87
|
1062 """
|
Chris@87
|
1063 Fail unless an exception of class exception_class and with message that
|
Chris@87
|
1064 matches expected_regexp is thrown by callable when invoked with arguments
|
Chris@87
|
1065 args and keyword arguments kwargs.
|
Chris@87
|
1066
|
Chris@87
|
1067 Name of this function adheres to Python 3.2+ reference, but should work in
|
Chris@87
|
1068 all versions down to 2.6.
|
Chris@87
|
1069
|
Chris@87
|
1070 """
|
Chris@87
|
1071 nose = import_nose()
|
Chris@87
|
1072
|
Chris@87
|
1073 global assert_raises_regex_impl
|
Chris@87
|
1074 if assert_raises_regex_impl is None:
|
Chris@87
|
1075 try:
|
Chris@87
|
1076 # Python 3.2+
|
Chris@87
|
1077 assert_raises_regex_impl = nose.tools.assert_raises_regex
|
Chris@87
|
1078 except AttributeError:
|
Chris@87
|
1079 try:
|
Chris@87
|
1080 # 2.7+
|
Chris@87
|
1081 assert_raises_regex_impl = nose.tools.assert_raises_regexp
|
Chris@87
|
1082 except AttributeError:
|
Chris@87
|
1083 # 2.6
|
Chris@87
|
1084
|
Chris@87
|
1085 # This class is copied from Python2.7 stdlib almost verbatim
|
Chris@87
|
1086 class _AssertRaisesContext(object):
|
Chris@87
|
1087 """A context manager used to implement TestCase.assertRaises* methods."""
|
Chris@87
|
1088
|
Chris@87
|
1089 def __init__(self, expected, expected_regexp=None):
|
Chris@87
|
1090 self.expected = expected
|
Chris@87
|
1091 self.expected_regexp = expected_regexp
|
Chris@87
|
1092
|
Chris@87
|
1093 def failureException(self, msg):
|
Chris@87
|
1094 return AssertionError(msg)
|
Chris@87
|
1095
|
Chris@87
|
1096 def __enter__(self):
|
Chris@87
|
1097 return self
|
Chris@87
|
1098
|
Chris@87
|
1099 def __exit__(self, exc_type, exc_value, tb):
|
Chris@87
|
1100 if exc_type is None:
|
Chris@87
|
1101 try:
|
Chris@87
|
1102 exc_name = self.expected.__name__
|
Chris@87
|
1103 except AttributeError:
|
Chris@87
|
1104 exc_name = str(self.expected)
|
Chris@87
|
1105 raise self.failureException(
|
Chris@87
|
1106 "{0} not raised".format(exc_name))
|
Chris@87
|
1107 if not issubclass(exc_type, self.expected):
|
Chris@87
|
1108 # let unexpected exceptions pass through
|
Chris@87
|
1109 return False
|
Chris@87
|
1110 self.exception = exc_value # store for later retrieval
|
Chris@87
|
1111 if self.expected_regexp is None:
|
Chris@87
|
1112 return True
|
Chris@87
|
1113
|
Chris@87
|
1114 expected_regexp = self.expected_regexp
|
Chris@87
|
1115 if isinstance(expected_regexp, basestring):
|
Chris@87
|
1116 expected_regexp = re.compile(expected_regexp)
|
Chris@87
|
1117 if not expected_regexp.search(str(exc_value)):
|
Chris@87
|
1118 raise self.failureException(
|
Chris@87
|
1119 '"%s" does not match "%s"' %
|
Chris@87
|
1120 (expected_regexp.pattern, str(exc_value)))
|
Chris@87
|
1121 return True
|
Chris@87
|
1122
|
Chris@87
|
1123 def impl(cls, regex, callable_obj, *a, **kw):
|
Chris@87
|
1124 mgr = _AssertRaisesContext(cls, regex)
|
Chris@87
|
1125 if callable_obj is None:
|
Chris@87
|
1126 return mgr
|
Chris@87
|
1127 with mgr:
|
Chris@87
|
1128 callable_obj(*a, **kw)
|
Chris@87
|
1129 assert_raises_regex_impl = impl
|
Chris@87
|
1130
|
Chris@87
|
1131 return assert_raises_regex_impl(exception_class, expected_regexp,
|
Chris@87
|
1132 callable_obj, *args, **kwargs)
|
Chris@87
|
1133
|
Chris@87
|
1134
|
Chris@87
|
1135 def decorate_methods(cls, decorator, testmatch=None):
|
Chris@87
|
1136 """
|
Chris@87
|
1137 Apply a decorator to all methods in a class matching a regular expression.
|
Chris@87
|
1138
|
Chris@87
|
1139 The given decorator is applied to all public methods of `cls` that are
|
Chris@87
|
1140 matched by the regular expression `testmatch`
|
Chris@87
|
1141 (``testmatch.search(methodname)``). Methods that are private, i.e. start
|
Chris@87
|
1142 with an underscore, are ignored.
|
Chris@87
|
1143
|
Chris@87
|
1144 Parameters
|
Chris@87
|
1145 ----------
|
Chris@87
|
1146 cls : class
|
Chris@87
|
1147 Class whose methods to decorate.
|
Chris@87
|
1148 decorator : function
|
Chris@87
|
1149 Decorator to apply to methods
|
Chris@87
|
1150 testmatch : compiled regexp or str, optional
|
Chris@87
|
1151 The regular expression. Default value is None, in which case the
|
Chris@87
|
1152 nose default (``re.compile(r'(?:^|[\\b_\\.%s-])[Tt]est' % os.sep)``)
|
Chris@87
|
1153 is used.
|
Chris@87
|
1154 If `testmatch` is a string, it is compiled to a regular expression
|
Chris@87
|
1155 first.
|
Chris@87
|
1156
|
Chris@87
|
1157 """
|
Chris@87
|
1158 if testmatch is None:
|
Chris@87
|
1159 testmatch = re.compile(r'(?:^|[\\b_\\.%s-])[Tt]est' % os.sep)
|
Chris@87
|
1160 else:
|
Chris@87
|
1161 testmatch = re.compile(testmatch)
|
Chris@87
|
1162 cls_attr = cls.__dict__
|
Chris@87
|
1163
|
Chris@87
|
1164 # delayed import to reduce startup time
|
Chris@87
|
1165 from inspect import isfunction
|
Chris@87
|
1166
|
Chris@87
|
1167 methods = [_m for _m in cls_attr.values() if isfunction(_m)]
|
Chris@87
|
1168 for function in methods:
|
Chris@87
|
1169 try:
|
Chris@87
|
1170 if hasattr(function, 'compat_func_name'):
|
Chris@87
|
1171 funcname = function.compat_func_name
|
Chris@87
|
1172 else:
|
Chris@87
|
1173 funcname = function.__name__
|
Chris@87
|
1174 except AttributeError:
|
Chris@87
|
1175 # not a function
|
Chris@87
|
1176 continue
|
Chris@87
|
1177 if testmatch.search(funcname) and not funcname.startswith('_'):
|
Chris@87
|
1178 setattr(cls, funcname, decorator(function))
|
Chris@87
|
1179 return
|
Chris@87
|
1180
|
Chris@87
|
1181
|
Chris@87
|
1182 def measure(code_str,times=1,label=None):
|
Chris@87
|
1183 """
|
Chris@87
|
1184 Return elapsed time for executing code in the namespace of the caller.
|
Chris@87
|
1185
|
Chris@87
|
1186 The supplied code string is compiled with the Python builtin ``compile``.
|
Chris@87
|
1187 The precision of the timing is 10 milli-seconds. If the code will execute
|
Chris@87
|
1188 fast on this timescale, it can be executed many times to get reasonable
|
Chris@87
|
1189 timing accuracy.
|
Chris@87
|
1190
|
Chris@87
|
1191 Parameters
|
Chris@87
|
1192 ----------
|
Chris@87
|
1193 code_str : str
|
Chris@87
|
1194 The code to be timed.
|
Chris@87
|
1195 times : int, optional
|
Chris@87
|
1196 The number of times the code is executed. Default is 1. The code is
|
Chris@87
|
1197 only compiled once.
|
Chris@87
|
1198 label : str, optional
|
Chris@87
|
1199 A label to identify `code_str` with. This is passed into ``compile``
|
Chris@87
|
1200 as the second argument (for run-time error messages).
|
Chris@87
|
1201
|
Chris@87
|
1202 Returns
|
Chris@87
|
1203 -------
|
Chris@87
|
1204 elapsed : float
|
Chris@87
|
1205 Total elapsed time in seconds for executing `code_str` `times` times.
|
Chris@87
|
1206
|
Chris@87
|
1207 Examples
|
Chris@87
|
1208 --------
|
Chris@87
|
1209 >>> etime = np.testing.measure('for i in range(1000): np.sqrt(i**2)',
|
Chris@87
|
1210 ... times=times)
|
Chris@87
|
1211 >>> print "Time for a single execution : ", etime / times, "s"
|
Chris@87
|
1212 Time for a single execution : 0.005 s
|
Chris@87
|
1213
|
Chris@87
|
1214 """
|
Chris@87
|
1215 frame = sys._getframe(1)
|
Chris@87
|
1216 locs, globs = frame.f_locals, frame.f_globals
|
Chris@87
|
1217
|
Chris@87
|
1218 code = compile(code_str,
|
Chris@87
|
1219 'Test name: %s ' % label,
|
Chris@87
|
1220 'exec')
|
Chris@87
|
1221 i = 0
|
Chris@87
|
1222 elapsed = jiffies()
|
Chris@87
|
1223 while i < times:
|
Chris@87
|
1224 i += 1
|
Chris@87
|
1225 exec(code, globs, locs)
|
Chris@87
|
1226 elapsed = jiffies() - elapsed
|
Chris@87
|
1227 return 0.01*elapsed
|
Chris@87
|
1228
|
Chris@87
|
1229 def _assert_valid_refcount(op):
|
Chris@87
|
1230 """
|
Chris@87
|
1231 Check that ufuncs don't mishandle refcount of object `1`.
|
Chris@87
|
1232 Used in a few regression tests.
|
Chris@87
|
1233 """
|
Chris@87
|
1234 import numpy as np
|
Chris@87
|
1235 a = np.arange(100 * 100)
|
Chris@87
|
1236 b = np.arange(100*100).reshape(100, 100)
|
Chris@87
|
1237 c = b
|
Chris@87
|
1238
|
Chris@87
|
1239 i = 1
|
Chris@87
|
1240
|
Chris@87
|
1241 rc = sys.getrefcount(i)
|
Chris@87
|
1242 for j in range(15):
|
Chris@87
|
1243 d = op(b, c)
|
Chris@87
|
1244
|
Chris@87
|
1245 assert_(sys.getrefcount(i) >= rc)
|
Chris@87
|
1246
|
Chris@87
|
1247 def assert_allclose(actual, desired, rtol=1e-7, atol=0,
|
Chris@87
|
1248 err_msg='', verbose=True):
|
Chris@87
|
1249 """
|
Chris@87
|
1250 Raises an AssertionError if two objects are not equal up to desired
|
Chris@87
|
1251 tolerance.
|
Chris@87
|
1252
|
Chris@87
|
1253 The test is equivalent to ``allclose(actual, desired, rtol, atol)``.
|
Chris@87
|
1254 It compares the difference between `actual` and `desired` to
|
Chris@87
|
1255 ``atol + rtol * abs(desired)``.
|
Chris@87
|
1256
|
Chris@87
|
1257 .. versionadded:: 1.5.0
|
Chris@87
|
1258
|
Chris@87
|
1259 Parameters
|
Chris@87
|
1260 ----------
|
Chris@87
|
1261 actual : array_like
|
Chris@87
|
1262 Array obtained.
|
Chris@87
|
1263 desired : array_like
|
Chris@87
|
1264 Array desired.
|
Chris@87
|
1265 rtol : float, optional
|
Chris@87
|
1266 Relative tolerance.
|
Chris@87
|
1267 atol : float, optional
|
Chris@87
|
1268 Absolute tolerance.
|
Chris@87
|
1269 err_msg : str, optional
|
Chris@87
|
1270 The error message to be printed in case of failure.
|
Chris@87
|
1271 verbose : bool, optional
|
Chris@87
|
1272 If True, the conflicting values are appended to the error message.
|
Chris@87
|
1273
|
Chris@87
|
1274 Raises
|
Chris@87
|
1275 ------
|
Chris@87
|
1276 AssertionError
|
Chris@87
|
1277 If actual and desired are not equal up to specified precision.
|
Chris@87
|
1278
|
Chris@87
|
1279 See Also
|
Chris@87
|
1280 --------
|
Chris@87
|
1281 assert_array_almost_equal_nulp, assert_array_max_ulp
|
Chris@87
|
1282
|
Chris@87
|
1283 Examples
|
Chris@87
|
1284 --------
|
Chris@87
|
1285 >>> x = [1e-5, 1e-3, 1e-1]
|
Chris@87
|
1286 >>> y = np.arccos(np.cos(x))
|
Chris@87
|
1287 >>> assert_allclose(x, y, rtol=1e-5, atol=0)
|
Chris@87
|
1288
|
Chris@87
|
1289 """
|
Chris@87
|
1290 import numpy as np
|
Chris@87
|
1291 def compare(x, y):
|
Chris@87
|
1292 return np.allclose(x, y, rtol=rtol, atol=atol)
|
Chris@87
|
1293
|
Chris@87
|
1294 actual, desired = np.asanyarray(actual), np.asanyarray(desired)
|
Chris@87
|
1295 header = 'Not equal to tolerance rtol=%g, atol=%g' % (rtol, atol)
|
Chris@87
|
1296 assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
|
Chris@87
|
1297 verbose=verbose, header=header)
|
Chris@87
|
1298
|
Chris@87
|
1299 def assert_array_almost_equal_nulp(x, y, nulp=1):
|
Chris@87
|
1300 """
|
Chris@87
|
1301 Compare two arrays relatively to their spacing.
|
Chris@87
|
1302
|
Chris@87
|
1303 This is a relatively robust method to compare two arrays whose amplitude
|
Chris@87
|
1304 is variable.
|
Chris@87
|
1305
|
Chris@87
|
1306 Parameters
|
Chris@87
|
1307 ----------
|
Chris@87
|
1308 x, y : array_like
|
Chris@87
|
1309 Input arrays.
|
Chris@87
|
1310 nulp : int, optional
|
Chris@87
|
1311 The maximum number of unit in the last place for tolerance (see Notes).
|
Chris@87
|
1312 Default is 1.
|
Chris@87
|
1313
|
Chris@87
|
1314 Returns
|
Chris@87
|
1315 -------
|
Chris@87
|
1316 None
|
Chris@87
|
1317
|
Chris@87
|
1318 Raises
|
Chris@87
|
1319 ------
|
Chris@87
|
1320 AssertionError
|
Chris@87
|
1321 If the spacing between `x` and `y` for one or more elements is larger
|
Chris@87
|
1322 than `nulp`.
|
Chris@87
|
1323
|
Chris@87
|
1324 See Also
|
Chris@87
|
1325 --------
|
Chris@87
|
1326 assert_array_max_ulp : Check that all items of arrays differ in at most
|
Chris@87
|
1327 N Units in the Last Place.
|
Chris@87
|
1328 spacing : Return the distance between x and the nearest adjacent number.
|
Chris@87
|
1329
|
Chris@87
|
1330 Notes
|
Chris@87
|
1331 -----
|
Chris@87
|
1332 An assertion is raised if the following condition is not met::
|
Chris@87
|
1333
|
Chris@87
|
1334 abs(x - y) <= nulps * spacing(max(abs(x), abs(y)))
|
Chris@87
|
1335
|
Chris@87
|
1336 Examples
|
Chris@87
|
1337 --------
|
Chris@87
|
1338 >>> x = np.array([1., 1e-10, 1e-20])
|
Chris@87
|
1339 >>> eps = np.finfo(x.dtype).eps
|
Chris@87
|
1340 >>> np.testing.assert_array_almost_equal_nulp(x, x*eps/2 + x)
|
Chris@87
|
1341
|
Chris@87
|
1342 >>> np.testing.assert_array_almost_equal_nulp(x, x*eps + x)
|
Chris@87
|
1343 Traceback (most recent call last):
|
Chris@87
|
1344 ...
|
Chris@87
|
1345 AssertionError: X and Y are not equal to 1 ULP (max is 2)
|
Chris@87
|
1346
|
Chris@87
|
1347 """
|
Chris@87
|
1348 import numpy as np
|
Chris@87
|
1349 ax = np.abs(x)
|
Chris@87
|
1350 ay = np.abs(y)
|
Chris@87
|
1351 ref = nulp * np.spacing(np.where(ax > ay, ax, ay))
|
Chris@87
|
1352 if not np.all(np.abs(x-y) <= ref):
|
Chris@87
|
1353 if np.iscomplexobj(x) or np.iscomplexobj(y):
|
Chris@87
|
1354 msg = "X and Y are not equal to %d ULP" % nulp
|
Chris@87
|
1355 else:
|
Chris@87
|
1356 max_nulp = np.max(nulp_diff(x, y))
|
Chris@87
|
1357 msg = "X and Y are not equal to %d ULP (max is %g)" % (nulp, max_nulp)
|
Chris@87
|
1358 raise AssertionError(msg)
|
Chris@87
|
1359
|
Chris@87
|
1360 def assert_array_max_ulp(a, b, maxulp=1, dtype=None):
|
Chris@87
|
1361 """
|
Chris@87
|
1362 Check that all items of arrays differ in at most N Units in the Last Place.
|
Chris@87
|
1363
|
Chris@87
|
1364 Parameters
|
Chris@87
|
1365 ----------
|
Chris@87
|
1366 a, b : array_like
|
Chris@87
|
1367 Input arrays to be compared.
|
Chris@87
|
1368 maxulp : int, optional
|
Chris@87
|
1369 The maximum number of units in the last place that elements of `a` and
|
Chris@87
|
1370 `b` can differ. Default is 1.
|
Chris@87
|
1371 dtype : dtype, optional
|
Chris@87
|
1372 Data-type to convert `a` and `b` to if given. Default is None.
|
Chris@87
|
1373
|
Chris@87
|
1374 Returns
|
Chris@87
|
1375 -------
|
Chris@87
|
1376 ret : ndarray
|
Chris@87
|
1377 Array containing number of representable floating point numbers between
|
Chris@87
|
1378 items in `a` and `b`.
|
Chris@87
|
1379
|
Chris@87
|
1380 Raises
|
Chris@87
|
1381 ------
|
Chris@87
|
1382 AssertionError
|
Chris@87
|
1383 If one or more elements differ by more than `maxulp`.
|
Chris@87
|
1384
|
Chris@87
|
1385 See Also
|
Chris@87
|
1386 --------
|
Chris@87
|
1387 assert_array_almost_equal_nulp : Compare two arrays relatively to their
|
Chris@87
|
1388 spacing.
|
Chris@87
|
1389
|
Chris@87
|
1390 Examples
|
Chris@87
|
1391 --------
|
Chris@87
|
1392 >>> a = np.linspace(0., 1., 100)
|
Chris@87
|
1393 >>> res = np.testing.assert_array_max_ulp(a, np.arcsin(np.sin(a)))
|
Chris@87
|
1394
|
Chris@87
|
1395 """
|
Chris@87
|
1396 import numpy as np
|
Chris@87
|
1397 ret = nulp_diff(a, b, dtype)
|
Chris@87
|
1398 if not np.all(ret <= maxulp):
|
Chris@87
|
1399 raise AssertionError("Arrays are not almost equal up to %g ULP" % \
|
Chris@87
|
1400 maxulp)
|
Chris@87
|
1401 return ret
|
Chris@87
|
1402
|
Chris@87
|
1403 def nulp_diff(x, y, dtype=None):
|
Chris@87
|
1404 """For each item in x and y, return the number of representable floating
|
Chris@87
|
1405 points between them.
|
Chris@87
|
1406
|
Chris@87
|
1407 Parameters
|
Chris@87
|
1408 ----------
|
Chris@87
|
1409 x : array_like
|
Chris@87
|
1410 first input array
|
Chris@87
|
1411 y : array_like
|
Chris@87
|
1412 second input array
|
Chris@87
|
1413
|
Chris@87
|
1414 Returns
|
Chris@87
|
1415 -------
|
Chris@87
|
1416 nulp : array_like
|
Chris@87
|
1417 number of representable floating point numbers between each item in x
|
Chris@87
|
1418 and y.
|
Chris@87
|
1419
|
Chris@87
|
1420 Examples
|
Chris@87
|
1421 --------
|
Chris@87
|
1422 # By definition, epsilon is the smallest number such as 1 + eps != 1, so
|
Chris@87
|
1423 # there should be exactly one ULP between 1 and 1 + eps
|
Chris@87
|
1424 >>> nulp_diff(1, 1 + np.finfo(x.dtype).eps)
|
Chris@87
|
1425 1.0
|
Chris@87
|
1426 """
|
Chris@87
|
1427 import numpy as np
|
Chris@87
|
1428 if dtype:
|
Chris@87
|
1429 x = np.array(x, dtype=dtype)
|
Chris@87
|
1430 y = np.array(y, dtype=dtype)
|
Chris@87
|
1431 else:
|
Chris@87
|
1432 x = np.array(x)
|
Chris@87
|
1433 y = np.array(y)
|
Chris@87
|
1434
|
Chris@87
|
1435 t = np.common_type(x, y)
|
Chris@87
|
1436 if np.iscomplexobj(x) or np.iscomplexobj(y):
|
Chris@87
|
1437 raise NotImplementedError("_nulp not implemented for complex array")
|
Chris@87
|
1438
|
Chris@87
|
1439 x = np.array(x, dtype=t)
|
Chris@87
|
1440 y = np.array(y, dtype=t)
|
Chris@87
|
1441
|
Chris@87
|
1442 if not x.shape == y.shape:
|
Chris@87
|
1443 raise ValueError("x and y do not have the same shape: %s - %s" % \
|
Chris@87
|
1444 (x.shape, y.shape))
|
Chris@87
|
1445
|
Chris@87
|
1446 def _diff(rx, ry, vdt):
|
Chris@87
|
1447 diff = np.array(rx-ry, dtype=vdt)
|
Chris@87
|
1448 return np.abs(diff)
|
Chris@87
|
1449
|
Chris@87
|
1450 rx = integer_repr(x)
|
Chris@87
|
1451 ry = integer_repr(y)
|
Chris@87
|
1452 return _diff(rx, ry, t)
|
Chris@87
|
1453
|
Chris@87
|
1454 def _integer_repr(x, vdt, comp):
|
Chris@87
|
1455 # Reinterpret binary representation of the float as sign-magnitude:
|
Chris@87
|
1456 # take into account two-complement representation
|
Chris@87
|
1457 # See also
|
Chris@87
|
1458 # http://www.cygnus-software.com/papers/comparingfloats/comparingfloats.htm
|
Chris@87
|
1459 rx = x.view(vdt)
|
Chris@87
|
1460 if not (rx.size == 1):
|
Chris@87
|
1461 rx[rx < 0] = comp - rx[rx<0]
|
Chris@87
|
1462 else:
|
Chris@87
|
1463 if rx < 0:
|
Chris@87
|
1464 rx = comp - rx
|
Chris@87
|
1465
|
Chris@87
|
1466 return rx
|
Chris@87
|
1467
|
Chris@87
|
1468 def integer_repr(x):
|
Chris@87
|
1469 """Return the signed-magnitude interpretation of the binary representation of
|
Chris@87
|
1470 x."""
|
Chris@87
|
1471 import numpy as np
|
Chris@87
|
1472 if x.dtype == np.float32:
|
Chris@87
|
1473 return _integer_repr(x, np.int32, np.int32(-2**31))
|
Chris@87
|
1474 elif x.dtype == np.float64:
|
Chris@87
|
1475 return _integer_repr(x, np.int64, np.int64(-2**63))
|
Chris@87
|
1476 else:
|
Chris@87
|
1477 raise ValueError("Unsupported dtype %s" % x.dtype)
|
Chris@87
|
1478
|
Chris@87
|
1479 # The following two classes are copied from python 2.6 warnings module (context
|
Chris@87
|
1480 # manager)
|
Chris@87
|
1481 class WarningMessage(object):
|
Chris@87
|
1482
|
Chris@87
|
1483 """
|
Chris@87
|
1484 Holds the result of a single showwarning() call.
|
Chris@87
|
1485
|
Chris@87
|
1486 Deprecated in 1.8.0
|
Chris@87
|
1487
|
Chris@87
|
1488 Notes
|
Chris@87
|
1489 -----
|
Chris@87
|
1490 `WarningMessage` is copied from the Python 2.6 warnings module,
|
Chris@87
|
1491 so it can be used in NumPy with older Python versions.
|
Chris@87
|
1492
|
Chris@87
|
1493 """
|
Chris@87
|
1494
|
Chris@87
|
1495 _WARNING_DETAILS = ("message", "category", "filename", "lineno", "file",
|
Chris@87
|
1496 "line")
|
Chris@87
|
1497
|
Chris@87
|
1498 def __init__(self, message, category, filename, lineno, file=None,
|
Chris@87
|
1499 line=None):
|
Chris@87
|
1500 local_values = locals()
|
Chris@87
|
1501 for attr in self._WARNING_DETAILS:
|
Chris@87
|
1502 setattr(self, attr, local_values[attr])
|
Chris@87
|
1503 if category:
|
Chris@87
|
1504 self._category_name = category.__name__
|
Chris@87
|
1505 else:
|
Chris@87
|
1506 self._category_name = None
|
Chris@87
|
1507
|
Chris@87
|
1508 def __str__(self):
|
Chris@87
|
1509 return ("{message : %r, category : %r, filename : %r, lineno : %s, "
|
Chris@87
|
1510 "line : %r}" % (self.message, self._category_name,
|
Chris@87
|
1511 self.filename, self.lineno, self.line))
|
Chris@87
|
1512
|
Chris@87
|
1513 class WarningManager(object):
|
Chris@87
|
1514 """
|
Chris@87
|
1515 A context manager that copies and restores the warnings filter upon
|
Chris@87
|
1516 exiting the context.
|
Chris@87
|
1517
|
Chris@87
|
1518 The 'record' argument specifies whether warnings should be captured by a
|
Chris@87
|
1519 custom implementation of ``warnings.showwarning()`` and be appended to a
|
Chris@87
|
1520 list returned by the context manager. Otherwise None is returned by the
|
Chris@87
|
1521 context manager. The objects appended to the list are arguments whose
|
Chris@87
|
1522 attributes mirror the arguments to ``showwarning()``.
|
Chris@87
|
1523
|
Chris@87
|
1524 The 'module' argument is to specify an alternative module to the module
|
Chris@87
|
1525 named 'warnings' and imported under that name. This argument is only useful
|
Chris@87
|
1526 when testing the warnings module itself.
|
Chris@87
|
1527
|
Chris@87
|
1528 Deprecated in 1.8.0
|
Chris@87
|
1529
|
Chris@87
|
1530 Notes
|
Chris@87
|
1531 -----
|
Chris@87
|
1532 `WarningManager` is a copy of the ``catch_warnings`` context manager
|
Chris@87
|
1533 from the Python 2.6 warnings module, with slight modifications.
|
Chris@87
|
1534 It is copied so it can be used in NumPy with older Python versions.
|
Chris@87
|
1535
|
Chris@87
|
1536 """
|
Chris@87
|
1537 def __init__(self, record=False, module=None):
|
Chris@87
|
1538 self._record = record
|
Chris@87
|
1539 if module is None:
|
Chris@87
|
1540 self._module = sys.modules['warnings']
|
Chris@87
|
1541 else:
|
Chris@87
|
1542 self._module = module
|
Chris@87
|
1543 self._entered = False
|
Chris@87
|
1544
|
Chris@87
|
1545 def __enter__(self):
|
Chris@87
|
1546 if self._entered:
|
Chris@87
|
1547 raise RuntimeError("Cannot enter %r twice" % self)
|
Chris@87
|
1548 self._entered = True
|
Chris@87
|
1549 self._filters = self._module.filters
|
Chris@87
|
1550 self._module.filters = self._filters[:]
|
Chris@87
|
1551 self._showwarning = self._module.showwarning
|
Chris@87
|
1552 if self._record:
|
Chris@87
|
1553 log = []
|
Chris@87
|
1554 def showwarning(*args, **kwargs):
|
Chris@87
|
1555 log.append(WarningMessage(*args, **kwargs))
|
Chris@87
|
1556 self._module.showwarning = showwarning
|
Chris@87
|
1557 return log
|
Chris@87
|
1558 else:
|
Chris@87
|
1559 return None
|
Chris@87
|
1560
|
Chris@87
|
1561 def __exit__(self):
|
Chris@87
|
1562 if not self._entered:
|
Chris@87
|
1563 raise RuntimeError("Cannot exit %r without entering first" % self)
|
Chris@87
|
1564 self._module.filters = self._filters
|
Chris@87
|
1565 self._module.showwarning = self._showwarning
|
Chris@87
|
1566
|
Chris@87
|
1567
|
Chris@87
|
1568 def assert_warns(warning_class, func, *args, **kw):
|
Chris@87
|
1569 """
|
Chris@87
|
1570 Fail unless the given callable throws the specified warning.
|
Chris@87
|
1571
|
Chris@87
|
1572 A warning of class warning_class should be thrown by the callable when
|
Chris@87
|
1573 invoked with arguments args and keyword arguments kwargs.
|
Chris@87
|
1574 If a different type of warning is thrown, it will not be caught, and the
|
Chris@87
|
1575 test case will be deemed to have suffered an error.
|
Chris@87
|
1576
|
Chris@87
|
1577 .. versionadded:: 1.4.0
|
Chris@87
|
1578
|
Chris@87
|
1579 Parameters
|
Chris@87
|
1580 ----------
|
Chris@87
|
1581 warning_class : class
|
Chris@87
|
1582 The class defining the warning that `func` is expected to throw.
|
Chris@87
|
1583 func : callable
|
Chris@87
|
1584 The callable to test.
|
Chris@87
|
1585 \\*args : Arguments
|
Chris@87
|
1586 Arguments passed to `func`.
|
Chris@87
|
1587 \\*\\*kwargs : Kwargs
|
Chris@87
|
1588 Keyword arguments passed to `func`.
|
Chris@87
|
1589
|
Chris@87
|
1590 Returns
|
Chris@87
|
1591 -------
|
Chris@87
|
1592 The value returned by `func`.
|
Chris@87
|
1593
|
Chris@87
|
1594 """
|
Chris@87
|
1595 with warnings.catch_warnings(record=True) as l:
|
Chris@87
|
1596 warnings.simplefilter('always')
|
Chris@87
|
1597 result = func(*args, **kw)
|
Chris@87
|
1598 if not len(l) > 0:
|
Chris@87
|
1599 raise AssertionError("No warning raised when calling %s"
|
Chris@87
|
1600 % func.__name__)
|
Chris@87
|
1601 if not l[0].category is warning_class:
|
Chris@87
|
1602 raise AssertionError("First warning for %s is not a " \
|
Chris@87
|
1603 "%s( is %s)" % (func.__name__, warning_class, l[0]))
|
Chris@87
|
1604 return result
|
Chris@87
|
1605
|
Chris@87
|
1606 def assert_no_warnings(func, *args, **kw):
|
Chris@87
|
1607 """
|
Chris@87
|
1608 Fail if the given callable produces any warnings.
|
Chris@87
|
1609
|
Chris@87
|
1610 .. versionadded:: 1.7.0
|
Chris@87
|
1611
|
Chris@87
|
1612 Parameters
|
Chris@87
|
1613 ----------
|
Chris@87
|
1614 func : callable
|
Chris@87
|
1615 The callable to test.
|
Chris@87
|
1616 \\*args : Arguments
|
Chris@87
|
1617 Arguments passed to `func`.
|
Chris@87
|
1618 \\*\\*kwargs : Kwargs
|
Chris@87
|
1619 Keyword arguments passed to `func`.
|
Chris@87
|
1620
|
Chris@87
|
1621 Returns
|
Chris@87
|
1622 -------
|
Chris@87
|
1623 The value returned by `func`.
|
Chris@87
|
1624
|
Chris@87
|
1625 """
|
Chris@87
|
1626 with warnings.catch_warnings(record=True) as l:
|
Chris@87
|
1627 warnings.simplefilter('always')
|
Chris@87
|
1628 result = func(*args, **kw)
|
Chris@87
|
1629 if len(l) > 0:
|
Chris@87
|
1630 raise AssertionError("Got warnings when calling %s: %s"
|
Chris@87
|
1631 % (func.__name__, l))
|
Chris@87
|
1632 return result
|
Chris@87
|
1633
|
Chris@87
|
1634
|
Chris@87
|
1635 def _gen_alignment_data(dtype=float32, type='binary', max_size=24):
|
Chris@87
|
1636 """
|
Chris@87
|
1637 generator producing data with different alignment and offsets
|
Chris@87
|
1638 to test simd vectorization
|
Chris@87
|
1639
|
Chris@87
|
1640 Parameters
|
Chris@87
|
1641 ----------
|
Chris@87
|
1642 dtype : dtype
|
Chris@87
|
1643 data type to produce
|
Chris@87
|
1644 type : string
|
Chris@87
|
1645 'unary': create data for unary operations, creates one input
|
Chris@87
|
1646 and output array
|
Chris@87
|
1647 'binary': create data for unary operations, creates two input
|
Chris@87
|
1648 and output array
|
Chris@87
|
1649 max_size : integer
|
Chris@87
|
1650 maximum size of data to produce
|
Chris@87
|
1651
|
Chris@87
|
1652 Returns
|
Chris@87
|
1653 -------
|
Chris@87
|
1654 if type is 'unary' yields one output, one input array and a message
|
Chris@87
|
1655 containing information on the data
|
Chris@87
|
1656 if type is 'binary' yields one output array, two input array and a message
|
Chris@87
|
1657 containing information on the data
|
Chris@87
|
1658
|
Chris@87
|
1659 """
|
Chris@87
|
1660 ufmt = 'unary offset=(%d, %d), size=%d, dtype=%r, %s'
|
Chris@87
|
1661 bfmt = 'binary offset=(%d, %d, %d), size=%d, dtype=%r, %s'
|
Chris@87
|
1662 for o in range(3):
|
Chris@87
|
1663 for s in range(o + 2, max(o + 3, max_size)):
|
Chris@87
|
1664 if type == 'unary':
|
Chris@87
|
1665 inp = lambda : arange(s, dtype=dtype)[o:]
|
Chris@87
|
1666 out = empty((s,), dtype=dtype)[o:]
|
Chris@87
|
1667 yield out, inp(), ufmt % (o, o, s, dtype, 'out of place')
|
Chris@87
|
1668 yield inp(), inp(), ufmt % (o, o, s, dtype, 'in place')
|
Chris@87
|
1669 yield out[1:], inp()[:-1], ufmt % \
|
Chris@87
|
1670 (o + 1, o, s - 1, dtype, 'out of place')
|
Chris@87
|
1671 yield out[:-1], inp()[1:], ufmt % \
|
Chris@87
|
1672 (o, o + 1, s - 1, dtype, 'out of place')
|
Chris@87
|
1673 yield inp()[:-1], inp()[1:], ufmt % \
|
Chris@87
|
1674 (o, o + 1, s - 1, dtype, 'aliased')
|
Chris@87
|
1675 yield inp()[1:], inp()[:-1], ufmt % \
|
Chris@87
|
1676 (o + 1, o, s - 1, dtype, 'aliased')
|
Chris@87
|
1677 if type == 'binary':
|
Chris@87
|
1678 inp1 = lambda :arange(s, dtype=dtype)[o:]
|
Chris@87
|
1679 inp2 = lambda :arange(s, dtype=dtype)[o:]
|
Chris@87
|
1680 out = empty((s,), dtype=dtype)[o:]
|
Chris@87
|
1681 yield out, inp1(), inp2(), bfmt % \
|
Chris@87
|
1682 (o, o, o, s, dtype, 'out of place')
|
Chris@87
|
1683 yield inp1(), inp1(), inp2(), bfmt % \
|
Chris@87
|
1684 (o, o, o, s, dtype, 'in place1')
|
Chris@87
|
1685 yield inp2(), inp1(), inp2(), bfmt % \
|
Chris@87
|
1686 (o, o, o, s, dtype, 'in place2')
|
Chris@87
|
1687 yield out[1:], inp1()[:-1], inp2()[:-1], bfmt % \
|
Chris@87
|
1688 (o + 1, o, o, s - 1, dtype, 'out of place')
|
Chris@87
|
1689 yield out[:-1], inp1()[1:], inp2()[:-1], bfmt % \
|
Chris@87
|
1690 (o, o + 1, o, s - 1, dtype, 'out of place')
|
Chris@87
|
1691 yield out[:-1], inp1()[:-1], inp2()[1:], bfmt % \
|
Chris@87
|
1692 (o, o, o + 1, s - 1, dtype, 'out of place')
|
Chris@87
|
1693 yield inp1()[1:], inp1()[:-1], inp2()[:-1], bfmt % \
|
Chris@87
|
1694 (o + 1, o, o, s - 1, dtype, 'aliased')
|
Chris@87
|
1695 yield inp1()[:-1], inp1()[1:], inp2()[:-1], bfmt % \
|
Chris@87
|
1696 (o, o + 1, o, s - 1, dtype, 'aliased')
|
Chris@87
|
1697 yield inp1()[:-1], inp1()[:-1], inp2()[1:], bfmt % \
|
Chris@87
|
1698 (o, o, o + 1, s - 1, dtype, 'aliased')
|
Chris@87
|
1699
|
Chris@87
|
1700
|
Chris@87
|
1701 class IgnoreException(Exception):
|
Chris@87
|
1702 "Ignoring this exception due to disabled feature"
|
Chris@87
|
1703
|
Chris@87
|
1704
|
Chris@87
|
1705 @contextlib.contextmanager
|
Chris@87
|
1706 def tempdir(*args, **kwargs):
|
Chris@87
|
1707 """Context manager to provide a temporary test folder.
|
Chris@87
|
1708
|
Chris@87
|
1709 All arguments are passed as this to the underlying tempfile.mkdtemp
|
Chris@87
|
1710 function.
|
Chris@87
|
1711
|
Chris@87
|
1712 """
|
Chris@87
|
1713 tmpdir = mkdtemp(*args, **kwargs)
|
Chris@87
|
1714 yield tmpdir
|
Chris@87
|
1715 shutil.rmtree(tmpdir)
|