utils.py 65 KB


  1. """
  2. Utility function to facilitate testing.
  3. """
  4. from __future__ import division, absolute_import, print_function
  5. import os
  6. import sys
  7. import re
  8. import operator
  9. import warnings
  10. from functools import partial
  11. import shutil
  12. import contextlib
  13. from tempfile import mkdtemp, mkstemp
  14. from .nosetester import import_nose
  15. from numpy.core import float32, empty, arange, array_repr, ndarray
  16. from numpy.lib.utils import deprecate
  17. if sys.version_info[0] >= 3:
  18. from io import StringIO
  19. else:
  20. from StringIO import StringIO
  21. __all__ = ['assert_equal', 'assert_almost_equal', 'assert_approx_equal',
  22. 'assert_array_equal', 'assert_array_less', 'assert_string_equal',
  23. 'assert_array_almost_equal', 'assert_raises', 'build_err_msg',
  24. 'decorate_methods', 'jiffies', 'memusage', 'print_assert_equal',
  25. 'raises', 'rand', 'rundocs', 'runstring', 'verbose', 'measure',
  26. 'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex',
  27. 'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings',
  28. 'assert_allclose', 'IgnoreException', 'clear_and_catch_warnings',
  29. 'SkipTest', 'KnownFailureException', 'temppath', 'tempdir']
  30. class KnownFailureException(Exception):
  31. '''Raise this exception to mark a test as a known failing test.'''
  32. pass
  33. KnownFailureTest = KnownFailureException # backwards compat
  34. # nose.SkipTest is unittest.case.SkipTest
  35. # import it into the namespace, so that it's available as np.testing.SkipTest
  36. try:
  37. from unittest.case import SkipTest
  38. except ImportError:
  39. # on py2.6 unittest.case is not available. Ask nose for a replacement.
  40. try:
  41. import nose
  42. SkipTest = nose.SkipTest
  43. except (ImportError, AttributeError):
  44. # if nose is not available, testing won't work anyway
  45. pass
  46. verbose = 0
  47. def assert_(val, msg=''):
  48. """
  49. Assert that works in release mode.
  50. Accepts callable msg to allow deferring evaluation until failure.
  51. The Python built-in ``assert`` does not work when executing code in
  52. optimized mode (the ``-O`` flag) - no byte-code is generated for it.
  53. For documentation on usage, refer to the Python documentation.
  54. """
  55. if not val:
  56. try:
  57. smsg = msg()
  58. except TypeError:
  59. smsg = msg
  60. raise AssertionError(smsg)
  61. def gisnan(x):
  62. """like isnan, but always raise an error if type not supported instead of
  63. returning a TypeError object.
  64. Notes
  65. -----
  66. isnan and other ufunc sometimes return a NotImplementedType object instead
  67. of raising any exception. This function is a wrapper to make sure an
  68. exception is always raised.
  69. This should be removed once this problem is solved at the Ufunc level."""
  70. from numpy.core import isnan
  71. st = isnan(x)
  72. if isinstance(st, type(NotImplemented)):
  73. raise TypeError("isnan not supported for this type")
  74. return st
  75. def gisfinite(x):
  76. """like isfinite, but always raise an error if type not supported instead of
  77. returning a TypeError object.
  78. Notes
  79. -----
  80. isfinite and other ufunc sometimes return a NotImplementedType object instead
  81. of raising any exception. This function is a wrapper to make sure an
  82. exception is always raised.
  83. This should be removed once this problem is solved at the Ufunc level."""
  84. from numpy.core import isfinite, errstate
  85. with errstate(invalid='ignore'):
  86. st = isfinite(x)
  87. if isinstance(st, type(NotImplemented)):
  88. raise TypeError("isfinite not supported for this type")
  89. return st
  90. def gisinf(x):
  91. """like isinf, but always raise an error if type not supported instead of
  92. returning a TypeError object.
  93. Notes
  94. -----
  95. isinf and other ufunc sometimes return a NotImplementedType object instead
  96. of raising any exception. This function is a wrapper to make sure an
  97. exception is always raised.
  98. This should be removed once this problem is solved at the Ufunc level."""
  99. from numpy.core import isinf, errstate
  100. with errstate(invalid='ignore'):
  101. st = isinf(x)
  102. if isinstance(st, type(NotImplemented)):
  103. raise TypeError("isinf not supported for this type")
  104. return st
  105. @deprecate(message="numpy.testing.rand is deprecated in numpy 1.11. "
  106. "Use numpy.random.rand instead.")
  107. def rand(*args):
  108. """Returns an array of random numbers with the given shape.
  109. This only uses the standard library, so it is useful for testing purposes.
  110. """
  111. import random
  112. from numpy.core import zeros, float64
  113. results = zeros(args, float64)
  114. f = results.flat
  115. for i in range(len(f)):
  116. f[i] = random.random()
  117. return results
  118. if os.name == 'nt':
  119. # Code "stolen" from enthought/debug/memusage.py
  120. def GetPerformanceAttributes(object, counter, instance=None,
  121. inum=-1, format=None, machine=None):
  122. # NOTE: Many counters require 2 samples to give accurate results,
  123. # including "% Processor Time" (as by definition, at any instant, a
  124. # thread's CPU usage is either 0 or 100). To read counters like this,
  125. # you should copy this function, but keep the counter open, and call
  126. # CollectQueryData() each time you need to know.
  127. # See http://msdn.microsoft.com/library/en-us/dnperfmo/html/perfmonpt2.asp
  128. # My older explanation for this was that the "AddCounter" process forced
  129. # the CPU to 100%, but the above makes more sense :)
  130. import win32pdh
  131. if format is None:
  132. format = win32pdh.PDH_FMT_LONG
  133. path = win32pdh.MakeCounterPath( (machine, object, instance, None, inum, counter))
  134. hq = win32pdh.OpenQuery()
  135. try:
  136. hc = win32pdh.AddCounter(hq, path)
  137. try:
  138. win32pdh.CollectQueryData(hq)
  139. type, val = win32pdh.GetFormattedCounterValue(hc, format)
  140. return val
  141. finally:
  142. win32pdh.RemoveCounter(hc)
  143. finally:
  144. win32pdh.CloseQuery(hq)
  145. def memusage(processName="python", instance=0):
  146. # from win32pdhutil, part of the win32all package
  147. import win32pdh
  148. return GetPerformanceAttributes("Process", "Virtual Bytes",
  149. processName, instance,
  150. win32pdh.PDH_FMT_LONG, None)
  151. elif sys.platform[:5] == 'linux':
  152. def memusage(_proc_pid_stat='/proc/%s/stat' % (os.getpid())):
  153. """
  154. Return virtual memory size in bytes of the running python.
  155. """
  156. try:
  157. f = open(_proc_pid_stat, 'r')
  158. l = f.readline().split(' ')
  159. f.close()
  160. return int(l[22])
  161. except:
  162. return
  163. else:
  164. def memusage():
  165. """
  166. Return memory usage of running python. [Not implemented]
  167. """
  168. raise NotImplementedError
  169. if sys.platform[:5] == 'linux':
  170. def jiffies(_proc_pid_stat='/proc/%s/stat' % (os.getpid()),
  171. _load_time=[]):
  172. """
  173. Return number of jiffies elapsed.
  174. Return number of jiffies (1/100ths of a second) that this
  175. process has been scheduled in user mode. See man 5 proc.
  176. """
  177. import time
  178. if not _load_time:
  179. _load_time.append(time.time())
  180. try:
  181. f = open(_proc_pid_stat, 'r')
  182. l = f.readline().split(' ')
  183. f.close()
  184. return int(l[13])
  185. except:
  186. return int(100*(time.time()-_load_time[0]))
  187. else:
  188. # os.getpid is not in all platforms available.
  189. # Using time is safe but inaccurate, especially when process
  190. # was suspended or sleeping.
  191. def jiffies(_load_time=[]):
  192. """
  193. Return number of jiffies elapsed.
  194. Return number of jiffies (1/100ths of a second) that this
  195. process has been scheduled in user mode. See man 5 proc.
  196. """
  197. import time
  198. if not _load_time:
  199. _load_time.append(time.time())
  200. return int(100*(time.time()-_load_time[0]))
  201. def build_err_msg(arrays, err_msg, header='Items are not equal:',
  202. verbose=True, names=('ACTUAL', 'DESIRED'), precision=8):
  203. msg = ['\n' + header]
  204. if err_msg:
  205. if err_msg.find('\n') == -1 and len(err_msg) < 79-len(header):
  206. msg = [msg[0] + ' ' + err_msg]
  207. else:
  208. msg.append(err_msg)
  209. if verbose:
  210. for i, a in enumerate(arrays):
  211. if isinstance(a, ndarray):
  212. # precision argument is only needed if the objects are ndarrays
  213. r_func = partial(array_repr, precision=precision)
  214. else:
  215. r_func = repr
  216. try:
  217. r = r_func(a)
  218. except:
  219. r = '[repr failed]'
  220. if r.count('\n') > 3:
  221. r = '\n'.join(r.splitlines()[:3])
  222. r += '...'
  223. msg.append(' %s: %s' % (names[i], r))
  224. return '\n'.join(msg)
  225. def assert_equal(actual,desired,err_msg='',verbose=True):
  226. """
  227. Raises an AssertionError if two objects are not equal.
  228. Given two objects (scalars, lists, tuples, dictionaries or numpy arrays),
  229. check that all elements of these objects are equal. An exception is raised
  230. at the first conflicting values.
  231. Parameters
  232. ----------
  233. actual : array_like
  234. The object to check.
  235. desired : array_like
  236. The expected object.
  237. err_msg : str, optional
  238. The error message to be printed in case of failure.
  239. verbose : bool, optional
  240. If True, the conflicting values are appended to the error message.
  241. Raises
  242. ------
  243. AssertionError
  244. If actual and desired are not equal.
  245. Examples
  246. --------
  247. >>> np.testing.assert_equal([4,5], [4,6])
  248. ...
  249. <type 'exceptions.AssertionError'>:
  250. Items are not equal:
  251. item=1
  252. ACTUAL: 5
  253. DESIRED: 6
  254. """
  255. __tracebackhide__ = True # Hide traceback for py.test
  256. if isinstance(desired, dict):
  257. if not isinstance(actual, dict):
  258. raise AssertionError(repr(type(actual)))
  259. assert_equal(len(actual), len(desired), err_msg, verbose)
  260. for k, i in desired.items():
  261. if k not in actual:
  262. raise AssertionError(repr(k))
  263. assert_equal(actual[k], desired[k], 'key=%r\n%s' % (k, err_msg), verbose)
  264. return
  265. if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)):
  266. assert_equal(len(actual), len(desired), err_msg, verbose)
  267. for k in range(len(desired)):
  268. assert_equal(actual[k], desired[k], 'item=%r\n%s' % (k, err_msg), verbose)
  269. return
  270. from numpy.core import ndarray, isscalar, signbit
  271. from numpy.lib import iscomplexobj, real, imag
  272. if isinstance(actual, ndarray) or isinstance(desired, ndarray):
  273. return assert_array_equal(actual, desired, err_msg, verbose)
  274. msg = build_err_msg([actual, desired], err_msg, verbose=verbose)
  275. # Handle complex numbers: separate into real/imag to handle
  276. # nan/inf/negative zero correctly
  277. # XXX: catch ValueError for subclasses of ndarray where iscomplex fail
  278. try:
  279. usecomplex = iscomplexobj(actual) or iscomplexobj(desired)
  280. except ValueError:
  281. usecomplex = False
  282. if usecomplex:
  283. if iscomplexobj(actual):
  284. actualr = real(actual)
  285. actuali = imag(actual)
  286. else:
  287. actualr = actual
  288. actuali = 0
  289. if iscomplexobj(desired):
  290. desiredr = real(desired)
  291. desiredi = imag(desired)
  292. else:
  293. desiredr = desired
  294. desiredi = 0
  295. try:
  296. assert_equal(actualr, desiredr)
  297. assert_equal(actuali, desiredi)
  298. except AssertionError:
  299. raise AssertionError(msg)
  300. # Inf/nan/negative zero handling
  301. try:
  302. # isscalar test to check cases such as [np.nan] != np.nan
  303. if isscalar(desired) != isscalar(actual):
  304. raise AssertionError(msg)
  305. # If one of desired/actual is not finite, handle it specially here:
  306. # check that both are nan if any is a nan, and test for equality
  307. # otherwise
  308. if not (gisfinite(desired) and gisfinite(actual)):
  309. isdesnan = gisnan(desired)
  310. isactnan = gisnan(actual)
  311. if isdesnan or isactnan:
  312. if not (isdesnan and isactnan):
  313. raise AssertionError(msg)
  314. else:
  315. if not desired == actual:
  316. raise AssertionError(msg)
  317. return
  318. elif desired == 0 and actual == 0:
  319. if not signbit(desired) == signbit(actual):
  320. raise AssertionError(msg)
  321. # If TypeError or ValueError raised while using isnan and co, just handle
  322. # as before
  323. except (TypeError, ValueError, NotImplementedError):
  324. pass
  325. # Explicitly use __eq__ for comparison, ticket #2552
  326. if not (desired == actual):
  327. raise AssertionError(msg)
  328. def print_assert_equal(test_string, actual, desired):
  329. """
  330. Test if two objects are equal, and print an error message if test fails.
  331. The test is performed with ``actual == desired``.
  332. Parameters
  333. ----------
  334. test_string : str
  335. The message supplied to AssertionError.
  336. actual : object
  337. The object to test for equality against `desired`.
  338. desired : object
  339. The expected result.
  340. Examples
  341. --------
  342. >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 1])
  343. >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 2])
  344. Traceback (most recent call last):
  345. ...
  346. AssertionError: Test XYZ of func xyz failed
  347. ACTUAL:
  348. [0, 1]
  349. DESIRED:
  350. [0, 2]
  351. """
  352. __tracebackhide__ = True # Hide traceback for py.test
  353. import pprint
  354. if not (actual == desired):
  355. msg = StringIO()
  356. msg.write(test_string)
  357. msg.write(' failed\nACTUAL: \n')
  358. pprint.pprint(actual, msg)
  359. msg.write('DESIRED: \n')
  360. pprint.pprint(desired, msg)
  361. raise AssertionError(msg.getvalue())
  362. def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True):
  363. """
  364. Raises an AssertionError if two items are not equal up to desired
  365. precision.
  366. .. note:: It is recommended to use one of `assert_allclose`,
  367. `assert_array_almost_equal_nulp` or `assert_array_max_ulp`
  368. instead of this function for more consistent floating point
  369. comparisons.
  370. The test is equivalent to ``abs(desired-actual) < 0.5 * 10**(-decimal)``.
  371. Given two objects (numbers or ndarrays), check that all elements of these
  372. objects are almost equal. An exception is raised at conflicting values.
  373. For ndarrays this delegates to assert_array_almost_equal
  374. Parameters
  375. ----------
  376. actual : array_like
  377. The object to check.
  378. desired : array_like
  379. The expected object.
  380. decimal : int, optional
  381. Desired precision, default is 7.
  382. err_msg : str, optional
  383. The error message to be printed in case of failure.
  384. verbose : bool, optional
  385. If True, the conflicting values are appended to the error message.
  386. Raises
  387. ------
  388. AssertionError
  389. If actual and desired are not equal up to specified precision.
  390. See Also
  391. --------
  392. assert_allclose: Compare two array_like objects for equality with desired
  393. relative and/or absolute precision.
  394. assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal
  395. Examples
  396. --------
  397. >>> import numpy.testing as npt
  398. >>> npt.assert_almost_equal(2.3333333333333, 2.33333334)
  399. >>> npt.assert_almost_equal(2.3333333333333, 2.33333334, decimal=10)
  400. ...
  401. <type 'exceptions.AssertionError'>:
  402. Items are not equal:
  403. ACTUAL: 2.3333333333333002
  404. DESIRED: 2.3333333399999998
  405. >>> npt.assert_almost_equal(np.array([1.0,2.3333333333333]),
  406. ... np.array([1.0,2.33333334]), decimal=9)
  407. ...
  408. <type 'exceptions.AssertionError'>:
  409. Arrays are not almost equal
  410. <BLANKLINE>
  411. (mismatch 50.0%)
  412. x: array([ 1. , 2.33333333])
  413. y: array([ 1. , 2.33333334])
  414. """
  415. __tracebackhide__ = True # Hide traceback for py.test
  416. from numpy.core import ndarray
  417. from numpy.lib import iscomplexobj, real, imag
  418. # Handle complex numbers: separate into real/imag to handle
  419. # nan/inf/negative zero correctly
  420. # XXX: catch ValueError for subclasses of ndarray where iscomplex fail
  421. try:
  422. usecomplex = iscomplexobj(actual) or iscomplexobj(desired)
  423. except ValueError:
  424. usecomplex = False
  425. def _build_err_msg():
  426. header = ('Arrays are not almost equal to %d decimals' % decimal)
  427. return build_err_msg([actual, desired], err_msg, verbose=verbose,
  428. header=header)
  429. if usecomplex:
  430. if iscomplexobj(actual):
  431. actualr = real(actual)
  432. actuali = imag(actual)
  433. else:
  434. actualr = actual
  435. actuali = 0
  436. if iscomplexobj(desired):
  437. desiredr = real(desired)
  438. desiredi = imag(desired)
  439. else:
  440. desiredr = desired
  441. desiredi = 0
  442. try:
  443. assert_almost_equal(actualr, desiredr, decimal=decimal)
  444. assert_almost_equal(actuali, desiredi, decimal=decimal)
  445. except AssertionError:
  446. raise AssertionError(_build_err_msg())
  447. if isinstance(actual, (ndarray, tuple, list)) \
  448. or isinstance(desired, (ndarray, tuple, list)):
  449. return assert_array_almost_equal(actual, desired, decimal, err_msg)
  450. try:
  451. # If one of desired/actual is not finite, handle it specially here:
  452. # check that both are nan if any is a nan, and test for equality
  453. # otherwise
  454. if not (gisfinite(desired) and gisfinite(actual)):
  455. if gisnan(desired) or gisnan(actual):
  456. if not (gisnan(desired) and gisnan(actual)):
  457. raise AssertionError(_build_err_msg())
  458. else:
  459. if not desired == actual:
  460. raise AssertionError(_build_err_msg())
  461. return
  462. except (NotImplementedError, TypeError):
  463. pass
  464. if round(abs(desired - actual), decimal) != 0:
  465. raise AssertionError(_build_err_msg())
  466. def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True):
  467. """
  468. Raises an AssertionError if two items are not equal up to significant
  469. digits.
  470. .. note:: It is recommended to use one of `assert_allclose`,
  471. `assert_array_almost_equal_nulp` or `assert_array_max_ulp`
  472. instead of this function for more consistent floating point
  473. comparisons.
  474. Given two numbers, check that they are approximately equal.
  475. Approximately equal is defined as the number of significant digits
  476. that agree.
  477. Parameters
  478. ----------
  479. actual : scalar
  480. The object to check.
  481. desired : scalar
  482. The expected object.
  483. significant : int, optional
  484. Desired precision, default is 7.
  485. err_msg : str, optional
  486. The error message to be printed in case of failure.
  487. verbose : bool, optional
  488. If True, the conflicting values are appended to the error message.
  489. Raises
  490. ------
  491. AssertionError
  492. If actual and desired are not equal up to specified precision.
  493. See Also
  494. --------
  495. assert_allclose: Compare two array_like objects for equality with desired
  496. relative and/or absolute precision.
  497. assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal
  498. Examples
  499. --------
  500. >>> np.testing.assert_approx_equal(0.12345677777777e-20, 0.1234567e-20)
  501. >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345671e-20,
  502. significant=8)
  503. >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345672e-20,
  504. significant=8)
  505. ...
  506. <type 'exceptions.AssertionError'>:
  507. Items are not equal to 8 significant digits:
  508. ACTUAL: 1.234567e-021
  509. DESIRED: 1.2345672000000001e-021
  510. the evaluated condition that raises the exception is
  511. >>> abs(0.12345670e-20/1e-21 - 0.12345672e-20/1e-21) >= 10**-(8-1)
  512. True
  513. """
  514. __tracebackhide__ = True # Hide traceback for py.test
  515. import numpy as np
  516. (actual, desired) = map(float, (actual, desired))
  517. if desired == actual:
  518. return
  519. # Normalized the numbers to be in range (-10.0,10.0)
  520. # scale = float(pow(10,math.floor(math.log10(0.5*(abs(desired)+abs(actual))))))
  521. with np.errstate(invalid='ignore'):
  522. scale = 0.5*(np.abs(desired) + np.abs(actual))
  523. scale = np.power(10, np.floor(np.log10(scale)))
  524. try:
  525. sc_desired = desired/scale
  526. except ZeroDivisionError:
  527. sc_desired = 0.0
  528. try:
  529. sc_actual = actual/scale
  530. except ZeroDivisionError:
  531. sc_actual = 0.0
  532. msg = build_err_msg([actual, desired], err_msg,
  533. header='Items are not equal to %d significant digits:' %
  534. significant,
  535. verbose=verbose)
  536. try:
  537. # If one of desired/actual is not finite, handle it specially here:
  538. # check that both are nan if any is a nan, and test for equality
  539. # otherwise
  540. if not (gisfinite(desired) and gisfinite(actual)):
  541. if gisnan(desired) or gisnan(actual):
  542. if not (gisnan(desired) and gisnan(actual)):
  543. raise AssertionError(msg)
  544. else:
  545. if not desired == actual:
  546. raise AssertionError(msg)
  547. return
  548. except (TypeError, NotImplementedError):
  549. pass
  550. if np.abs(sc_desired - sc_actual) >= np.power(10., -(significant-1)):
  551. raise AssertionError(msg)
  552. def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
  553. header='', precision=6):
  554. __tracebackhide__ = True # Hide traceback for py.test
  555. from numpy.core import array, isnan, isinf, any, all, inf
  556. x = array(x, copy=False, subok=True)
  557. y = array(y, copy=False, subok=True)
  558. def safe_comparison(*args, **kwargs):
  559. # There are a number of cases where comparing two arrays hits special
  560. # cases in array_richcompare, specifically around strings and void
  561. # dtypes. Basically, we just can't do comparisons involving these
  562. # types, unless both arrays have exactly the *same* type. So
  563. # e.g. you can apply == to two string arrays, or two arrays with
  564. # identical structured dtypes. But if you compare a non-string array
  565. # to a string array, or two arrays with non-identical structured
  566. # dtypes, or anything like that, then internally stuff blows up.
  567. # Currently, when things blow up, we just return a scalar False or
  568. # True. But we also emit a DeprecationWarning, b/c eventually we
  569. # should raise an error here. (Ideally we might even make this work
  570. # properly, but since that will require rewriting a bunch of how
  571. # ufuncs work then we are not counting on that.)
  572. #
  573. # The point of this little function is to let the DeprecationWarning
  574. # pass (or maybe eventually catch the errors and return False, I
  575. # dunno, that's a little trickier and we can figure that out when the
  576. # time comes).
  577. with warnings.catch_warnings():
  578. warnings.filterwarnings("ignore", category=DeprecationWarning)
  579. return comparison(*args, **kwargs)
  580. def isnumber(x):
  581. return x.dtype.char in '?bhilqpBHILQPefdgFDG'
  582. def chk_same_position(x_id, y_id, hasval='nan'):
  583. """Handling nan/inf: check that x and y have the nan/inf at the same
  584. locations."""
  585. try:
  586. assert_array_equal(x_id, y_id)
  587. except AssertionError:
  588. msg = build_err_msg([x, y],
  589. err_msg + '\nx and y %s location mismatch:'
  590. % (hasval), verbose=verbose, header=header,
  591. names=('x', 'y'), precision=precision)
  592. raise AssertionError(msg)
  593. try:
  594. cond = (x.shape == () or y.shape == ()) or x.shape == y.shape
  595. if not cond:
  596. msg = build_err_msg([x, y],
  597. err_msg
  598. + '\n(shapes %s, %s mismatch)' % (x.shape,
  599. y.shape),
  600. verbose=verbose, header=header,
  601. names=('x', 'y'), precision=precision)
  602. if not cond:
  603. raise AssertionError(msg)
  604. if isnumber(x) and isnumber(y):
  605. x_isnan, y_isnan = isnan(x), isnan(y)
  606. x_isinf, y_isinf = isinf(x), isinf(y)
  607. # Validate that the special values are in the same place
  608. if any(x_isnan) or any(y_isnan):
  609. chk_same_position(x_isnan, y_isnan, hasval='nan')
  610. if any(x_isinf) or any(y_isinf):
  611. # Check +inf and -inf separately, since they are different
  612. chk_same_position(x == +inf, y == +inf, hasval='+inf')
  613. chk_same_position(x == -inf, y == -inf, hasval='-inf')
  614. # Combine all the special values
  615. x_id, y_id = x_isnan, y_isnan
  616. x_id |= x_isinf
  617. y_id |= y_isinf
  618. # Only do the comparison if actual values are left
  619. if all(x_id):
  620. return
  621. if any(x_id):
  622. val = safe_comparison(x[~x_id], y[~y_id])
  623. else:
  624. val = safe_comparison(x, y)
  625. else:
  626. val = safe_comparison(x, y)
  627. if isinstance(val, bool):
  628. cond = val
  629. reduced = [0]
  630. else:
  631. reduced = val.ravel()
  632. cond = reduced.all()
  633. reduced = reduced.tolist()
  634. if not cond:
  635. match = 100-100.0*reduced.count(1)/len(reduced)
  636. msg = build_err_msg([x, y],
  637. err_msg
  638. + '\n(mismatch %s%%)' % (match,),
  639. verbose=verbose, header=header,
  640. names=('x', 'y'), precision=precision)
  641. if not cond:
  642. raise AssertionError(msg)
  643. except ValueError:
  644. import traceback
  645. efmt = traceback.format_exc()
  646. header = 'error during assertion:\n\n%s\n\n%s' % (efmt, header)
  647. msg = build_err_msg([x, y], err_msg, verbose=verbose, header=header,
  648. names=('x', 'y'), precision=precision)
  649. raise ValueError(msg)
  650. def assert_array_equal(x, y, err_msg='', verbose=True):
  651. """
  652. Raises an AssertionError if two array_like objects are not equal.
  653. Given two array_like objects, check that the shape is equal and all
  654. elements of these objects are equal. An exception is raised at
  655. shape mismatch or conflicting values. In contrast to the standard usage
  656. in numpy, NaNs are compared like numbers, no assertion is raised if
  657. both objects have NaNs in the same positions.
  658. The usual caution for verifying equality with floating point numbers is
  659. advised.
  660. Parameters
  661. ----------
  662. x : array_like
  663. The actual object to check.
  664. y : array_like
  665. The desired, expected object.
  666. err_msg : str, optional
  667. The error message to be printed in case of failure.
  668. verbose : bool, optional
  669. If True, the conflicting values are appended to the error message.
  670. Raises
  671. ------
  672. AssertionError
  673. If actual and desired objects are not equal.
  674. See Also
  675. --------
  676. assert_allclose: Compare two array_like objects for equality with desired
  677. relative and/or absolute precision.
  678. assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal
  679. Examples
  680. --------
  681. The first assert does not raise an exception:
  682. >>> np.testing.assert_array_equal([1.0,2.33333,np.nan],
  683. ... [np.exp(0),2.33333, np.nan])
  684. Assert fails with numerical inprecision with floats:
  685. >>> np.testing.assert_array_equal([1.0,np.pi,np.nan],
  686. ... [1, np.sqrt(np.pi)**2, np.nan])
  687. ...
  688. <type 'exceptions.ValueError'>:
  689. AssertionError:
  690. Arrays are not equal
  691. <BLANKLINE>
  692. (mismatch 50.0%)
  693. x: array([ 1. , 3.14159265, NaN])
  694. y: array([ 1. , 3.14159265, NaN])
  695. Use `assert_allclose` or one of the nulp (number of floating point values)
  696. functions for these cases instead:
  697. >>> np.testing.assert_allclose([1.0,np.pi,np.nan],
  698. ... [1, np.sqrt(np.pi)**2, np.nan],
  699. ... rtol=1e-10, atol=0)
  700. """
  701. assert_array_compare(operator.__eq__, x, y, err_msg=err_msg,
  702. verbose=verbose, header='Arrays are not equal')
  703. def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
  704. """
  705. Raises an AssertionError if two objects are not equal up to desired
  706. precision.
  707. .. note:: It is recommended to use one of `assert_allclose`,
  708. `assert_array_almost_equal_nulp` or `assert_array_max_ulp`
  709. instead of this function for more consistent floating point
  710. comparisons.
  711. The test verifies identical shapes and verifies values with
  712. ``abs(desired-actual) < 0.5 * 10**(-decimal)``.
  713. Given two array_like objects, check that the shape is equal and all
  714. elements of these objects are almost equal. An exception is raised at
  715. shape mismatch or conflicting values. In contrast to the standard usage
  716. in numpy, NaNs are compared like numbers, no assertion is raised if
  717. both objects have NaNs in the same positions.
  718. Parameters
  719. ----------
  720. x : array_like
  721. The actual object to check.
  722. y : array_like
  723. The desired, expected object.
  724. decimal : int, optional
  725. Desired precision, default is 6.
  726. err_msg : str, optional
  727. The error message to be printed in case of failure.
  728. verbose : bool, optional
  729. If True, the conflicting values are appended to the error message.
  730. Raises
  731. ------
  732. AssertionError
  733. If actual and desired are not equal up to specified precision.
  734. See Also
  735. --------
  736. assert_allclose: Compare two array_like objects for equality with desired
  737. relative and/or absolute precision.
  738. assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal
  739. Examples
  740. --------
  741. the first assert does not raise an exception
  742. >>> np.testing.assert_array_almost_equal([1.0,2.333,np.nan],
  743. [1.0,2.333,np.nan])
  744. >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan],
  745. ... [1.0,2.33339,np.nan], decimal=5)
  746. ...
  747. <type 'exceptions.AssertionError'>:
  748. AssertionError:
  749. Arrays are not almost equal
  750. <BLANKLINE>
  751. (mismatch 50.0%)
  752. x: array([ 1. , 2.33333, NaN])
  753. y: array([ 1. , 2.33339, NaN])
  754. >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan],
  755. ... [1.0,2.33333, 5], decimal=5)
  756. <type 'exceptions.ValueError'>:
  757. ValueError:
  758. Arrays are not almost equal
  759. x: array([ 1. , 2.33333, NaN])
  760. y: array([ 1. , 2.33333, 5. ])
  761. """
  762. __tracebackhide__ = True # Hide traceback for py.test
  763. from numpy.core import around, number, float_, result_type, array
  764. from numpy.core.numerictypes import issubdtype
  765. from numpy.core.fromnumeric import any as npany
  766. def compare(x, y):
  767. try:
  768. if npany(gisinf(x)) or npany( gisinf(y)):
  769. xinfid = gisinf(x)
  770. yinfid = gisinf(y)
  771. if not xinfid == yinfid:
  772. return False
  773. # if one item, x and y is +- inf
  774. if x.size == y.size == 1:
  775. return x == y
  776. x = x[~xinfid]
  777. y = y[~yinfid]
  778. except (TypeError, NotImplementedError):
  779. pass
  780. # make sure y is an inexact type to avoid abs(MIN_INT); will cause
  781. # casting of x later.
  782. dtype = result_type(y, 1.)
  783. y = array(y, dtype=dtype, copy=False, subok=True)
  784. z = abs(x-y)
  785. if not issubdtype(z.dtype, number):
  786. z = z.astype(float_) # handle object arrays
  787. return around(z, decimal) <= 10.0**(-decimal)
  788. assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
  789. header=('Arrays are not almost equal to %d decimals' % decimal),
  790. precision=decimal)
  791. def assert_array_less(x, y, err_msg='', verbose=True):
  792. """
  793. Raises an AssertionError if two array_like objects are not ordered by less
  794. than.
  795. Given two array_like objects, check that the shape is equal and all
  796. elements of the first object are strictly smaller than those of the
  797. second object. An exception is raised at shape mismatch or incorrectly
  798. ordered values. Shape mismatch does not raise if an object has zero
  799. dimension. In contrast to the standard usage in numpy, NaNs are
  800. compared, no assertion is raised if both objects have NaNs in the same
  801. positions.
  802. Parameters
  803. ----------
  804. x : array_like
  805. The smaller object to check.
  806. y : array_like
  807. The larger object to compare.
  808. err_msg : string
  809. The error message to be printed in case of failure.
  810. verbose : bool
  811. If True, the conflicting values are appended to the error message.
  812. Raises
  813. ------
  814. AssertionError
  815. If actual and desired objects are not equal.
  816. See Also
  817. --------
  818. assert_array_equal: tests objects for equality
  819. assert_array_almost_equal: test objects for equality up to precision
  820. Examples
  821. --------
  822. >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1.1, 2.0, np.nan])
  823. >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1, 2.0, np.nan])
  824. ...
  825. <type 'exceptions.ValueError'>:
  826. Arrays are not less-ordered
  827. (mismatch 50.0%)
  828. x: array([ 1., 1., NaN])
  829. y: array([ 1., 2., NaN])
  830. >>> np.testing.assert_array_less([1.0, 4.0], 3)
  831. ...
  832. <type 'exceptions.ValueError'>:
  833. Arrays are not less-ordered
  834. (mismatch 50.0%)
  835. x: array([ 1., 4.])
  836. y: array(3)
  837. >>> np.testing.assert_array_less([1.0, 2.0, 3.0], [4])
  838. ...
  839. <type 'exceptions.ValueError'>:
  840. Arrays are not less-ordered
  841. (shapes (3,), (1,) mismatch)
  842. x: array([ 1., 2., 3.])
  843. y: array([4])
  844. """
  845. __tracebackhide__ = True # Hide traceback for py.test
  846. assert_array_compare(operator.__lt__, x, y, err_msg=err_msg,
  847. verbose=verbose,
  848. header='Arrays are not less-ordered')
  849. def runstring(astr, dict):
  850. exec(astr, dict)
  851. def assert_string_equal(actual, desired):
  852. """
  853. Test if two strings are equal.
  854. If the given strings are equal, `assert_string_equal` does nothing.
  855. If they are not equal, an AssertionError is raised, and the diff
  856. between the strings is shown.
  857. Parameters
  858. ----------
  859. actual : str
  860. The string to test for equality against the expected string.
  861. desired : str
  862. The expected string.
  863. Examples
  864. --------
  865. >>> np.testing.assert_string_equal('abc', 'abc')
  866. >>> np.testing.assert_string_equal('abc', 'abcd')
  867. Traceback (most recent call last):
  868. File "<stdin>", line 1, in <module>
  869. ...
  870. AssertionError: Differences in strings:
  871. - abc+ abcd? +
  872. """
  873. # delay import of difflib to reduce startup time
  874. __tracebackhide__ = True # Hide traceback for py.test
  875. import difflib
  876. if not isinstance(actual, str):
  877. raise AssertionError(repr(type(actual)))
  878. if not isinstance(desired, str):
  879. raise AssertionError(repr(type(desired)))
  880. if re.match(r'\A'+desired+r'\Z', actual, re.M):
  881. return
  882. diff = list(difflib.Differ().compare(actual.splitlines(1), desired.splitlines(1)))
  883. diff_list = []
  884. while diff:
  885. d1 = diff.pop(0)
  886. if d1.startswith(' '):
  887. continue
  888. if d1.startswith('- '):
  889. l = [d1]
  890. d2 = diff.pop(0)
  891. if d2.startswith('? '):
  892. l.append(d2)
  893. d2 = diff.pop(0)
  894. if not d2.startswith('+ '):
  895. raise AssertionError(repr(d2))
  896. l.append(d2)
  897. if diff:
  898. d3 = diff.pop(0)
  899. if d3.startswith('? '):
  900. l.append(d3)
  901. else:
  902. diff.insert(0, d3)
  903. if re.match(r'\A'+d2[2:]+r'\Z', d1[2:]):
  904. continue
  905. diff_list.extend(l)
  906. continue
  907. raise AssertionError(repr(d1))
  908. if not diff_list:
  909. return
  910. msg = 'Differences in strings:\n%s' % (''.join(diff_list)).rstrip()
  911. if actual != desired:
  912. raise AssertionError(msg)
  913. def rundocs(filename=None, raise_on_error=True):
  914. """
  915. Run doctests found in the given file.
  916. By default `rundocs` raises an AssertionError on failure.
  917. Parameters
  918. ----------
  919. filename : str
  920. The path to the file for which the doctests are run.
  921. raise_on_error : bool
  922. Whether to raise an AssertionError when a doctest fails. Default is
  923. True.
  924. Notes
  925. -----
  926. The doctests can be run by the user/developer by adding the ``doctests``
  927. argument to the ``test()`` call. For example, to run all tests (including
  928. doctests) for `numpy.lib`:
  929. >>> np.lib.test(doctests=True) #doctest: +SKIP
  930. """
  931. import doctest
  932. import imp
  933. if filename is None:
  934. f = sys._getframe(1)
  935. filename = f.f_globals['__file__']
  936. name = os.path.splitext(os.path.basename(filename))[0]
  937. path = [os.path.dirname(filename)]
  938. file, pathname, description = imp.find_module(name, path)
  939. try:
  940. m = imp.load_module(name, file, pathname, description)
  941. finally:
  942. file.close()
  943. tests = doctest.DocTestFinder().find(m)
  944. runner = doctest.DocTestRunner(verbose=False)
  945. msg = []
  946. if raise_on_error:
  947. out = lambda s: msg.append(s)
  948. else:
  949. out = None
  950. for test in tests:
  951. runner.run(test, out=out)
  952. if runner.failures > 0 and raise_on_error:
  953. raise AssertionError("Some doctests failed:\n%s" % "\n".join(msg))
  954. def raises(*args,**kwargs):
  955. nose = import_nose()
  956. return nose.tools.raises(*args,**kwargs)
  957. def assert_raises(*args,**kwargs):
  958. """
  959. assert_raises(exception_class, callable, *args, **kwargs)
  960. Fail unless an exception of class exception_class is thrown
  961. by callable when invoked with arguments args and keyword
  962. arguments kwargs. If a different type of exception is
  963. thrown, it will not be caught, and the test case will be
  964. deemed to have suffered an error, exactly as for an
  965. unexpected exception.
  966. Alternatively, `assert_raises` can be used as a context manager:
  967. >>> from numpy.testing import assert_raises
  968. >>> with assert_raises(ZeroDivisionError):
  969. ... 1 / 0
  970. is equivalent to
  971. >>> def div(x, y):
  972. ... return x / y
  973. >>> assert_raises(ZeroDivisionError, div, 1, 0)
  974. """
  975. __tracebackhide__ = True # Hide traceback for py.test
  976. nose = import_nose()
  977. return nose.tools.assert_raises(*args,**kwargs)
  978. assert_raises_regex_impl = None
  979. def assert_raises_regex(exception_class, expected_regexp,
  980. callable_obj=None, *args, **kwargs):
  981. """
  982. Fail unless an exception of class exception_class and with message that
  983. matches expected_regexp is thrown by callable when invoked with arguments
  984. args and keyword arguments kwargs.
  985. Name of this function adheres to Python 3.2+ reference, but should work in
  986. all versions down to 2.6.
  987. """
  988. __tracebackhide__ = True # Hide traceback for py.test
  989. nose = import_nose()
  990. global assert_raises_regex_impl
  991. if assert_raises_regex_impl is None:
  992. try:
  993. # Python 3.2+
  994. assert_raises_regex_impl = nose.tools.assert_raises_regex
  995. except AttributeError:
  996. try:
  997. # 2.7+
  998. assert_raises_regex_impl = nose.tools.assert_raises_regexp
  999. except AttributeError:
  1000. # 2.6
  1001. # This class is copied from Python2.7 stdlib almost verbatim
  1002. class _AssertRaisesContext(object):
  1003. """A context manager used to implement TestCase.assertRaises* methods."""
  1004. def __init__(self, expected, expected_regexp=None):
  1005. self.expected = expected
  1006. self.expected_regexp = expected_regexp
  1007. def failureException(self, msg):
  1008. return AssertionError(msg)
  1009. def __enter__(self):
  1010. return self
  1011. def __exit__(self, exc_type, exc_value, tb):
  1012. if exc_type is None:
  1013. try:
  1014. exc_name = self.expected.__name__
  1015. except AttributeError:
  1016. exc_name = str(self.expected)
  1017. raise self.failureException(
  1018. "{0} not raised".format(exc_name))
  1019. if not issubclass(exc_type, self.expected):
  1020. # let unexpected exceptions pass through
  1021. return False
  1022. self.exception = exc_value # store for later retrieval
  1023. if self.expected_regexp is None:
  1024. return True
  1025. expected_regexp = self.expected_regexp
  1026. if isinstance(expected_regexp, basestring):
  1027. expected_regexp = re.compile(expected_regexp)
  1028. if not expected_regexp.search(str(exc_value)):
  1029. raise self.failureException(
  1030. '"%s" does not match "%s"' %
  1031. (expected_regexp.pattern, str(exc_value)))
  1032. return True
  1033. def impl(cls, regex, callable_obj, *a, **kw):
  1034. mgr = _AssertRaisesContext(cls, regex)
  1035. if callable_obj is None:
  1036. return mgr
  1037. with mgr:
  1038. callable_obj(*a, **kw)
  1039. assert_raises_regex_impl = impl
  1040. return assert_raises_regex_impl(exception_class, expected_regexp,
  1041. callable_obj, *args, **kwargs)
  1042. def decorate_methods(cls, decorator, testmatch=None):
  1043. """
  1044. Apply a decorator to all methods in a class matching a regular expression.
  1045. The given decorator is applied to all public methods of `cls` that are
  1046. matched by the regular expression `testmatch`
  1047. (``testmatch.search(methodname)``). Methods that are private, i.e. start
  1048. with an underscore, are ignored.
  1049. Parameters
  1050. ----------
  1051. cls : class
  1052. Class whose methods to decorate.
  1053. decorator : function
  1054. Decorator to apply to methods
  1055. testmatch : compiled regexp or str, optional
  1056. The regular expression. Default value is None, in which case the
  1057. nose default (``re.compile(r'(?:^|[\\b_\\.%s-])[Tt]est' % os.sep)``)
  1058. is used.
  1059. If `testmatch` is a string, it is compiled to a regular expression
  1060. first.
  1061. """
  1062. if testmatch is None:
  1063. testmatch = re.compile(r'(?:^|[\\b_\\.%s-])[Tt]est' % os.sep)
  1064. else:
  1065. testmatch = re.compile(testmatch)
  1066. cls_attr = cls.__dict__
  1067. # delayed import to reduce startup time
  1068. from inspect import isfunction
  1069. methods = [_m for _m in cls_attr.values() if isfunction(_m)]
  1070. for function in methods:
  1071. try:
  1072. if hasattr(function, 'compat_func_name'):
  1073. funcname = function.compat_func_name
  1074. else:
  1075. funcname = function.__name__
  1076. except AttributeError:
  1077. # not a function
  1078. continue
  1079. if testmatch.search(funcname) and not funcname.startswith('_'):
  1080. setattr(cls, funcname, decorator(function))
  1081. return
  1082. def measure(code_str,times=1,label=None):
  1083. """
  1084. Return elapsed time for executing code in the namespace of the caller.
  1085. The supplied code string is compiled with the Python builtin ``compile``.
  1086. The precision of the timing is 10 milli-seconds. If the code will execute
  1087. fast on this timescale, it can be executed many times to get reasonable
  1088. timing accuracy.
  1089. Parameters
  1090. ----------
  1091. code_str : str
  1092. The code to be timed.
  1093. times : int, optional
  1094. The number of times the code is executed. Default is 1. The code is
  1095. only compiled once.
  1096. label : str, optional
  1097. A label to identify `code_str` with. This is passed into ``compile``
  1098. as the second argument (for run-time error messages).
  1099. Returns
  1100. -------
  1101. elapsed : float
  1102. Total elapsed time in seconds for executing `code_str` `times` times.
  1103. Examples
  1104. --------
  1105. >>> etime = np.testing.measure('for i in range(1000): np.sqrt(i**2)',
  1106. ... times=times)
  1107. >>> print("Time for a single execution : ", etime / times, "s")
  1108. Time for a single execution : 0.005 s
  1109. """
  1110. frame = sys._getframe(1)
  1111. locs, globs = frame.f_locals, frame.f_globals
  1112. code = compile(code_str,
  1113. 'Test name: %s ' % label,
  1114. 'exec')
  1115. i = 0
  1116. elapsed = jiffies()
  1117. while i < times:
  1118. i += 1
  1119. exec(code, globs, locs)
  1120. elapsed = jiffies() - elapsed
  1121. return 0.01*elapsed
  1122. def _assert_valid_refcount(op):
  1123. """
  1124. Check that ufuncs don't mishandle refcount of object `1`.
  1125. Used in a few regression tests.
  1126. """
  1127. import numpy as np
  1128. b = np.arange(100*100).reshape(100, 100)
  1129. c = b
  1130. i = 1
  1131. rc = sys.getrefcount(i)
  1132. for j in range(15):
  1133. d = op(b, c)
  1134. assert_(sys.getrefcount(i) >= rc)
  1135. del d # for pyflakes
  1136. def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=False,
  1137. err_msg='', verbose=True):
  1138. """
  1139. Raises an AssertionError if two objects are not equal up to desired
  1140. tolerance.
  1141. The test is equivalent to ``allclose(actual, desired, rtol, atol)``.
  1142. It compares the difference between `actual` and `desired` to
  1143. ``atol + rtol * abs(desired)``.
  1144. .. versionadded:: 1.5.0
  1145. Parameters
  1146. ----------
  1147. actual : array_like
  1148. Array obtained.
  1149. desired : array_like
  1150. Array desired.
  1151. rtol : float, optional
  1152. Relative tolerance.
  1153. atol : float, optional
  1154. Absolute tolerance.
  1155. equal_nan : bool, optional.
  1156. If True, NaNs will compare equal.
  1157. err_msg : str, optional
  1158. The error message to be printed in case of failure.
  1159. verbose : bool, optional
  1160. If True, the conflicting values are appended to the error message.
  1161. Raises
  1162. ------
  1163. AssertionError
  1164. If actual and desired are not equal up to specified precision.
  1165. See Also
  1166. --------
  1167. assert_array_almost_equal_nulp, assert_array_max_ulp
  1168. Examples
  1169. --------
  1170. >>> x = [1e-5, 1e-3, 1e-1]
  1171. >>> y = np.arccos(np.cos(x))
  1172. >>> assert_allclose(x, y, rtol=1e-5, atol=0)
  1173. """
  1174. __tracebackhide__ = True # Hide traceback for py.test
  1175. import numpy as np
  1176. def compare(x, y):
  1177. return np.core.numeric.isclose(x, y, rtol=rtol, atol=atol,
  1178. equal_nan=equal_nan)
  1179. actual, desired = np.asanyarray(actual), np.asanyarray(desired)
  1180. header = 'Not equal to tolerance rtol=%g, atol=%g' % (rtol, atol)
  1181. assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
  1182. verbose=verbose, header=header)
  1183. def assert_array_almost_equal_nulp(x, y, nulp=1):
  1184. """
  1185. Compare two arrays relatively to their spacing.
  1186. This is a relatively robust method to compare two arrays whose amplitude
  1187. is variable.
  1188. Parameters
  1189. ----------
  1190. x, y : array_like
  1191. Input arrays.
  1192. nulp : int, optional
  1193. The maximum number of unit in the last place for tolerance (see Notes).
  1194. Default is 1.
  1195. Returns
  1196. -------
  1197. None
  1198. Raises
  1199. ------
  1200. AssertionError
  1201. If the spacing between `x` and `y` for one or more elements is larger
  1202. than `nulp`.
  1203. See Also
  1204. --------
  1205. assert_array_max_ulp : Check that all items of arrays differ in at most
  1206. N Units in the Last Place.
  1207. spacing : Return the distance between x and the nearest adjacent number.
  1208. Notes
  1209. -----
  1210. An assertion is raised if the following condition is not met::
  1211. abs(x - y) <= nulps * spacing(maximum(abs(x), abs(y)))
  1212. Examples
  1213. --------
  1214. >>> x = np.array([1., 1e-10, 1e-20])
  1215. >>> eps = np.finfo(x.dtype).eps
  1216. >>> np.testing.assert_array_almost_equal_nulp(x, x*eps/2 + x)
  1217. >>> np.testing.assert_array_almost_equal_nulp(x, x*eps + x)
  1218. Traceback (most recent call last):
  1219. ...
  1220. AssertionError: X and Y are not equal to 1 ULP (max is 2)
  1221. """
  1222. __tracebackhide__ = True # Hide traceback for py.test
  1223. import numpy as np
  1224. ax = np.abs(x)
  1225. ay = np.abs(y)
  1226. ref = nulp * np.spacing(np.where(ax > ay, ax, ay))
  1227. if not np.all(np.abs(x-y) <= ref):
  1228. if np.iscomplexobj(x) or np.iscomplexobj(y):
  1229. msg = "X and Y are not equal to %d ULP" % nulp
  1230. else:
  1231. max_nulp = np.max(nulp_diff(x, y))
  1232. msg = "X and Y are not equal to %d ULP (max is %g)" % (nulp, max_nulp)
  1233. raise AssertionError(msg)
  1234. def assert_array_max_ulp(a, b, maxulp=1, dtype=None):
  1235. """
  1236. Check that all items of arrays differ in at most N Units in the Last Place.
  1237. Parameters
  1238. ----------
  1239. a, b : array_like
  1240. Input arrays to be compared.
  1241. maxulp : int, optional
  1242. The maximum number of units in the last place that elements of `a` and
  1243. `b` can differ. Default is 1.
  1244. dtype : dtype, optional
  1245. Data-type to convert `a` and `b` to if given. Default is None.
  1246. Returns
  1247. -------
  1248. ret : ndarray
  1249. Array containing number of representable floating point numbers between
  1250. items in `a` and `b`.
  1251. Raises
  1252. ------
  1253. AssertionError
  1254. If one or more elements differ by more than `maxulp`.
  1255. See Also
  1256. --------
  1257. assert_array_almost_equal_nulp : Compare two arrays relatively to their
  1258. spacing.
  1259. Examples
  1260. --------
  1261. >>> a = np.linspace(0., 1., 100)
  1262. >>> res = np.testing.assert_array_max_ulp(a, np.arcsin(np.sin(a)))
  1263. """
  1264. __tracebackhide__ = True # Hide traceback for py.test
  1265. import numpy as np
  1266. ret = nulp_diff(a, b, dtype)
  1267. if not np.all(ret <= maxulp):
  1268. raise AssertionError("Arrays are not almost equal up to %g ULP" %
  1269. maxulp)
  1270. return ret
  1271. def nulp_diff(x, y, dtype=None):
  1272. """For each item in x and y, return the number of representable floating
  1273. points between them.
  1274. Parameters
  1275. ----------
  1276. x : array_like
  1277. first input array
  1278. y : array_like
  1279. second input array
  1280. dtype : dtype, optional
  1281. Data-type to convert `x` and `y` to if given. Default is None.
  1282. Returns
  1283. -------
  1284. nulp : array_like
  1285. number of representable floating point numbers between each item in x
  1286. and y.
  1287. Examples
  1288. --------
  1289. # By definition, epsilon is the smallest number such as 1 + eps != 1, so
  1290. # there should be exactly one ULP between 1 and 1 + eps
  1291. >>> nulp_diff(1, 1 + np.finfo(x.dtype).eps)
  1292. 1.0
  1293. """
  1294. import numpy as np
  1295. if dtype:
  1296. x = np.array(x, dtype=dtype)
  1297. y = np.array(y, dtype=dtype)
  1298. else:
  1299. x = np.array(x)
  1300. y = np.array(y)
  1301. t = np.common_type(x, y)
  1302. if np.iscomplexobj(x) or np.iscomplexobj(y):
  1303. raise NotImplementedError("_nulp not implemented for complex array")
  1304. x = np.array(x, dtype=t)
  1305. y = np.array(y, dtype=t)
  1306. if not x.shape == y.shape:
  1307. raise ValueError("x and y do not have the same shape: %s - %s" %
  1308. (x.shape, y.shape))
  1309. def _diff(rx, ry, vdt):
  1310. diff = np.array(rx-ry, dtype=vdt)
  1311. return np.abs(diff)
  1312. rx = integer_repr(x)
  1313. ry = integer_repr(y)
  1314. return _diff(rx, ry, t)
  1315. def _integer_repr(x, vdt, comp):
  1316. # Reinterpret binary representation of the float as sign-magnitude:
  1317. # take into account two-complement representation
  1318. # See also
  1319. # http://www.cygnus-software.com/papers/comparingfloats/comparingfloats.htm
  1320. rx = x.view(vdt)
  1321. if not (rx.size == 1):
  1322. rx[rx < 0] = comp - rx[rx < 0]
  1323. else:
  1324. if rx < 0:
  1325. rx = comp - rx
  1326. return rx
  1327. def integer_repr(x):
  1328. """Return the signed-magnitude interpretation of the binary representation of
  1329. x."""
  1330. import numpy as np
  1331. if x.dtype == np.float32:
  1332. return _integer_repr(x, np.int32, np.int32(-2**31))
  1333. elif x.dtype == np.float64:
  1334. return _integer_repr(x, np.int64, np.int64(-2**63))
  1335. else:
  1336. raise ValueError("Unsupported dtype %s" % x.dtype)
  1337. # The following two classes are copied from python 2.6 warnings module (context
  1338. # manager)
  1339. class WarningMessage(object):
  1340. """
  1341. Holds the result of a single showwarning() call.
  1342. Deprecated in 1.8.0
  1343. Notes
  1344. -----
  1345. `WarningMessage` is copied from the Python 2.6 warnings module,
  1346. so it can be used in NumPy with older Python versions.
  1347. """
  1348. _WARNING_DETAILS = ("message", "category", "filename", "lineno", "file",
  1349. "line")
  1350. def __init__(self, message, category, filename, lineno, file=None,
  1351. line=None):
  1352. local_values = locals()
  1353. for attr in self._WARNING_DETAILS:
  1354. setattr(self, attr, local_values[attr])
  1355. if category:
  1356. self._category_name = category.__name__
  1357. else:
  1358. self._category_name = None
  1359. def __str__(self):
  1360. return ("{message : %r, category : %r, filename : %r, lineno : %s, "
  1361. "line : %r}" % (self.message, self._category_name,
  1362. self.filename, self.lineno, self.line))
  1363. class WarningManager(object):
  1364. """
  1365. A context manager that copies and restores the warnings filter upon
  1366. exiting the context.
  1367. The 'record' argument specifies whether warnings should be captured by a
  1368. custom implementation of ``warnings.showwarning()`` and be appended to a
  1369. list returned by the context manager. Otherwise None is returned by the
  1370. context manager. The objects appended to the list are arguments whose
  1371. attributes mirror the arguments to ``showwarning()``.
  1372. The 'module' argument is to specify an alternative module to the module
  1373. named 'warnings' and imported under that name. This argument is only useful
  1374. when testing the warnings module itself.
  1375. Deprecated in 1.8.0
  1376. Notes
  1377. -----
  1378. `WarningManager` is a copy of the ``catch_warnings`` context manager
  1379. from the Python 2.6 warnings module, with slight modifications.
  1380. It is copied so it can be used in NumPy with older Python versions.
  1381. """
  1382. def __init__(self, record=False, module=None):
  1383. self._record = record
  1384. if module is None:
  1385. self._module = sys.modules['warnings']
  1386. else:
  1387. self._module = module
  1388. self._entered = False
  1389. def __enter__(self):
  1390. if self._entered:
  1391. raise RuntimeError("Cannot enter %r twice" % self)
  1392. self._entered = True
  1393. self._filters = self._module.filters
  1394. self._module.filters = self._filters[:]
  1395. self._showwarning = self._module.showwarning
  1396. if self._record:
  1397. log = []
  1398. def showwarning(*args, **kwargs):
  1399. log.append(WarningMessage(*args, **kwargs))
  1400. self._module.showwarning = showwarning
  1401. return log
  1402. else:
  1403. return None
  1404. def __exit__(self):
  1405. if not self._entered:
  1406. raise RuntimeError("Cannot exit %r without entering first" % self)
  1407. self._module.filters = self._filters
  1408. self._module.showwarning = self._showwarning
  1409. @contextlib.contextmanager
  1410. def _assert_warns_context(warning_class, name=None):
  1411. __tracebackhide__ = True # Hide traceback for py.test
  1412. with warnings.catch_warnings(record=True) as l:
  1413. warnings.simplefilter('always')
  1414. yield
  1415. if not len(l) > 0:
  1416. name_str = " when calling %s" % name if name is not None else ""
  1417. raise AssertionError("No warning raised" + name_str)
  1418. if not l[0].category is warning_class:
  1419. name_str = "%s " % name if name is not None else ""
  1420. raise AssertionError("First warning %sis not a %s (is %s)"
  1421. % (name_str, warning_class, l[0]))
  1422. def assert_warns(warning_class, *args, **kwargs):
  1423. """
  1424. Fail unless the given callable throws the specified warning.
  1425. A warning of class warning_class should be thrown by the callable when
  1426. invoked with arguments args and keyword arguments kwargs.
  1427. If a different type of warning is thrown, it will not be caught, and the
  1428. test case will be deemed to have suffered an error.
  1429. If called with all arguments other than the warning class omitted, may be
  1430. used as a context manager:
  1431. with assert_warns(SomeWarning):
  1432. do_something()
  1433. The ability to be used as a context manager is new in NumPy v1.11.0.
  1434. .. versionadded:: 1.4.0
  1435. Parameters
  1436. ----------
  1437. warning_class : class
  1438. The class defining the warning that `func` is expected to throw.
  1439. func : callable
  1440. The callable to test.
  1441. \\*args : Arguments
  1442. Arguments passed to `func`.
  1443. \\*\\*kwargs : Kwargs
  1444. Keyword arguments passed to `func`.
  1445. Returns
  1446. -------
  1447. The value returned by `func`.
  1448. """
  1449. if not args:
  1450. return _assert_warns_context(warning_class)
  1451. func = args[0]
  1452. args = args[1:]
  1453. with _assert_warns_context(warning_class, name=func.__name__):
  1454. return func(*args, **kwargs)
  1455. @contextlib.contextmanager
  1456. def _assert_no_warnings_context(name=None):
  1457. __tracebackhide__ = True # Hide traceback for py.test
  1458. with warnings.catch_warnings(record=True) as l:
  1459. warnings.simplefilter('always')
  1460. yield
  1461. if len(l) > 0:
  1462. name_str = " when calling %s" % name if name is not None else ""
  1463. raise AssertionError("Got warnings%s: %s" % (name_str, l))
  1464. def assert_no_warnings(*args, **kwargs):
  1465. """
  1466. Fail if the given callable produces any warnings.
  1467. If called with all arguments omitted, may be used as a context manager:
  1468. with assert_no_warnings():
  1469. do_something()
  1470. The ability to be used as a context manager is new in NumPy v1.11.0.
  1471. .. versionadded:: 1.7.0
  1472. Parameters
  1473. ----------
  1474. func : callable
  1475. The callable to test.
  1476. \\*args : Arguments
  1477. Arguments passed to `func`.
  1478. \\*\\*kwargs : Kwargs
  1479. Keyword arguments passed to `func`.
  1480. Returns
  1481. -------
  1482. The value returned by `func`.
  1483. """
  1484. if not args:
  1485. return _assert_no_warnings_context()
  1486. func = args[0]
  1487. args = args[1:]
  1488. with _assert_no_warnings_context(name=func.__name__):
  1489. return func(*args, **kwargs)
  1490. def _gen_alignment_data(dtype=float32, type='binary', max_size=24):
  1491. """
  1492. generator producing data with different alignment and offsets
  1493. to test simd vectorization
  1494. Parameters
  1495. ----------
  1496. dtype : dtype
  1497. data type to produce
  1498. type : string
  1499. 'unary': create data for unary operations, creates one input
  1500. and output array
  1501. 'binary': create data for unary operations, creates two input
  1502. and output array
  1503. max_size : integer
  1504. maximum size of data to produce
  1505. Returns
  1506. -------
  1507. if type is 'unary' yields one output, one input array and a message
  1508. containing information on the data
  1509. if type is 'binary' yields one output array, two input array and a message
  1510. containing information on the data
  1511. """
  1512. ufmt = 'unary offset=(%d, %d), size=%d, dtype=%r, %s'
  1513. bfmt = 'binary offset=(%d, %d, %d), size=%d, dtype=%r, %s'
  1514. for o in range(3):
  1515. for s in range(o + 2, max(o + 3, max_size)):
  1516. if type == 'unary':
  1517. inp = lambda: arange(s, dtype=dtype)[o:]
  1518. out = empty((s,), dtype=dtype)[o:]
  1519. yield out, inp(), ufmt % (o, o, s, dtype, 'out of place')
  1520. yield inp(), inp(), ufmt % (o, o, s, dtype, 'in place')
  1521. yield out[1:], inp()[:-1], ufmt % \
  1522. (o + 1, o, s - 1, dtype, 'out of place')
  1523. yield out[:-1], inp()[1:], ufmt % \
  1524. (o, o + 1, s - 1, dtype, 'out of place')
  1525. yield inp()[:-1], inp()[1:], ufmt % \
  1526. (o, o + 1, s - 1, dtype, 'aliased')
  1527. yield inp()[1:], inp()[:-1], ufmt % \
  1528. (o + 1, o, s - 1, dtype, 'aliased')
  1529. if type == 'binary':
  1530. inp1 = lambda: arange(s, dtype=dtype)[o:]
  1531. inp2 = lambda: arange(s, dtype=dtype)[o:]
  1532. out = empty((s,), dtype=dtype)[o:]
  1533. yield out, inp1(), inp2(), bfmt % \
  1534. (o, o, o, s, dtype, 'out of place')
  1535. yield inp1(), inp1(), inp2(), bfmt % \
  1536. (o, o, o, s, dtype, 'in place1')
  1537. yield inp2(), inp1(), inp2(), bfmt % \
  1538. (o, o, o, s, dtype, 'in place2')
  1539. yield out[1:], inp1()[:-1], inp2()[:-1], bfmt % \
  1540. (o + 1, o, o, s - 1, dtype, 'out of place')
  1541. yield out[:-1], inp1()[1:], inp2()[:-1], bfmt % \
  1542. (o, o + 1, o, s - 1, dtype, 'out of place')
  1543. yield out[:-1], inp1()[:-1], inp2()[1:], bfmt % \
  1544. (o, o, o + 1, s - 1, dtype, 'out of place')
  1545. yield inp1()[1:], inp1()[:-1], inp2()[:-1], bfmt % \
  1546. (o + 1, o, o, s - 1, dtype, 'aliased')
  1547. yield inp1()[:-1], inp1()[1:], inp2()[:-1], bfmt % \
  1548. (o, o + 1, o, s - 1, dtype, 'aliased')
  1549. yield inp1()[:-1], inp1()[:-1], inp2()[1:], bfmt % \
  1550. (o, o, o + 1, s - 1, dtype, 'aliased')
  1551. class IgnoreException(Exception):
  1552. "Ignoring this exception due to disabled feature"
  1553. @contextlib.contextmanager
  1554. def tempdir(*args, **kwargs):
  1555. """Context manager to provide a temporary test folder.
  1556. All arguments are passed as this to the underlying tempfile.mkdtemp
  1557. function.
  1558. """
  1559. tmpdir = mkdtemp(*args, **kwargs)
  1560. try:
  1561. yield tmpdir
  1562. finally:
  1563. shutil.rmtree(tmpdir)
  1564. @contextlib.contextmanager
  1565. def temppath(*args, **kwargs):
  1566. """Context manager for temporary files.
  1567. Context manager that returns the path to a closed temporary file. Its
  1568. parameters are the same as for tempfile.mkstemp and are passed directly
  1569. to that function. The underlying file is removed when the context is
  1570. exited, so it should be closed at that time.
  1571. Windows does not allow a temporary file to be opened if it is already
  1572. open, so the underlying file must be closed after opening before it
  1573. can be opened again.
  1574. """
  1575. fd, path = mkstemp(*args, **kwargs)
  1576. os.close(fd)
  1577. try:
  1578. yield path
  1579. finally:
  1580. os.remove(path)
  1581. class clear_and_catch_warnings(warnings.catch_warnings):
  1582. """ Context manager that resets warning registry for catching warnings
  1583. Warnings can be slippery, because, whenever a warning is triggered, Python
  1584. adds a ``__warningregistry__`` member to the *calling* module. This makes
  1585. it impossible to retrigger the warning in this module, whatever you put in
  1586. the warnings filters. This context manager accepts a sequence of `modules`
  1587. as a keyword argument to its constructor and:
  1588. * stores and removes any ``__warningregistry__`` entries in given `modules`
  1589. on entry;
  1590. * resets ``__warningregistry__`` to its previous state on exit.
  1591. This makes it possible to trigger any warning afresh inside the context
  1592. manager without disturbing the state of warnings outside.
  1593. For compatibility with Python 3.0, please consider all arguments to be
  1594. keyword-only.
  1595. Parameters
  1596. ----------
  1597. record : bool, optional
  1598. Specifies whether warnings should be captured by a custom
  1599. implementation of ``warnings.showwarning()`` and be appended to a list
  1600. returned by the context manager. Otherwise None is returned by the
  1601. context manager. The objects appended to the list are arguments whose
  1602. attributes mirror the arguments to ``showwarning()``.
  1603. modules : sequence, optional
  1604. Sequence of modules for which to reset warnings registry on entry and
  1605. restore on exit
  1606. Examples
  1607. --------
  1608. >>> import warnings
  1609. >>> with clear_and_catch_warnings(modules=[np.core.fromnumeric]):
  1610. ... warnings.simplefilter('always')
  1611. ... # do something that raises a warning in np.core.fromnumeric
  1612. """
  1613. class_modules = ()
  1614. def __init__(self, record=False, modules=()):
  1615. self.modules = set(modules).union(self.class_modules)
  1616. self._warnreg_copies = {}
  1617. super(clear_and_catch_warnings, self).__init__(record=record)
  1618. def __enter__(self):
  1619. for mod in self.modules:
  1620. if hasattr(mod, '__warningregistry__'):
  1621. mod_reg = mod.__warningregistry__
  1622. self._warnreg_copies[mod] = mod_reg.copy()
  1623. mod_reg.clear()
  1624. return super(clear_and_catch_warnings, self).__enter__()
  1625. def __exit__(self, *exc_info):
  1626. super(clear_and_catch_warnings, self).__exit__(*exc_info)
  1627. for mod in self.modules:
  1628. if hasattr(mod, '__warningregistry__'):
  1629. mod.__warningregistry__.clear()
  1630. if mod in self._warnreg_copies:
  1631. mod.__warningregistry__.update(self._warnreg_copies[mod])