Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2Utility function to facilitate testing. 

3 

4""" 

5import os 

6import sys 

7import platform 

8import re 

9import gc 

10import operator 

11import warnings 

12from functools import partial, wraps 

13import shutil 

14import contextlib 

15from tempfile import mkdtemp, mkstemp 

16from unittest.case import SkipTest 

17from warnings import WarningMessage 

18import pprint 

19 

20from numpy.core import( 

21 intp, float32, empty, arange, array_repr, ndarray, isnat, array) 

22import numpy.linalg.lapack_lite 

23 

24from io import StringIO 

25 

26__all__ = [ 

27 'assert_equal', 'assert_almost_equal', 'assert_approx_equal', 

28 'assert_array_equal', 'assert_array_less', 'assert_string_equal', 

29 'assert_array_almost_equal', 'assert_raises', 'build_err_msg', 

30 'decorate_methods', 'jiffies', 'memusage', 'print_assert_equal', 

31 'raises', 'rundocs', 'runstring', 'verbose', 'measure', 

32 'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex', 

33 'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings', 

34 'assert_allclose', 'IgnoreException', 'clear_and_catch_warnings', 

35 'SkipTest', 'KnownFailureException', 'temppath', 'tempdir', 'IS_PYPY', 

36 'HAS_REFCOUNT', 'suppress_warnings', 'assert_array_compare', 

37 '_assert_valid_refcount', '_gen_alignment_data', 'assert_no_gc_cycles', 

38 'break_cycles', 'HAS_LAPACK64' 

39 ] 

40 

41 

42class KnownFailureException(Exception): 

43 '''Raise this exception to mark a test as a known failing test.''' 

44 pass 

45 

46 

47KnownFailureTest = KnownFailureException # backwards compat 

48verbose = 0 

49 

50IS_PYPY = platform.python_implementation() == 'PyPy' 

51HAS_REFCOUNT = getattr(sys, 'getrefcount', None) is not None 

52HAS_LAPACK64 = numpy.linalg.lapack_lite._ilp64 

53 

54 

55def import_nose(): 

56 """ Import nose only when needed. 

57 """ 

58 nose_is_good = True 

59 minimum_nose_version = (1, 0, 0) 

60 try: 

61 import nose 

62 except ImportError: 

63 nose_is_good = False 

64 else: 

65 if nose.__versioninfo__ < minimum_nose_version: 

66 nose_is_good = False 

67 

68 if not nose_is_good: 

69 msg = ('Need nose >= %d.%d.%d for tests - see ' 

70 'https://nose.readthedocs.io' % 

71 minimum_nose_version) 

72 raise ImportError(msg) 

73 

74 return nose 

75 

76 

77def assert_(val, msg=''): 

78 """ 

79 Assert that works in release mode. 

80 Accepts callable msg to allow deferring evaluation until failure. 

81 

82 The Python built-in ``assert`` does not work when executing code in 

83 optimized mode (the ``-O`` flag) - no byte-code is generated for it. 

84 

85 For documentation on usage, refer to the Python documentation. 

86 

87 """ 

88 __tracebackhide__ = True # Hide traceback for py.test 

89 if not val: 

90 try: 

91 smsg = msg() 

92 except TypeError: 

93 smsg = msg 

94 raise AssertionError(smsg) 

95 

96 

97def gisnan(x): 

98 """like isnan, but always raise an error if type not supported instead of 

99 returning a TypeError object. 

100 

101 Notes 

102 ----- 

103 isnan and other ufunc sometimes return a NotImplementedType object instead 

104 of raising any exception. This function is a wrapper to make sure an 

105 exception is always raised. 

106 

107 This should be removed once this problem is solved at the Ufunc level.""" 

108 from numpy.core import isnan 

109 st = isnan(x) 

110 if isinstance(st, type(NotImplemented)): 

111 raise TypeError("isnan not supported for this type") 

112 return st 

113 

114 

115def gisfinite(x): 

116 """like isfinite, but always raise an error if type not supported instead of 

117 returning a TypeError object. 

118 

119 Notes 

120 ----- 

121 isfinite and other ufunc sometimes return a NotImplementedType object instead 

122 of raising any exception. This function is a wrapper to make sure an 

123 exception is always raised. 

124 

125 This should be removed once this problem is solved at the Ufunc level.""" 

126 from numpy.core import isfinite, errstate 

127 with errstate(invalid='ignore'): 

128 st = isfinite(x) 

129 if isinstance(st, type(NotImplemented)): 

130 raise TypeError("isfinite not supported for this type") 

131 return st 

132 

133 

134def gisinf(x): 

135 """like isinf, but always raise an error if type not supported instead of 

136 returning a TypeError object. 

137 

138 Notes 

139 ----- 

140 isinf and other ufunc sometimes return a NotImplementedType object instead 

141 of raising any exception. This function is a wrapper to make sure an 

142 exception is always raised. 

143 

144 This should be removed once this problem is solved at the Ufunc level.""" 

145 from numpy.core import isinf, errstate 

146 with errstate(invalid='ignore'): 

147 st = isinf(x) 

148 if isinstance(st, type(NotImplemented)): 

149 raise TypeError("isinf not supported for this type") 

150 return st 

151 

152 

153if os.name == 'nt': 

154 # Code "stolen" from enthought/debug/memusage.py 

155 def GetPerformanceAttributes(object, counter, instance=None, 

156 inum=-1, format=None, machine=None): 

157 # NOTE: Many counters require 2 samples to give accurate results, 

158 # including "% Processor Time" (as by definition, at any instant, a 

159 # thread's CPU usage is either 0 or 100). To read counters like this, 

160 # you should copy this function, but keep the counter open, and call 

161 # CollectQueryData() each time you need to know. 

162 # See http://msdn.microsoft.com/library/en-us/dnperfmo/html/perfmonpt2.asp (dead link) 

163 # My older explanation for this was that the "AddCounter" process forced 

164 # the CPU to 100%, but the above makes more sense :) 

165 import win32pdh 

166 if format is None: 

167 format = win32pdh.PDH_FMT_LONG 

168 path = win32pdh.MakeCounterPath( (machine, object, instance, None, inum, counter)) 

169 hq = win32pdh.OpenQuery() 

170 try: 

171 hc = win32pdh.AddCounter(hq, path) 

172 try: 

173 win32pdh.CollectQueryData(hq) 

174 type, val = win32pdh.GetFormattedCounterValue(hc, format) 

175 return val 

176 finally: 

177 win32pdh.RemoveCounter(hc) 

178 finally: 

179 win32pdh.CloseQuery(hq) 

180 

181 def memusage(processName="python", instance=0): 

182 # from win32pdhutil, part of the win32all package 

183 import win32pdh 

184 return GetPerformanceAttributes("Process", "Virtual Bytes", 

185 processName, instance, 

186 win32pdh.PDH_FMT_LONG, None) 

187elif sys.platform[:5] == 'linux': 

188 

189 def memusage(_proc_pid_stat='/proc/%s/stat' % (os.getpid())): 

190 """ 

191 Return virtual memory size in bytes of the running python. 

192 

193 """ 

194 try: 

195 with open(_proc_pid_stat, 'r') as f: 

196 l = f.readline().split(' ') 

197 return int(l[22]) 

198 except Exception: 

199 return 

200else: 

201 def memusage(): 

202 """ 

203 Return memory usage of running python. [Not implemented] 

204 

205 """ 

206 raise NotImplementedError 

207 

208 

209if sys.platform[:5] == 'linux': 

210 def jiffies(_proc_pid_stat='/proc/%s/stat' % (os.getpid()), 

211 _load_time=[]): 

212 """ 

213 Return number of jiffies elapsed. 

214 

215 Return number of jiffies (1/100ths of a second) that this 

216 process has been scheduled in user mode. See man 5 proc. 

217 

218 """ 

219 import time 

220 if not _load_time: 

221 _load_time.append(time.time()) 

222 try: 

223 with open(_proc_pid_stat, 'r') as f: 

224 l = f.readline().split(' ') 

225 return int(l[13]) 

226 except Exception: 

227 return int(100*(time.time()-_load_time[0])) 

228else: 

229 # os.getpid is not in all platforms available. 

230 # Using time is safe but inaccurate, especially when process 

231 # was suspended or sleeping. 

232 def jiffies(_load_time=[]): 

233 """ 

234 Return number of jiffies elapsed. 

235 

236 Return number of jiffies (1/100ths of a second) that this 

237 process has been scheduled in user mode. See man 5 proc. 

238 

239 """ 

240 import time 

241 if not _load_time: 

242 _load_time.append(time.time()) 

243 return int(100*(time.time()-_load_time[0])) 

244 

245 

246def build_err_msg(arrays, err_msg, header='Items are not equal:', 

247 verbose=True, names=('ACTUAL', 'DESIRED'), precision=8): 

248 msg = ['\n' + header] 

249 if err_msg: 

250 if err_msg.find('\n') == -1 and len(err_msg) < 79-len(header): 

251 msg = [msg[0] + ' ' + err_msg] 

252 else: 

253 msg.append(err_msg) 

254 if verbose: 

255 for i, a in enumerate(arrays): 

256 

257 if isinstance(a, ndarray): 

258 # precision argument is only needed if the objects are ndarrays 

259 r_func = partial(array_repr, precision=precision) 

260 else: 

261 r_func = repr 

262 

263 try: 

264 r = r_func(a) 

265 except Exception as exc: 

266 r = '[repr failed for <{}>: {}]'.format(type(a).__name__, exc) 

267 if r.count('\n') > 3: 

268 r = '\n'.join(r.splitlines()[:3]) 

269 r += '...' 

270 msg.append(' %s: %s' % (names[i], r)) 

271 return '\n'.join(msg) 

272 

273 

274def assert_equal(actual, desired, err_msg='', verbose=True): 

275 """ 

276 Raises an AssertionError if two objects are not equal. 

277 

278 Given two objects (scalars, lists, tuples, dictionaries or numpy arrays), 

279 check that all elements of these objects are equal. An exception is raised 

280 at the first conflicting values. 

281 

282 When one of `actual` and `desired` is a scalar and the other is array_like, 

283 the function checks that each element of the array_like object is equal to 

284 the scalar. 

285 

286 This function handles NaN comparisons as if NaN was a "normal" number. 

287 That is, AssertionError is not raised if both objects have NaNs in the same 

288 positions. This is in contrast to the IEEE standard on NaNs, which says 

289 that NaN compared to anything must return False. 

290 

291 Parameters 

292 ---------- 

293 actual : array_like 

294 The object to check. 

295 desired : array_like 

296 The expected object. 

297 err_msg : str, optional 

298 The error message to be printed in case of failure. 

299 verbose : bool, optional 

300 If True, the conflicting values are appended to the error message. 

301 

302 Raises 

303 ------ 

304 AssertionError 

305 If actual and desired are not equal. 

306 

307 Examples 

308 -------- 

309 >>> np.testing.assert_equal([4,5], [4,6]) 

310 Traceback (most recent call last): 

311 ... 

312 AssertionError: 

313 Items are not equal: 

314 item=1 

315 ACTUAL: 5 

316 DESIRED: 6 

317 

318 The following comparison does not raise an exception. There are NaNs 

319 in the inputs, but they are in the same positions. 

320 

321 >>> np.testing.assert_equal(np.array([1.0, 2.0, np.nan]), [1, 2, np.nan]) 

322 

