test_functools.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591
  1. import copy
  2. import functools
  3. import sys
  4. import unittest
  5. from test import test_support
  6. from weakref import proxy
  7. import pickle
  8. @staticmethod
  9. def PythonPartial(func, *args, **keywords):
  10. 'Pure Python approximation of partial()'
  11. def newfunc(*fargs, **fkeywords):
  12. newkeywords = keywords.copy()
  13. newkeywords.update(fkeywords)
  14. return func(*(args + fargs), **newkeywords)
  15. newfunc.func = func
  16. newfunc.args = args
  17. newfunc.keywords = keywords
  18. return newfunc
  19. def capture(*args, **kw):
  20. """capture all positional and keyword arguments"""
  21. return args, kw
  22. def signature(part):
  23. """ return the signature of a partial object """
  24. return (part.func, part.args, part.keywords, part.__dict__)
  25. class MyTuple(tuple):
  26. pass
  27. class BadTuple(tuple):
  28. def __add__(self, other):
  29. return list(self) + list(other)
  30. class MyDict(dict):
  31. pass
  32. class TestPartial(unittest.TestCase):
  33. thetype = functools.partial
  34. def test_basic_examples(self):
  35. p = self.thetype(capture, 1, 2, a=10, b=20)
  36. self.assertEqual(p(3, 4, b=30, c=40),
  37. ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
  38. p = self.thetype(map, lambda x: x*10)
  39. self.assertEqual(p([1,2,3,4]), [10, 20, 30, 40])
  40. def test_attributes(self):
  41. p = self.thetype(capture, 1, 2, a=10, b=20)
  42. # attributes should be readable
  43. self.assertEqual(p.func, capture)
  44. self.assertEqual(p.args, (1, 2))
  45. self.assertEqual(p.keywords, dict(a=10, b=20))
  46. # attributes should not be writable
  47. self.assertRaises(TypeError, setattr, p, 'func', map)
  48. self.assertRaises(TypeError, setattr, p, 'args', (1, 2))
  49. self.assertRaises(TypeError, setattr, p, 'keywords', dict(a=1, b=2))
  50. p = self.thetype(hex)
  51. try:
  52. del p.__dict__
  53. except TypeError:
  54. pass
  55. else:
  56. self.fail('partial object allowed __dict__ to be deleted')
  57. def test_argument_checking(self):
  58. self.assertRaises(TypeError, self.thetype) # need at least a func arg
  59. try:
  60. self.thetype(2)()
  61. except TypeError:
  62. pass
  63. else:
  64. self.fail('First arg not checked for callability')
  65. def test_protection_of_callers_dict_argument(self):
  66. # a caller's dictionary should not be altered by partial
  67. def func(a=10, b=20):
  68. return a
  69. d = {'a':3}
  70. p = self.thetype(func, a=5)
  71. self.assertEqual(p(**d), 3)
  72. self.assertEqual(d, {'a':3})
  73. p(b=7)
  74. self.assertEqual(d, {'a':3})
  75. def test_arg_combinations(self):
  76. # exercise special code paths for zero args in either partial
  77. # object or the caller
  78. p = self.thetype(capture)
  79. self.assertEqual(p(), ((), {}))
  80. self.assertEqual(p(1,2), ((1,2), {}))
  81. p = self.thetype(capture, 1, 2)
  82. self.assertEqual(p(), ((1,2), {}))
  83. self.assertEqual(p(3,4), ((1,2,3,4), {}))
  84. def test_kw_combinations(self):
  85. # exercise special code paths for no keyword args in
  86. # either the partial object or the caller
  87. p = self.thetype(capture)
  88. self.assertEqual(p.keywords, {})
  89. self.assertEqual(p(), ((), {}))
  90. self.assertEqual(p(a=1), ((), {'a':1}))
  91. p = self.thetype(capture, a=1)
  92. self.assertEqual(p.keywords, {'a':1})
  93. self.assertEqual(p(), ((), {'a':1}))
  94. self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
  95. # keyword args in the call override those in the partial object
  96. self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
  97. def test_positional(self):
  98. # make sure positional arguments are captured correctly
  99. for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
  100. p = self.thetype(capture, *args)
  101. expected = args + ('x',)
  102. got, empty = p('x')
  103. self.assertTrue(expected == got and empty == {})
  104. def test_keyword(self):
  105. # make sure keyword arguments are captured correctly
  106. for a in ['a', 0, None, 3.5]:
  107. p = self.thetype(capture, a=a)
  108. expected = {'a':a,'x':None}
  109. empty, got = p(x=None)
  110. self.assertTrue(expected == got and empty == ())
  111. def test_no_side_effects(self):
  112. # make sure there are no side effects that affect subsequent calls
  113. p = self.thetype(capture, 0, a=1)
  114. args1, kw1 = p(1, b=2)
  115. self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
  116. args2, kw2 = p()
  117. self.assertTrue(args2 == (0,) and kw2 == {'a':1})
  118. def test_error_propagation(self):
  119. def f(x, y):
  120. x // y
  121. self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
  122. self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
  123. self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
  124. self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
  125. def test_weakref(self):
  126. f = self.thetype(int, base=16)
  127. p = proxy(f)
  128. self.assertEqual(f.func, p.func)
  129. f = None
  130. self.assertRaises(ReferenceError, getattr, p, 'func')
  131. def test_with_bound_and_unbound_methods(self):
  132. data = map(str, range(10))
  133. join = self.thetype(str.join, '')
  134. self.assertEqual(join(data), '0123456789')
  135. join = self.thetype(''.join)
  136. self.assertEqual(join(data), '0123456789')
  137. def test_pickle(self):
  138. f = self.thetype(signature, ['asdf'], bar=[True])
  139. f.attr = []
  140. for proto in range(pickle.HIGHEST_PROTOCOL + 1):
  141. f_copy = pickle.loads(pickle.dumps(f, proto))
  142. self.assertEqual(signature(f_copy), signature(f))
  143. def test_copy(self):
  144. f = self.thetype(signature, ['asdf'], bar=[True])
  145. f.attr = []
  146. f_copy = copy.copy(f)
  147. self.assertEqual(signature(f_copy), signature(f))
  148. self.assertIs(f_copy.attr, f.attr)
  149. self.assertIs(f_copy.args, f.args)
  150. self.assertIs(f_copy.keywords, f.keywords)
  151. def test_deepcopy(self):
  152. f = self.thetype(signature, ['asdf'], bar=[True])
  153. f.attr = []
  154. f_copy = copy.deepcopy(f)
  155. self.assertEqual(signature(f_copy), signature(f))
  156. self.assertIsNot(f_copy.attr, f.attr)
  157. self.assertIsNot(f_copy.args, f.args)
  158. self.assertIsNot(f_copy.args[0], f.args[0])
  159. self.assertIsNot(f_copy.keywords, f.keywords)
  160. self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
  161. def test_setstate(self):
  162. f = self.thetype(signature)
  163. f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
  164. self.assertEqual(signature(f),
  165. (capture, (1,), dict(a=10), dict(attr=[])))
  166. self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
  167. f.__setstate__((capture, (1,), dict(a=10), None))
  168. self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
  169. self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
  170. f.__setstate__((capture, (1,), None, None))
  171. #self.assertEqual(signature(f), (capture, (1,), {}, {}))
  172. self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
  173. self.assertEqual(f(2), ((1, 2), {}))
  174. self.assertEqual(f(), ((1,), {}))
  175. f.__setstate__((capture, (), {}, None))
  176. self.assertEqual(signature(f), (capture, (), {}, {}))
  177. self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
  178. self.assertEqual(f(2), ((2,), {}))
  179. self.assertEqual(f(), ((), {}))
  180. def test_setstate_errors(self):
  181. f = self.thetype(signature)
  182. self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
  183. self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
  184. self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
  185. self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
  186. self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
  187. self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
  188. self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
  189. def test_setstate_subclasses(self):
  190. f = self.thetype(signature)
  191. f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
  192. s = signature(f)
  193. self.assertEqual(s, (capture, (1,), dict(a=10), {}))
  194. self.assertIs(type(s[1]), tuple)
  195. self.assertIs(type(s[2]), dict)
  196. r = f()
  197. self.assertEqual(r, ((1,), {'a': 10}))
  198. self.assertIs(type(r[0]), tuple)
  199. self.assertIs(type(r[1]), dict)
  200. f.__setstate__((capture, BadTuple((1,)), {}, None))
  201. s = signature(f)
  202. self.assertEqual(s, (capture, (1,), {}, {}))
  203. self.assertIs(type(s[1]), tuple)
  204. r = f(2)
  205. self.assertEqual(r, ((1, 2), {}))
  206. self.assertIs(type(r[0]), tuple)
  207. # Issue 6083: Reference counting bug
  208. def test_setstate_refcount(self):
  209. class BadSequence:
  210. def __len__(self):
  211. return 4
  212. def __getitem__(self, key):
  213. if key == 0:
  214. return max
  215. elif key == 1:
  216. return tuple(range(1000000))
  217. elif key in (2, 3):
  218. return {}
  219. raise IndexError
  220. f = self.thetype(object)
  221. self.assertRaises(TypeError, f.__setstate__, BadSequence())
  222. class PartialSubclass(functools.partial):
  223. pass
  224. class TestPartialSubclass(TestPartial):
  225. thetype = PartialSubclass
  226. class TestPythonPartial(TestPartial):
  227. thetype = PythonPartial
  228. # the python version isn't picklable
  229. test_pickle = None
  230. test_setstate = None
  231. test_setstate_errors = None
  232. test_setstate_subclasses = None
  233. test_setstate_refcount = None
  234. # the python version isn't deepcopyable
  235. test_deepcopy = None
  236. # the python version isn't a type
  237. test_attributes = None
  238. class TestUpdateWrapper(unittest.TestCase):
  239. def check_wrapper(self, wrapper, wrapped,
  240. assigned=functools.WRAPPER_ASSIGNMENTS,
  241. updated=functools.WRAPPER_UPDATES):
  242. # Check attributes were assigned
  243. for name in assigned:
  244. self.assertTrue(getattr(wrapper, name) is getattr(wrapped, name))
  245. # Check attributes were updated
  246. for name in updated:
  247. wrapper_attr = getattr(wrapper, name)
  248. wrapped_attr = getattr(wrapped, name)
  249. for key in wrapped_attr:
  250. self.assertTrue(wrapped_attr[key] is wrapper_attr[key])
  251. def _default_update(self):
  252. def f():
  253. """This is a test"""
  254. pass
  255. f.attr = 'This is also a test'
  256. def wrapper():
  257. pass
  258. functools.update_wrapper(wrapper, f)
  259. return wrapper, f
  260. def test_default_update(self):
  261. wrapper, f = self._default_update()
  262. self.check_wrapper(wrapper, f)
  263. self.assertEqual(wrapper.__name__, 'f')
  264. self.assertEqual(wrapper.attr, 'This is also a test')
  265. @unittest.skipIf(sys.flags.optimize >= 2,
  266. "Docstrings are omitted with -O2 and above")
  267. def test_default_update_doc(self):
  268. wrapper, f = self._default_update()
  269. self.assertEqual(wrapper.__doc__, 'This is a test')
  270. def test_no_update(self):
  271. def f():
  272. """This is a test"""
  273. pass
  274. f.attr = 'This is also a test'
  275. def wrapper():
  276. pass
  277. functools.update_wrapper(wrapper, f, (), ())
  278. self.check_wrapper(wrapper, f, (), ())
  279. self.assertEqual(wrapper.__name__, 'wrapper')
  280. self.assertEqual(wrapper.__doc__, None)
  281. self.assertFalse(hasattr(wrapper, 'attr'))
  282. def test_selective_update(self):
  283. def f():
  284. pass
  285. f.attr = 'This is a different test'
  286. f.dict_attr = dict(a=1, b=2, c=3)
  287. def wrapper():
  288. pass
  289. wrapper.dict_attr = {}
  290. assign = ('attr',)
  291. update = ('dict_attr',)
  292. functools.update_wrapper(wrapper, f, assign, update)
  293. self.check_wrapper(wrapper, f, assign, update)
  294. self.assertEqual(wrapper.__name__, 'wrapper')
  295. self.assertEqual(wrapper.__doc__, None)
  296. self.assertEqual(wrapper.attr, 'This is a different test')
  297. self.assertEqual(wrapper.dict_attr, f.dict_attr)
  298. @test_support.requires_docstrings
  299. def test_builtin_update(self):
  300. # Test for bug #1576241
  301. def wrapper():
  302. pass
  303. functools.update_wrapper(wrapper, max)
  304. self.assertEqual(wrapper.__name__, 'max')
  305. self.assertTrue(wrapper.__doc__.startswith('max('))
  306. class TestWraps(TestUpdateWrapper):
  307. def _default_update(self):
  308. def f():
  309. """This is a test"""
  310. pass
  311. f.attr = 'This is also a test'
  312. @functools.wraps(f)
  313. def wrapper():
  314. pass
  315. self.check_wrapper(wrapper, f)
  316. return wrapper
  317. def test_default_update(self):
  318. wrapper = self._default_update()
  319. self.assertEqual(wrapper.__name__, 'f')
  320. self.assertEqual(wrapper.attr, 'This is also a test')
  321. @unittest.skipIf(sys.flags.optimize >= 2,
  322. "Docstrings are omitted with -O2 and above")
  323. def test_default_update_doc(self):
  324. wrapper = self._default_update()
  325. self.assertEqual(wrapper.__doc__, 'This is a test')
  326. def test_no_update(self):
  327. def f():
  328. """This is a test"""
  329. pass
  330. f.attr = 'This is also a test'
  331. @functools.wraps(f, (), ())
  332. def wrapper():
  333. pass
  334. self.check_wrapper(wrapper, f, (), ())
  335. self.assertEqual(wrapper.__name__, 'wrapper')
  336. self.assertEqual(wrapper.__doc__, None)
  337. self.assertFalse(hasattr(wrapper, 'attr'))
  338. def test_selective_update(self):
  339. def f():
  340. pass
  341. f.attr = 'This is a different test'
  342. f.dict_attr = dict(a=1, b=2, c=3)
  343. def add_dict_attr(f):
  344. f.dict_attr = {}
  345. return f
  346. assign = ('attr',)
  347. update = ('dict_attr',)
  348. @functools.wraps(f, assign, update)
  349. @add_dict_attr
  350. def wrapper():
  351. pass
  352. self.check_wrapper(wrapper, f, assign, update)
  353. self.assertEqual(wrapper.__name__, 'wrapper')
  354. self.assertEqual(wrapper.__doc__, None)
  355. self.assertEqual(wrapper.attr, 'This is a different test')
  356. self.assertEqual(wrapper.dict_attr, f.dict_attr)
  357. class TestReduce(unittest.TestCase):
  358. def test_reduce(self):
  359. class Squares:
  360. def __init__(self, max):
  361. self.max = max
  362. self.sofar = []
  363. def __len__(self): return len(self.sofar)
  364. def __getitem__(self, i):
  365. if not 0 <= i < self.max: raise IndexError
  366. n = len(self.sofar)
  367. while n <= i:
  368. self.sofar.append(n*n)
  369. n += 1
  370. return self.sofar[i]
  371. reduce = functools.reduce
  372. self.assertEqual(reduce(lambda x, y: x+y, ['a', 'b', 'c'], ''), 'abc')
  373. self.assertEqual(
  374. reduce(lambda x, y: x+y, [['a', 'c'], [], ['d', 'w']], []),
  375. ['a','c','d','w']
  376. )
  377. self.assertEqual(reduce(lambda x, y: x*y, range(2,8), 1), 5040)
  378. self.assertEqual(
  379. reduce(lambda x, y: x*y, range(2,21), 1L),
  380. 2432902008176640000L
  381. )
  382. self.assertEqual(reduce(lambda x, y: x+y, Squares(10)), 285)
  383. self.assertEqual(reduce(lambda x, y: x+y, Squares(10), 0), 285)
  384. self.assertEqual(reduce(lambda x, y: x+y, Squares(0), 0), 0)
  385. self.assertRaises(TypeError, reduce)
  386. self.assertRaises(TypeError, reduce, 42, 42)
  387. self.assertRaises(TypeError, reduce, 42, 42, 42)
  388. self.assertEqual(reduce(42, "1"), "1") # func is never called with one item
  389. self.assertEqual(reduce(42, "", "1"), "1") # func is never called with one item
  390. self.assertRaises(TypeError, reduce, 42, (42, 42))
  391. class TestCmpToKey(unittest.TestCase):
  392. def test_cmp_to_key(self):
  393. def mycmp(x, y):
  394. return y - x
  395. self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
  396. [4, 3, 2, 1, 0])
  397. def test_hash(self):
  398. def mycmp(x, y):
  399. return y - x
  400. key = functools.cmp_to_key(mycmp)
  401. k = key(10)
  402. self.assertRaises(TypeError, hash(k))
  403. class TestTotalOrdering(unittest.TestCase):
  404. def test_total_ordering_lt(self):
  405. @functools.total_ordering
  406. class A:
  407. def __init__(self, value):
  408. self.value = value
  409. def __lt__(self, other):
  410. return self.value < other.value
  411. def __eq__(self, other):
  412. return self.value == other.value
  413. self.assertTrue(A(1) < A(2))
  414. self.assertTrue(A(2) > A(1))
  415. self.assertTrue(A(1) <= A(2))
  416. self.assertTrue(A(2) >= A(1))
  417. self.assertTrue(A(2) <= A(2))
  418. self.assertTrue(A(2) >= A(2))
  419. def test_total_ordering_le(self):
  420. @functools.total_ordering
  421. class A:
  422. def __init__(self, value):
  423. self.value = value
  424. def __le__(self, other):
  425. return self.value <= other.value
  426. def __eq__(self, other):
  427. return self.value == other.value
  428. self.assertTrue(A(1) < A(2))
  429. self.assertTrue(A(2) > A(1))
  430. self.assertTrue(A(1) <= A(2))
  431. self.assertTrue(A(2) >= A(1))
  432. self.assertTrue(A(2) <= A(2))
  433. self.assertTrue(A(2) >= A(2))
  434. def test_total_ordering_gt(self):
  435. @functools.total_ordering
  436. class A:
  437. def __init__(self, value):
  438. self.value = value
  439. def __gt__(self, other):
  440. return self.value > other.value
  441. def __eq__(self, other):
  442. return self.value == other.value
  443. self.assertTrue(A(1) < A(2))
  444. self.assertTrue(A(2) > A(1))
  445. self.assertTrue(A(1) <= A(2))
  446. self.assertTrue(A(2) >= A(1))
  447. self.assertTrue(A(2) <= A(2))
  448. self.assertTrue(A(2) >= A(2))
  449. def test_total_ordering_ge(self):
  450. @functools.total_ordering
  451. class A:
  452. def __init__(self, value):
  453. self.value = value
  454. def __ge__(self, other):
  455. return self.value >= other.value
  456. def __eq__(self, other):
  457. return self.value == other.value
  458. self.assertTrue(A(1) < A(2))
  459. self.assertTrue(A(2) > A(1))
  460. self.assertTrue(A(1) <= A(2))
  461. self.assertTrue(A(2) >= A(1))
  462. self.assertTrue(A(2) <= A(2))
  463. self.assertTrue(A(2) >= A(2))
  464. def test_total_ordering_no_overwrite(self):
  465. # new methods should not overwrite existing
  466. @functools.total_ordering
  467. class A(str):
  468. pass
  469. self.assertTrue(A("a") < A("b"))
  470. self.assertTrue(A("b") > A("a"))
  471. self.assertTrue(A("a") <= A("b"))
  472. self.assertTrue(A("b") >= A("a"))
  473. self.assertTrue(A("b") <= A("b"))
  474. self.assertTrue(A("b") >= A("b"))
  475. def test_no_operations_defined(self):
  476. with self.assertRaises(ValueError):
  477. @functools.total_ordering
  478. class A:
  479. pass
  480. def test_bug_10042(self):
  481. @functools.total_ordering
  482. class TestTO:
  483. def __init__(self, value):
  484. self.value = value
  485. def __eq__(self, other):
  486. if isinstance(other, TestTO):
  487. return self.value == other.value
  488. return False
  489. def __lt__(self, other):
  490. if isinstance(other, TestTO):
  491. return self.value < other.value
  492. raise TypeError
  493. with self.assertRaises(TypeError):
  494. TestTO(8) <= ()
  495. def test_main(verbose=None):
  496. test_classes = (
  497. TestPartial,
  498. TestPartialSubclass,
  499. TestPythonPartial,
  500. TestUpdateWrapper,
  501. TestTotalOrdering,
  502. TestWraps,
  503. TestReduce,
  504. )
  505. test_support.run_unittest(*test_classes)
  506. # verify reference counting
  507. if verbose and hasattr(sys, "gettotalrefcount"):
  508. import gc
  509. counts = [None] * 5
  510. for i in xrange(len(counts)):
  511. test_support.run_unittest(*test_classes)
  512. gc.collect()
  513. counts[i] = sys.gettotalrefcount()
  514. print counts
  515. if __name__ == '__main__':
  516. test_main(verbose=True)