Coverage for /home/martinb/.local/share/virtualenvs/camcops/lib/python3.6/site-packages/numpy/testing/_private/utils.py : 11%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2Utility function to facilitate testing.
4"""
5import os
6import sys
7import platform
8import re
9import gc
10import operator
11import warnings
12from functools import partial, wraps
13import shutil
14import contextlib
15from tempfile import mkdtemp, mkstemp
16from unittest.case import SkipTest
17from warnings import WarningMessage
18import pprint
20from numpy.core import(
21 intp, float32, empty, arange, array_repr, ndarray, isnat, array)
22import numpy.linalg.lapack_lite
24from io import StringIO
26__all__ = [
27 'assert_equal', 'assert_almost_equal', 'assert_approx_equal',
28 'assert_array_equal', 'assert_array_less', 'assert_string_equal',
29 'assert_array_almost_equal', 'assert_raises', 'build_err_msg',
30 'decorate_methods', 'jiffies', 'memusage', 'print_assert_equal',
31 'raises', 'rundocs', 'runstring', 'verbose', 'measure',
32 'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex',
33 'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings',
34 'assert_allclose', 'IgnoreException', 'clear_and_catch_warnings',
35 'SkipTest', 'KnownFailureException', 'temppath', 'tempdir', 'IS_PYPY',
36 'HAS_REFCOUNT', 'suppress_warnings', 'assert_array_compare',
37 '_assert_valid_refcount', '_gen_alignment_data', 'assert_no_gc_cycles',
38 'break_cycles', 'HAS_LAPACK64'
39 ]
42class KnownFailureException(Exception):
43 '''Raise this exception to mark a test as a known failing test.'''
44 pass
47KnownFailureTest = KnownFailureException # backwards compat
48verbose = 0
50IS_PYPY = platform.python_implementation() == 'PyPy'
51HAS_REFCOUNT = getattr(sys, 'getrefcount', None) is not None
52HAS_LAPACK64 = numpy.linalg.lapack_lite._ilp64
55def import_nose():
56 """ Import nose only when needed.
57 """
58 nose_is_good = True
59 minimum_nose_version = (1, 0, 0)
60 try:
61 import nose
62 except ImportError:
63 nose_is_good = False
64 else:
65 if nose.__versioninfo__ < minimum_nose_version:
66 nose_is_good = False
68 if not nose_is_good:
69 msg = ('Need nose >= %d.%d.%d for tests - see '
70 'https://nose.readthedocs.io' %
71 minimum_nose_version)
72 raise ImportError(msg)
74 return nose
77def assert_(val, msg=''):
78 """
79 Assert that works in release mode.
80 Accepts callable msg to allow deferring evaluation until failure.
82 The Python built-in ``assert`` does not work when executing code in
83 optimized mode (the ``-O`` flag) - no byte-code is generated for it.
85 For documentation on usage, refer to the Python documentation.
87 """
88 __tracebackhide__ = True # Hide traceback for py.test
89 if not val:
90 try:
91 smsg = msg()
92 except TypeError:
93 smsg = msg
94 raise AssertionError(smsg)
97def gisnan(x):
98 """like isnan, but always raise an error if type not supported instead of
99 returning a TypeError object.
101 Notes
102 -----
103 isnan and other ufunc sometimes return a NotImplementedType object instead
104 of raising any exception. This function is a wrapper to make sure an
105 exception is always raised.
107 This should be removed once this problem is solved at the Ufunc level."""
108 from numpy.core import isnan
109 st = isnan(x)
110 if isinstance(st, type(NotImplemented)):
111 raise TypeError("isnan not supported for this type")
112 return st
115def gisfinite(x):
116 """like isfinite, but always raise an error if type not supported instead of
117 returning a TypeError object.
119 Notes
120 -----
121 isfinite and other ufunc sometimes return a NotImplementedType object instead
122 of raising any exception. This function is a wrapper to make sure an
123 exception is always raised.
125 This should be removed once this problem is solved at the Ufunc level."""
126 from numpy.core import isfinite, errstate
127 with errstate(invalid='ignore'):
128 st = isfinite(x)
129 if isinstance(st, type(NotImplemented)):
130 raise TypeError("isfinite not supported for this type")
131 return st
134def gisinf(x):
135 """like isinf, but always raise an error if type not supported instead of
136 returning a TypeError object.
138 Notes
139 -----
140 isinf and other ufunc sometimes return a NotImplementedType object instead
141 of raising any exception. This function is a wrapper to make sure an
142 exception is always raised.
144 This should be removed once this problem is solved at the Ufunc level."""
145 from numpy.core import isinf, errstate
146 with errstate(invalid='ignore'):
147 st = isinf(x)
148 if isinstance(st, type(NotImplemented)):
149 raise TypeError("isinf not supported for this type")
150 return st
153if os.name == 'nt':
154 # Code "stolen" from enthought/debug/memusage.py
155 def GetPerformanceAttributes(object, counter, instance=None,
156 inum=-1, format=None, machine=None):
157 # NOTE: Many counters require 2 samples to give accurate results,
158 # including "% Processor Time" (as by definition, at any instant, a
159 # thread's CPU usage is either 0 or 100). To read counters like this,
160 # you should copy this function, but keep the counter open, and call
161 # CollectQueryData() each time you need to know.
162 # See http://msdn.microsoft.com/library/en-us/dnperfmo/html/perfmonpt2.asp (dead link)
163 # My older explanation for this was that the "AddCounter" process forced
164 # the CPU to 100%, but the above makes more sense :)
165 import win32pdh
166 if format is None:
167 format = win32pdh.PDH_FMT_LONG
168 path = win32pdh.MakeCounterPath( (machine, object, instance, None, inum, counter))
169 hq = win32pdh.OpenQuery()
170 try:
171 hc = win32pdh.AddCounter(hq, path)
172 try:
173 win32pdh.CollectQueryData(hq)
174 type, val = win32pdh.GetFormattedCounterValue(hc, format)
175 return val
176 finally:
177 win32pdh.RemoveCounter(hc)
178 finally:
179 win32pdh.CloseQuery(hq)
181 def memusage(processName="python", instance=0):
182 # from win32pdhutil, part of the win32all package
183 import win32pdh
184 return GetPerformanceAttributes("Process", "Virtual Bytes",
185 processName, instance,
186 win32pdh.PDH_FMT_LONG, None)
187elif sys.platform[:5] == 'linux':
189 def memusage(_proc_pid_stat='/proc/%s/stat' % (os.getpid())):
190 """
191 Return virtual memory size in bytes of the running python.
193 """
194 try:
195 with open(_proc_pid_stat, 'r') as f:
196 l = f.readline().split(' ')
197 return int(l[22])
198 except Exception:
199 return
200else:
201 def memusage():
202 """
203 Return memory usage of running python. [Not implemented]
205 """
206 raise NotImplementedError
209if sys.platform[:5] == 'linux':
210 def jiffies(_proc_pid_stat='/proc/%s/stat' % (os.getpid()),
211 _load_time=[]):
212 """
213 Return number of jiffies elapsed.
215 Return number of jiffies (1/100ths of a second) that this
216 process has been scheduled in user mode. See man 5 proc.
218 """
219 import time
220 if not _load_time:
221 _load_time.append(time.time())
222 try:
223 with open(_proc_pid_stat, 'r') as f:
224 l = f.readline().split(' ')
225 return int(l[13])
226 except Exception:
227 return int(100*(time.time()-_load_time[0]))
228else:
229 # os.getpid is not in all platforms available.
230 # Using time is safe but inaccurate, especially when process
231 # was suspended or sleeping.
232 def jiffies(_load_time=[]):
233 """
234 Return number of jiffies elapsed.
236 Return number of jiffies (1/100ths of a second) that this
237 process has been scheduled in user mode. See man 5 proc.
239 """
240 import time
241 if not _load_time:
242 _load_time.append(time.time())
243 return int(100*(time.time()-_load_time[0]))
246def build_err_msg(arrays, err_msg, header='Items are not equal:',
247 verbose=True, names=('ACTUAL', 'DESIRED'), precision=8):
248 msg = ['\n' + header]
249 if err_msg:
250 if err_msg.find('\n') == -1 and len(err_msg) < 79-len(header):
251 msg = [msg[0] + ' ' + err_msg]
252 else:
253 msg.append(err_msg)
254 if verbose:
255 for i, a in enumerate(arrays):
257 if isinstance(a, ndarray):
258 # precision argument is only needed if the objects are ndarrays
259 r_func = partial(array_repr, precision=precision)
260 else:
261 r_func = repr
263 try:
264 r = r_func(a)
265 except Exception as exc:
266 r = '[repr failed for <{}>: {}]'.format(type(a).__name__, exc)
267 if r.count('\n') > 3:
268 r = '\n'.join(r.splitlines()[:3])
269 r += '...'
270 msg.append(' %s: %s' % (names[i], r))
271 return '\n'.join(msg)
274def assert_equal(actual, desired, err_msg='', verbose=True):
275 """
276 Raises an AssertionError if two objects are not equal.
278 Given two objects (scalars, lists, tuples, dictionaries or numpy arrays),
279 check that all elements of these objects are equal. An exception is raised
280 at the first conflicting values.
282 When one of `actual` and `desired` is a scalar and the other is array_like,
283 the function checks that each element of the array_like object is equal to
284 the scalar.
286 This function handles NaN comparisons as if NaN was a "normal" number.
287 That is, AssertionError is not raised if both objects have NaNs in the same
288 positions. This is in contrast to the IEEE standard on NaNs, which says
289 that NaN compared to anything must return False.
291 Parameters
292 ----------
293 actual : array_like
294 The object to check.
295 desired : array_like
296 The expected object.
297 err_msg : str, optional
298 The error message to be printed in case of failure.
299 verbose : bool, optional
300 If True, the conflicting values are appended to the error message.
302 Raises
303 ------
304 AssertionError
305 If actual and desired are not equal.
307 Examples
308 --------
309 >>> np.testing.assert_equal([4,5], [4,6])
310 Traceback (most recent call last):
311 ...
312 AssertionError:
313 Items are not equal:
314 item=1
315 ACTUAL: 5
316 DESIRED: 6
318 The following comparison does not raise an exception. There are NaNs
319 in the inputs, but they are in the same positions.
321 >>> np.testing.assert_equal(np.array([1.0, 2.0, np.nan]), [1, 2, np.nan])
323 """
324 __tracebackhide__ = True # Hide traceback for py.test
325 if isinstance(desired, dict):
326 if not isinstance(actual, dict):
327 raise AssertionError(repr(type(actual)))
328 assert_equal(len(actual), len(desired), err_msg, verbose)
329 for k, i in desired.items():
330 if k not in actual:
331 raise AssertionError(repr(k))
332 assert_equal(actual[k], desired[k], 'key=%r\n%s' % (k, err_msg), verbose)
333 return
334 if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)):
335 assert_equal(len(actual), len(desired), err_msg, verbose)
336 for k in range(len(desired)):
337 assert_equal(actual[k], desired[k], 'item=%r\n%s' % (k, err_msg), verbose)
338 return
339 from numpy.core import ndarray, isscalar, signbit
340 from numpy.lib import iscomplexobj, real, imag
341 if isinstance(actual, ndarray) or isinstance(desired, ndarray):
342 return assert_array_equal(actual, desired, err_msg, verbose)
343 msg = build_err_msg([actual, desired], err_msg, verbose=verbose)
345 # Handle complex numbers: separate into real/imag to handle
346 # nan/inf/negative zero correctly
347 # XXX: catch ValueError for subclasses of ndarray where iscomplex fail
348 try:
349 usecomplex = iscomplexobj(actual) or iscomplexobj(desired)
350 except (ValueError, TypeError):
351 usecomplex = False
353 if usecomplex:
354 if iscomplexobj(actual):
355 actualr = real(actual)
356 actuali = imag(actual)
357 else:
358 actualr = actual
359 actuali = 0
360 if iscomplexobj(desired):
361 desiredr = real(desired)
362 desiredi = imag(desired)
363 else:
364 desiredr = desired
365 desiredi = 0
366 try:
367 assert_equal(actualr, desiredr)
368 assert_equal(actuali, desiredi)
369 except AssertionError:
370 raise AssertionError(msg)
372 # isscalar test to check cases such as [np.nan] != np.nan
373 if isscalar(desired) != isscalar(actual):
374 raise AssertionError(msg)
376 try:
377 isdesnat = isnat(desired)
378 isactnat = isnat(actual)
379 dtypes_match = array(desired).dtype.type == array(actual).dtype.type
380 if isdesnat and isactnat:
381 # If both are NaT (and have the same dtype -- datetime or
382 # timedelta) they are considered equal.
383 if dtypes_match:
384 return
385 else:
386 raise AssertionError(msg)
388 except (TypeError, ValueError, NotImplementedError):
389 pass
391 # Inf/nan/negative zero handling
392 try:
393 isdesnan = gisnan(desired)
394 isactnan = gisnan(actual)
395 if isdesnan and isactnan:
396 return # both nan, so equal
398 # handle signed zero specially for floats
399 array_actual = array(actual)
400 array_desired = array(desired)
401 if (array_actual.dtype.char in 'Mm' or
402 array_desired.dtype.char in 'Mm'):
403 # version 1.18
404 # until this version, gisnan failed for datetime64 and timedelta64.
405 # Now it succeeds but comparison to scalar with a different type
406 # emits a DeprecationWarning.
407 # Avoid that by skipping the next check
408 raise NotImplementedError('cannot compare to a scalar '
409 'with a different type')
411 if desired == 0 and actual == 0:
412 if not signbit(desired) == signbit(actual):
413 raise AssertionError(msg)
415 except (TypeError, ValueError, NotImplementedError):
416 pass
418 try:
419 # Explicitly use __eq__ for comparison, gh-2552
420 if not (desired == actual):
421 raise AssertionError(msg)
423 except (DeprecationWarning, FutureWarning) as e:
424 # this handles the case when the two types are not even comparable
425 if 'elementwise == comparison' in e.args[0]:
426 raise AssertionError(msg)
427 else:
428 raise
431def print_assert_equal(test_string, actual, desired):
432 """
433 Test if two objects are equal, and print an error message if test fails.
435 The test is performed with ``actual == desired``.
437 Parameters
438 ----------
439 test_string : str
440 The message supplied to AssertionError.
441 actual : object
442 The object to test for equality against `desired`.
443 desired : object
444 The expected result.
446 Examples
447 --------
448 >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 1])
449 >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 2])
450 Traceback (most recent call last):
451 ...
452 AssertionError: Test XYZ of func xyz failed
453 ACTUAL:
454 [0, 1]
455 DESIRED:
456 [0, 2]
458 """
459 __tracebackhide__ = True # Hide traceback for py.test
460 import pprint
462 if not (actual == desired):
463 msg = StringIO()
464 msg.write(test_string)
465 msg.write(' failed\nACTUAL: \n')
466 pprint.pprint(actual, msg)
467 msg.write('DESIRED: \n')
468 pprint.pprint(desired, msg)
469 raise AssertionError(msg.getvalue())
472def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True):
473 """
474 Raises an AssertionError if two items are not equal up to desired
475 precision.
477 .. note:: It is recommended to use one of `assert_allclose`,
478 `assert_array_almost_equal_nulp` or `assert_array_max_ulp`
479 instead of this function for more consistent floating point
480 comparisons.
482 The test verifies that the elements of ``actual`` and ``desired`` satisfy.
484 ``abs(desired-actual) < 1.5 * 10**(-decimal)``
486 That is a looser test than originally documented, but agrees with what the
487 actual implementation in `assert_array_almost_equal` did up to rounding
488 vagaries. An exception is raised at conflicting values. For ndarrays this
489 delegates to assert_array_almost_equal
491 Parameters
492 ----------
493 actual : array_like
494 The object to check.
495 desired : array_like
496 The expected object.
497 decimal : int, optional
498 Desired precision, default is 7.
499 err_msg : str, optional
500 The error message to be printed in case of failure.
501 verbose : bool, optional
502 If True, the conflicting values are appended to the error message.
504 Raises
505 ------
506 AssertionError
507 If actual and desired are not equal up to specified precision.
509 See Also
510 --------
511 assert_allclose: Compare two array_like objects for equality with desired
512 relative and/or absolute precision.
513 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal
515 Examples
516 --------
517 >>> import numpy.testing as npt
518 >>> npt.assert_almost_equal(2.3333333333333, 2.33333334)
519 >>> npt.assert_almost_equal(2.3333333333333, 2.33333334, decimal=10)
520 Traceback (most recent call last):
521 ...
522 AssertionError:
523 Arrays are not almost equal to 10 decimals
524 ACTUAL: 2.3333333333333
525 DESIRED: 2.33333334
527 >>> npt.assert_almost_equal(np.array([1.0,2.3333333333333]),
528 ... np.array([1.0,2.33333334]), decimal=9)
529 Traceback (most recent call last):
530 ...
531 AssertionError:
532 Arrays are not almost equal to 9 decimals
533 <BLANKLINE>
534 Mismatched elements: 1 / 2 (50%)
535 Max absolute difference: 6.66669964e-09
536 Max relative difference: 2.85715698e-09
537 x: array([1. , 2.333333333])
538 y: array([1. , 2.33333334])
540 """
541 __tracebackhide__ = True # Hide traceback for py.test
542 from numpy.core import ndarray
543 from numpy.lib import iscomplexobj, real, imag
545 # Handle complex numbers: separate into real/imag to handle
546 # nan/inf/negative zero correctly
547 # XXX: catch ValueError for subclasses of ndarray where iscomplex fail
548 try:
549 usecomplex = iscomplexobj(actual) or iscomplexobj(desired)
550 except ValueError:
551 usecomplex = False
553 def _build_err_msg():
554 header = ('Arrays are not almost equal to %d decimals' % decimal)
555 return build_err_msg([actual, desired], err_msg, verbose=verbose,
556 header=header)
558 if usecomplex:
559 if iscomplexobj(actual):
560 actualr = real(actual)
561 actuali = imag(actual)
562 else:
563 actualr = actual
564 actuali = 0
565 if iscomplexobj(desired):
566 desiredr = real(desired)
567 desiredi = imag(desired)
568 else:
569 desiredr = desired
570 desiredi = 0
571 try:
572 assert_almost_equal(actualr, desiredr, decimal=decimal)
573 assert_almost_equal(actuali, desiredi, decimal=decimal)
574 except AssertionError:
575 raise AssertionError(_build_err_msg())
577 if isinstance(actual, (ndarray, tuple, list)) \
578 or isinstance(desired, (ndarray, tuple, list)):
579 return assert_array_almost_equal(actual, desired, decimal, err_msg)
580 try:
581 # If one of desired/actual is not finite, handle it specially here:
582 # check that both are nan if any is a nan, and test for equality
583 # otherwise
584 if not (gisfinite(desired) and gisfinite(actual)):
585 if gisnan(desired) or gisnan(actual):
586 if not (gisnan(desired) and gisnan(actual)):
587 raise AssertionError(_build_err_msg())
588 else:
589 if not desired == actual:
590 raise AssertionError(_build_err_msg())
591 return
592 except (NotImplementedError, TypeError):
593 pass
594 if abs(desired - actual) >= 1.5 * 10.0**(-decimal):
595 raise AssertionError(_build_err_msg())
598def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True):
599 """
600 Raises an AssertionError if two items are not equal up to significant
601 digits.
603 .. note:: It is recommended to use one of `assert_allclose`,
604 `assert_array_almost_equal_nulp` or `assert_array_max_ulp`
605 instead of this function for more consistent floating point
606 comparisons.
608 Given two numbers, check that they are approximately equal.
609 Approximately equal is defined as the number of significant digits
610 that agree.
612 Parameters
613 ----------
614 actual : scalar
615 The object to check.
616 desired : scalar
617 The expected object.
618 significant : int, optional
619 Desired precision, default is 7.
620 err_msg : str, optional
621 The error message to be printed in case of failure.
622 verbose : bool, optional
623 If True, the conflicting values are appended to the error message.
625 Raises
626 ------
627 AssertionError
628 If actual and desired are not equal up to specified precision.
630 See Also
631 --------
632 assert_allclose: Compare two array_like objects for equality with desired
633 relative and/or absolute precision.
634 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal
636 Examples
637 --------
638 >>> np.testing.assert_approx_equal(0.12345677777777e-20, 0.1234567e-20)
639 >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345671e-20,
640 ... significant=8)
641 >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345672e-20,
642 ... significant=8)
643 Traceback (most recent call last):
644 ...
645 AssertionError:
646 Items are not equal to 8 significant digits:
647 ACTUAL: 1.234567e-21
648 DESIRED: 1.2345672e-21
650 the evaluated condition that raises the exception is
652 >>> abs(0.12345670e-20/1e-21 - 0.12345672e-20/1e-21) >= 10**-(8-1)
653 True
655 """
656 __tracebackhide__ = True # Hide traceback for py.test
657 import numpy as np
659 (actual, desired) = map(float, (actual, desired))
660 if desired == actual:
661 return
662 # Normalized the numbers to be in range (-10.0,10.0)
663 # scale = float(pow(10,math.floor(math.log10(0.5*(abs(desired)+abs(actual))))))
664 with np.errstate(invalid='ignore'):
665 scale = 0.5*(np.abs(desired) + np.abs(actual))
666 scale = np.power(10, np.floor(np.log10(scale)))
667 try:
668 sc_desired = desired/scale
669 except ZeroDivisionError:
670 sc_desired = 0.0
671 try:
672 sc_actual = actual/scale
673 except ZeroDivisionError:
674 sc_actual = 0.0
675 msg = build_err_msg(
676 [actual, desired], err_msg,
677 header='Items are not equal to %d significant digits:' % significant,
678 verbose=verbose)
679 try:
680 # If one of desired/actual is not finite, handle it specially here:
681 # check that both are nan if any is a nan, and test for equality
682 # otherwise
683 if not (gisfinite(desired) and gisfinite(actual)):
684 if gisnan(desired) or gisnan(actual):
685 if not (gisnan(desired) and gisnan(actual)):
686 raise AssertionError(msg)
687 else:
688 if not desired == actual:
689 raise AssertionError(msg)
690 return
691 except (TypeError, NotImplementedError):
692 pass
693 if np.abs(sc_desired - sc_actual) >= np.power(10., -(significant-1)):
694 raise AssertionError(msg)
697def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
698 header='', precision=6, equal_nan=True,
699 equal_inf=True):
700 __tracebackhide__ = True # Hide traceback for py.test
701 from numpy.core import array, array2string, isnan, inf, bool_, errstate, all, max, object_
703 x = array(x, copy=False, subok=True)
704 y = array(y, copy=False, subok=True)
706 # original array for output formatting
707 ox, oy = x, y
709 def isnumber(x):
710 return x.dtype.char in '?bhilqpBHILQPefdgFDG'
712 def istime(x):
713 return x.dtype.char in "Mm"
715 def func_assert_same_pos(x, y, func=isnan, hasval='nan'):
716 """Handling nan/inf.
718 Combine results of running func on x and y, checking that they are True
719 at the same locations.
721 """
722 x_id = func(x)
723 y_id = func(y)
724 # We include work-arounds here to handle three types of slightly
725 # pathological ndarray subclasses:
726 # (1) all() on `masked` array scalars can return masked arrays, so we
727 # use != True
728 # (2) __eq__ on some ndarray subclasses returns Python booleans
729 # instead of element-wise comparisons, so we cast to bool_() and
730 # use isinstance(..., bool) checks
731 # (3) subclasses with bare-bones __array_function__ implementations may
732 # not implement np.all(), so favor using the .all() method
733 # We are not committed to supporting such subclasses, but it's nice to
734 # support them if possible.
735 if bool_(x_id == y_id).all() != True:
736 msg = build_err_msg([x, y],
737 err_msg + '\nx and y %s location mismatch:'
738 % (hasval), verbose=verbose, header=header,
739 names=('x', 'y'), precision=precision)
740 raise AssertionError(msg)
741 # If there is a scalar, then here we know the array has the same
742 # flag as it everywhere, so we should return the scalar flag.
743 if isinstance(x_id, bool) or x_id.ndim == 0:
744 return bool_(x_id)
745 elif isinstance(x_id, bool) or y_id.ndim == 0:
746 return bool_(y_id)
747 else:
748 return y_id
750 try:
751 cond = (x.shape == () or y.shape == ()) or x.shape == y.shape
752 if not cond:
753 msg = build_err_msg([x, y],
754 err_msg
755 + '\n(shapes %s, %s mismatch)' % (x.shape,
756 y.shape),
757 verbose=verbose, header=header,
758 names=('x', 'y'), precision=precision)
759 raise AssertionError(msg)
761 flagged = bool_(False)
762 if isnumber(x) and isnumber(y):
763 if equal_nan:
764 flagged = func_assert_same_pos(x, y, func=isnan, hasval='nan')
766 if equal_inf:
767 flagged |= func_assert_same_pos(x, y,
768 func=lambda xy: xy == +inf,
769 hasval='+inf')
770 flagged |= func_assert_same_pos(x, y,
771 func=lambda xy: xy == -inf,
772 hasval='-inf')
774 elif istime(x) and istime(y):
775 # If one is datetime64 and the other timedelta64 there is no point
776 if equal_nan and x.dtype.type == y.dtype.type:
777 flagged = func_assert_same_pos(x, y, func=isnat, hasval="NaT")
779 if flagged.ndim > 0:
780 x, y = x[~flagged], y[~flagged]
781 # Only do the comparison if actual values are left
782 if x.size == 0:
783 return
784 elif flagged:
785 # no sense doing comparison if everything is flagged.
786 return
788 val = comparison(x, y)
790 if isinstance(val, bool):
791 cond = val
792 reduced = array([val])
793 else:
794 reduced = val.ravel()
795 cond = reduced.all()
797 # The below comparison is a hack to ensure that fully masked
798 # results, for which val.ravel().all() returns np.ma.masked,
799 # do not trigger a failure (np.ma.masked != True evaluates as
800 # np.ma.masked, which is falsy).
801 if cond != True:
802 n_mismatch = reduced.size - reduced.sum(dtype=intp)
803 n_elements = flagged.size if flagged.ndim != 0 else reduced.size
804 percent_mismatch = 100 * n_mismatch / n_elements
805 remarks = [
806 'Mismatched elements: {} / {} ({:.3g}%)'.format(
807 n_mismatch, n_elements, percent_mismatch)]
809 with errstate(invalid='ignore', divide='ignore'):
810 # ignore errors for non-numeric types
811 with contextlib.suppress(TypeError):
812 error = abs(x - y)
813 max_abs_error = max(error)
814 if getattr(error, 'dtype', object_) == object_:
815 remarks.append('Max absolute difference: '
816 + str(max_abs_error))
817 else:
818 remarks.append('Max absolute difference: '
819 + array2string(max_abs_error))
821 # note: this definition of relative error matches that one
822 # used by assert_allclose (found in np.isclose)
823 # Filter values where the divisor would be zero
824 nonzero = bool_(y != 0)
825 if all(~nonzero):
826 max_rel_error = array(inf)
827 else:
828 max_rel_error = max(error[nonzero] / abs(y[nonzero]))
829 if getattr(error, 'dtype', object_) == object_:
830 remarks.append('Max relative difference: '
831 + str(max_rel_error))
832 else:
833 remarks.append('Max relative difference: '
834 + array2string(max_rel_error))
836 err_msg += '\n' + '\n'.join(remarks)
837 msg = build_err_msg([ox, oy], err_msg,
838 verbose=verbose, header=header,
839 names=('x', 'y'), precision=precision)
840 raise AssertionError(msg)
841 except ValueError:
842 import traceback
843 efmt = traceback.format_exc()
844 header = 'error during assertion:\n\n%s\n\n%s' % (efmt, header)
846 msg = build_err_msg([x, y], err_msg, verbose=verbose, header=header,
847 names=('x', 'y'), precision=precision)
848 raise ValueError(msg)
851def assert_array_equal(x, y, err_msg='', verbose=True):
852 """
853 Raises an AssertionError if two array_like objects are not equal.
855 Given two array_like objects, check that the shape is equal and all
856 elements of these objects are equal (but see the Notes for the special
857 handling of a scalar). An exception is raised at shape mismatch or
858 conflicting values. In contrast to the standard usage in numpy, NaNs
859 are compared like numbers, no assertion is raised if both objects have
860 NaNs in the same positions.
862 The usual caution for verifying equality with floating point numbers is
863 advised.
865 Parameters
866 ----------
867 x : array_like
868 The actual object to check.
869 y : array_like
870 The desired, expected object.
871 err_msg : str, optional
872 The error message to be printed in case of failure.
873 verbose : bool, optional
874 If True, the conflicting values are appended to the error message.
876 Raises
877 ------
878 AssertionError
879 If actual and desired objects are not equal.
881 See Also
882 --------
883 assert_allclose: Compare two array_like objects for equality with desired
884 relative and/or absolute precision.
885 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal
887 Notes
888 -----
889 When one of `x` and `y` is a scalar and the other is array_like, the
890 function checks that each element of the array_like object is equal to
891 the scalar.
893 Examples
894 --------
895 The first assert does not raise an exception:
897 >>> np.testing.assert_array_equal([1.0,2.33333,np.nan],
898 ... [np.exp(0),2.33333, np.nan])
900 Assert fails with numerical imprecision with floats:
902 >>> np.testing.assert_array_equal([1.0,np.pi,np.nan],
903 ... [1, np.sqrt(np.pi)**2, np.nan])
904 Traceback (most recent call last):
905 ...
906 AssertionError:
907 Arrays are not equal
908 <BLANKLINE>
909 Mismatched elements: 1 / 3 (33.3%)
910 Max absolute difference: 4.4408921e-16
911 Max relative difference: 1.41357986e-16
912 x: array([1. , 3.141593, nan])
913 y: array([1. , 3.141593, nan])
915 Use `assert_allclose` or one of the nulp (number of floating point values)
916 functions for these cases instead:
918 >>> np.testing.assert_allclose([1.0,np.pi,np.nan],
919 ... [1, np.sqrt(np.pi)**2, np.nan],
920 ... rtol=1e-10, atol=0)
922 As mentioned in the Notes section, `assert_array_equal` has special
923 handling for scalars. Here the test checks that each value in `x` is 3:
925 >>> x = np.full((2, 5), fill_value=3)
926 >>> np.testing.assert_array_equal(x, 3)
928 """
929 __tracebackhide__ = True # Hide traceback for py.test
930 assert_array_compare(operator.__eq__, x, y, err_msg=err_msg,
931 verbose=verbose, header='Arrays are not equal')
934def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
935 """
936 Raises an AssertionError if two objects are not equal up to desired
937 precision.
939 .. note:: It is recommended to use one of `assert_allclose`,
940 `assert_array_almost_equal_nulp` or `assert_array_max_ulp`
941 instead of this function for more consistent floating point
942 comparisons.
944 The test verifies identical shapes and that the elements of ``actual`` and
945 ``desired`` satisfy.
947 ``abs(desired-actual) < 1.5 * 10**(-decimal)``
949 That is a looser test than originally documented, but agrees with what the
950 actual implementation did up to rounding vagaries. An exception is raised
951 at shape mismatch or conflicting values. In contrast to the standard usage
952 in numpy, NaNs are compared like numbers, no assertion is raised if both
953 objects have NaNs in the same positions.
955 Parameters
956 ----------
957 x : array_like
958 The actual object to check.
959 y : array_like
960 The desired, expected object.
961 decimal : int, optional
962 Desired precision, default is 6.
963 err_msg : str, optional
964 The error message to be printed in case of failure.
965 verbose : bool, optional
966 If True, the conflicting values are appended to the error message.
968 Raises
969 ------
970 AssertionError
971 If actual and desired are not equal up to specified precision.
973 See Also
974 --------
975 assert_allclose: Compare two array_like objects for equality with desired
976 relative and/or absolute precision.
977 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal
979 Examples
980 --------
981 the first assert does not raise an exception
983 >>> np.testing.assert_array_almost_equal([1.0,2.333,np.nan],
984 ... [1.0,2.333,np.nan])
986 >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan],
987 ... [1.0,2.33339,np.nan], decimal=5)
988 Traceback (most recent call last):
989 ...
990 AssertionError:
991 Arrays are not almost equal to 5 decimals
992 <BLANKLINE>
993 Mismatched elements: 1 / 3 (33.3%)
994 Max absolute difference: 6.e-05
995 Max relative difference: 2.57136612e-05
996 x: array([1. , 2.33333, nan])
997 y: array([1. , 2.33339, nan])
999 >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan],
1000 ... [1.0,2.33333, 5], decimal=5)
1001 Traceback (most recent call last):
1002 ...
1003 AssertionError:
1004 Arrays are not almost equal to 5 decimals
1005 <BLANKLINE>
1006 x and y nan location mismatch:
1007 x: array([1. , 2.33333, nan])
1008 y: array([1. , 2.33333, 5. ])
1010 """
1011 __tracebackhide__ = True # Hide traceback for py.test
1012 from numpy.core import number, float_, result_type, array
1013 from numpy.core.numerictypes import issubdtype
1014 from numpy.core.fromnumeric import any as npany
1016 def compare(x, y):
1017 try:
1018 if npany(gisinf(x)) or npany( gisinf(y)):
1019 xinfid = gisinf(x)
1020 yinfid = gisinf(y)
1021 if not (xinfid == yinfid).all():
1022 return False
1023 # if one item, x and y is +- inf
1024 if x.size == y.size == 1:
1025 return x == y
1026 x = x[~xinfid]
1027 y = y[~yinfid]
1028 except (TypeError, NotImplementedError):
1029 pass
1031 # make sure y is an inexact type to avoid abs(MIN_INT); will cause
1032 # casting of x later.
1033 dtype = result_type(y, 1.)
1034 y = array(y, dtype=dtype, copy=False, subok=True)
1035 z = abs(x - y)
1037 if not issubdtype(z.dtype, number):
1038 z = z.astype(float_) # handle object arrays
1040 return z < 1.5 * 10.0**(-decimal)
1042 assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
1043 header=('Arrays are not almost equal to %d decimals' % decimal),
1044 precision=decimal)
1047def assert_array_less(x, y, err_msg='', verbose=True):
1048 """
1049 Raises an AssertionError if two array_like objects are not ordered by less
1050 than.
1052 Given two array_like objects, check that the shape is equal and all
1053 elements of the first object are strictly smaller than those of the
1054 second object. An exception is raised at shape mismatch or incorrectly
1055 ordered values. Shape mismatch does not raise if an object has zero
1056 dimension. In contrast to the standard usage in numpy, NaNs are
1057 compared, no assertion is raised if both objects have NaNs in the same
1058 positions.
1062 Parameters
1063 ----------
1064 x : array_like
1065 The smaller object to check.
1066 y : array_like
1067 The larger object to compare.
1068 err_msg : string
1069 The error message to be printed in case of failure.
1070 verbose : bool
1071 If True, the conflicting values are appended to the error message.
1073 Raises
1074 ------
1075 AssertionError
1076 If actual and desired objects are not equal.
1078 See Also
1079 --------
1080 assert_array_equal: tests objects for equality
1081 assert_array_almost_equal: test objects for equality up to precision
1085 Examples
1086 --------
1087 >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1.1, 2.0, np.nan])
1088 >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1, 2.0, np.nan])
1089 Traceback (most recent call last):
1090 ...
1091 AssertionError:
1092 Arrays are not less-ordered
1093 <BLANKLINE>
1094 Mismatched elements: 1 / 3 (33.3%)
1095 Max absolute difference: 1.
1096 Max relative difference: 0.5
1097 x: array([ 1., 1., nan])
1098 y: array([ 1., 2., nan])
1100 >>> np.testing.assert_array_less([1.0, 4.0], 3)
1101 Traceback (most recent call last):
1102 ...
1103 AssertionError:
1104 Arrays are not less-ordered
1105 <BLANKLINE>
1106 Mismatched elements: 1 / 2 (50%)
1107 Max absolute difference: 2.
1108 Max relative difference: 0.66666667
1109 x: array([1., 4.])
1110 y: array(3)
1112 >>> np.testing.assert_array_less([1.0, 2.0, 3.0], [4])
1113 Traceback (most recent call last):
1114 ...
1115 AssertionError:
1116 Arrays are not less-ordered
1117 <BLANKLINE>
1118 (shapes (3,), (1,) mismatch)
1119 x: array([1., 2., 3.])
1120 y: array([4])
1122 """
1123 __tracebackhide__ = True # Hide traceback for py.test
1124 assert_array_compare(operator.__lt__, x, y, err_msg=err_msg,
1125 verbose=verbose,
1126 header='Arrays are not less-ordered',
1127 equal_inf=False)
1130def runstring(astr, dict):
1131 exec(astr, dict)
1134def assert_string_equal(actual, desired):
1135 """
1136 Test if two strings are equal.
1138 If the given strings are equal, `assert_string_equal` does nothing.
1139 If they are not equal, an AssertionError is raised, and the diff
1140 between the strings is shown.
1142 Parameters
1143 ----------
1144 actual : str
1145 The string to test for equality against the expected string.
1146 desired : str
1147 The expected string.
1149 Examples
1150 --------
1151 >>> np.testing.assert_string_equal('abc', 'abc')
1152 >>> np.testing.assert_string_equal('abc', 'abcd')
1153 Traceback (most recent call last):
1154 File "<stdin>", line 1, in <module>
1155 ...
1156 AssertionError: Differences in strings:
1157 - abc+ abcd? +
1159 """
1160 # delay import of difflib to reduce startup time
1161 __tracebackhide__ = True # Hide traceback for py.test
1162 import difflib
1164 if not isinstance(actual, str):
1165 raise AssertionError(repr(type(actual)))
1166 if not isinstance(desired, str):
1167 raise AssertionError(repr(type(desired)))
1168 if desired == actual:
1169 return
1171 diff = list(difflib.Differ().compare(actual.splitlines(True), desired.splitlines(True)))
1172 diff_list = []
1173 while diff:
1174 d1 = diff.pop(0)
1175 if d1.startswith(' '):
1176 continue
1177 if d1.startswith('- '):
1178 l = [d1]
1179 d2 = diff.pop(0)
1180 if d2.startswith('? '):
1181 l.append(d2)
1182 d2 = diff.pop(0)
1183 if not d2.startswith('+ '):
1184 raise AssertionError(repr(d2))
1185 l.append(d2)
1186 if diff:
1187 d3 = diff.pop(0)
1188 if d3.startswith('? '):
1189 l.append(d3)
1190 else:
1191 diff.insert(0, d3)
1192 if d2[2:] == d1[2:]:
1193 continue
1194 diff_list.extend(l)
1195 continue
1196 raise AssertionError(repr(d1))
1197 if not diff_list:
1198 return
1199 msg = 'Differences in strings:\n%s' % (''.join(diff_list)).rstrip()
1200 if actual != desired:
1201 raise AssertionError(msg)
1204def rundocs(filename=None, raise_on_error=True):
1205 """
1206 Run doctests found in the given file.
1208 By default `rundocs` raises an AssertionError on failure.
1210 Parameters
1211 ----------
1212 filename : str
1213 The path to the file for which the doctests are run.
1214 raise_on_error : bool
1215 Whether to raise an AssertionError when a doctest fails. Default is
1216 True.
1218 Notes
1219 -----
1220 The doctests can be run by the user/developer by adding the ``doctests``
1221 argument to the ``test()`` call. For example, to run all tests (including
1222 doctests) for `numpy.lib`:
1224 >>> np.lib.test(doctests=True) # doctest: +SKIP
1225 """
1226 from numpy.compat import npy_load_module
1227 import doctest
1228 if filename is None:
1229 f = sys._getframe(1)
1230 filename = f.f_globals['__file__']
1231 name = os.path.splitext(os.path.basename(filename))[0]
1232 m = npy_load_module(name, filename)
1234 tests = doctest.DocTestFinder().find(m)
1235 runner = doctest.DocTestRunner(verbose=False)
1237 msg = []
1238 if raise_on_error:
1239 out = lambda s: msg.append(s)
1240 else:
1241 out = None
1243 for test in tests:
1244 runner.run(test, out=out)
1246 if runner.failures > 0 and raise_on_error:
1247 raise AssertionError("Some doctests failed:\n%s" % "\n".join(msg))
1250def raises(*args):
1251 """Decorator to check for raised exceptions.
1253 The decorated test function must raise one of the passed exceptions to
1254 pass. If you want to test many assertions about exceptions in a single
1255 test, you may want to use `assert_raises` instead.
1257 .. warning::
1258 This decorator is nose specific, do not use it if you are using a
1259 different test framework.
1261 Parameters
1262 ----------
1263 args : exceptions
1264 The test passes if any of the passed exceptions is raised.
1266 Raises
1267 ------
1268 AssertionError
1270 Examples
1271 --------
1273 Usage::
1275 @raises(TypeError, ValueError)
1276 def test_raises_type_error():
1277 raise TypeError("This test passes")
1279 @raises(Exception)
1280 def test_that_fails_by_passing():
1281 pass
1283 """
1284 nose = import_nose()
1285 return nose.tools.raises(*args)
1287#
1288# assert_raises and assert_raises_regex are taken from unittest.
1289#
1290import unittest
1293class _Dummy(unittest.TestCase):
1294 def nop(self):
1295 pass
1297_d = _Dummy('nop')
1299def assert_raises(*args, **kwargs):
1300 """
1301 assert_raises(exception_class, callable, *args, **kwargs)
1302 assert_raises(exception_class)
1304 Fail unless an exception of class exception_class is thrown
1305 by callable when invoked with arguments args and keyword
1306 arguments kwargs. If a different type of exception is
1307 thrown, it will not be caught, and the test case will be
1308 deemed to have suffered an error, exactly as for an
1309 unexpected exception.
1311 Alternatively, `assert_raises` can be used as a context manager:
1313 >>> from numpy.testing import assert_raises
1314 >>> with assert_raises(ZeroDivisionError):
1315 ... 1 / 0
1317 is equivalent to
1319 >>> def div(x, y):
1320 ... return x / y
1321 >>> assert_raises(ZeroDivisionError, div, 1, 0)
1323 """
1324 __tracebackhide__ = True # Hide traceback for py.test
1325 return _d.assertRaises(*args,**kwargs)
1328def assert_raises_regex(exception_class, expected_regexp, *args, **kwargs):
1329 """
1330 assert_raises_regex(exception_class, expected_regexp, callable, *args,
1331 **kwargs)
1332 assert_raises_regex(exception_class, expected_regexp)
1334 Fail unless an exception of class exception_class and with message that
1335 matches expected_regexp is thrown by callable when invoked with arguments
1336 args and keyword arguments kwargs.
1338 Alternatively, can be used as a context manager like `assert_raises`.
1340 Name of this function adheres to Python 3.2+ reference, but should work in
1341 all versions down to 2.6.
1343 Notes
1344 -----
1345 .. versionadded:: 1.9.0
1347 """
1348 __tracebackhide__ = True # Hide traceback for py.test
1349 return _d.assertRaisesRegex(exception_class, expected_regexp, *args, **kwargs)
1352def decorate_methods(cls, decorator, testmatch=None):
1353 """
1354 Apply a decorator to all methods in a class matching a regular expression.
1356 The given decorator is applied to all public methods of `cls` that are
1357 matched by the regular expression `testmatch`
1358 (``testmatch.search(methodname)``). Methods that are private, i.e. start
1359 with an underscore, are ignored.
1361 Parameters
1362 ----------
1363 cls : class
1364 Class whose methods to decorate.
1365 decorator : function
1366 Decorator to apply to methods
1367 testmatch : compiled regexp or str, optional
1368 The regular expression. Default value is None, in which case the
1369 nose default (``re.compile(r'(?:^|[\\b_\\.%s-])[Tt]est' % os.sep)``)
1370 is used.
1371 If `testmatch` is a string, it is compiled to a regular expression
1372 first.
1374 """
1375 if testmatch is None:
1376 testmatch = re.compile(r'(?:^|[\\b_\\.%s-])[Tt]est' % os.sep)
1377 else:
1378 testmatch = re.compile(testmatch)
1379 cls_attr = cls.__dict__
1381 # delayed import to reduce startup time
1382 from inspect import isfunction
1384 methods = [_m for _m in cls_attr.values() if isfunction(_m)]
1385 for function in methods:
1386 try:
1387 if hasattr(function, 'compat_func_name'):
1388 funcname = function.compat_func_name
1389 else:
1390 funcname = function.__name__
1391 except AttributeError:
1392 # not a function
1393 continue
1394 if testmatch.search(funcname) and not funcname.startswith('_'):
1395 setattr(cls, funcname, decorator(function))
1396 return
1399def measure(code_str, times=1, label=None):
1400 """
1401 Return elapsed time for executing code in the namespace of the caller.
1403 The supplied code string is compiled with the Python builtin ``compile``.
1404 The precision of the timing is 10 milli-seconds. If the code will execute
1405 fast on this timescale, it can be executed many times to get reasonable
1406 timing accuracy.
1408 Parameters
1409 ----------
1410 code_str : str
1411 The code to be timed.
1412 times : int, optional
1413 The number of times the code is executed. Default is 1. The code is
1414 only compiled once.
1415 label : str, optional
1416 A label to identify `code_str` with. This is passed into ``compile``
1417 as the second argument (for run-time error messages).
1419 Returns
1420 -------
1421 elapsed : float
1422 Total elapsed time in seconds for executing `code_str` `times` times.
1424 Examples
1425 --------
1426 >>> times = 10
1427 >>> etime = np.testing.measure('for i in range(1000): np.sqrt(i**2)', times=times)
1428 >>> print("Time for a single execution : ", etime / times, "s") # doctest: +SKIP
1429 Time for a single execution : 0.005 s
1431 """
1432 frame = sys._getframe(1)
1433 locs, globs = frame.f_locals, frame.f_globals
1435 code = compile(code_str,
1436 'Test name: %s ' % label,
1437 'exec')
1438 i = 0
1439 elapsed = jiffies()
1440 while i < times:
1441 i += 1
1442 exec(code, globs, locs)
1443 elapsed = jiffies() - elapsed
1444 return 0.01*elapsed
1447def _assert_valid_refcount(op):
1448 """
1449 Check that ufuncs don't mishandle refcount of object `1`.
1450 Used in a few regression tests.
1451 """
1452 if not HAS_REFCOUNT:
1453 return True
1455 import gc
1456 import numpy as np
1458 b = np.arange(100*100).reshape(100, 100)
1459 c = b
1460 i = 1
1462 gc.disable()
1463 try:
1464 rc = sys.getrefcount(i)
1465 for j in range(15):
1466 d = op(b, c)
1467 assert_(sys.getrefcount(i) >= rc)
1468 finally:
1469 gc.enable()
1470 del d # for pyflakes
1473def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True,
1474 err_msg='', verbose=True):
1475 """
1476 Raises an AssertionError if two objects are not equal up to desired
1477 tolerance.
1479 The test is equivalent to ``allclose(actual, desired, rtol, atol)`` (note
1480 that ``allclose`` has different default values). It compares the difference
1481 between `actual` and `desired` to ``atol + rtol * abs(desired)``.
1483 .. versionadded:: 1.5.0
1485 Parameters
1486 ----------
1487 actual : array_like
1488 Array obtained.
1489 desired : array_like
1490 Array desired.
1491 rtol : float, optional
1492 Relative tolerance.
1493 atol : float, optional
1494 Absolute tolerance.
1495 equal_nan : bool, optional.
1496 If True, NaNs will compare equal.
1497 err_msg : str, optional
1498 The error message to be printed in case of failure.
1499 verbose : bool, optional
1500 If True, the conflicting values are appended to the error message.
1502 Raises
1503 ------
1504 AssertionError
1505 If actual and desired are not equal up to specified precision.
1507 See Also
1508 --------
1509 assert_array_almost_equal_nulp, assert_array_max_ulp
1511 Examples
1512 --------
1513 >>> x = [1e-5, 1e-3, 1e-1]
1514 >>> y = np.arccos(np.cos(x))
1515 >>> np.testing.assert_allclose(x, y, rtol=1e-5, atol=0)
1517 """
1518 __tracebackhide__ = True # Hide traceback for py.test
1519 import numpy as np
1521 def compare(x, y):
1522 return np.core.numeric.isclose(x, y, rtol=rtol, atol=atol,
1523 equal_nan=equal_nan)
1525 actual, desired = np.asanyarray(actual), np.asanyarray(desired)
1526 header = 'Not equal to tolerance rtol=%g, atol=%g' % (rtol, atol)
1527 assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
1528 verbose=verbose, header=header, equal_nan=equal_nan)
1531def assert_array_almost_equal_nulp(x, y, nulp=1):
1532 """
1533 Compare two arrays relatively to their spacing.
1535 This is a relatively robust method to compare two arrays whose amplitude
1536 is variable.
1538 Parameters
1539 ----------
1540 x, y : array_like
1541 Input arrays.
1542 nulp : int, optional
1543 The maximum number of unit in the last place for tolerance (see Notes).
1544 Default is 1.
1546 Returns
1547 -------
1548 None
1550 Raises
1551 ------
1552 AssertionError
1553 If the spacing between `x` and `y` for one or more elements is larger
1554 than `nulp`.
1556 See Also
1557 --------
1558 assert_array_max_ulp : Check that all items of arrays differ in at most
1559 N Units in the Last Place.
1560 spacing : Return the distance between x and the nearest adjacent number.
1562 Notes
1563 -----
1564 An assertion is raised if the following condition is not met::
1566 abs(x - y) <= nulps * spacing(maximum(abs(x), abs(y)))
1568 Examples
1569 --------
1570 >>> x = np.array([1., 1e-10, 1e-20])
1571 >>> eps = np.finfo(x.dtype).eps
1572 >>> np.testing.assert_array_almost_equal_nulp(x, x*eps/2 + x)
1574 >>> np.testing.assert_array_almost_equal_nulp(x, x*eps + x)
1575 Traceback (most recent call last):
1576 ...
1577 AssertionError: X and Y are not equal to 1 ULP (max is 2)
1579 """
1580 __tracebackhide__ = True # Hide traceback for py.test
1581 import numpy as np
1582 ax = np.abs(x)
1583 ay = np.abs(y)
1584 ref = nulp * np.spacing(np.where(ax > ay, ax, ay))
1585 if not np.all(np.abs(x-y) <= ref):
1586 if np.iscomplexobj(x) or np.iscomplexobj(y):
1587 msg = "X and Y are not equal to %d ULP" % nulp
1588 else:
1589 max_nulp = np.max(nulp_diff(x, y))
1590 msg = "X and Y are not equal to %d ULP (max is %g)" % (nulp, max_nulp)
1591 raise AssertionError(msg)
1594def assert_array_max_ulp(a, b, maxulp=1, dtype=None):
1595 """
1596 Check that all items of arrays differ in at most N Units in the Last Place.
1598 Parameters
1599 ----------
1600 a, b : array_like
1601 Input arrays to be compared.
1602 maxulp : int, optional
1603 The maximum number of units in the last place that elements of `a` and
1604 `b` can differ. Default is 1.
1605 dtype : dtype, optional
1606 Data-type to convert `a` and `b` to if given. Default is None.
1608 Returns
1609 -------
1610 ret : ndarray
1611 Array containing number of representable floating point numbers between
1612 items in `a` and `b`.
1614 Raises
1615 ------
1616 AssertionError
1617 If one or more elements differ by more than `maxulp`.
1619 Notes
1620 -----
1621 For computing the ULP difference, this API does not differentiate between
1622 various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000
1623 is zero).
1625 See Also
1626 --------
1627 assert_array_almost_equal_nulp : Compare two arrays relatively to their
1628 spacing.
1630 Examples
1631 --------
1632 >>> a = np.linspace(0., 1., 100)
1633 >>> res = np.testing.assert_array_max_ulp(a, np.arcsin(np.sin(a)))
1635 """
1636 __tracebackhide__ = True # Hide traceback for py.test
1637 import numpy as np
1638 ret = nulp_diff(a, b, dtype)
1639 if not np.all(ret <= maxulp):
1640 raise AssertionError("Arrays are not almost equal up to %g "
1641 "ULP (max difference is %g ULP)" %
1642 (maxulp, np.max(ret)))
1643 return ret
1646def nulp_diff(x, y, dtype=None):
1647 """For each item in x and y, return the number of representable floating
1648 points between them.
1650 Parameters
1651 ----------
1652 x : array_like
1653 first input array
1654 y : array_like
1655 second input array
1656 dtype : dtype, optional
1657 Data-type to convert `x` and `y` to if given. Default is None.
1659 Returns
1660 -------
1661 nulp : array_like
1662 number of representable floating point numbers between each item in x
1663 and y.
1665 Notes
1666 -----
1667 For computing the ULP difference, this API does not differentiate between
1668 various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000
1669 is zero).
1671 Examples
1672 --------
1673 # By definition, epsilon is the smallest number such as 1 + eps != 1, so
1674 # there should be exactly one ULP between 1 and 1 + eps
1675 >>> nulp_diff(1, 1 + np.finfo(x.dtype).eps)
1676 1.0
1677 """
1678 import numpy as np
1679 if dtype:
1680 x = np.array(x, dtype=dtype)
1681 y = np.array(y, dtype=dtype)
1682 else:
1683 x = np.array(x)
1684 y = np.array(y)
1686 t = np.common_type(x, y)
1687 if np.iscomplexobj(x) or np.iscomplexobj(y):
1688 raise NotImplementedError("_nulp not implemented for complex array")
1690 x = np.array([x], dtype=t)
1691 y = np.array([y], dtype=t)
1693 x[np.isnan(x)] = np.nan
1694 y[np.isnan(y)] = np.nan
1696 if not x.shape == y.shape:
1697 raise ValueError("x and y do not have the same shape: %s - %s" %
1698 (x.shape, y.shape))
1700 def _diff(rx, ry, vdt):
1701 diff = np.array(rx-ry, dtype=vdt)
1702 return np.abs(diff)
1704 rx = integer_repr(x)
1705 ry = integer_repr(y)
1706 return _diff(rx, ry, t)
1709def _integer_repr(x, vdt, comp):
1710 # Reinterpret binary representation of the float as sign-magnitude:
1711 # take into account two-complement representation
1712 # See also
1713 # https://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/
1714 rx = x.view(vdt)
1715 if not (rx.size == 1):
1716 rx[rx < 0] = comp - rx[rx < 0]
1717 else:
1718 if rx < 0:
1719 rx = comp - rx
1721 return rx
1724def integer_repr(x):
1725 """Return the signed-magnitude interpretation of the binary representation of
1726 x."""
1727 import numpy as np
1728 if x.dtype == np.float16:
1729 return _integer_repr(x, np.int16, np.int16(-2**15))
1730 elif x.dtype == np.float32:
1731 return _integer_repr(x, np.int32, np.int32(-2**31))
1732 elif x.dtype == np.float64:
1733 return _integer_repr(x, np.int64, np.int64(-2**63))
1734 else:
1735 raise ValueError("Unsupported dtype %s" % x.dtype)
1738@contextlib.contextmanager
1739def _assert_warns_context(warning_class, name=None):
1740 __tracebackhide__ = True # Hide traceback for py.test
1741 with suppress_warnings() as sup:
1742 l = sup.record(warning_class)
1743 yield
1744 if not len(l) > 0:
1745 name_str = " when calling %s" % name if name is not None else ""
1746 raise AssertionError("No warning raised" + name_str)
1749def assert_warns(warning_class, *args, **kwargs):
1750 """
1751 Fail unless the given callable throws the specified warning.
1753 A warning of class warning_class should be thrown by the callable when
1754 invoked with arguments args and keyword arguments kwargs.
1755 If a different type of warning is thrown, it will not be caught.
1757 If called with all arguments other than the warning class omitted, may be
1758 used as a context manager:
1760 with assert_warns(SomeWarning):
1761 do_something()
1763 The ability to be used as a context manager is new in NumPy v1.11.0.
1765 .. versionadded:: 1.4.0
1767 Parameters
1768 ----------
1769 warning_class : class
1770 The class defining the warning that `func` is expected to throw.
1771 func : callable
1772 The callable to test.
1773 \\*args : Arguments
1774 Arguments passed to `func`.
1775 \\*\\*kwargs : Kwargs
1776 Keyword arguments passed to `func`.
1778 Returns
1779 -------
1780 The value returned by `func`.
1782 """
1783 if not args:
1784 return _assert_warns_context(warning_class)
1786 func = args[0]
1787 args = args[1:]
1788 with _assert_warns_context(warning_class, name=func.__name__):
1789 return func(*args, **kwargs)
1792@contextlib.contextmanager
1793def _assert_no_warnings_context(name=None):
1794 __tracebackhide__ = True # Hide traceback for py.test
1795 with warnings.catch_warnings(record=True) as l:
1796 warnings.simplefilter('always')
1797 yield
1798 if len(l) > 0:
1799 name_str = " when calling %s" % name if name is not None else ""
1800 raise AssertionError("Got warnings%s: %s" % (name_str, l))
1803def assert_no_warnings(*args, **kwargs):
1804 """
1805 Fail if the given callable produces any warnings.
1807 If called with all arguments omitted, may be used as a context manager:
1809 with assert_no_warnings():
1810 do_something()
1812 The ability to be used as a context manager is new in NumPy v1.11.0.
1814 .. versionadded:: 1.7.0
1816 Parameters
1817 ----------
1818 func : callable
1819 The callable to test.
1820 \\*args : Arguments
1821 Arguments passed to `func`.
1822 \\*\\*kwargs : Kwargs
1823 Keyword arguments passed to `func`.
1825 Returns
1826 -------
1827 The value returned by `func`.
1829 """
1830 if not args:
1831 return _assert_no_warnings_context()
1833 func = args[0]
1834 args = args[1:]
1835 with _assert_no_warnings_context(name=func.__name__):
1836 return func(*args, **kwargs)
1839def _gen_alignment_data(dtype=float32, type='binary', max_size=24):
1840 """
1841 generator producing data with different alignment and offsets
1842 to test simd vectorization
1844 Parameters
1845 ----------
1846 dtype : dtype
1847 data type to produce
1848 type : string
1849 'unary': create data for unary operations, creates one input
1850 and output array
1851 'binary': create data for unary operations, creates two input
1852 and output array
1853 max_size : integer
1854 maximum size of data to produce
1856 Returns
1857 -------
1858 if type is 'unary' yields one output, one input array and a message
1859 containing information on the data
1860 if type is 'binary' yields one output array, two input array and a message
1861 containing information on the data
1863 """
1864 ufmt = 'unary offset=(%d, %d), size=%d, dtype=%r, %s'
1865 bfmt = 'binary offset=(%d, %d, %d), size=%d, dtype=%r, %s'
1866 for o in range(3):
1867 for s in range(o + 2, max(o + 3, max_size)):
1868 if type == 'unary':
1869 inp = lambda: arange(s, dtype=dtype)[o:]
1870 out = empty((s,), dtype=dtype)[o:]
1871 yield out, inp(), ufmt % (o, o, s, dtype, 'out of place')
1872 d = inp()
1873 yield d, d, ufmt % (o, o, s, dtype, 'in place')
1874 yield out[1:], inp()[:-1], ufmt % \
1875 (o + 1, o, s - 1, dtype, 'out of place')
1876 yield out[:-1], inp()[1:], ufmt % \
1877 (o, o + 1, s - 1, dtype, 'out of place')
1878 yield inp()[:-1], inp()[1:], ufmt % \
1879 (o, o + 1, s - 1, dtype, 'aliased')
1880 yield inp()[1:], inp()[:-1], ufmt % \
1881 (o + 1, o, s - 1, dtype, 'aliased')
1882 if type == 'binary':
1883 inp1 = lambda: arange(s, dtype=dtype)[o:]
1884 inp2 = lambda: arange(s, dtype=dtype)[o:]
1885 out = empty((s,), dtype=dtype)[o:]
1886 yield out, inp1(), inp2(), bfmt % \
1887 (o, o, o, s, dtype, 'out of place')
1888 d = inp1()
1889 yield d, d, inp2(), bfmt % \
1890 (o, o, o, s, dtype, 'in place1')
1891 d = inp2()
1892 yield d, inp1(), d, bfmt % \
1893 (o, o, o, s, dtype, 'in place2')
1894 yield out[1:], inp1()[:-1], inp2()[:-1], bfmt % \
1895 (o + 1, o, o, s - 1, dtype, 'out of place')
1896 yield out[:-1], inp1()[1:], inp2()[:-1], bfmt % \
1897 (o, o + 1, o, s - 1, dtype, 'out of place')
1898 yield out[:-1], inp1()[:-1], inp2()[1:], bfmt % \
1899 (o, o, o + 1, s - 1, dtype, 'out of place')
1900 yield inp1()[1:], inp1()[:-1], inp2()[:-1], bfmt % \
1901 (o + 1, o, o, s - 1, dtype, 'aliased')
1902 yield inp1()[:-1], inp1()[1:], inp2()[:-1], bfmt % \
1903 (o, o + 1, o, s - 1, dtype, 'aliased')
1904 yield inp1()[:-1], inp1()[:-1], inp2()[1:], bfmt % \
1905 (o, o, o + 1, s - 1, dtype, 'aliased')
1908class IgnoreException(Exception):
1909 "Ignoring this exception due to disabled feature"
1910 pass
1913@contextlib.contextmanager
1914def tempdir(*args, **kwargs):
1915 """Context manager to provide a temporary test folder.
1917 All arguments are passed as this to the underlying tempfile.mkdtemp
1918 function.
1920 """
1921 tmpdir = mkdtemp(*args, **kwargs)
1922 try:
1923 yield tmpdir
1924 finally:
1925 shutil.rmtree(tmpdir)
1928@contextlib.contextmanager
1929def temppath(*args, **kwargs):
1930 """Context manager for temporary files.
1932 Context manager that returns the path to a closed temporary file. Its
1933 parameters are the same as for tempfile.mkstemp and are passed directly
1934 to that function. The underlying file is removed when the context is
1935 exited, so it should be closed at that time.
1937 Windows does not allow a temporary file to be opened if it is already
1938 open, so the underlying file must be closed after opening before it
1939 can be opened again.
1941 """
1942 fd, path = mkstemp(*args, **kwargs)
1943 os.close(fd)
1944 try:
1945 yield path
1946 finally:
1947 os.remove(path)
1950class clear_and_catch_warnings(warnings.catch_warnings):
1951 """ Context manager that resets warning registry for catching warnings
1953 Warnings can be slippery, because, whenever a warning is triggered, Python
1954 adds a ``__warningregistry__`` member to the *calling* module. This makes
1955 it impossible to retrigger the warning in this module, whatever you put in
1956 the warnings filters. This context manager accepts a sequence of `modules`
1957 as a keyword argument to its constructor and:
1959 * stores and removes any ``__warningregistry__`` entries in given `modules`
1960 on entry;
1961 * resets ``__warningregistry__`` to its previous state on exit.
1963 This makes it possible to trigger any warning afresh inside the context
1964 manager without disturbing the state of warnings outside.
1966 For compatibility with Python 3.0, please consider all arguments to be
1967 keyword-only.
1969 Parameters
1970 ----------
1971 record : bool, optional
1972 Specifies whether warnings should be captured by a custom
1973 implementation of ``warnings.showwarning()`` and be appended to a list
1974 returned by the context manager. Otherwise None is returned by the
1975 context manager. The objects appended to the list are arguments whose
1976 attributes mirror the arguments to ``showwarning()``.
1977 modules : sequence, optional
1978 Sequence of modules for which to reset warnings registry on entry and
1979 restore on exit. To work correctly, all 'ignore' filters should
1980 filter by one of these modules.
1982 Examples
1983 --------
1984 >>> import warnings
1985 >>> with np.testing.clear_and_catch_warnings(
1986 ... modules=[np.core.fromnumeric]):
1987 ... warnings.simplefilter('always')
1988 ... warnings.filterwarnings('ignore', module='np.core.fromnumeric')
1989 ... # do something that raises a warning but ignore those in
1990 ... # np.core.fromnumeric
1991 """
1992 class_modules = ()
1994 def __init__(self, record=False, modules=()):
1995 self.modules = set(modules).union(self.class_modules)
1996 self._warnreg_copies = {}
1997 super(clear_and_catch_warnings, self).__init__(record=record)
1999 def __enter__(self):
2000 for mod in self.modules:
2001 if hasattr(mod, '__warningregistry__'):
2002 mod_reg = mod.__warningregistry__
2003 self._warnreg_copies[mod] = mod_reg.copy()
2004 mod_reg.clear()
2005 return super(clear_and_catch_warnings, self).__enter__()
2007 def __exit__(self, *exc_info):
2008 super(clear_and_catch_warnings, self).__exit__(*exc_info)
2009 for mod in self.modules:
2010 if hasattr(mod, '__warningregistry__'):
2011 mod.__warningregistry__.clear()
2012 if mod in self._warnreg_copies:
2013 mod.__warningregistry__.update(self._warnreg_copies[mod])
2016class suppress_warnings:
2017 """
2018 Context manager and decorator doing much the same as
2019 ``warnings.catch_warnings``.
2021 However, it also provides a filter mechanism to work around
2022 https://bugs.python.org/issue4180.
2024 This bug causes Python before 3.4 to not reliably show warnings again
2025 after they have been ignored once (even within catch_warnings). It
2026 means that no "ignore" filter can be used easily, since following
2027 tests might need to see the warning. Additionally it allows easier
2028 specificity for testing warnings and can be nested.
2030 Parameters
2031 ----------
2032 forwarding_rule : str, optional
2033 One of "always", "once", "module", or "location". Analogous to
2034 the usual warnings module filter mode, it is useful to reduce
2035 noise mostly on the outmost level. Unsuppressed and unrecorded
2036 warnings will be forwarded based on this rule. Defaults to "always".
2037 "location" is equivalent to the warnings "default", match by exact
2038 location the warning warning originated from.
2040 Notes
2041 -----
2042 Filters added inside the context manager will be discarded again
2043 when leaving it. Upon entering all filters defined outside a
2044 context will be applied automatically.
2046 When a recording filter is added, matching warnings are stored in the
2047 ``log`` attribute as well as in the list returned by ``record``.
2049 If filters are added and the ``module`` keyword is given, the
2050 warning registry of this module will additionally be cleared when
2051 applying it, entering the context, or exiting it. This could cause
2052 warnings to appear a second time after leaving the context if they
2053 were configured to be printed once (default) and were already
2054 printed before the context was entered.
2056 Nesting this context manager will work as expected when the
2057 forwarding rule is "always" (default). Unfiltered and unrecorded
2058 warnings will be passed out and be matched by the outer level.
2059 On the outmost level they will be printed (or caught by another
2060 warnings context). The forwarding rule argument can modify this
2061 behaviour.
2063 Like ``catch_warnings`` this context manager is not threadsafe.
2065 Examples
2066 --------
2068 With a context manager::
2070 with np.testing.suppress_warnings() as sup:
2071 sup.filter(DeprecationWarning, "Some text")
2072 sup.filter(module=np.ma.core)
2073 log = sup.record(FutureWarning, "Does this occur?")
2074 command_giving_warnings()
2075 # The FutureWarning was given once, the filtered warnings were
2076 # ignored. All other warnings abide outside settings (may be
2077 # printed/error)
2078 assert_(len(log) == 1)
2079 assert_(len(sup.log) == 1) # also stored in log attribute
2081 Or as a decorator::
2083 sup = np.testing.suppress_warnings()
2084 sup.filter(module=np.ma.core) # module must match exactly
2085 @sup
2086 def some_function():
2087 # do something which causes a warning in np.ma.core
2088 pass
2089 """
2090 def __init__(self, forwarding_rule="always"):
2091 self._entered = False
2093 # Suppressions are either instance or defined inside one with block:
2094 self._suppressions = []
2096 if forwarding_rule not in {"always", "module", "once", "location"}:
2097 raise ValueError("unsupported forwarding rule.")
2098 self._forwarding_rule = forwarding_rule
2100 def _clear_registries(self):
2101 if hasattr(warnings, "_filters_mutated"):
2102 # clearing the registry should not be necessary on new pythons,
2103 # instead the filters should be mutated.
2104 warnings._filters_mutated()
2105 return
2106 # Simply clear the registry, this should normally be harmless,
2107 # note that on new pythons it would be invalidated anyway.
2108 for module in self._tmp_modules:
2109 if hasattr(module, "__warningregistry__"):
2110 module.__warningregistry__.clear()
2112 def _filter(self, category=Warning, message="", module=None, record=False):
2113 if record:
2114 record = [] # The log where to store warnings
2115 else:
2116 record = None
2117 if self._entered:
2118 if module is None:
2119 warnings.filterwarnings(
2120 "always", category=category, message=message)
2121 else:
2122 module_regex = module.__name__.replace('.', r'\.') + '$'
2123 warnings.filterwarnings(
2124 "always", category=category, message=message,
2125 module=module_regex)
2126 self._tmp_modules.add(module)
2127 self._clear_registries()
2129 self._tmp_suppressions.append(
2130 (category, message, re.compile(message, re.I), module, record))
2131 else:
2132 self._suppressions.append(
2133 (category, message, re.compile(message, re.I), module, record))
2135 return record
2137 def filter(self, category=Warning, message="", module=None):
2138 """
2139 Add a new suppressing filter or apply it if the state is entered.
2141 Parameters
2142 ----------
2143 category : class, optional
2144 Warning class to filter
2145 message : string, optional
2146 Regular expression matching the warning message.
2147 module : module, optional
2148 Module to filter for. Note that the module (and its file)
2149 must match exactly and cannot be a submodule. This may make
2150 it unreliable for external modules.
2152 Notes
2153 -----
2154 When added within a context, filters are only added inside
2155 the context and will be forgotten when the context is exited.
2156 """
2157 self._filter(category=category, message=message, module=module,
2158 record=False)
2160 def record(self, category=Warning, message="", module=None):
2161 """
2162 Append a new recording filter or apply it if the state is entered.
2164 All warnings matching will be appended to the ``log`` attribute.
2166 Parameters
2167 ----------
2168 category : class, optional
2169 Warning class to filter
2170 message : string, optional
2171 Regular expression matching the warning message.
2172 module : module, optional
2173 Module to filter for. Note that the module (and its file)
2174 must match exactly and cannot be a submodule. This may make
2175 it unreliable for external modules.
2177 Returns
2178 -------
2179 log : list
2180 A list which will be filled with all matched warnings.
2182 Notes
2183 -----
2184 When added within a context, filters are only added inside
2185 the context and will be forgotten when the context is exited.
2186 """
2187 return self._filter(category=category, message=message, module=module,
2188 record=True)
2190 def __enter__(self):
2191 if self._entered:
2192 raise RuntimeError("cannot enter suppress_warnings twice.")
2194 self._orig_show = warnings.showwarning
2195 self._filters = warnings.filters
2196 warnings.filters = self._filters[:]
2198 self._entered = True
2199 self._tmp_suppressions = []
2200 self._tmp_modules = set()
2201 self._forwarded = set()
2203 self.log = [] # reset global log (no need to keep same list)
2205 for cat, mess, _, mod, log in self._suppressions:
2206 if log is not None:
2207 del log[:] # clear the log
2208 if mod is None:
2209 warnings.filterwarnings(
2210 "always", category=cat, message=mess)
2211 else:
2212 module_regex = mod.__name__.replace('.', r'\.') + '$'
2213 warnings.filterwarnings(
2214 "always", category=cat, message=mess,
2215 module=module_regex)
2216 self._tmp_modules.add(mod)
2217 warnings.showwarning = self._showwarning
2218 self._clear_registries()
2220 return self
2222 def __exit__(self, *exc_info):
2223 warnings.showwarning = self._orig_show
2224 warnings.filters = self._filters
2225 self._clear_registries()
2226 self._entered = False
2227 del self._orig_show
2228 del self._filters
2230 def _showwarning(self, message, category, filename, lineno,
2231 *args, use_warnmsg=None, **kwargs):
2232 for cat, _, pattern, mod, rec in (
2233 self._suppressions + self._tmp_suppressions)[::-1]:
2234 if (issubclass(category, cat) and
2235 pattern.match(message.args[0]) is not None):
2236 if mod is None:
2237 # Message and category match, either recorded or ignored
2238 if rec is not None:
2239 msg = WarningMessage(message, category, filename,
2240 lineno, **kwargs)
2241 self.log.append(msg)
2242 rec.append(msg)
2243 return
2244 # Use startswith, because warnings strips the c or o from
2245 # .pyc/.pyo files.
2246 elif mod.__file__.startswith(filename):
2247 # The message and module (filename) match
2248 if rec is not None:
2249 msg = WarningMessage(message, category, filename,
2250 lineno, **kwargs)
2251 self.log.append(msg)
2252 rec.append(msg)
2253 return
2255 # There is no filter in place, so pass to the outside handler
2256 # unless we should only pass it once
2257 if self._forwarding_rule == "always":
2258 if use_warnmsg is None:
2259 self._orig_show(message, category, filename, lineno,
2260 *args, **kwargs)
2261 else:
2262 self._orig_showmsg(use_warnmsg)
2263 return
2265 if self._forwarding_rule == "once":
2266 signature = (message.args, category)
2267 elif self._forwarding_rule == "module":
2268 signature = (message.args, category, filename)
2269 elif self._forwarding_rule == "location":
2270 signature = (message.args, category, filename, lineno)
2272 if signature in self._forwarded:
2273 return
2274 self._forwarded.add(signature)
2275 if use_warnmsg is None:
2276 self._orig_show(message, category, filename, lineno, *args,
2277 **kwargs)
2278 else:
2279 self._orig_showmsg(use_warnmsg)
2281 def __call__(self, func):
2282 """
2283 Function decorator to apply certain suppressions to a whole
2284 function.
2285 """
2286 @wraps(func)
2287 def new_func(*args, **kwargs):
2288 with self:
2289 return func(*args, **kwargs)
2291 return new_func
2294@contextlib.contextmanager
2295def _assert_no_gc_cycles_context(name=None):
2296 __tracebackhide__ = True # Hide traceback for py.test
2298 # not meaningful to test if there is no refcounting
2299 if not HAS_REFCOUNT:
2300 yield
2301 return
2303 assert_(gc.isenabled())
2304 gc.disable()
2305 gc_debug = gc.get_debug()
2306 try:
2307 for i in range(100):
2308 if gc.collect() == 0:
2309 break
2310 else:
2311 raise RuntimeError(
2312 "Unable to fully collect garbage - perhaps a __del__ method is "
2313 "creating more reference cycles?")
2315 gc.set_debug(gc.DEBUG_SAVEALL)
2316 yield
2317 # gc.collect returns the number of unreachable objects in cycles that
2318 # were found -- we are checking that no cycles were created in the context
2319 n_objects_in_cycles = gc.collect()
2320 objects_in_cycles = gc.garbage[:]
2321 finally:
2322 del gc.garbage[:]
2323 gc.set_debug(gc_debug)
2324 gc.enable()
2326 if n_objects_in_cycles:
2327 name_str = " when calling %s" % name if name is not None else ""
2328 raise AssertionError(
2329 "Reference cycles were found{}: {} objects were collected, "
2330 "of which {} are shown below:{}"
2331 .format(
2332 name_str,
2333 n_objects_in_cycles,
2334 len(objects_in_cycles),
2335 ''.join(
2336 "\n {} object with id={}:\n {}".format(
2337 type(o).__name__,
2338 id(o),
2339 pprint.pformat(o).replace('\n', '\n ')
2340 ) for o in objects_in_cycles
2341 )
2342 )
2343 )
2346def assert_no_gc_cycles(*args, **kwargs):
2347 """
2348 Fail if the given callable produces any reference cycles.
2350 If called with all arguments omitted, may be used as a context manager:
2352 with assert_no_gc_cycles():
2353 do_something()
2355 .. versionadded:: 1.15.0
2357 Parameters
2358 ----------
2359 func : callable
2360 The callable to test.
2361 \\*args : Arguments
2362 Arguments passed to `func`.
2363 \\*\\*kwargs : Kwargs
2364 Keyword arguments passed to `func`.
2366 Returns
2367 -------
2368 Nothing. The result is deliberately discarded to ensure that all cycles
2369 are found.
2371 """
2372 if not args:
2373 return _assert_no_gc_cycles_context()
2375 func = args[0]
2376 args = args[1:]
2377 with _assert_no_gc_cycles_context(name=func.__name__):
2378 func(*args, **kwargs)
2380def break_cycles():
2381 """
2382 Break reference cycles by calling gc.collect
2383 Objects can call other objects' methods (for instance, another object's
2384 __del__) inside their own __del__. On PyPy, the interpreter only runs
2385 between calls to gc.collect, so multiple calls are needed to completely
2386 release all cycles.
2387 """
2389 gc.collect()
2390 if IS_PYPY:
2391 # interpreter runs now, to call deleted objects' __del__ methods
2392 gc.collect()
2393 # one more, just to make sure
2394 gc.collect()
2397def requires_memory(free_bytes):
2398 """Decorator to skip a test if not enough memory is available"""
2399 import pytest
2401 def decorator(func):
2402 @wraps(func)
2403 def wrapper(*a, **kw):
2404 msg = check_free_memory(free_bytes)
2405 if msg is not None:
2406 pytest.skip(msg)
2408 try:
2409 return func(*a, **kw)
2410 except MemoryError:
2411 # Probably ran out of memory regardless: don't regard as failure
2412 pytest.xfail("MemoryError raised")
2414 return wrapper
2416 return decorator
2419def check_free_memory(free_bytes):
2420 """
2421 Check whether `free_bytes` amount of memory is currently free.
2422 Returns: None if enough memory available, otherwise error message
2423 """
2424 env_var = 'NPY_AVAILABLE_MEM'
2425 env_value = os.environ.get(env_var)
2426 if env_value is not None:
2427 try:
2428 mem_free = _parse_size(env_value)
2429 except ValueError as exc:
2430 raise ValueError('Invalid environment variable {}: {!s}'.format(
2431 env_var, exc))
2433 msg = ('{0} GB memory required, but environment variable '
2434 'NPY_AVAILABLE_MEM={1} set'.format(
2435 free_bytes/1e9, env_value))
2436 else:
2437 mem_free = _get_mem_available()
2439 if mem_free is None:
2440 msg = ("Could not determine available memory; set NPY_AVAILABLE_MEM "
2441 "environment variable (e.g. NPY_AVAILABLE_MEM=16GB) to run "
2442 "the test.")
2443 mem_free = -1
2444 else:
2445 msg = '{0} GB memory required, but {1} GB available'.format(
2446 free_bytes/1e9, mem_free/1e9)
2448 return msg if mem_free < free_bytes else None
2451def _parse_size(size_str):
2452 """Convert memory size strings ('12 GB' etc.) to float"""
2453 suffixes = {'': 1, 'b': 1,
2454 'k': 1000, 'm': 1000**2, 'g': 1000**3, 't': 1000**4,
2455 'kb': 1000, 'mb': 1000**2, 'gb': 1000**3, 'tb': 1000**4,
2456 'kib': 1024, 'mib': 1024**2, 'gib': 1024**3, 'tib': 1024**4}
2458 size_re = re.compile(r'^\s*(\d+|\d+\.\d+)\s*({0})\s*$'.format(
2459 '|'.join(suffixes.keys())), re.I)
2461 m = size_re.match(size_str.lower())
2462 if not m or m.group(2) not in suffixes:
2463 raise ValueError("value {!r} not a valid size".format(size_str))
2464 return int(float(m.group(1)) * suffixes[m.group(2)])
2467def _get_mem_available():
2468 """Return available memory in bytes, or None if unknown."""
2469 try:
2470 import psutil
2471 return psutil.virtual_memory().available
2472 except (ImportError, AttributeError):
2473 pass
2475 if sys.platform.startswith('linux'):
2476 info = {}
2477 with open('/proc/meminfo', 'r') as f:
2478 for line in f:
2479 p = line.split()
2480 info[p[0].strip(':').lower()] = int(p[1]) * 1024
2482 if 'memavailable' in info:
2483 # Linux >= 3.14
2484 return info['memavailable']
2485 else:
2486 return info['memfree'] + info['cached']
2488 return None
2491def _no_tracing(func):
2492 """
2493 Decorator to temporarily turn off tracing for the duration of a test.
2494 Needed in tests that check refcounting, otherwise the tracing itself
2495 influences the refcounts
2496 """
2497 if not hasattr(sys, 'gettrace'):
2498 return func
2499 else:
2500 @wraps(func)
2501 def wrapper(*args, **kwargs):
2502 original_trace = sys.gettrace()
2503 try:
2504 sys.settrace(None)
2505 return func(*args, **kwargs)
2506 finally:
2507 sys.settrace(original_trace)
2508 return wrapper