323 """ 

324 __tracebackhide__ = True # Hide traceback for py.test 

325 if isinstance(desired, dict): 

326 if not isinstance(actual, dict): 

327 raise AssertionError(repr(type(actual))) 

328 assert_equal(len(actual), len(desired), err_msg, verbose) 

329 for k, i in desired.items(): 

330 if k not in actual: 

331 raise AssertionError(repr(k)) 

332 assert_equal(actual[k], desired[k], 'key=%r\n%s' % (k, err_msg), verbose) 

333 return 

334 if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)): 

335 assert_equal(len(actual), len(desired), err_msg, verbose) 

336 for k in range(len(desired)): 

337 assert_equal(actual[k], desired[k], 'item=%r\n%s' % (k, err_msg), verbose) 

338 return 

339 from numpy.core import ndarray, isscalar, signbit 

340 from numpy.lib import iscomplexobj, real, imag 

341 if isinstance(actual, ndarray) or isinstance(desired, ndarray): 

342 return assert_array_equal(actual, desired, err_msg, verbose) 

343 msg = build_err_msg([actual, desired], err_msg, verbose=verbose) 

344 

345 # Handle complex numbers: separate into real/imag to handle 

346 # nan/inf/negative zero correctly 

347 # XXX: catch ValueError for subclasses of ndarray where iscomplex fail 

348 try: 

349 usecomplex = iscomplexobj(actual) or iscomplexobj(desired) 

350 except (ValueError, TypeError): 

351 usecomplex = False 

352 

353 if usecomplex: 

354 if iscomplexobj(actual): 

355 actualr = real(actual) 

356 actuali = imag(actual) 

357 else: 

358 actualr = actual 

359 actuali = 0 

360 if iscomplexobj(desired): 

361 desiredr = real(desired) 

362 desiredi = imag(desired) 

363 else: 

364 desiredr = desired 

365 desiredi = 0 

366 try: 

367 assert_equal(actualr, desiredr) 

368 assert_equal(actuali, desiredi) 

369 except AssertionError: 

370 raise AssertionError(msg) 

371 

372 # isscalar test to check cases such as [np.nan] != np.nan 

373 if isscalar(desired) != isscalar(actual): 

374 raise AssertionError(msg) 

375 

376 try: 

377 isdesnat = isnat(desired) 

378 isactnat = isnat(actual) 

379 dtypes_match = array(desired).dtype.type == array(actual).dtype.type 

380 if isdesnat and isactnat: 

381 # If both are NaT (and have the same dtype -- datetime or 

382 # timedelta) they are considered equal. 

383 if dtypes_match: 

384 return 

385 else: 

386 raise AssertionError(msg) 

387 

388 except (TypeError, ValueError, NotImplementedError): 

389 pass 

390 

391 # Inf/nan/negative zero handling 

392 try: 

393 isdesnan = gisnan(desired) 

394 isactnan = gisnan(actual) 

395 if isdesnan and isactnan: 

396 return # both nan, so equal 

397 

398 # handle signed zero specially for floats 

399 array_actual = array(actual) 

400 array_desired = array(desired) 

401 if (array_actual.dtype.char in 'Mm' or 

402 array_desired.dtype.char in 'Mm'): 

403 # version 1.18 

404 # until this version, gisnan failed for datetime64 and timedelta64. 

405 # Now it succeeds but comparison to scalar with a different type 

406 # emits a DeprecationWarning. 

407 # Avoid that by skipping the next check 

408 raise NotImplementedError('cannot compare to a scalar ' 

409 'with a different type') 

410 

411 if desired == 0 and actual == 0: 

412 if not signbit(desired) == signbit(actual): 

413 raise AssertionError(msg) 

414 

415 except (TypeError, ValueError, NotImplementedError): 

416 pass 

417 

418 try: 

419 # Explicitly use __eq__ for comparison, gh-2552 

420 if not (desired == actual): 

421 raise AssertionError(msg) 

422 

423 except (DeprecationWarning, FutureWarning) as e: 

424 # this handles the case when the two types are not even comparable 

425 if 'elementwise == comparison' in e.args[0]: 

426 raise AssertionError(msg) 

427 else: 

428 raise 

429 

430 

431def print_assert_equal(test_string, actual, desired): 

432 """ 

433 Test if two objects are equal, and print an error message if test fails. 

434 

435 The test is performed with ``actual == desired``. 

436 

437 Parameters 

438 ---------- 

439 test_string : str 

440 The message supplied to AssertionError. 

441 actual : object 

442 The object to test for equality against `desired`. 

443 desired : object 

444 The expected result. 

445 

446 Examples 

447 -------- 

448 >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 1]) 

449 >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 2]) 

450 Traceback (most recent call last): 

451 ... 

452 AssertionError: Test XYZ of func xyz failed 

453 ACTUAL: 

454 [0, 1] 

455 DESIRED: 

456 [0, 2] 

457 

458 """ 

459 __tracebackhide__ = True # Hide traceback for py.test 

460 import pprint 

461 

462 if not (actual == desired): 

463 msg = StringIO() 

464 msg.write(test_string) 

465 msg.write(' failed\nACTUAL: \n') 

466 pprint.pprint(actual, msg) 

467 msg.write('DESIRED: \n') 

468 pprint.pprint(desired, msg) 

469 raise AssertionError(msg.getvalue()) 

470 

471 

472def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True): 

473 """ 

474 Raises an AssertionError if two items are not equal up to desired 

475 precision. 

476 

477 .. note:: It is recommended to use one of `assert_allclose`, 

478 `assert_array_almost_equal_nulp` or `assert_array_max_ulp` 

479 instead of this function for more consistent floating point 

480 comparisons. 

481 

482 The test verifies that the elements of ``actual`` and ``desired`` satisfy. 

483 

484 ``abs(desired-actual) < 1.5 * 10**(-decimal)`` 

485 

486 That is a looser test than originally documented, but agrees with what the 

487 actual implementation in `assert_array_almost_equal` did up to rounding 

488 vagaries. An exception is raised at conflicting values. For ndarrays this 

489 delegates to assert_array_almost_equal 

490 

491 Parameters 

492 ---------- 

493 actual : array_like 

494 The object to check. 

495 desired : array_like 

496 The expected object. 

497 decimal : int, optional 

498 Desired precision, default is 7. 

499 err_msg : str, optional 

500 The error message to be printed in case of failure. 

501 verbose : bool, optional 

502 If True, the conflicting values are appended to the error message. 

503 

504 Raises 

505 ------ 

506 AssertionError 

507 If actual and desired are not equal up to specified precision. 

508 

509 See Also 

510 -------- 

511 assert_allclose: Compare two array_like objects for equality with desired 

512 relative and/or absolute precision. 

513 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal 

514 

515 Examples 

516 -------- 

517 >>> import numpy.testing as npt 

518 >>> npt.assert_almost_equal(2.3333333333333, 2.33333334) 

519 >>> npt.assert_almost_equal(2.3333333333333, 2.33333334, decimal=10) 

520 Traceback (most recent call last): 

521 ... 

522 AssertionError: 

523 Arrays are not almost equal to 10 decimals 

524 ACTUAL: 2.3333333333333 

525 DESIRED: 2.33333334 

526 

527 >>> npt.assert_almost_equal(np.array([1.0,2.3333333333333]), 

528 ... np.array([1.0,2.33333334]), decimal=9) 

529 Traceback (most recent call last): 

530 ... 

531 AssertionError: 

532 Arrays are not almost equal to 9 decimals 

533 <BLANKLINE> 

534 Mismatched elements: 1 / 2 (50%) 

535 Max absolute difference: 6.66669964e-09 

536 Max relative difference: 2.85715698e-09 

537 x: array([1. , 2.333333333]) 

538 y: array([1. , 2.33333334]) 

539 

540 """ 

541 __tracebackhide__ = True # Hide traceback for py.test 

542 from numpy.core import ndarray 

543 from numpy.lib import iscomplexobj, real, imag 

544 

545 # Handle complex numbers: separate into real/imag to handle 

546 # nan/inf/negative zero correctly 

547 # XXX: catch ValueError for subclasses of ndarray where iscomplex fail 

548 try: 

549 usecomplex = iscomplexobj(actual) or iscomplexobj(desired) 

550 except ValueError: 

551 usecomplex = False 

552 

553 def _build_err_msg(): 

554 header = ('Arrays are not almost equal to %d decimals' % decimal) 

555 return build_err_msg([actual, desired], err_msg, verbose=verbose, 

556 header=header) 

557 

558 if usecomplex: 

559 if iscomplexobj(actual): 

560 actualr = real(actual) 

561 actuali = imag(actual) 

562 else: 

563 actualr = actual 

564 actuali = 0 

565 if iscomplexobj(desired): 

566 desiredr = real(desired) 

567 desiredi = imag(desired) 

568 else: 

569 desiredr = desired 

570 desiredi = 0 

571 try: 

572 assert_almost_equal(actualr, desiredr, decimal=decimal) 

573 assert_almost_equal(actuali, desiredi, decimal=decimal) 

574 except AssertionError: 

575 raise AssertionError(_build_err_msg()) 

576 

577 if isinstance(actual, (ndarray, tuple, list)) \ 

578 or isinstance(desired, (ndarray, tuple, list)): 

579 return assert_array_almost_equal(actual, desired, decimal, err_msg) 

580 try: 

581 # If one of desired/actual is not finite, handle it specially here: 

582 # check that both are nan if any is a nan, and test for equality 

583 # otherwise 

584 if not (gisfinite(desired) and gisfinite(actual)): 

585 if gisnan(desired) or gisnan(actual): 

586 if not (gisnan(desired) and gisnan(actual)): 

587 raise AssertionError(_build_err_msg()) 

588 else: 

589 if not desired == actual: 

590 raise AssertionError(_build_err_msg()) 

591 return 

592 except (NotImplementedError, TypeError): 

593 pass 

594 if abs(desired - actual) >= 1.5 * 10.0**(-decimal): 

595 raise AssertionError(_build_err_msg()) 

596 

597 

598def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True): 

599 """ 

600 Raises an AssertionError if two items are not equal up to significant 

601 digits. 

602 

603 .. note:: It is recommended to use one of `assert_allclose`, 

604 `assert_array_almost_equal_nulp` or `assert_array_max_ulp` 

605 instead of this function for more consistent floating point 

606 comparisons. 

607 

608 Given two numbers, check that they are approximately equal. 

609 Approximately equal is defined as the number of significant digits 

610 that agree. 

611 

612 Parameters 

613 ---------- 

614 actual : scalar 

615 The object to check. 

616 desired : scalar 

617 The expected object. 

618 significant : int, optional 

619 Desired precision, default is 7. 

620 err_msg : str, optional 

621 The error message to be printed in case of failure. 

622 verbose : bool, optional 

623 If True, the conflicting values are appended to the error message. 

624 

625 Raises 

626 ------ 

627 AssertionError 

628 If actual and desired are not equal up to specified precision. 

629 

630 See Also 

631 -------- 

632 assert_allclose: Compare two array_like objects for equality with desired 

633 relative and/or absolute precision. 

634 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal 

635 

636 Examples 

637 -------- 

638 >>> np.testing.assert_approx_equal(0.12345677777777e-20, 0.1234567e-20) 

639 >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345671e-20, 

640 ... significant=8) 

641 >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345672e-20, 

642 ... significant=8) 

643 Traceback (most recent call last): 

644 ... 

645 AssertionError: 

646 Items are not equal to 8 significant digits: 

647 ACTUAL: 1.234567e-21 

648 DESIRED: 1.2345672e-21 

