test_contextlib.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. """Unit tests for contextlib.py, and other context managers."""
  2. import sys
  3. import tempfile
  4. import unittest
  5. from contextlib import * # Tests __all__
  6. from test import test_support
  7. try:
  8. import threading
  9. except ImportError:
  10. threading = None
  11. class ContextManagerTestCase(unittest.TestCase):
  12. def test_contextmanager_plain(self):
  13. state = []
  14. @contextmanager
  15. def woohoo():
  16. state.append(1)
  17. yield 42
  18. state.append(999)
  19. with woohoo() as x:
  20. self.assertEqual(state, [1])
  21. self.assertEqual(x, 42)
  22. state.append(x)
  23. self.assertEqual(state, [1, 42, 999])
  24. def test_contextmanager_finally(self):
  25. state = []
  26. @contextmanager
  27. def woohoo():
  28. state.append(1)
  29. try:
  30. yield 42
  31. finally:
  32. state.append(999)
  33. with self.assertRaises(ZeroDivisionError):
  34. with woohoo() as x:
  35. self.assertEqual(state, [1])
  36. self.assertEqual(x, 42)
  37. state.append(x)
  38. raise ZeroDivisionError()
  39. self.assertEqual(state, [1, 42, 999])
  40. def test_contextmanager_no_reraise(self):
  41. @contextmanager
  42. def whee():
  43. yield
  44. ctx = whee()
  45. ctx.__enter__()
  46. # Calling __exit__ should not result in an exception
  47. self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
  48. def test_contextmanager_trap_yield_after_throw(self):
  49. @contextmanager
  50. def whoo():
  51. try:
  52. yield
  53. except:
  54. yield
  55. ctx = whoo()
  56. ctx.__enter__()
  57. self.assertRaises(
  58. RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
  59. )
  60. def test_contextmanager_except(self):
  61. state = []
  62. @contextmanager
  63. def woohoo():
  64. state.append(1)
  65. try:
  66. yield 42
  67. except ZeroDivisionError, e:
  68. state.append(e.args[0])
  69. self.assertEqual(state, [1, 42, 999])
  70. with woohoo() as x:
  71. self.assertEqual(state, [1])
  72. self.assertEqual(x, 42)
  73. state.append(x)
  74. raise ZeroDivisionError(999)
  75. self.assertEqual(state, [1, 42, 999])
  76. def _create_contextmanager_attribs(self):
  77. def attribs(**kw):
  78. def decorate(func):
  79. for k,v in kw.items():
  80. setattr(func,k,v)
  81. return func
  82. return decorate
  83. @contextmanager
  84. @attribs(foo='bar')
  85. def baz(spam):
  86. """Whee!"""
  87. return baz
  88. def test_contextmanager_attribs(self):
  89. baz = self._create_contextmanager_attribs()
  90. self.assertEqual(baz.__name__,'baz')
  91. self.assertEqual(baz.foo, 'bar')
  92. @unittest.skipIf(sys.flags.optimize >= 2,
  93. "Docstrings are omitted with -O2 and above")
  94. def test_contextmanager_doc_attrib(self):
  95. baz = self._create_contextmanager_attribs()
  96. self.assertEqual(baz.__doc__, "Whee!")
  97. def test_keywords(self):
  98. # Ensure no keyword arguments are inhibited
  99. @contextmanager
  100. def woohoo(self, func, args, kwds):
  101. yield (self, func, args, kwds)
  102. with woohoo(self=11, func=22, args=33, kwds=44) as target:
  103. self.assertEqual(target, (11, 22, 33, 44))
  104. class NestedTestCase(unittest.TestCase):
  105. # XXX This needs more work
  106. def test_nested(self):
  107. @contextmanager
  108. def a():
  109. yield 1
  110. @contextmanager
  111. def b():
  112. yield 2
  113. @contextmanager
  114. def c():
  115. yield 3
  116. with nested(a(), b(), c()) as (x, y, z):
  117. self.assertEqual(x, 1)
  118. self.assertEqual(y, 2)
  119. self.assertEqual(z, 3)
  120. def test_nested_cleanup(self):
  121. state = []
  122. @contextmanager
  123. def a():
  124. state.append(1)
  125. try:
  126. yield 2
  127. finally:
  128. state.append(3)
  129. @contextmanager
  130. def b():
  131. state.append(4)
  132. try:
  133. yield 5
  134. finally:
  135. state.append(6)
  136. with self.assertRaises(ZeroDivisionError):
  137. with nested(a(), b()) as (x, y):
  138. state.append(x)
  139. state.append(y)
  140. 1 // 0
  141. self.assertEqual(state, [1, 4, 2, 5, 6, 3])
  142. def test_nested_right_exception(self):
  143. @contextmanager
  144. def a():
  145. yield 1
  146. class b(object):
  147. def __enter__(self):
  148. return 2
  149. def __exit__(self, *exc_info):
  150. try:
  151. raise Exception()
  152. except:
  153. pass
  154. with self.assertRaises(ZeroDivisionError):
  155. with nested(a(), b()) as (x, y):
  156. 1 // 0
  157. self.assertEqual((x, y), (1, 2))
  158. def test_nested_b_swallows(self):
  159. @contextmanager
  160. def a():
  161. yield
  162. @contextmanager
  163. def b():
  164. try:
  165. yield
  166. except:
  167. # Swallow the exception
  168. pass
  169. try:
  170. with nested(a(), b()):
  171. 1 // 0
  172. except ZeroDivisionError:
  173. self.fail("Didn't swallow ZeroDivisionError")
  174. def test_nested_break(self):
  175. @contextmanager
  176. def a():
  177. yield
  178. state = 0
  179. while True:
  180. state += 1
  181. with nested(a(), a()):
  182. break
  183. state += 10
  184. self.assertEqual(state, 1)
  185. def test_nested_continue(self):
  186. @contextmanager
  187. def a():
  188. yield
  189. state = 0
  190. while state < 3:
  191. state += 1
  192. with nested(a(), a()):
  193. continue
  194. state += 10
  195. self.assertEqual(state, 3)
  196. def test_nested_return(self):
  197. @contextmanager
  198. def a():
  199. try:
  200. yield
  201. except:
  202. pass
  203. def foo():
  204. with nested(a(), a()):
  205. return 1
  206. return 10
  207. self.assertEqual(foo(), 1)
  208. class ClosingTestCase(unittest.TestCase):
  209. # XXX This needs more work
  210. def test_closing(self):
  211. state = []
  212. class C:
  213. def close(self):
  214. state.append(1)
  215. x = C()
  216. self.assertEqual(state, [])
  217. with closing(x) as y:
  218. self.assertEqual(x, y)
  219. self.assertEqual(state, [1])
  220. def test_closing_error(self):
  221. state = []
  222. class C:
  223. def close(self):
  224. state.append(1)
  225. x = C()
  226. self.assertEqual(state, [])
  227. with self.assertRaises(ZeroDivisionError):
  228. with closing(x) as y:
  229. self.assertEqual(x, y)
  230. 1 // 0
  231. self.assertEqual(state, [1])
  232. class FileContextTestCase(unittest.TestCase):
  233. def testWithOpen(self):
  234. tfn = tempfile.mktemp()
  235. try:
  236. f = None
  237. with open(tfn, "w") as f:
  238. self.assertFalse(f.closed)
  239. f.write("Booh\n")
  240. self.assertTrue(f.closed)
  241. f = None
  242. with self.assertRaises(ZeroDivisionError):
  243. with open(tfn, "r") as f:
  244. self.assertFalse(f.closed)
  245. self.assertEqual(f.read(), "Booh\n")
  246. 1 // 0
  247. self.assertTrue(f.closed)
  248. finally:
  249. test_support.unlink(tfn)
  250. @unittest.skipUnless(threading, 'Threading required for this test.')
  251. class LockContextTestCase(unittest.TestCase):
  252. def boilerPlate(self, lock, locked):
  253. self.assertFalse(locked())
  254. with lock:
  255. self.assertTrue(locked())
  256. self.assertFalse(locked())
  257. with self.assertRaises(ZeroDivisionError):
  258. with lock:
  259. self.assertTrue(locked())
  260. 1 // 0
  261. self.assertFalse(locked())
  262. def testWithLock(self):
  263. lock = threading.Lock()
  264. self.boilerPlate(lock, lock.locked)
  265. def testWithRLock(self):
  266. lock = threading.RLock()
  267. self.boilerPlate(lock, lock._is_owned)
  268. def testWithCondition(self):
  269. lock = threading.Condition()
  270. def locked():
  271. return lock._is_owned()
  272. self.boilerPlate(lock, locked)
  273. def testWithSemaphore(self):
  274. lock = threading.Semaphore()
  275. def locked():
  276. if lock.acquire(False):
  277. lock.release()
  278. return False
  279. else:
  280. return True
  281. self.boilerPlate(lock, locked)
  282. def testWithBoundedSemaphore(self):
  283. lock = threading.BoundedSemaphore()
  284. def locked():
  285. if lock.acquire(False):
  286. lock.release()
  287. return False
  288. else:
  289. return True
  290. self.boilerPlate(lock, locked)
  291. # This is needed to make the test actually run under regrtest.py!
  292. def test_main():
  293. with test_support.check_warnings(("With-statements now directly support "
  294. "multiple context managers",
  295. DeprecationWarning)):
  296. test_support.run_unittest(__name__)
  297. if __name__ == "__main__":
  298. test_main()