testwith.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. import unittest
  2. from warnings import catch_warnings
  3. from unittest.test.testmock.support import is_instance
  4. from unittest.mock import MagicMock, Mock, patch, sentinel, mock_open, call
  5. something = sentinel.Something
  6. something_else = sentinel.SomethingElse
  7. class WithTest(unittest.TestCase):
  8. def test_with_statement(self):
  9. with patch('%s.something' % __name__, sentinel.Something2):
  10. self.assertEqual(something, sentinel.Something2, "unpatched")
  11. self.assertEqual(something, sentinel.Something)
  12. def test_with_statement_exception(self):
  13. try:
  14. with patch('%s.something' % __name__, sentinel.Something2):
  15. self.assertEqual(something, sentinel.Something2, "unpatched")
  16. raise Exception('pow')
  17. except Exception:
  18. pass
  19. else:
  20. self.fail("patch swallowed exception")
  21. self.assertEqual(something, sentinel.Something)
  22. def test_with_statement_as(self):
  23. with patch('%s.something' % __name__) as mock_something:
  24. self.assertEqual(something, mock_something, "unpatched")
  25. self.assertTrue(is_instance(mock_something, MagicMock),
  26. "patching wrong type")
  27. self.assertEqual(something, sentinel.Something)
  28. def test_patch_object_with_statement(self):
  29. class Foo(object):
  30. something = 'foo'
  31. original = Foo.something
  32. with patch.object(Foo, 'something'):
  33. self.assertNotEqual(Foo.something, original, "unpatched")
  34. self.assertEqual(Foo.something, original)
  35. def test_with_statement_nested(self):
  36. with catch_warnings(record=True):
  37. with patch('%s.something' % __name__) as mock_something, patch('%s.something_else' % __name__) as mock_something_else:
  38. self.assertEqual(something, mock_something, "unpatched")
  39. self.assertEqual(something_else, mock_something_else,
  40. "unpatched")
  41. self.assertEqual(something, sentinel.Something)
  42. self.assertEqual(something_else, sentinel.SomethingElse)
  43. def test_with_statement_specified(self):
  44. with patch('%s.something' % __name__, sentinel.Patched) as mock_something:
  45. self.assertEqual(something, mock_something, "unpatched")
  46. self.assertEqual(mock_something, sentinel.Patched, "wrong patch")
  47. self.assertEqual(something, sentinel.Something)
  48. def testContextManagerMocking(self):
  49. mock = Mock()
  50. mock.__enter__ = Mock()
  51. mock.__exit__ = Mock()
  52. mock.__exit__.return_value = False
  53. with mock as m:
  54. self.assertEqual(m, mock.__enter__.return_value)
  55. mock.__enter__.assert_called_with()
  56. mock.__exit__.assert_called_with(None, None, None)
  57. def test_context_manager_with_magic_mock(self):
  58. mock = MagicMock()
  59. with self.assertRaises(TypeError):
  60. with mock:
  61. 'foo' + 3
  62. mock.__enter__.assert_called_with()
  63. self.assertTrue(mock.__exit__.called)
  64. def test_with_statement_same_attribute(self):
  65. with patch('%s.something' % __name__, sentinel.Patched) as mock_something:
  66. self.assertEqual(something, mock_something, "unpatched")
  67. with patch('%s.something' % __name__) as mock_again:
  68. self.assertEqual(something, mock_again, "unpatched")
  69. self.assertEqual(something, mock_something,
  70. "restored with wrong instance")
  71. self.assertEqual(something, sentinel.Something, "not restored")
  72. def test_with_statement_imbricated(self):
  73. with patch('%s.something' % __name__) as mock_something:
  74. self.assertEqual(something, mock_something, "unpatched")
  75. with patch('%s.something_else' % __name__) as mock_something_else:
  76. self.assertEqual(something_else, mock_something_else,
  77. "unpatched")
  78. self.assertEqual(something, sentinel.Something)
  79. self.assertEqual(something_else, sentinel.SomethingElse)
  80. def test_dict_context_manager(self):
  81. foo = {}
  82. with patch.dict(foo, {'a': 'b'}):
  83. self.assertEqual(foo, {'a': 'b'})
  84. self.assertEqual(foo, {})
  85. with self.assertRaises(NameError):
  86. with patch.dict(foo, {'a': 'b'}):
  87. self.assertEqual(foo, {'a': 'b'})
  88. raise NameError('Konrad')
  89. self.assertEqual(foo, {})
  90. class TestMockOpen(unittest.TestCase):
  91. def test_mock_open(self):
  92. mock = mock_open()
  93. with patch('%s.open' % __name__, mock, create=True) as patched:
  94. self.assertIs(patched, mock)
  95. open('foo')
  96. mock.assert_called_once_with('foo')
  97. def test_mock_open_context_manager(self):
  98. mock = mock_open()
  99. handle = mock.return_value
  100. with patch('%s.open' % __name__, mock, create=True):
  101. with open('foo') as f:
  102. f.read()
  103. expected_calls = [call('foo'), call().__enter__(), call().read(),
  104. call().__exit__(None, None, None)]
  105. self.assertEqual(mock.mock_calls, expected_calls)
  106. self.assertIs(f, handle)
  107. def test_mock_open_context_manager_multiple_times(self):
  108. mock = mock_open()
  109. with patch('%s.open' % __name__, mock, create=True):
  110. with open('foo') as f:
  111. f.read()
  112. with open('bar') as f:
  113. f.read()
  114. expected_calls = [
  115. call('foo'), call().__enter__(), call().read(),
  116. call().__exit__(None, None, None),
  117. call('bar'), call().__enter__(), call().read(),
  118. call().__exit__(None, None, None)]
  119. self.assertEqual(mock.mock_calls, expected_calls)
  120. def test_explicit_mock(self):
  121. mock = MagicMock()
  122. mock_open(mock)
  123. with patch('%s.open' % __name__, mock, create=True) as patched:
  124. self.assertIs(patched, mock)
  125. open('foo')
  126. mock.assert_called_once_with('foo')
  127. def test_read_data(self):
  128. mock = mock_open(read_data='foo')
  129. with patch('%s.open' % __name__, mock, create=True):
  130. h = open('bar')
  131. result = h.read()
  132. self.assertEqual(result, 'foo')
  133. def test_readline_data(self):
  134. # Check that readline will return all the lines from the fake file
  135. mock = mock_open(read_data='foo\nbar\nbaz\n')
  136. with patch('%s.open' % __name__, mock, create=True):
  137. h = open('bar')
  138. line1 = h.readline()
  139. line2 = h.readline()
  140. line3 = h.readline()
  141. self.assertEqual(line1, 'foo\n')
  142. self.assertEqual(line2, 'bar\n')
  143. self.assertEqual(line3, 'baz\n')
  144. # Check that we properly emulate a file that doesn't end in a newline
  145. mock = mock_open(read_data='foo')
  146. with patch('%s.open' % __name__, mock, create=True):
  147. h = open('bar')
  148. result = h.readline()
  149. self.assertEqual(result, 'foo')
  150. def test_readlines_data(self):
  151. # Test that emulating a file that ends in a newline character works
  152. mock = mock_open(read_data='foo\nbar\nbaz\n')
  153. with patch('%s.open' % __name__, mock, create=True):
  154. h = open('bar')
  155. result = h.readlines()
  156. self.assertEqual(result, ['foo\n', 'bar\n', 'baz\n'])
  157. # Test that files without a final newline will also be correctly
  158. # emulated
  159. mock = mock_open(read_data='foo\nbar\nbaz')
  160. with patch('%s.open' % __name__, mock, create=True):
  161. h = open('bar')
  162. result = h.readlines()
  163. self.assertEqual(result, ['foo\n', 'bar\n', 'baz'])
  164. def test_read_bytes(self):
  165. mock = mock_open(read_data=b'\xc6')
  166. with patch('%s.open' % __name__, mock, create=True):
  167. with open('abc', 'rb') as f:
  168. result = f.read()
  169. self.assertEqual(result, b'\xc6')
  170. def test_readline_bytes(self):
  171. m = mock_open(read_data=b'abc\ndef\nghi\n')
  172. with patch('%s.open' % __name__, m, create=True):
  173. with open('abc', 'rb') as f:
  174. line1 = f.readline()
  175. line2 = f.readline()
  176. line3 = f.readline()
  177. self.assertEqual(line1, b'abc\n')
  178. self.assertEqual(line2, b'def\n')
  179. self.assertEqual(line3, b'ghi\n')
  180. def test_readlines_bytes(self):
  181. m = mock_open(read_data=b'abc\ndef\nghi\n')
  182. with patch('%s.open' % __name__, m, create=True):
  183. with open('abc', 'rb') as f:
  184. result = f.readlines()
  185. self.assertEqual(result, [b'abc\n', b'def\n', b'ghi\n'])
  186. def test_mock_open_read_with_argument(self):
  187. # At one point calling read with an argument was broken
  188. # for mocks returned by mock_open
  189. some_data = 'foo\nbar\nbaz'
  190. mock = mock_open(read_data=some_data)
  191. self.assertEqual(mock().read(10), some_data)
  192. def test_interleaved_reads(self):
  193. # Test that calling read, readline, and readlines pulls data
  194. # sequentially from the data we preload with
  195. mock = mock_open(read_data='foo\nbar\nbaz\n')
  196. with patch('%s.open' % __name__, mock, create=True):
  197. h = open('bar')
  198. line1 = h.readline()
  199. rest = h.readlines()
  200. self.assertEqual(line1, 'foo\n')
  201. self.assertEqual(rest, ['bar\n', 'baz\n'])
  202. mock = mock_open(read_data='foo\nbar\nbaz\n')
  203. with patch('%s.open' % __name__, mock, create=True):
  204. h = open('bar')
  205. line1 = h.readline()
  206. rest = h.read()
  207. self.assertEqual(line1, 'foo\n')
  208. self.assertEqual(rest, 'bar\nbaz\n')
  209. def test_overriding_return_values(self):
  210. mock = mock_open(read_data='foo')
  211. handle = mock()
  212. handle.read.return_value = 'bar'
  213. handle.readline.return_value = 'bar'
  214. handle.readlines.return_value = ['bar']
  215. self.assertEqual(handle.read(), 'bar')
  216. self.assertEqual(handle.readline(), 'bar')
  217. self.assertEqual(handle.readlines(), ['bar'])
  218. # call repeatedly to check that a StopIteration is not propagated
  219. self.assertEqual(handle.readline(), 'bar')
  220. self.assertEqual(handle.readline(), 'bar')
  221. if __name__ == '__main__':
  222. unittest.main()