649 

650 the evaluated condition that raises the exception is 

651 

652 >>> abs(0.12345670e-20/1e-21 - 0.12345672e-20/1e-21) >= 10**-(8-1) 

653 True 

654 

655 """ 

656 __tracebackhide__ = True # Hide traceback for py.test 

657 import numpy as np 

658 

659 (actual, desired) = map(float, (actual, desired)) 

660 if desired == actual: 

661 return 

662 # Normalized the numbers to be in range (-10.0,10.0) 

663 # scale = float(pow(10,math.floor(math.log10(0.5*(abs(desired)+abs(actual)))))) 

664 with np.errstate(invalid='ignore'): 

665 scale = 0.5*(np.abs(desired) + np.abs(actual)) 

666 scale = np.power(10, np.floor(np.log10(scale))) 

667 try: 

668 sc_desired = desired/scale 

669 except ZeroDivisionError: 

670 sc_desired = 0.0 

671 try: 

672 sc_actual = actual/scale 

673 except ZeroDivisionError: 

674 sc_actual = 0.0 

675 msg = build_err_msg( 

676 [actual, desired], err_msg, 

677 header='Items are not equal to %d significant digits:' % significant, 

678 verbose=verbose) 

679 try: 

680 # If one of desired/actual is not finite, handle it specially here: 

681 # check that both are nan if any is a nan, and test for equality 

682 # otherwise 

683 if not (gisfinite(desired) and gisfinite(actual)): 

684 if gisnan(desired) or gisnan(actual): 

685 if not (gisnan(desired) and gisnan(actual)): 

686 raise AssertionError(msg) 

687 else: 

688 if not desired == actual: 

689 raise AssertionError(msg) 

690 return 

691 except (TypeError, NotImplementedError): 

692 pass 

693 if np.abs(sc_desired - sc_actual) >= np.power(10., -(significant-1)): 

694 raise AssertionError(msg) 

695 

696 

697def assert_array_compare(comparison, x, y, err_msg='', verbose=True, 

698 header='', precision=6, equal_nan=True, 

699 equal_inf=True): 

700 __tracebackhide__ = True # Hide traceback for py.test 

701 from numpy.core import array, array2string, isnan, inf, bool_, errstate, all, max, object_ 

702 

703 x = array(x, copy=False, subok=True) 

704 y = array(y, copy=False, subok=True) 

705 

706 # original array for output formatting 

707 ox, oy = x, y 

708 

709 def isnumber(x): 

710 return x.dtype.char in '?bhilqpBHILQPefdgFDG' 

711 

712 def istime(x): 

713 return x.dtype.char in "Mm" 

714 

715 def func_assert_same_pos(x, y, func=isnan, hasval='nan'): 

716 """Handling nan/inf. 

717 

718 Combine results of running func on x and y, checking that they are True 

719 at the same locations. 

720 

721 """ 

722 x_id = func(x) 

723 y_id = func(y) 

724 # We include work-arounds here to handle three types of slightly 

725 # pathological ndarray subclasses: 

726 # (1) all() on `masked` array scalars can return masked arrays, so we 

727 # use != True 

728 # (2) __eq__ on some ndarray subclasses returns Python booleans 

729 # instead of element-wise comparisons, so we cast to bool_() and 

730 # use isinstance(..., bool) checks 

731 # (3) subclasses with bare-bones __array_function__ implementations may 

732 # not implement np.all(), so favor using the .all() method 

733 # We are not committed to supporting such subclasses, but it's nice to 

734 # support them if possible. 

735 if bool_(x_id == y_id).all() != True: 

736 msg = build_err_msg([x, y], 

737 err_msg + '\nx and y %s location mismatch:' 

738 % (hasval), verbose=verbose, header=header, 

739 names=('x', 'y'), precision=precision) 

740 raise AssertionError(msg) 

741 # If there is a scalar, then here we know the array has the same 

742 # flag as it everywhere, so we should return the scalar flag. 

743 if isinstance(x_id, bool) or x_id.ndim == 0: 

744 return bool_(x_id) 

745 elif isinstance(x_id, bool) or y_id.ndim == 0: 

746 return bool_(y_id) 

747 else: 

748 return y_id 

749 

750 try: 

751 cond = (x.shape == () or y.shape == ()) or x.shape == y.shape 

752 if not cond: 

753 msg = build_err_msg([x, y], 

754 err_msg 

755 + '\n(shapes %s, %s mismatch)' % (x.shape, 

756 y.shape), 

757 verbose=verbose, header=header, 

758 names=('x', 'y'), precision=precision) 

759 raise AssertionError(msg) 

760 

761 flagged = bool_(False) 

762 if isnumber(x) and isnumber(y): 

763 if equal_nan: 

764 flagged = func_assert_same_pos(x, y, func=isnan, hasval='nan') 

765 

766 if equal_inf: 

767 flagged |= func_assert_same_pos(x, y, 

768 func=lambda xy: xy == +inf, 

769 hasval='+inf') 

770 flagged |= func_assert_same_pos(x, y, 

771 func=lambda xy: xy == -inf, 

772 hasval='-inf') 

773 

774 elif istime(x) and istime(y): 

775 # If one is datetime64 and the other timedelta64 there is no point 

776 if equal_nan and x.dtype.type == y.dtype.type: 

777 flagged = func_assert_same_pos(x, y, func=isnat, hasval="NaT") 

778 

779 if flagged.ndim > 0: 

780 x, y = x[~flagged], y[~flagged] 

781 # Only do the comparison if actual values are left 

782 if x.size == 0: 

783 return 

784 elif flagged: 

785 # no sense doing comparison if everything is flagged. 

786 return 

787 

788 val = comparison(x, y) 

789 

790 if isinstance(val, bool): 

791 cond = val 

792 reduced = array([val]) 

793 else: 

794 reduced = val.ravel() 

795 cond = reduced.all() 

796 

797 # The below comparison is a hack to ensure that fully masked 

798 # results, for which val.ravel().all() returns np.ma.masked, 

799 # do not trigger a failure (np.ma.masked != True evaluates as 

800 # np.ma.masked, which is falsy). 

801 if cond != True: 

802 n_mismatch = reduced.size - reduced.sum(dtype=intp) 

803 n_elements = flagged.size if flagged.ndim != 0 else reduced.size 

804 percent_mismatch = 100 * n_mismatch / n_elements 

805 remarks = [ 

806 'Mismatched elements: {} / {} ({:.3g}%)'.format( 

807 n_mismatch, n_elements, percent_mismatch)] 

808 

809 with errstate(invalid='ignore', divide='ignore'): 

810 # ignore errors for non-numeric types 

811 with contextlib.suppress(TypeError): 

812 error = abs(x - y) 

813 max_abs_error = max(error) 

814 if getattr(error, 'dtype', object_) == object_: 

815 remarks.append('Max absolute difference: ' 

816 + str(max_abs_error)) 

817 else: 

818 remarks.append('Max absolute difference: ' 

819 + array2string(max_abs_error)) 

820 

821 # note: this definition of relative error matches that one 

822 # used by assert_allclose (found in np.isclose) 

823 # Filter values where the divisor would be zero 

824 nonzero = bool_(y != 0) 

825 if all(~nonzero): 

826 max_rel_error = array(inf) 

827 else: 

828 max_rel_error = max(error[nonzero] / abs(y[nonzero])) 

829 if getattr(error, 'dtype', object_) == object_: 

830 remarks.append('Max relative difference: ' 

831 + str(max_rel_error)) 

832 else: 

833 remarks.append('Max relative difference: ' 

834 + array2string(max_rel_error)) 

835 

836 err_msg += '\n' + '\n'.join(remarks) 

837 msg = build_err_msg([ox, oy], err_msg, 

838 verbose=verbose, header=header, 

839 names=('x', 'y'), precision=precision) 

840 raise AssertionError(msg) 

841 except ValueError: 

842 import traceback 

843 efmt = traceback.format_exc() 

844 header = 'error during assertion:\n\n%s\n\n%s' % (efmt, header) 

845 

846 msg = build_err_msg([x, y], err_msg, verbose=verbose, header=header, 

847 names=('x', 'y'), precision=precision) 

848 raise ValueError(msg) 

849 

850 

851def assert_array_equal(x, y, err_msg='', verbose=True): 

852 """ 

853 Raises an AssertionError if two array_like objects are not equal. 

854 

855 Given two array_like objects, check that the shape is equal and all 

856 elements of these objects are equal (but see the Notes for the special 

857 handling of a scalar). An exception is raised at shape mismatch or 

858 conflicting values. In contrast to the standard usage in numpy, NaNs 

859 are compared like numbers, no assertion is raised if both objects have 

860 NaNs in the same positions. 

861 

862 The usual caution for verifying equality with floating point numbers is 

863 advised. 

864 

865 Parameters 

866 ---------- 

867 x : array_like 

868 The actual object to check. 

869 y : array_like 

870 The desired, expected object. 

871 err_msg : str, optional 

872 The error message to be printed in case of failure. 

873 verbose : bool, optional 

874 If True, the conflicting values are appended to the error message. 

875 

876 Raises 

877 ------ 

878 AssertionError 

879 If actual and desired objects are not equal. 

880 

881 See Also 

882 -------- 

883 assert_allclose: Compare two array_like objects for equality with desired 

884 relative and/or absolute precision. 

885 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal 

886 

887 Notes 

888 ----- 

889 When one of `x` and `y` is a scalar and the other is array_like, the 

890 function checks that each element of the array_like object is equal to 

891 the scalar. 

892 

893 Examples 

894 -------- 

895 The first assert does not raise an exception: 

896 

897 >>> np.testing.assert_array_equal([1.0,2.33333,np.nan], 

898 ... [np.exp(0),2.33333, np.nan]) 

899 

900 Assert fails with numerical imprecision with floats: 

901 

902 >>> np.testing.assert_array_equal([1.0,np.pi,np.nan], 

903 ... [1, np.sqrt(np.pi)**2, np.nan]) 

904 Traceback (most recent call last): 

905 ... 

906 AssertionError: 

907 Arrays are not equal 

908 <BLANKLINE> 

909 Mismatched elements: 1 / 3 (33.3%) 

910 Max absolute difference: 4.4408921e-16 

911 Max relative difference: 1.41357986e-16 

912 x: array([1. , 3.141593, nan]) 

913 y: array([1. , 3.141593, nan]) 

914 

915 Use `assert_allclose` or one of the nulp (number of floating point values) 

916 functions for these cases instead: 

917 

918 >>> np.testing.assert_allclose([1.0,np.pi,np.nan], 

919 ... [1, np.sqrt(np.pi)**2, np.nan], 

920 ... rtol=1e-10, atol=0) 

921 

922 As mentioned in the Notes section, `assert_array_equal` has special 

923 handling for scalars. Here the test checks that each value in `x` is 3: 

924 

925 >>> x = np.full((2, 5), fill_value=3) 

926 >>> np.testing.assert_array_equal(x, 3) 

927 

928 """ 

929 __tracebackhide__ = True # Hide traceback for py.test 

930 assert_array_compare(operator.__eq__, x, y, err_msg=err_msg, 

931 verbose=verbose, header='Arrays are not equal') 

932 

933 

934def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True): 

935 """ 

936 Raises an AssertionError if two objects are not equal up to desired 

937 precision. 

938 

939 .. note:: It is recommended to use one of `assert_allclose`, 

940 `assert_array_almost_equal_nulp` or `assert_array_max_ulp` 

941 instead of this function for more consistent floating point 

