test_richcmp.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. # Tests for rich comparisons
  2. import unittest
  3. from test import test_support
  4. import operator
  5. class Number:
  6. def __init__(self, x):
  7. self.x = x
  8. def __lt__(self, other):
  9. return self.x < other
  10. def __le__(self, other):
  11. return self.x <= other
  12. def __eq__(self, other):
  13. return self.x == other
  14. def __ne__(self, other):
  15. return self.x != other
  16. def __gt__(self, other):
  17. return self.x > other
  18. def __ge__(self, other):
  19. return self.x >= other
  20. def __cmp__(self, other):
  21. raise test_support.TestFailed, "Number.__cmp__() should not be called"
  22. def __repr__(self):
  23. return "Number(%r)" % (self.x, )
  24. class Vector:
  25. def __init__(self, data):
  26. self.data = data
  27. def __len__(self):
  28. return len(self.data)
  29. def __getitem__(self, i):
  30. return self.data[i]
  31. def __setitem__(self, i, v):
  32. self.data[i] = v
  33. __hash__ = None # Vectors cannot be hashed
  34. def __nonzero__(self):
  35. raise TypeError, "Vectors cannot be used in Boolean contexts"
  36. def __cmp__(self, other):
  37. raise test_support.TestFailed, "Vector.__cmp__() should not be called"
  38. def __repr__(self):
  39. return "Vector(%r)" % (self.data, )
  40. def __lt__(self, other):
  41. return Vector([a < b for a, b in zip(self.data, self.__cast(other))])
  42. def __le__(self, other):
  43. return Vector([a <= b for a, b in zip(self.data, self.__cast(other))])
  44. def __eq__(self, other):
  45. return Vector([a == b for a, b in zip(self.data, self.__cast(other))])
  46. def __ne__(self, other):
  47. return Vector([a != b for a, b in zip(self.data, self.__cast(other))])
  48. def __gt__(self, other):
  49. return Vector([a > b for a, b in zip(self.data, self.__cast(other))])
  50. def __ge__(self, other):
  51. return Vector([a >= b for a, b in zip(self.data, self.__cast(other))])
  52. def __cast(self, other):
  53. if isinstance(other, Vector):
  54. other = other.data
  55. if len(self.data) != len(other):
  56. raise ValueError, "Cannot compare vectors of different length"
  57. return other
  58. opmap = {
  59. "lt": (lambda a,b: a< b, operator.lt, operator.__lt__),
  60. "le": (lambda a,b: a<=b, operator.le, operator.__le__),
  61. "eq": (lambda a,b: a==b, operator.eq, operator.__eq__),
  62. "ne": (lambda a,b: a!=b, operator.ne, operator.__ne__),
  63. "gt": (lambda a,b: a> b, operator.gt, operator.__gt__),
  64. "ge": (lambda a,b: a>=b, operator.ge, operator.__ge__)
  65. }
  66. class VectorTest(unittest.TestCase):
  67. def checkfail(self, error, opname, *args):
  68. for op in opmap[opname]:
  69. self.assertRaises(error, op, *args)
  70. def checkequal(self, opname, a, b, expres):
  71. for op in opmap[opname]:
  72. realres = op(a, b)
  73. # can't use assertEqual(realres, expres) here
  74. self.assertEqual(len(realres), len(expres))
  75. for i in xrange(len(realres)):
  76. # results are bool, so we can use "is" here
  77. self.assertTrue(realres[i] is expres[i])
  78. def test_mixed(self):
  79. # check that comparisons involving Vector objects
  80. # which return rich results (i.e. Vectors with itemwise
  81. # comparison results) work
  82. a = Vector(range(2))
  83. b = Vector(range(3))
  84. # all comparisons should fail for different length
  85. for opname in opmap:
  86. self.checkfail(ValueError, opname, a, b)
  87. a = range(5)
  88. b = 5 * [2]
  89. # try mixed arguments (but not (a, b) as that won't return a bool vector)
  90. args = [(a, Vector(b)), (Vector(a), b), (Vector(a), Vector(b))]
  91. for (a, b) in args:
  92. self.checkequal("lt", a, b, [True, True, False, False, False])
  93. self.checkequal("le", a, b, [True, True, True, False, False])
  94. self.checkequal("eq", a, b, [False, False, True, False, False])
  95. self.checkequal("ne", a, b, [True, True, False, True, True ])
  96. self.checkequal("gt", a, b, [False, False, False, True, True ])
  97. self.checkequal("ge", a, b, [False, False, True, True, True ])
  98. for ops in opmap.itervalues():
  99. for op in ops:
  100. # calls __nonzero__, which should fail
  101. self.assertRaises(TypeError, bool, op(a, b))
  102. class NumberTest(unittest.TestCase):
  103. def test_basic(self):
  104. # Check that comparisons involving Number objects
  105. # give the same results give as comparing the
  106. # corresponding ints
  107. for a in xrange(3):
  108. for b in xrange(3):
  109. for typea in (int, Number):
  110. for typeb in (int, Number):
  111. if typea==typeb==int:
  112. continue # the combination int, int is useless
  113. ta = typea(a)
  114. tb = typeb(b)
  115. for ops in opmap.itervalues():
  116. for op in ops:
  117. realoutcome = op(a, b)
  118. testoutcome = op(ta, tb)
  119. self.assertEqual(realoutcome, testoutcome)
  120. def checkvalue(self, opname, a, b, expres):
  121. for typea in (int, Number):
  122. for typeb in (int, Number):
  123. ta = typea(a)
  124. tb = typeb(b)
  125. for op in opmap[opname]:
  126. realres = op(ta, tb)
  127. realres = getattr(realres, "x", realres)
  128. self.assertTrue(realres is expres)
  129. def test_values(self):
  130. # check all operators and all comparison results
  131. self.checkvalue("lt", 0, 0, False)
  132. self.checkvalue("le", 0, 0, True )
  133. self.checkvalue("eq", 0, 0, True )
  134. self.checkvalue("ne", 0, 0, False)
  135. self.checkvalue("gt", 0, 0, False)
  136. self.checkvalue("ge", 0, 0, True )
  137. self.checkvalue("lt", 0, 1, True )
  138. self.checkvalue("le", 0, 1, True )
  139. self.checkvalue("eq", 0, 1, False)
  140. self.checkvalue("ne", 0, 1, True )
  141. self.checkvalue("gt", 0, 1, False)
  142. self.checkvalue("ge", 0, 1, False)
  143. self.checkvalue("lt", 1, 0, False)
  144. self.checkvalue("le", 1, 0, False)
  145. self.checkvalue("eq", 1, 0, False)
  146. self.checkvalue("ne", 1, 0, True )
  147. self.checkvalue("gt", 1, 0, True )
  148. self.checkvalue("ge", 1, 0, True )
  149. class MiscTest(unittest.TestCase):
  150. def test_misbehavin(self):
  151. class Misb:
  152. def __lt__(self_, other): return 0
  153. def __gt__(self_, other): return 0
  154. def __eq__(self_, other): return 0
  155. def __le__(self_, other): self.fail("This shouldn't happen")
  156. def __ge__(self_, other): self.fail("This shouldn't happen")
  157. def __ne__(self_, other): self.fail("This shouldn't happen")
  158. def __cmp__(self_, other): raise RuntimeError, "expected"
  159. a = Misb()
  160. b = Misb()
  161. self.assertEqual(a<b, 0)
  162. self.assertEqual(a==b, 0)
  163. self.assertEqual(a>b, 0)
  164. self.assertRaises(RuntimeError, cmp, a, b)
  165. def test_not(self):
  166. # Check that exceptions in __nonzero__ are properly
  167. # propagated by the not operator
  168. import operator
  169. class Exc(Exception):
  170. pass
  171. class Bad:
  172. def __nonzero__(self):
  173. raise Exc
  174. def do(bad):
  175. not bad
  176. for func in (do, operator.not_):
  177. self.assertRaises(Exc, func, Bad())
  178. def test_recursion(self):
  179. # Check that comparison for recursive objects fails gracefully
  180. from UserList import UserList
  181. a = UserList()
  182. b = UserList()
  183. a.append(b)
  184. b.append(a)
  185. self.assertRaises(RuntimeError, operator.eq, a, b)
  186. self.assertRaises(RuntimeError, operator.ne, a, b)
  187. self.assertRaises(RuntimeError, operator.lt, a, b)
  188. self.assertRaises(RuntimeError, operator.le, a, b)
  189. self.assertRaises(RuntimeError, operator.gt, a, b)
  190. self.assertRaises(RuntimeError, operator.ge, a, b)
  191. b.append(17)
  192. # Even recursive lists of different lengths are different,
  193. # but they cannot be ordered
  194. self.assertTrue(not (a == b))
  195. self.assertTrue(a != b)
  196. self.assertRaises(RuntimeError, operator.lt, a, b)
  197. self.assertRaises(RuntimeError, operator.le, a, b)
  198. self.assertRaises(RuntimeError, operator.gt, a, b)
  199. self.assertRaises(RuntimeError, operator.ge, a, b)
  200. a.append(17)
  201. self.assertRaises(RuntimeError, operator.eq, a, b)
  202. self.assertRaises(RuntimeError, operator.ne, a, b)
  203. a.insert(0, 11)
  204. b.insert(0, 12)
  205. self.assertTrue(not (a == b))
  206. self.assertTrue(a != b)
  207. self.assertTrue(a < b)
  208. class DictTest(unittest.TestCase):
  209. def test_dicts(self):
  210. # Verify that __eq__ and __ne__ work for dicts even if the keys and
  211. # values don't support anything other than __eq__ and __ne__ (and
  212. # __hash__). Complex numbers are a fine example of that.
  213. import random
  214. imag1a = {}
  215. for i in range(50):
  216. imag1a[random.randrange(100)*1j] = random.randrange(100)*1j
  217. items = imag1a.items()
  218. random.shuffle(items)
  219. imag1b = {}
  220. for k, v in items:
  221. imag1b[k] = v
  222. imag2 = imag1b.copy()
  223. imag2[k] = v + 1.0
  224. self.assertTrue(imag1a == imag1a)
  225. self.assertTrue(imag1a == imag1b)
  226. self.assertTrue(imag2 == imag2)
  227. self.assertTrue(imag1a != imag2)
  228. for opname in ("lt", "le", "gt", "ge"):
  229. for op in opmap[opname]:
  230. self.assertRaises(TypeError, op, imag1a, imag2)
  231. class ListTest(unittest.TestCase):
  232. def test_coverage(self):
  233. # exercise all comparisons for lists
  234. x = [42]
  235. self.assertIs(x<x, False)
  236. self.assertIs(x<=x, True)
  237. self.assertIs(x==x, True)
  238. self.assertIs(x!=x, False)
  239. self.assertIs(x>x, False)
  240. self.assertIs(x>=x, True)
  241. y = [42, 42]
  242. self.assertIs(x<y, True)
  243. self.assertIs(x<=y, True)
  244. self.assertIs(x==y, False)
  245. self.assertIs(x!=y, True)
  246. self.assertIs(x>y, False)
  247. self.assertIs(x>=y, False)
  248. def test_badentry(self):
  249. # make sure that exceptions for item comparison are properly
  250. # propagated in list comparisons
  251. class Exc(Exception):
  252. pass
  253. class Bad:
  254. def __eq__(self, other):
  255. raise Exc
  256. x = [Bad()]
  257. y = [Bad()]
  258. for op in opmap["eq"]:
  259. self.assertRaises(Exc, op, x, y)
  260. def test_goodentry(self):
  261. # This test exercises the final call to PyObject_RichCompare()
  262. # in Objects/listobject.c::list_richcompare()
  263. class Good:
  264. def __lt__(self, other):
  265. return True
  266. x = [Good()]
  267. y = [Good()]
  268. for op in opmap["lt"]:
  269. self.assertIs(op(x, y), True)
  270. def test_main():
  271. test_support.run_unittest(VectorTest, NumberTest, MiscTest, ListTest)
  272. with test_support.check_py3k_warnings(("dict inequality comparisons "
  273. "not supported in 3.x",
  274. DeprecationWarning)):
  275. test_support.run_unittest(DictTest)
  276. if __name__ == "__main__":
  277. test_main()