test_heapq.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. """Unittests for heapq."""
  2. import sys
  3. import random
  4. from test import test_support
  5. from unittest import TestCase, skipUnless
  6. py_heapq = test_support.import_fresh_module('heapq', blocked=['_heapq'])
  7. c_heapq = test_support.import_fresh_module('heapq', fresh=['_heapq'])
  8. # _heapq.nlargest/nsmallest are saved in heapq._nlargest/_smallest when
  9. # _heapq is imported, so check them there
  10. func_names = ['heapify', 'heappop', 'heappush', 'heappushpop',
  11. 'heapreplace', '_nlargest', '_nsmallest']
  12. class TestModules(TestCase):
  13. def test_py_functions(self):
  14. for fname in func_names:
  15. self.assertEqual(getattr(py_heapq, fname).__module__, 'heapq')
  16. @skipUnless(c_heapq, 'requires _heapq')
  17. def test_c_functions(self):
  18. for fname in func_names:
  19. self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq')
  20. class TestHeap(TestCase):
  21. module = None
  22. def test_push_pop(self):
  23. # 1) Push 256 random numbers and pop them off, verifying all's OK.
  24. heap = []
  25. data = []
  26. self.check_invariant(heap)
  27. for i in range(256):
  28. item = random.random()
  29. data.append(item)
  30. self.module.heappush(heap, item)
  31. self.check_invariant(heap)
  32. results = []
  33. while heap:
  34. item = self.module.heappop(heap)
  35. self.check_invariant(heap)
  36. results.append(item)
  37. data_sorted = data[:]
  38. data_sorted.sort()
  39. self.assertEqual(data_sorted, results)
  40. # 2) Check that the invariant holds for a sorted array
  41. self.check_invariant(results)
  42. self.assertRaises(TypeError, self.module.heappush, [])
  43. try:
  44. self.assertRaises(TypeError, self.module.heappush, None, None)
  45. self.assertRaises(TypeError, self.module.heappop, None)
  46. except AttributeError:
  47. pass
  48. def check_invariant(self, heap):
  49. # Check the heap invariant.
  50. for pos, item in enumerate(heap):
  51. if pos: # pos 0 has no parent
  52. parentpos = (pos-1) >> 1
  53. self.assertTrue(heap[parentpos] <= item)
  54. def test_heapify(self):
  55. for size in range(30):
  56. heap = [random.random() for dummy in range(size)]
  57. self.module.heapify(heap)
  58. self.check_invariant(heap)
  59. self.assertRaises(TypeError, self.module.heapify, None)
  60. def test_naive_nbest(self):
  61. data = [random.randrange(2000) for i in range(1000)]
  62. heap = []
  63. for item in data:
  64. self.module.heappush(heap, item)
  65. if len(heap) > 10:
  66. self.module.heappop(heap)
  67. heap.sort()
  68. self.assertEqual(heap, sorted(data)[-10:])
  69. def heapiter(self, heap):
  70. # An iterator returning a heap's elements, smallest-first.
  71. try:
  72. while 1:
  73. yield self.module.heappop(heap)
  74. except IndexError:
  75. pass
  76. def test_nbest(self):
  77. # Less-naive "N-best" algorithm, much faster (if len(data) is big
  78. # enough <wink>) than sorting all of data. However, if we had a max
  79. # heap instead of a min heap, it could go faster still via
  80. # heapify'ing all of data (linear time), then doing 10 heappops
  81. # (10 log-time steps).
  82. data = [random.randrange(2000) for i in range(1000)]
  83. heap = data[:10]
  84. self.module.heapify(heap)
  85. for item in data[10:]:
  86. if item > heap[0]: # this gets rarer the longer we run
  87. self.module.heapreplace(heap, item)
  88. self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
  89. self.assertRaises(TypeError, self.module.heapreplace, None)
  90. self.assertRaises(TypeError, self.module.heapreplace, None, None)
  91. self.assertRaises(IndexError, self.module.heapreplace, [], None)
  92. def test_nbest_with_pushpop(self):
  93. data = [random.randrange(2000) for i in range(1000)]
  94. heap = data[:10]
  95. self.module.heapify(heap)
  96. for item in data[10:]:
  97. self.module.heappushpop(heap, item)
  98. self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
  99. self.assertEqual(self.module.heappushpop([], 'x'), 'x')
  100. def test_heappushpop(self):
  101. h = []
  102. x = self.module.heappushpop(h, 10)
  103. self.assertEqual((h, x), ([], 10))
  104. h = [10]
  105. x = self.module.heappushpop(h, 10.0)
  106. self.assertEqual((h, x), ([10], 10.0))
  107. self.assertEqual(type(h[0]), int)
  108. self.assertEqual(type(x), float)
  109. h = [10];
  110. x = self.module.heappushpop(h, 9)
  111. self.assertEqual((h, x), ([10], 9))
  112. h = [10];
  113. x = self.module.heappushpop(h, 11)
  114. self.assertEqual((h, x), ([11], 10))
  115. def test_heapsort(self):
  116. # Exercise everything with repeated heapsort checks
  117. for trial in xrange(100):
  118. size = random.randrange(50)
  119. data = [random.randrange(25) for i in range(size)]
  120. if trial & 1: # Half of the time, use heapify
  121. heap = data[:]
  122. self.module.heapify(heap)
  123. else: # The rest of the time, use heappush
  124. heap = []
  125. for item in data:
  126. self.module.heappush(heap, item)
  127. heap_sorted = [self.module.heappop(heap) for i in range(size)]
  128. self.assertEqual(heap_sorted, sorted(data))
  129. def test_merge(self):
  130. inputs = []
  131. for i in xrange(random.randrange(5)):
  132. row = sorted(random.randrange(1000) for j in range(random.randrange(10)))
  133. inputs.append(row)
  134. self.assertEqual(sorted(chain(*inputs)), list(self.module.merge(*inputs)))
  135. self.assertEqual(list(self.module.merge()), [])
  136. def test_merge_does_not_suppress_index_error(self):
  137. # Issue 19018: Heapq.merge suppresses IndexError from user generator
  138. def iterable():
  139. s = list(range(10))
  140. for i in range(20):
  141. yield s[i] # IndexError when i > 10
  142. with self.assertRaises(IndexError):
  143. list(self.module.merge(iterable(), iterable()))
  144. def test_merge_stability(self):
  145. class Int(int):
  146. pass
  147. inputs = [[], [], [], []]
  148. for i in range(20000):
  149. stream = random.randrange(4)
  150. x = random.randrange(500)
  151. obj = Int(x)
  152. obj.pair = (x, stream)
  153. inputs[stream].append(obj)
  154. for stream in inputs:
  155. stream.sort()
  156. result = [i.pair for i in self.module.merge(*inputs)]
  157. self.assertEqual(result, sorted(result))
  158. def test_nsmallest(self):
  159. data = [(random.randrange(2000), i) for i in range(1000)]
  160. for f in (None, lambda x: x[0] * 547 % 2000):
  161. for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
  162. self.assertEqual(self.module.nsmallest(n, data), sorted(data)[:n])
  163. self.assertEqual(self.module.nsmallest(n, data, key=f),
  164. sorted(data, key=f)[:n])
  165. def test_nlargest(self):
  166. data = [(random.randrange(2000), i) for i in range(1000)]
  167. for f in (None, lambda x: x[0] * 547 % 2000):
  168. for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
  169. self.assertEqual(self.module.nlargest(n, data),
  170. sorted(data, reverse=True)[:n])
  171. self.assertEqual(self.module.nlargest(n, data, key=f),
  172. sorted(data, key=f, reverse=True)[:n])
  173. def test_comparison_operator(self):
  174. # Issue 3051: Make sure heapq works with both __lt__ and __le__
  175. def hsort(data, comp):
  176. data = map(comp, data)
  177. self.module.heapify(data)
  178. return [self.module.heappop(data).x for i in range(len(data))]
  179. class LT:
  180. def __init__(self, x):
  181. self.x = x
  182. def __lt__(self, other):
  183. return self.x > other.x
  184. class LE:
  185. def __init__(self, x):
  186. self.x = x
  187. def __le__(self, other):
  188. return self.x >= other.x
  189. data = [random.random() for i in range(100)]
  190. target = sorted(data, reverse=True)
  191. self.assertEqual(hsort(data, LT), target)
  192. self.assertEqual(hsort(data, LE), target)
  193. class TestHeapPython(TestHeap):
  194. module = py_heapq
  195. @skipUnless(c_heapq, 'requires _heapq')
  196. class TestHeapC(TestHeap):
  197. module = c_heapq
  198. #==============================================================================
  199. class LenOnly:
  200. "Dummy sequence class defining __len__ but not __getitem__."
  201. def __len__(self):
  202. return 10
  203. class GetOnly:
  204. "Dummy sequence class defining __getitem__ but not __len__."
  205. def __getitem__(self, ndx):
  206. return 10
  207. class CmpErr:
  208. "Dummy element that always raises an error during comparison"
  209. def __cmp__(self, other):
  210. raise ZeroDivisionError
  211. def R(seqn):
  212. 'Regular generator'
  213. for i in seqn:
  214. yield i
  215. class G:
  216. 'Sequence using __getitem__'
  217. def __init__(self, seqn):
  218. self.seqn = seqn
  219. def __getitem__(self, i):
  220. return self.seqn[i]
  221. class I:
  222. 'Sequence using iterator protocol'
  223. def __init__(self, seqn):
  224. self.seqn = seqn
  225. self.i = 0
  226. def __iter__(self):
  227. return self
  228. def next(self):
  229. if self.i >= len(self.seqn): raise StopIteration
  230. v = self.seqn[self.i]
  231. self.i += 1
  232. return v
  233. class Ig:
  234. 'Sequence using iterator protocol defined with a generator'
  235. def __init__(self, seqn):
  236. self.seqn = seqn
  237. self.i = 0
  238. def __iter__(self):
  239. for val in self.seqn:
  240. yield val
  241. class X:
  242. 'Missing __getitem__ and __iter__'
  243. def __init__(self, seqn):
  244. self.seqn = seqn
  245. self.i = 0
  246. def next(self):
  247. if self.i >= len(self.seqn): raise StopIteration
  248. v = self.seqn[self.i]
  249. self.i += 1
  250. return v
  251. class N:
  252. 'Iterator missing next()'
  253. def __init__(self, seqn):
  254. self.seqn = seqn
  255. self.i = 0
  256. def __iter__(self):
  257. return self
  258. class E:
  259. 'Test propagation of exceptions'
  260. def __init__(self, seqn):
  261. self.seqn = seqn
  262. self.i = 0
  263. def __iter__(self):
  264. return self
  265. def next(self):
  266. 3 // 0
  267. class S:
  268. 'Test immediate stop'
  269. def __init__(self, seqn):
  270. pass
  271. def __iter__(self):
  272. return self
  273. def next(self):
  274. raise StopIteration
  275. from itertools import chain, imap
  276. def L(seqn):
  277. 'Test multiple tiers of iterators'
  278. return chain(imap(lambda x:x, R(Ig(G(seqn)))))
  279. class SideEffectLT:
  280. def __init__(self, value, heap):
  281. self.value = value
  282. self.heap = heap
  283. def __lt__(self, other):
  284. self.heap[:] = []
  285. return self.value < other.value
  286. class TestErrorHandling(TestCase):
  287. module = None
  288. def test_non_sequence(self):
  289. for f in (self.module.heapify, self.module.heappop):
  290. self.assertRaises((TypeError, AttributeError), f, 10)
  291. for f in (self.module.heappush, self.module.heapreplace,
  292. self.module.nlargest, self.module.nsmallest):
  293. self.assertRaises((TypeError, AttributeError), f, 10, 10)
  294. def test_len_only(self):
  295. for f in (self.module.heapify, self.module.heappop):
  296. self.assertRaises((TypeError, AttributeError), f, LenOnly())
  297. for f in (self.module.heappush, self.module.heapreplace):
  298. self.assertRaises((TypeError, AttributeError), f, LenOnly(), 10)
  299. for f in (self.module.nlargest, self.module.nsmallest):
  300. self.assertRaises(TypeError, f, 2, LenOnly())
  301. def test_get_only(self):
  302. seq = [CmpErr(), CmpErr(), CmpErr()]
  303. for f in (self.module.heapify, self.module.heappop):
  304. self.assertRaises(ZeroDivisionError, f, seq)
  305. for f in (self.module.heappush, self.module.heapreplace):
  306. self.assertRaises(ZeroDivisionError, f, seq, 10)
  307. for f in (self.module.nlargest, self.module.nsmallest):
  308. self.assertRaises(ZeroDivisionError, f, 2, seq)
  309. def test_arg_parsing(self):
  310. for f in (self.module.heapify, self.module.heappop,
  311. self.module.heappush, self.module.heapreplace,
  312. self.module.nlargest, self.module.nsmallest):
  313. self.assertRaises((TypeError, AttributeError), f, 10)
  314. def test_iterable_args(self):
  315. for f in (self.module.nlargest, self.module.nsmallest):
  316. for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
  317. for g in (G, I, Ig, L, R):
  318. with test_support.check_py3k_warnings(
  319. ("comparing unequal types not supported",
  320. DeprecationWarning), quiet=True):
  321. self.assertEqual(f(2, g(s)), f(2,s))
  322. self.assertEqual(f(2, S(s)), [])
  323. self.assertRaises(TypeError, f, 2, X(s))
  324. self.assertRaises(TypeError, f, 2, N(s))
  325. self.assertRaises(ZeroDivisionError, f, 2, E(s))
  326. # Issue #17278: the heap may change size while it's being walked.
  327. def test_heappush_mutating_heap(self):
  328. heap = []
  329. heap.extend(SideEffectLT(i, heap) for i in range(200))
  330. # Python version raises IndexError, C version RuntimeError
  331. with self.assertRaises((IndexError, RuntimeError)):
  332. self.module.heappush(heap, SideEffectLT(5, heap))
  333. def test_heappop_mutating_heap(self):
  334. heap = []
  335. heap.extend(SideEffectLT(i, heap) for i in range(200))
  336. # Python version raises IndexError, C version RuntimeError
  337. with self.assertRaises((IndexError, RuntimeError)):
  338. self.module.heappop(heap)
  339. class TestErrorHandlingPython(TestErrorHandling):
  340. module = py_heapq
  341. @skipUnless(c_heapq, 'requires _heapq')
  342. class TestErrorHandlingC(TestErrorHandling):
  343. module = c_heapq
  344. #==============================================================================
  345. def test_main(verbose=None):
  346. test_classes = [TestModules, TestHeapPython, TestHeapC,
  347. TestErrorHandlingPython, TestErrorHandlingC]
  348. test_support.run_unittest(*test_classes)
  349. # verify reference counting
  350. if verbose and hasattr(sys, "gettotalrefcount"):
  351. import gc
  352. counts = [None] * 5
  353. for i in xrange(len(counts)):
  354. test_support.run_unittest(*test_classes)
  355. gc.collect()
  356. counts[i] = sys.gettotalrefcount()
  357. print counts
  358. if __name__ == "__main__":
  359. test_main(verbose=True)