942 comparisons. 

943 

944 The test verifies identical shapes and that the elements of ``actual`` and 

945 ``desired`` satisfy. 

946 

947 ``abs(desired-actual) < 1.5 * 10**(-decimal)`` 

948 

949 That is a looser test than originally documented, but agrees with what the 

950 actual implementation did up to rounding vagaries. An exception is raised 

951 at shape mismatch or conflicting values. In contrast to the standard usage 

952 in numpy, NaNs are compared like numbers, no assertion is raised if both 

953 objects have NaNs in the same positions. 

954 

955 Parameters 

956 ---------- 

957 x : array_like 

958 The actual object to check. 

959 y : array_like 

960 The desired, expected object. 

961 decimal : int, optional 

962 Desired precision, default is 6. 

963 err_msg : str, optional 

964 The error message to be printed in case of failure. 

965 verbose : bool, optional 

966 If True, the conflicting values are appended to the error message. 

967 

968 Raises 

969 ------ 

970 AssertionError 

971 If actual and desired are not equal up to specified precision. 

972 

973 See Also 

974 -------- 

975 assert_allclose: Compare two array_like objects for equality with desired 

976 relative and/or absolute precision. 

977 assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal 

978 

979 Examples 

980 -------- 

981 the first assert does not raise an exception 

982 

983 >>> np.testing.assert_array_almost_equal([1.0,2.333,np.nan], 

984 ... [1.0,2.333,np.nan]) 

985 

986 >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan], 

987 ... [1.0,2.33339,np.nan], decimal=5) 

988 Traceback (most recent call last): 

989 ... 

990 AssertionError: 

991 Arrays are not almost equal to 5 decimals 

992 <BLANKLINE> 

993 Mismatched elements: 1 / 3 (33.3%) 

994 Max absolute difference: 6.e-05 

995 Max relative difference: 2.57136612e-05 

996 x: array([1. , 2.33333, nan]) 

997 y: array([1. , 2.33339, nan]) 

998 

999 >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan], 

1000 ... [1.0,2.33333, 5], decimal=5) 

1001 Traceback (most recent call last): 

1002 ... 

1003 AssertionError: 

1004 Arrays are not almost equal to 5 decimals 

1005 <BLANKLINE> 

1006 x and y nan location mismatch: 

1007 x: array([1. , 2.33333, nan]) 

1008 y: array([1. , 2.33333, 5. ]) 

1009 

1010 """ 

1011 __tracebackhide__ = True # Hide traceback for py.test 

1012 from numpy.core import number, float_, result_type, array 

1013 from numpy.core.numerictypes import issubdtype 

1014 from numpy.core.fromnumeric import any as npany 

1015 

1016 def compare(x, y): 

1017 try: 

1018 if npany(gisinf(x)) or npany( gisinf(y)): 

1019 xinfid = gisinf(x) 

1020 yinfid = gisinf(y) 

1021 if not (xinfid == yinfid).all(): 

1022 return False 

1023 # if one item, x and y is +- inf 

1024 if x.size == y.size == 1: 

1025 return x == y 

1026 x = x[~xinfid] 

1027 y = y[~yinfid] 

1028 except (TypeError, NotImplementedError): 

1029 pass 

1030 

1031 # make sure y is an inexact type to avoid abs(MIN_INT); will cause 

1032 # casting of x later. 

1033 dtype = result_type(y, 1.) 

1034 y = array(y, dtype=dtype, copy=False, subok=True) 

1035 z = abs(x - y) 

1036 

1037 if not issubdtype(z.dtype, number): 

1038 z = z.astype(float_) # handle object arrays 

1039 

1040 return z < 1.5 * 10.0**(-decimal) 

1041 

1042 assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose, 

1043 header=('Arrays are not almost equal to %d decimals' % decimal), 

1044 precision=decimal) 

1045 

1046 

1047def assert_array_less(x, y, err_msg='', verbose=True): 

1048 """ 

1049 Raises an AssertionError if two array_like objects are not ordered by less 

1050 than. 

1051 

1052 Given two array_like objects, check that the shape is equal and all 

1053 elements of the first object are strictly smaller than those of the 

1054 second object. An exception is raised at shape mismatch or incorrectly 

1055 ordered values. Shape mismatch does not raise if an object has zero 

1056 dimension. In contrast to the standard usage in numpy, NaNs are 

1057 compared, no assertion is raised if both objects have NaNs in the same 

1058 positions. 

1059 

1060 

1061 

1062 Parameters 

1063 ---------- 

1064 x : array_like 

1065 The smaller object to check. 

1066 y : array_like 

1067 The larger object to compare. 

1068 err_msg : string 

1069 The error message to be printed in case of failure. 

1070 verbose : bool 

1071 If True, the conflicting values are appended to the error message. 

1072 

1073 Raises 

1074 ------ 

1075 AssertionError 

1076 If actual and desired objects are not equal. 

1077 

1078 See Also 

1079 -------- 

1080 assert_array_equal: tests objects for equality 

1081 assert_array_almost_equal: test objects for equality up to precision 

1082 

1083 

1084 

1085 Examples 

1086 -------- 

1087 >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1.1, 2.0, np.nan]) 

1088 >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1, 2.0, np.nan]) 

1089 Traceback (most recent call last): 

1090 ... 

1091 AssertionError: 

1092 Arrays are not less-ordered 

1093 <BLANKLINE> 

1094 Mismatched elements: 1 / 3 (33.3%) 

1095 Max absolute difference: 1. 

1096 Max relative difference: 0.5 

1097 x: array([ 1., 1., nan]) 

1098 y: array([ 1., 2., nan]) 

1099 

1100 >>> np.testing.assert_array_less([1.0, 4.0], 3) 

1101 Traceback (most recent call last): 

1102 ... 

1103 AssertionError: 

1104 Arrays are not less-ordered 

1105 <BLANKLINE> 

1106 Mismatched elements: 1 / 2 (50%) 

1107 Max absolute difference: 2. 

1108 Max relative difference: 0.66666667 

1109 x: array([1., 4.]) 

1110 y: array(3) 

1111 

1112 >>> np.testing.assert_array_less([1.0, 2.0, 3.0], [4]) 

1113 Traceback (most recent call last): 

1114 ... 

1115 AssertionError: 

1116 Arrays are not less-ordered 

1117 <BLANKLINE> 

1118 (shapes (3,), (1,) mismatch) 

1119 x: array([1., 2., 3.]) 

1120 y: array([4]) 

1121 

1122 """ 

1123 __tracebackhide__ = True # Hide traceback for py.test 

1124 assert_array_compare(operator.__lt__, x, y, err_msg=err_msg, 

1125 verbose=verbose, 

1126 header='Arrays are not less-ordered', 

1127 equal_inf=False) 

1128 

1129 

1130def runstring(astr, dict): 

1131 exec(astr, dict) 

1132 

1133 

1134def assert_string_equal(actual, desired): 

1135 """ 

1136 Test if two strings are equal. 

1137 

1138 If the given strings are equal, `assert_string_equal` does nothing. 

1139 If they are not equal, an AssertionError is raised, and the diff 

1140 between the strings is shown. 

1141 

1142 Parameters 

1143 ---------- 

1144 actual : str 

1145 The string to test for equality against the expected string. 

1146 desired : str 

1147 The expected string. 

1148 

1149 Examples 

1150 -------- 

1151 >>> np.testing.assert_string_equal('abc', 'abc') 

1152 >>> np.testing.assert_string_equal('abc', 'abcd') 

1153 Traceback (most recent call last): 

1154 File "<stdin>", line 1, in <module> 

1155 ... 

1156 AssertionError: Differences in strings: 

1157 - abc+ abcd? + 

1158 

1159 """ 

1160 # delay import of difflib to reduce startup time 

1161 __tracebackhide__ = True # Hide traceback for py.test 

1162 import difflib 

1163 

1164 if not isinstance(actual, str): 

1165 raise AssertionError(repr(type(actual))) 

1166 if not isinstance(desired, str): 

1167 raise AssertionError(repr(type(desired))) 

1168 if desired == actual: 

1169 return 

1170 

1171 diff = list(difflib.Differ().compare(actual.splitlines(True), desired.splitlines(True))) 

1172 diff_list = [] 

1173 while diff: 

1174 d1 = diff.pop(0) 

1175 if d1.startswith(' '): 

1176 continue 

1177 if d1.startswith('- '): 

1178 l = [d1] 

1179 d2 = diff.pop(0) 

1180 if d2.startswith('? '): 

1181 l.append(d2) 

1182 d2 = diff.pop(0) 

1183 if not d2.startswith('+ '): 

1184 raise AssertionError(repr(d2)) 

1185 l.append(d2) 

1186 if diff: 

1187 d3 = diff.pop(0) 

1188 if d3.startswith('? '): 

1189 l.append(d3) 

1190 else: 

1191 diff.insert(0, d3) 

1192 if d2[2:] == d1[2:]: 

1193 continue 

1194 diff_list.extend(l) 

1195 continue 

1196 raise AssertionError(repr(d1)) 

1197 if not diff_list: 

1198 return 

1199 msg = 'Differences in strings:\n%s' % (''.join(diff_list)).rstrip() 

1200 if actual != desired: 

1201 raise AssertionError(msg) 

1202 

1203 

1204def rundocs(filename=None, raise_on_error=True): 

1205 """ 

1206 Run doctests found in the given file. 

1207 

1208 By default `rundocs` raises an AssertionError on failure. 

1209 

1210 Parameters 

1211 ---------- 

1212 filename : str 

1213 The path to the file for which the doctests are run. 

1214 raise_on_error : bool 

1215 Whether to raise an AssertionError when a doctest fails. Default is 

1216 True. 

1217 

1218 Notes 

1219 ----- 

1220 The doctests can be run by the user/developer by adding the ``doctests`` 

1221 argument to the ``test()`` call. For example, to run all tests (including 

1222 doctests) for `numpy.lib`: 

1223 

1224 >>> np.lib.test(doctests=True) # doctest: +SKIP 

1225 """ 

1226 from numpy.compat import npy_load_module 

1227 import doctest 

1228 if filename is None: 

1229 f = sys._getframe(1) 

1230 filename = f.f_globals['__file__'] 

1231 name = os.path.splitext(os.path.basename(filename))[0] 

1232 m = npy_load_module(name, filename) 

1233 

1234 tests = doctest.DocTestFinder().find(m) 

1235 runner = doctest.DocTestRunner(verbose=False) 

1236 

1237 msg = [] 

1238 if raise_on_error: 

1239 out = lambda s: msg.append(s) 

1240 else: 

1241 out = None 

1242 

1243 for test in tests: 

1244 runner.run(test, out=out) 

1245 

1246 if runner.failures > 0 and raise_on_error: 

1247 raise AssertionError("Some doctests failed:\n%s" % "\n".join(msg)) 

1248 

1249 

1250def raises(*args): 

1251 """Decorator to check for raised exceptions. 

1252 

1253 The decorated test function must raise one of the passed exceptions to 

1254 pass. If you want to test many assertions about exceptions in a single 

1255 test, you may want to use `assert_raises` instead. 

1256 

1257 .. warning:: 

1258 This decorator is nose specific, do not use it if you are using a 

1259 different test framework. 

1260 

1261 Parameters 

1262 ---------- 

1263 args : exceptions 

1264 The test passes if any of the passed exceptions is raised. 

1265 

1266 Raises 

1267 ------ 

1268 AssertionError 

1269 

1270 Examples 

1271 -------- 

1272 

1273 Usage:: 

1274 

1275 @raises(TypeError, ValueError) 

1276 def test_raises_type_error(): 

1277 raise TypeError("This test passes") 

1278 

1279 @raises(Exception) 

1280 def test_that_fails_by_passing(): 

1281 pass 

1282 

1283 """ 

1284 nose = import_nose() 

1285 return nose.tools.raises(*args) 

1286 

1287# 

1288# assert_raises and assert_raises_regex are taken from unittest. 

1289# 

1290import unittest 

1291 

1292 

1293class _Dummy(unittest.TestCase): 

1294 def nop(self): 

1295 pass 

1296 

1297_d = _Dummy('nop') 

1298 

1299def assert_raises(*args, **kwargs): 

1300 """ 

1301 assert_raises(exception_class, callable, *args, **kwargs) 

1302 assert_raises(exception_class) 

1303 

1304 Fail unless an exception of class exception_class is thrown 

1305 by callable when invoked with arguments args and keyword 

1306 arguments kwargs. If a different type of exception is 

1307 thrown, it will not be caught, and the test case will be 

1308 deemed to have suffered an error, exactly as for an 

1309 unexpected exception. 

1310 

1311 Alternatively, `assert_raises` can be used as a context manager: 

1312 

1313 >>> from numpy.testing import assert_raises 

1314 >>> with assert_raises(ZeroDivisionError): 

1315 ... 1 / 0 

1316 

1317 is equivalent to 

1318 

1319 >>> def div(x, y): 

1320 ... return x / y 

1321 >>> assert_raises(ZeroDivisionError, div, 1, 0) 

1322 

1323 """ 

1324 __tracebackhide__ = True # Hide traceback for py.test 

1325 return _d.assertRaises(*args,**kwargs) 

1326 

1327 

1328def assert_raises_regex(exception_class, expected_regexp, *args, **kwargs): 

1329 """ 

1330 assert_raises_regex(exception_class, expected_regexp, callable, *args, 

1331 **kwargs) 

1332 assert_raises_regex(exception_class, expected_regexp) 

1333 

1334 Fail unless an exception of class exception_class and with message that 

1335 matches expected_regexp is thrown by callable when invoked with arguments 

1336 args and keyword arguments kwargs. 

1337 

1338 Alternatively, can be used as a context manager like `assert_raises`. 

1339 

1340 Name of this function adheres to Python 3.2+ reference, but should work in 

1341 all versions down to 2.6. 

1342 

1343 Notes 

1344 ----- 

1345 .. versionadded:: 1.9.0 

1346 

1347 """ 

1348 __tracebackhide__ = True # Hide traceback for py.test 

1349 return _d.assertRaisesRegex(exception_class, expected_regexp, *args, **kwargs) 

1350 

1351 

1352def decorate_methods(cls, decorator, testmatch=None): 

1353 """ 

1354 Apply a decorator to all methods in a class matching a regular expression. 

1355 

1356 The given decorator is applied to all public methods of `cls` that are 

1357 matched by the regular expression `testmatch` 

1358 (``testmatch.search(methodname)``). Methods that are private, i.e. start 

1359 with an underscore, are ignored. 

1360 

1361 Parameters 

1362 ---------- 

1363 cls : class 

1364 Class whose methods to decorate. 

1365 decorator : function 

1366 Decorator to apply to methods 

1367 testmatch : compiled regexp or str, optional 

1368 The regular expression. Default value is None, in which case the 

1369 nose default (``re.compile(r'(?:^|[\\b_\\.%s-])[Tt]est' % os.sep)``) 

1370 is used. 

1371 If `testmatch` is a string, it is compiled to a regular expression 

1372 first. 

1373 

1374 """ 

1375 if testmatch is None: 

1376 testmatch = re.compile(r'(?:^|[\\b_\\.%s-])[Tt]est' % os.sep) 

1377 else: 

1378 testmatch = re.compile(testmatch) 

1379 cls_attr = cls.__dict__ 

1380 

1381 # delayed import to reduce startup time 

1382 from inspect import isfunction 

1383 

1384 methods = [_m for _m in cls_attr.values() if isfunction(_m)] 

1385 for function in methods: 

1386 try: 

1387 if hasattr(function, 'compat_func_name'): 

1388 funcname = function.compat_func_name 

1389 else: 

1390 funcname = function.__name__ 

1391 except AttributeError: 

1392 # not a function 

1393 continue 

1394 if testmatch.search(funcname) and not funcname.startswith('_'): 

1395 setattr(cls, funcname, decorator(function)) 

1396 return 

1397 

1398 

1399def measure(code_str, times=1, label=None): 

1400 """ 

1401 Return elapsed time for executing code in the namespace of the caller. 

1402 

1403 The supplied code string is compiled with the Python builtin ``compile``. 

1404 The precision of the timing is 10 milli-seconds. If the code will execute 

1405 fast on this timescale, it can be executed many times to get reasonable 

1406 timing accuracy. 

1407 

1408 Parameters 

1409 ---------- 

1410 code_str : str 

1411 The code to be timed. 

1412 times : int, optional 

1413 The number of times the code is executed. Default is 1. The code is 

1414 only compiled once. 

1415 label : str, optional 

1416 A label to identify `code_str` with. This is passed into ``compile`` 

1417 as the second argument (for run-time error messages). 

1418 

1419 Returns 

1420 ------- 

1421 elapsed : float 

1422 Total elapsed time in seconds for executing `code_str` `times` times. 

1423 

1424 Examples 

1425 -------- 

1426 >>> times = 10 

1427 >>> etime = np.testing.measure('for i in range(1000): np.sqrt(i**2)', times=times) 

1428 >>> print("Time for a single execution : ", etime / times, "s") # doctest: +SKIP 

1429 Time for a single execution : 0.005 s 

1430 

1431 """ 

1432 frame = sys._getframe(1) 

1433 locs, globs = frame.f_locals, frame.f_globals 

1434 

1435 code = compile(code_str, 

1436 'Test name: %s ' % label, 

1437 'exec') 

1438 i = 0 

1439 elapsed = jiffies() 

1440 while i < times: 

1441 i += 1 

1442 exec(code, globs, locs) 

1443 elapsed = jiffies() - elapsed 

1444 return 0.01*elapsed 

1445 

1446 

1447def _assert_valid_refcount(op): 

1448 """ 

1449 Check that ufuncs don't mishandle refcount of object `1`. 

1450 Used in a few regression tests. 

1451 """ 

1452 if not HAS_REFCOUNT: 

1453 return True 

1454 

1455 import gc 

1456 import numpy as np 

1457 

1458 b = np.arange(100*100).reshape(100, 100) 

1459 c = b 

1460 i = 1 

1461 

1462 gc.disable() 

1463 try: 

1464 rc = sys.getrefcount(i) 

1465 for j in range(15): 

1466 d = op(b, c) 

1467 assert_(sys.getrefcount(i) >= rc) 

1468 finally: 

1469 gc.enable() 

1470 del d # for pyflakes 

1471 

1472 

1473def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True, 

1474 err_msg='', verbose=True): 

1475 """ 

1476 Raises an AssertionError if two objects are not equal up to desired 

1477 tolerance. 

1478 

1479 The test is equivalent to ``allclose(actual, desired, rtol, atol)`` (note 

1480 that ``allclose`` has different default values). It compares the difference 

1481 between `actual` and `desired` to ``atol + rtol * abs(desired)``. 

1482 

1483 .. versionadded:: 1.5.0 

1484 

1485 Parameters 

1486 ---------- 

1487 actual : array_like 

1488 Array obtained. 

1489 desired : array_like 

1490 Array desired. 

1491 rtol : float, optional 

1492 Relative tolerance. 

1493 atol : float, optional 

1494 Absolute tolerance. 

1495 equal_nan : bool, optional. 

1496 If True, NaNs will compare equal. 

1497 err_msg : str, optional 

1498 The error message to be printed in case of failure. 

1499 verbose : bool, optional 

1500 If True, the conflicting values are appended to the error message. 

1501 

1502 Raises 

1503 ------ 

1504 AssertionError 

1505 If actual and desired are not equal up to specified precision. 

1506 

1507 See Also 

1508 -------- 

1509 assert_array_almost_equal_nulp, assert_array_max_ulp 

1510 

1511 Examples 

1512 -------- 

1513 >>> x = [1e-5, 1e-3, 1e-1] 

1514 >>> y = np.arccos(np.cos(x)) 

1515 >>> np.testing.assert_allclose(x, y, rtol=1e-5, atol=0) 

1516 

1517 """ 

1518 __tracebackhide__ = True # Hide traceback for py.test 

1519 import numpy as np 

1520 

1521 def compare(x, y): 

1522 return np.core.numeric.isclose(x, y, rtol=rtol, atol=atol, 

1523 equal_nan=equal_nan) 

1524 

1525 actual, desired = np.asanyarray(actual), np.asanyarray(desired) 

1526 header = 'Not equal to tolerance rtol=%g, atol=%g' % (rtol, atol) 

1527 assert_array_compare(compare, actual, desired, err_msg=str(err_msg), 

1528 verbose=verbose, header=header, equal_nan=equal_nan) 

1529 

1530 

1531def assert_array_almost_equal_nulp(x, y, nulp=1): 

1532 """ 

1533 Compare two arrays relatively to their spacing. 

1534 

1535 This is a relatively robust method to compare two arrays whose amplitude 

1536 is variable. 

1537 

1538 Parameters 

1539 ---------- 

1540 x, y : array_like 

1541 Input arrays. 

1542 nulp : int, optional 

1543 The maximum number of unit in the last place for tolerance (see Notes). 

1544 Default is 1. 

1545 

1546 Returns 

1547 ------- 

1548 None 

1549 

1550 Raises 

1551 ------ 

1552 AssertionError 

1553 If the spacing between `x` and `y` for one or more elements is larger 

1554 than `nulp`. 

1555 

1556 See Also 

1557 -------- 

1558 assert_array_max_ulp : Check that all items of arrays differ in at most 

1559 N Units in the Last Place. 

1560 spacing : Return the distance between x and the nearest adjacent number. 

1561 

1562 Notes 

1563 ----- 

1564 An assertion is raised if the following condition is not met:: 

1565 

1566 abs(x - y) <= nulps * spacing(maximum(abs(x), abs(y))) 

1567 

1568 Examples 

1569 -------- 

1570 >>> x = np.array([1., 1e-10, 1e-20]) 

1571 >>> eps = np.finfo(x.dtype).eps 

1572 >>> np.testing.assert_array_almost_equal_nulp(x, x*eps/2 + x) 

1573 

1574 >>> np.testing.assert_array_almost_equal_nulp(x, x*eps + x) 

1575 Traceback (most recent call last): 

1576 ... 

1577 AssertionError: X and Y are not equal to 1 ULP (max is 2) 

1578 

1579 """ 

1580 __tracebackhide__ = True # Hide traceback for py.test 

1581 import numpy as np 

1582 ax = np.abs(x) 

1583 ay = np.abs(y) 

1584 ref = nulp * np.spacing(np.where(ax > ay, ax, ay)) 

1585 if not np.all(np.abs(x-y) <= ref): 

1586 if np.iscomplexobj(x) or np.iscomplexobj(y): 

1587 msg = "X and Y are not equal to %d ULP" % nulp 

1588 else: 

1589 max_nulp = np.max(nulp_diff(x, y)) 

1590 msg = "X and Y are not equal to %d ULP (max is %g)" % (nulp, max_nulp) 

1591 raise AssertionError(msg) 

1592 

1593 

1594def assert_array_max_ulp(a, b, maxulp=1, dtype=None): 

1595 """ 

1596 Check that all items of arrays differ in at most N Units in the Last Place. 

1597 

1598 Parameters 

1599 ---------- 

1600 a, b : array_like 

1601 Input arrays to be compared. 

1602 maxulp : int, optional 

1603 The maximum number of units in the last place that elements of `a` and 

1604 `b` can differ. Default is 1. 

1605 dtype : dtype, optional 

1606 Data-type to convert `a` and `b` to if given. Default is None. 

1607 

1608 Returns 

1609 ------- 

1610 ret : ndarray 

1611 Array containing number of representable floating point numbers between 

1612 items in `a` and `b`. 

1613 

1614 Raises 

1615 ------ 

1616 AssertionError 

1617 If one or more elements differ by more than `maxulp`. 

1618 

1619 Notes 

1620 ----- 

1621 For computing the ULP difference, this API does not differentiate between 

1622 various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000 

1623 is zero). 

1624 

1625 See Also 

1626 -------- 

1627 assert_array_almost_equal_nulp : Compare two arrays relatively to their 

1628 spacing. 

1629 

1630 Examples 

1631 -------- 

1632 >>> a = np.linspace(0., 1., 100) 

1633 >>> res = np.testing.assert_array_max_ulp(a, np.arcsin(np.sin(a))) 

1634 

1635 """ 

1636 __tracebackhide__ = True # Hide traceback for py.test 

1637 import numpy as np 

1638 ret = nulp_diff(a, b, dtype) 

1639 if not np.all(ret <= maxulp): 

1640 raise AssertionError("Arrays are not almost equal up to %g " 

1641 "ULP (max difference is %g ULP)" % 

1642 (maxulp, np.max(ret))) 

1643 return ret 

1644 

1645 

1646def nulp_diff(x, y, dtype=None): 

1647 """For each item in x and y, return the number of representable floating 

1648 points between them. 

1649 

1650 Parameters 

1651 ---------- 

1652 x : array_like 

1653 first input array 

1654 y : array_like 

1655 second input array 

1656 dtype : dtype, optional 

1657 Data-type to convert `x` and `y` to if given. Default is None. 

1658 

1659 Returns 

1660 ------- 

1661 nulp : array_like 

1662 number of representable floating point numbers between each item in x 

1663 and y. 

1664 

1665 Notes 

1666 ----- 

1667 For computing the ULP difference, this API does not differentiate between 

1668 various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000 

1669 is zero). 

1670 

1671 Examples 

1672 -------- 

1673 # By definition, epsilon is the smallest number such as 1 + eps != 1, so 

1674 # there should be exactly one ULP between 1 and 1 + eps 

1675 >>> nulp_diff(1, 1 + np.finfo(x.dtype).eps) 

1676 1.0 

1677 """ 

1678 import numpy as np 

1679 if dtype: 

1680 x = np.array(x, dtype=dtype) 

1681 y = np.array(y, dtype=dtype) 

1682 else: 

1683 x = np.array(x) 

1684 y = np.array(y) 

1685 

1686 t = np.common_type(x, y) 

1687 if np.iscomplexobj(x) or np.iscomplexobj(y): 

1688 raise NotImplementedError("_nulp not implemented for complex array") 

1689 

1690 x = np.array([x], dtype=t) 

1691 y = np.array([y], dtype=t) 

1692 

1693 x[np.isnan(x)] = np.nan 

1694 y[np.isnan(y)] = np.nan 

1695 

1696 if not x.shape == y.shape: 

1697 raise ValueError("x and y do not have the same shape: %s - %s" % 

1698 (x.shape, y.shape)) 

1699 

1700 def _diff(rx, ry, vdt): 

1701 diff = np.array(rx-ry, dtype=vdt) 

1702 return np.abs(diff) 

1703 

1704 rx = integer_repr(x) 

1705 ry = integer_repr(y) 

1706 return _diff(rx, ry, t) 

1707 

1708 

1709def _integer_repr(x, vdt, comp): 

1710 # Reinterpret binary representation of the float as sign-magnitude: 

1711 # take into account two-complement representation 

1712 # See also 

1713 # https://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/ 

1714 rx = x.view(vdt) 

1715 if not (rx.size == 1): 

1716 rx[rx < 0] = comp - rx[rx < 0] 

1717 else: 

1718 if rx < 0: 

1719 rx = comp - rx 

1720 

1721 return rx 

1722 

1723 

1724def integer_repr(x): 

1725 """Return the signed-magnitude interpretation of the binary representation of 

1726 x.""" 

1727 import numpy as np 

1728 if x.dtype == np.float16: 

1729 return _integer_repr(x, np.int16, np.int16(-2**15)) 

1730 elif x.dtype == np.float32: 

1731 return _integer_repr(x, np.int32, np.int32(-2**31)) 

1732 elif x.dtype == np.float64: 

1733 return _integer_repr(x, np.int64, np.int64(-2**63)) 

1734 else: 

1735 raise ValueError("Unsupported dtype %s" % x.dtype) 

1736 

1737 

1738@contextlib.contextmanager 

1739def _assert_warns_context(warning_class, name=None): 

1740 __tracebackhide__ = True # Hide traceback for py.test 

1741 with suppress_warnings() as sup: 

1742 l = sup.record(warning_class) 

1743 yield 

1744 if not len(l) > 0: 

1745 name_str = " when calling %s" % name if name is not None else "" 

1746 raise AssertionError("No warning raised" + name_str) 

1747 

1748 

1749def assert_warns(warning_class, *args, **kwargs): 

1750 """ 

1751 Fail unless the given callable throws the specified warning. 

1752 

1753 A warning of class warning_class should be thrown by the callable when 

1754 invoked with arguments args and keyword arguments kwargs. 

1755 If a different type of warning is thrown, it will not be caught. 

1756 

1757 If called with all arguments other than the warning class omitted, may be 

1758 used as a context manager: 

1759 

1760 with assert_warns(SomeWarning): 

1761 do_something() 

1762 

1763 The ability to be used as a context manager is new in NumPy v1.11.0. 

1764 

1765 .. versionadded:: 1.4.0 

1766 

1767 Parameters 

1768 ---------- 

1769 warning_class : class 

1770 The class defining the warning that `func` is expected to throw. 

1771 func : callable 

1772 The callable to test. 

1773 \\*args : Arguments 

1774 Arguments passed to `func`. 

1775 \\*\\*kwargs : Kwargs 

1776 Keyword arguments passed to `func`. 

1777 

1778 Returns 

1779 ------- 

1780 The value returned by `func`. 

1781 

1782 """ 

1783 if not args: 

1784 return _assert_warns_context(warning_class) 

1785 

1786 func = args[0] 

1787 args = args[1:] 

1788 with _assert_warns_context(warning_class, name=func.__name__): 

1789 return func(*args, **kwargs) 

1790 

1791 

1792@contextlib.contextmanager 

1793def _assert_no_warnings_context(name=None): 

1794 __tracebackhide__ = True # Hide traceback for py.test 

1795 with warnings.catch_warnings(record=True) as l: 

1796 warnings.simplefilter('always') 

1797 yield 

1798 if len(l) > 0: 

1799 name_str = " when calling %s" % name if name is not None else "" 

1800 raise AssertionError("Got warnings%s: %s" % (name_str, l)) 

1801 

1802 

1803def assert_no_warnings(*args, **kwargs): 

1804 """ 

1805 Fail if the given callable produces any warnings. 

1806 

1807 If called with all arguments omitted, may be used as a context manager: 

1808 

1809 with assert_no_warnings(): 

1810 do_something() 

1811 

1812 The ability to be used as a context manager is new in NumPy v1.11.0. 

1813 

1814 .. versionadded:: 1.7.0 

1815 

1816 Parameters 

1817 ---------- 

1818 func : callable 

1819 The callable to test. 

1820 \\*args : Arguments 

1821 Arguments passed to `func`. 

1822 \\*\\*kwargs : Kwargs 

1823 Keyword arguments passed to `func`. 

1824 

1825 Returns 

1826 ------- 

1827 The value returned by `func`. 

1828 

1829 """ 

1830 if not args: 

1831 return _assert_no_warnings_context() 

1832 

1833 func = args[0] 

1834 args = args[1:] 

1835 with _assert_no_warnings_context(name=func.__name__): 

1836 return func(*args, **kwargs) 

1837 

1838 

1839def _gen_alignment_data(dtype=float32, type='binary', max_size=24): 

1840 """ 

1841 generator producing data with different alignment and offsets 

1842 to test simd vectorization 

1843 

1844 Parameters 

1845 ---------- 

1846 dtype : dtype 

1847 data type to produce 

1848 type : string 

1849 'unary': create data for unary operations, creates one input 

1850 and output array 

1851 'binary': create data for unary operations, creates two input 

1852 and output array 

1853 max_size : integer 

1854 maximum size of data to produce 

1855 

1856 Returns 

1857 ------- 

1858 if type is 'unary' yields one output, one input array and a message 

1859 containing information on the data 

1860 if type is 'binary' yields one output array, two input array and a message 

1861 containing information on the data 

1862 

1863 """ 

1864 ufmt = 'unary offset=(%d, %d), size=%d, dtype=%r, %s' 

1865 bfmt = 'binary offset=(%d, %d, %d), size=%d, dtype=%r, %s' 

1866 for o in range(3): 

1867 for s in range(o + 2, max(o + 3, max_size)): 

1868 if type == 'unary': 

1869 inp = lambda: arange(s, dtype=dtype)[o:] 

1870 out = empty((s,), dtype=dtype)[o:] 

1871 yield out, inp(), ufmt % (o, o, s, dtype, 'out of place') 

1872 d = inp() 

1873 yield d, d, ufmt % (o, o, s, dtype, 'in place') 

1874 yield out[1:], inp()[:-1], ufmt % \ 

1875 (o + 1, o, s - 1, dtype, 'out of place') 

1876 yield out[:-1], inp()[1:], ufmt % \ 

1877 (o, o + 1, s - 1, dtype, 'out of place') 

1878 yield inp()[:-1], inp()[1:], ufmt % \ 

1879 (o, o + 1, s - 1, dtype, 'aliased') 

1880 yield inp()[1:], inp()[:-1], ufmt % \ 

1881 (o + 1, o, s - 1, dtype, 'aliased') 

1882 if type == 'binary': 

1883 inp1 = lambda: arange(s, dtype=dtype)[o:] 

1884 inp2 = lambda: arange(s, dtype=dtype)[o:] 

1885 out = empty((s,), dtype=dtype)[o:] 

1886 yield out, inp1(), inp2(), bfmt % \ 

1887 (o, o, o, s, dtype, 'out of place') 

1888 d = inp1() 

1889 yield d, d, inp2(), bfmt % \ 

1890 (o, o, o, s, dtype, 'in place1') 

1891 d = inp2() 

1892 yield d, inp1(), d, bfmt % \ 

1893 (o, o, o, s, dtype, 'in place2') 

1894 yield out[1:], inp1()[:-1], inp2()[:-1], bfmt % \ 

1895 (o + 1, o, o, s - 1, dtype, 'out of place') 

1896 yield out[:-1], inp1()[1:], inp2()[:-1], bfmt % \ 

1897 (o, o + 1, o, s - 1, dtype, 'out of place') 

1898 yield out[:-1], inp1()[:-1], inp2()[1:], bfmt % \ 

1899 (o, o, o + 1, s - 1, dtype, 'out of place') 

1900 yield inp1()[1:], inp1()[:-1], inp2()[:-1], bfmt % \ 

1901 (o + 1, o, o, s - 1, dtype, 'aliased') 

1902 yield inp1()[:-1], inp1()[1:], inp2()[:-1], bfmt % \ 

1903 (o, o + 1, o, s - 1, dtype, 'aliased') 

1904 yield inp1()[:-1], inp1()[:-1], inp2()[1:], bfmt % \ 

1905 (o, o, o + 1, s - 1, dtype, 'aliased') 

1906 

1907 

1908class IgnoreException(Exception): 

1909 "Ignoring this exception due to disabled feature" 

1910 pass 

1911 

1912 

1913@contextlib.contextmanager 

1914def tempdir(*args, **kwargs): 

1915 """Context manager to provide a temporary test folder. 

1916 

1917 All arguments are passed as this to the underlying tempfile.mkdtemp 

1918 function. 

1919 

1920 """ 

1921 tmpdir = mkdtemp(*args, **kwargs) 

1922 try: 

1923 yield tmpdir 

1924 finally: 

1925 shutil.rmtree(tmpdir) 

1926 

1927 

1928@contextlib.contextmanager 

1929def temppath(*args, **kwargs): 

1930 """Context manager for temporary files. 

1931 

1932 Context manager that returns the path to a closed temporary file. Its 

1933 parameters are the same as for tempfile.mkstemp and are passed directly 

1934 to that function. The underlying file is removed when the context is 

1935 exited, so it should be closed at that time. 

1936 

1937 Windows does not allow a temporary file to be opened if it is already 

1938 open, so the underlying file must be closed after opening before it 

1939 can be opened again. 

1940 

1941 """ 

1942 fd, path = mkstemp(*args, **kwargs) 

1943 os.close(fd) 

1944 try: 

1945 yield path 

1946 finally: 

1947 os.remove(path) 

1948 

1949 

1950class clear_and_catch_warnings(warnings.catch_warnings): 

1951 """ Context manager that resets warning registry for catching warnings 

1952 

1953 Warnings can be slippery, because, whenever a warning is triggered, Python 

1954 adds a ``__warningregistry__`` member to the *calling* module. This makes 

1955 it impossible to retrigger the warning in this module, whatever you put in 

1956 the warnings filters. This context manager accepts a sequence of `modules` 

1957 as a keyword argument to its constructor and: 

1958 

1959 * stores and removes any ``__warningregistry__`` entries in given `modules` 

1960 on entry; 

1961 * resets ``__warningregistry__`` to its previous state on exit. 

1962 

1963 This makes it possible to trigger any warning afresh inside the context 

1964 manager without disturbing the state of warnings outside. 

1965 

1966 For compatibility with Python 3.0, please consider all arguments to be 

1967 keyword-only. 

1968 

1969 Parameters 

1970 ---------- 

1971 record : bool, optional 

1972 Specifies whether warnings should be captured by a custom 

1973 implementation of ``warnings.showwarning()`` and be appended to a list 

1974 returned by the context manager. Otherwise None is returned by the 

1975 context manager. The objects appended to the list are arguments whose 

1976 attributes mirror the arguments to ``showwarning()``. 

1977 modules : sequence, optional 

1978 Sequence of modules for which to reset warnings registry on entry and 

1979 restore on exit. To work correctly, all 'ignore' filters should 

1980 filter by one of these modules. 

1981 

1982 Examples 

1983 -------- 

1984 >>> import warnings 

1985 >>> with np.testing.clear_and_catch_warnings( 

1986 ... modules=[np.core.fromnumeric]): 

1987 ... warnings.simplefilter('always') 

1988 ... warnings.filterwarnings('ignore', module='np.core.fromnumeric') 

1989 ... # do something that raises a warning but ignore those in 

1990 ... # np.core.fromnumeric 

1991 """ 

1992 class_modules = () 

1993 

1994 def __init__(self, record=False, modules=()): 

1995 self.modules = set(modules).union(self.class_modules) 

1996 self._warnreg_copies = {} 

1997 super(clear_and_catch_warnings, self).__init__(record=record) 

1998 

1999 def __enter__(self): 

2000 for mod in self.modules: 

2001 if hasattr(mod, '__warningregistry__'): 

2002 mod_reg = mod.__warningregistry__ 

2003 self._warnreg_copies[mod] = mod_reg.copy() 

2004 mod_reg.clear() 

2005 return super(clear_and_catch_warnings, self).__enter__() 

2006 

2007 def __exit__(self, *exc_info): 

2008 super(clear_and_catch_warnings, self).__exit__(*exc_info) 

2009 for mod in self.modules: 

2010 if hasattr(mod, '__warningregistry__'): 

2011 mod.__warningregistry__.clear() 

2012 if mod in self._warnreg_copies: 

2013 mod.__warningregistry__.update(self._warnreg_copies[mod]) 

2014 

2015 

2016class suppress_warnings: 

2017 """ 

2018 Context manager and decorator doing much the same as 

2019 ``warnings.catch_warnings``. 

2020 

2021 However, it also provides a filter mechanism to work around 

2022 https://bugs.python.org/issue4180. 

2023 

2024 This bug causes Python before 3.4 to not reliably show warnings again 

2025 after they have been ignored once (even within catch_warnings). It 

2026 means that no "ignore" filter can be used easily, since following 

2027 tests might need to see the warning. Additionally it allows easier 

2028 specificity for testing warnings and can be nested. 

2029 

2030 Parameters 

2031 ---------- 

2032 forwarding_rule : str, optional 

2033 One of "always", "once", "module", or "location". Analogous to 

2034 the usual warnings module filter mode, it is useful to reduce 

2035 noise mostly on the outmost level. Unsuppressed and unrecorded 

2036 warnings will be forwarded based on this rule. Defaults to "always". 

2037 "location" is equivalent to the warnings "default", match by exact 

2038 location the warning warning originated from. 

2039 

2040 Notes 

2041 ----- 

2042 Filters added inside the context manager will be discarded again 

2043 when leaving it. Upon entering all filters defined outside a 

2044 context will be applied automatically. 

2045 

2046 When a recording filter is added, matching warnings are stored in the 

2047 ``log`` attribute as well as in the list returned by ``record``. 

2048 

2049 If filters are added and the ``module`` keyword is given, the 

2050 warning registry of this module will additionally be cleared when 

2051 applying it, entering the context, or exiting it. This could cause 

2052 warnings to appear a second time after leaving the context if they 

2053 were configured to be printed once (default) and were already 

2054 printed before the context was entered. 

2055 

2056 Nesting this context manager will work as expected when the 

2057 forwarding rule is "always" (default). Unfiltered and unrecorded 

2058 warnings will be passed out and be matched by the outer level. 

2059 On the outmost level they will be printed (or caught by another 

2060 warnings context). The forwarding rule argument can modify this 

2061 behaviour. 

2062 

2063 Like ``catch_warnings`` this context manager is not threadsafe. 

2064 

2065 Examples 

2066 -------- 

2067 

2068 With a context manager:: 

2069 

2070 with np.testing.suppress_warnings() as sup: 

2071 sup.filter(DeprecationWarning, "Some text") 

2072 sup.filter(module=np.ma.core) 

2073 log = sup.record(FutureWarning, "Does this occur?") 

2074 command_giving_warnings() 

2075 # The FutureWarning was given once, the filtered warnings were 

2076 # ignored. All other warnings abide outside settings (may be 

2077 # printed/error) 

2078 assert_(len(log) == 1) 

2079 assert_(len(sup.log) == 1) # also stored in log attribute 

2080 

2081 Or as a decorator:: 

2082 

2083 sup = np.testing.suppress_warnings() 

2084 sup.filter(module=np.ma.core) # module must match exactly 

2085 @sup 

2086 def some_function(): 

2087 # do something which causes a warning in np.ma.core 

2088 pass 

2089 """ 

2090 def __init__(self, forwarding_rule="always"): 

2091 self._entered = False 

2092 

2093 # Suppressions are either instance or defined inside one with block: 

2094 self._suppressions = [] 

2095 

2096 if forwarding_rule not in {"always", "module", "once", "location"}: 

2097 raise ValueError("unsupported forwarding rule.") 

2098 self._forwarding_rule = forwarding_rule 

2099 

2100 def _clear_registries(self): 

2101 if hasattr(warnings, "_filters_mutated"): 

2102 # clearing the registry should not be necessary on new pythons, 

2103 # instead the filters should be mutated. 

2104 warnings._filters_mutated() 

2105 return 

2106 # Simply clear the registry, this should normally be harmless, 

2107 # note that on new pythons it would be invalidated anyway. 

2108 for module in self._tmp_modules: 

2109 if hasattr(module, "__warningregistry__"): 

2110 module.__warningregistry__.clear() 

2111 

2112 def _filter(self, category=Warning, message="", module=None, record=False): 

2113 if record: 

2114 record = [] # The log where to store warnings 

2115 else: 

2116 record = None 

2117 if self._entered: 

2118 if module is None: 

2119 warnings.filterwarnings( 

2120 "always", category=category, message=message) 

2121 else: 

2122 module_regex = module.__name__.replace('.', r'\.') + '$' 

2123 warnings.filterwarnings( 

2124 "always", category=category, message=message, 

2125 module=module_regex) 

2126 self._tmp_modules.add(module) 

2127 self._clear_registries() 

2128 

2129 self._tmp_suppressions.append( 

2130 (category, message, re.compile(message, re.I), module, record)) 

2131 else: 

2132 self._suppressions.append( 

2133 (category, message, re.compile(message, re.I), module, record)) 

2134 

2135 return record 

2136 

2137 def filter(self, category=Warning, message="", module=None): 

2138 """ 

2139 Add a new suppressing filter or apply it if the state is entered. 

2140 

2141 Parameters 

2142 ---------- 

2143 category : class, optional 

2144 Warning class to filter 

2145 message : string, optional 

2146 Regular expression matching the warning message. 

2147 module : module, optional 

2148 Module to filter for. Note that the module (and its file) 

2149 must match exactly and cannot be a submodule. This may make 

2150 it unreliable for external modules. 

2151 

2152 Notes 

2153 ----- 

2154 When added within a context, filters are only added inside 

2155 the context and will be forgotten when the context is exited. 

2156 """ 

2157 self._filter(category=category, message=message, module=module, 

2158 record=False) 

2159 

2160 def record(self, category=Warning, message="", module=None): 

2161 """ 

2162 Append a new recording filter or apply it if the state is entered. 

2163 

2164 All warnings matching will be appended to the ``log`` attribute. 

2165 

2166 Parameters 

2167 ---------- 

2168 category : class, optional 

2169 Warning class to filter 

2170 message : string, optional 

2171 Regular expression matching the warning message. 

2172 module : module, optional 

2173 Module to filter for. Note that the module (and its file) 

2174 must match exactly and cannot be a submodule. This may make 

2175 it unreliable for external modules. 

2176 

2177 Returns 

2178 ------- 

2179 log : list 

2180 A list which will be filled with all matched warnings. 

2181 

2182 Notes 

2183 ----- 

2184 When added within a context, filters are only added inside 

2185 the context and will be forgotten when the context is exited. 

2186 """ 

2187 return self._filter(category=category, message=message, module=module, 

2188 record=True) 

2189 

2190 def __enter__(self): 

2191 if self._entered: 

2192 raise RuntimeError("cannot enter suppress_warnings twice.") 

2193 

2194 self._orig_show = warnings.showwarning 

2195 self._filters = warnings.filters 

2196 warnings.filters = self._filters[:] 

2197 

2198 self._entered = True 

2199 self._tmp_suppressions = [] 

2200 self._tmp_modules = set() 

2201 self._forwarded = set() 

2202 

2203 self.log = [] # reset global log (no need to keep same list) 

2204 

2205 for cat, mess, _, mod, log in self._suppressions: 

2206 if log is not None: 

2207 del log[:] # clear the log 

2208 if mod is None: 

2209 warnings.filterwarnings( 

2210 "always", category=cat, message=mess) 

2211 else: 

2212 module_regex = mod.__name__.replace('.', r'\.') + '$' 

2213 warnings.filterwarnings( 

2214 "always", category=cat, message=mess, 

2215 module=module_regex) 

2216 self._tmp_modules.add(mod) 

2217 warnings.showwarning = self._showwarning 

2218 self._clear_registries() 

2219 

2220 return self 

2221 

2222 def __exit__(self, *exc_info): 

2223 warnings.showwarning = self._orig_show 

2224 warnings.filters = self._filters 

2225 self._clear_registries() 

2226 self._entered = False 

2227 del self._orig_show 

2228 del self._filters 

2229 

2230 def _showwarning(self, message, category, filename, lineno, 

2231 *args, use_warnmsg=None, **kwargs): 

2232 for cat, _, pattern, mod, rec in ( 

2233 self._suppressions + self._tmp_suppressions)[::-1]: 

2234 if (issubclass(category, cat) and 

2235 pattern.match(message.args[0]) is not None): 

2236 if mod is None: 

2237 # Message and category match, either recorded or ignored 

2238 if rec is not None: 

2239 msg = WarningMessage(message, category, filename, 

2240 lineno, **kwargs) 

2241 self.log.append(msg) 

2242 rec.append(msg) 

2243 return 

2244 # Use startswith, because warnings strips the c or o from 

2245 # .pyc/.pyo files. 

2246 elif mod.__file__.startswith(filename): 

2247 # The message and module (filename) match 

2248 if rec is not None: 

2249 msg = WarningMessage(message, category, filename, 

2250 lineno, **kwargs) 

2251 self.log.append(msg) 

2252 rec.append(msg) 

2253 return 

2254 

2255 # There is no filter in place, so pass to the outside handler 

2256 # unless we should only pass it once 

2257 if self._forwarding_rule == "always": 

2258 if use_warnmsg is None: 

2259 self._orig_show(message, category, filename, lineno, 

2260 *args, **kwargs) 

2261 else: 

2262 self._orig_showmsg(use_warnmsg) 

2263 return 

2264 

2265 if self._forwarding_rule == "once": 

2266 signature = (message.args, category) 

2267 elif self._forwarding_rule == "module": 

2268 signature = (message.args, category, filename) 

2269 elif self._forwarding_rule == "location": 

2270 signature = (message.args, category, filename, lineno) 

2271 

2272 if signature in self._forwarded: 

2273 return 

2274 self._forwarded.add(signature) 

2275 if use_warnmsg is None: 

2276 self._orig_show(message, category, filename, lineno, *args, 

2277 **kwargs) 

2278 else: 

2279 self._orig_showmsg(use_warnmsg) 

2280 

2281 def __call__(self, func): 

2282 """ 

2283 Function decorator to apply certain suppressions to a whole 

2284 function. 

2285 """ 

2286 @wraps(func) 

2287 def new_func(*args, **kwargs): 

2288 with self: 

2289 return func(*args, **kwargs) 

2290 

2291 return new_func 

2292 

2293 

2294@contextlib.contextmanager 

2295def _assert_no_gc_cycles_context(name=None): 

2296 __tracebackhide__ = True # Hide traceback for py.test 

2297 

2298 # not meaningful to test if there is no refcounting 

2299 if not HAS_REFCOUNT: 

2300 yield 

2301 return 

2302 

2303 assert_(gc.isenabled()) 

2304 gc.disable() 

2305 gc_debug = gc.get_debug() 

2306 try: 

2307 for i in range(100): 

2308 if gc.collect() == 0: 

2309 break 

2310 else: 

2311 raise RuntimeError( 

2312 "Unable to fully collect garbage - perhaps a __del__ method is " 

2313 "creating more reference cycles?") 

2314 

2315 gc.set_debug(gc.DEBUG_SAVEALL) 

2316 yield 

2317 # gc.collect returns the number of unreachable objects in cycles that 

2318 # were found -- we are checking that no cycles were created in the context 

2319 n_objects_in_cycles = gc.collect() 

2320 objects_in_cycles = gc.garbage[:] 

2321 finally: 

2322 del gc.garbage[:] 

2323 gc.set_debug(gc_debug) 

2324 gc.enable() 

2325 

2326 if n_objects_in_cycles: 

2327 name_str = " when calling %s" % name if name is not None else "" 

2328 raise AssertionError( 

2329 "Reference cycles were found{}: {} objects were collected, " 

2330 "of which {} are shown below:{}" 

2331 .format( 

2332 name_str, 

2333 n_objects_in_cycles, 

2334 len(objects_in_cycles), 

2335 ''.join( 

2336 "\n {} object with id={}:\n {}".format( 

2337 type(o).__name__, 

2338 id(o), 

2339 pprint.pformat(o).replace('\n', '\n ') 

2340 ) for o in objects_in_cycles 

2341 ) 

2342 ) 

2343 ) 

2344 

2345 

2346def assert_no_gc_cycles(*args, **kwargs): 

2347 """ 

2348 Fail if the given callable produces any reference cycles. 

2349 

2350 If called with all arguments omitted, may be used as a context manager: 

2351 

2352 with assert_no_gc_cycles(): 

2353 do_something() 

2354 

2355 .. versionadded:: 1.15.0 

2356 

2357 Parameters 

2358 ---------- 

2359 func : callable 

2360 The callable to test. 

2361 \\*args : Arguments 

2362 Arguments passed to `func`. 

2363 \\*\\*kwargs : Kwargs 

2364 Keyword arguments passed to `func`. 

2365 

2366 Returns 

2367 ------- 

2368 Nothing. The result is deliberately discarded to ensure that all cycles 

2369 are found. 

2370 

2371 """ 

2372 if not args: 

2373 return _assert_no_gc_cycles_context() 

2374 

2375 func = args[0] 

2376 args = args[1:] 

2377 with _assert_no_gc_cycles_context(name=func.__name__): 

2378 func(*args, **kwargs) 

2379 

2380def break_cycles(): 

2381 """ 

2382 Break reference cycles by calling gc.collect 

2383 Objects can call other objects' methods (for instance, another object's 

2384 __del__) inside their own __del__. On PyPy, the interpreter only runs 

2385 between calls to gc.collect, so multiple calls are needed to completely 

2386 release all cycles. 

2387 """ 

2388 

2389 gc.collect() 

2390 if IS_PYPY: 

2391 # interpreter runs now, to call deleted objects' __del__ methods 

2392 gc.collect() 

2393 # one more, just to make sure 

2394 gc.collect() 

2395 

2396 

2397def requires_memory(free_bytes): 

2398 """Decorator to skip a test if not enough memory is available""" 

2399 import pytest 

2400 

2401 def decorator(func): 

2402 @wraps(func) 

2403 def wrapper(*a, **kw): 

2404 msg = check_free_memory(free_bytes) 

2405 if msg is not None: 

2406 pytest.skip(msg) 

2407 

2408 try: 

2409 return func(*a, **kw) 

2410 except MemoryError: 

2411 # Probably ran out of memory regardless: don't regard as failure 

2412 pytest.xfail("MemoryError raised") 

2413 

2414 return wrapper 

2415 

2416 return decorator 

2417 

2418 

2419def check_free_memory(free_bytes): 

2420 """ 

2421 Check whether `free_bytes` amount of memory is currently free. 

2422 Returns: None if enough memory available, otherwise error message 

2423 """ 

2424 env_var = 'NPY_AVAILABLE_MEM' 

2425 env_value = os.environ.get(env_var) 

2426 if env_value is not None: 

2427 try: 

2428 mem_free = _parse_size(env_value) 

2429 except ValueError as exc: 

2430 raise ValueError('Invalid environment variable {}: {!s}'.format( 

2431 env_var, exc)) 

2432 

2433 msg = ('{0} GB memory required, but environment variable ' 

2434 'NPY_AVAILABLE_MEM={1} set'.format( 

2435 free_bytes/1e9, env_value)) 

2436 else: 

2437 mem_free = _get_mem_available() 

2438 

2439 if mem_free is None: 

2440 msg = ("Could not determine available memory; set NPY_AVAILABLE_MEM " 

2441 "environment variable (e.g. NPY_AVAILABLE_MEM=16GB) to run " 

2442 "the test.") 

2443 mem_free = -1 

2444 else: 

2445 msg = '{0} GB memory required, but {1} GB available'.format( 

2446 free_bytes/1e9, mem_free/1e9) 

2447 

2448 return msg if mem_free < free_bytes else None 

2449 

2450 

2451def _parse_size(size_str): 

2452 """Convert memory size strings ('12 GB' etc.) to float""" 

2453 suffixes = {'': 1, 'b': 1, 

2454 'k': 1000, 'm': 1000**2, 'g': 1000**3, 't': 1000**4, 

2455 'kb': 1000, 'mb': 1000**2, 'gb': 1000**3, 'tb': 1000**4, 

2456 'kib': 1024, 'mib': 1024**2, 'gib': 1024**3, 'tib': 1024**4} 

2457 

2458 size_re = re.compile(r'^\s*(\d+|\d+\.\d+)\s*({0})\s*$'.format( 

2459 '|'.join(suffixes.keys())), re.I) 

2460 

2461 m = size_re.match(size_str.lower()) 

2462 if not m or m.group(2) not in suffixes: 

2463 raise ValueError("value {!r} not a valid size".format(size_str)) 

2464 return int(float(m.group(1)) * suffixes[m.group(2)]) 

2465 

2466 

2467def _get_mem_available(): 

2468 """Return available memory in bytes, or None if unknown.""" 

2469 try: 

2470 import psutil 

2471 return psutil.virtual_memory().available 

2472 except (ImportError, AttributeError): 

2473 pass 

2474 

2475 if sys.platform.startswith('linux'): 

2476 info = {} 

2477 with open('/proc/meminfo', 'r') as f: 

2478 for line in f: 

2479 p = line.split() 

2480 info[p[0].strip(':').lower()] = int(p[1]) * 1024 

2481 

2482 if 'memavailable' in info: 

2483 # Linux >= 3.14 

2484 return info['memavailable'] 

2485 else: 

2486 return info['memfree'] + info['cached'] 

2487 

2488 return None 

2489 

2490 

2491def _no_tracing(func): 

2492 """ 

2493 Decorator to temporarily turn off tracing for the duration of a test. 

2494 Needed in tests that check refcounting, otherwise the tracing itself 

2495 influences the refcounts 

2496 """ 

2497 if not hasattr(sys, 'gettrace'): 

2498 return func 

2499 else: 

2500 @wraps(func) 

2501 def wrapper(*args, **kwargs): 

2502 original_trace = sys.gettrace() 

2503 try: 

2504 sys.settrace(None) 

2505 return func(*args, **kwargs) 

2506 finally: 

2507 sys.settrace(original_trace) 

2508 return wrapper 

2509