userfunctions.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. #-*- coding: ISO-8859-1 -*-
  2. # pysqlite2/test/userfunctions.py: tests for user-defined functions and
  3. # aggregates.
  4. #
  5. # Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de>
  6. #
  7. # This file is part of pysqlite.
  8. #
  9. # This software is provided 'as-is', without any express or implied
  10. # warranty. In no event will the authors be held liable for any damages
  11. # arising from the use of this software.
  12. #
  13. # Permission is granted to anyone to use this software for any purpose,
  14. # including commercial applications, and to alter it and redistribute it
  15. # freely, subject to the following restrictions:
  16. #
  17. # 1. The origin of this software must not be misrepresented; you must not
  18. # claim that you wrote the original software. If you use this software
  19. # in a product, an acknowledgment in the product documentation would be
  20. # appreciated but is not required.
  21. # 2. Altered source versions must be plainly marked as such, and must not be
  22. # misrepresented as being the original software.
  23. # 3. This notice may not be removed or altered from any source distribution.
  24. import unittest
  25. import sqlite3 as sqlite
  26. from test import test_support
  27. def func_returntext():
  28. return "foo"
  29. def func_returnunicode():
  30. return u"bar"
  31. def func_returnint():
  32. return 42
  33. def func_returnfloat():
  34. return 3.14
  35. def func_returnnull():
  36. return None
  37. def func_returnblob():
  38. with test_support.check_py3k_warnings():
  39. return buffer("blob")
  40. def func_returnlonglong():
  41. return 1<<31
  42. def func_raiseexception():
  43. 5 // 0
  44. def func_isstring(v):
  45. return type(v) is unicode
  46. def func_isint(v):
  47. return type(v) is int
  48. def func_isfloat(v):
  49. return type(v) is float
  50. def func_isnone(v):
  51. return type(v) is type(None)
  52. def func_isblob(v):
  53. return type(v) is buffer
  54. def func_islonglong(v):
  55. return isinstance(v, (int, long)) and v >= 1<<31
  56. class AggrNoStep:
  57. def __init__(self):
  58. pass
  59. def finalize(self):
  60. return 1
  61. class AggrNoFinalize:
  62. def __init__(self):
  63. pass
  64. def step(self, x):
  65. pass
  66. class AggrExceptionInInit:
  67. def __init__(self):
  68. 5 // 0
  69. def step(self, x):
  70. pass
  71. def finalize(self):
  72. pass
  73. class AggrExceptionInStep:
  74. def __init__(self):
  75. pass
  76. def step(self, x):
  77. 5 // 0
  78. def finalize(self):
  79. return 42
  80. class AggrExceptionInFinalize:
  81. def __init__(self):
  82. pass
  83. def step(self, x):
  84. pass
  85. def finalize(self):
  86. 5 // 0
  87. class AggrCheckType:
  88. def __init__(self):
  89. self.val = None
  90. def step(self, whichType, val):
  91. theType = {"str": unicode, "int": int, "float": float, "None": type(None), "blob": buffer}
  92. self.val = int(theType[whichType] is type(val))
  93. def finalize(self):
  94. return self.val
  95. class AggrSum:
  96. def __init__(self):
  97. self.val = 0.0
  98. def step(self, val):
  99. self.val += val
  100. def finalize(self):
  101. return self.val
  102. class FunctionTests(unittest.TestCase):
  103. def setUp(self):
  104. self.con = sqlite.connect(":memory:")
  105. self.con.create_function("returntext", 0, func_returntext)
  106. self.con.create_function("returnunicode", 0, func_returnunicode)
  107. self.con.create_function("returnint", 0, func_returnint)
  108. self.con.create_function("returnfloat", 0, func_returnfloat)
  109. self.con.create_function("returnnull", 0, func_returnnull)
  110. self.con.create_function("returnblob", 0, func_returnblob)
  111. self.con.create_function("returnlonglong", 0, func_returnlonglong)
  112. self.con.create_function("raiseexception", 0, func_raiseexception)
  113. self.con.create_function("isstring", 1, func_isstring)
  114. self.con.create_function("isint", 1, func_isint)
  115. self.con.create_function("isfloat", 1, func_isfloat)
  116. self.con.create_function("isnone", 1, func_isnone)
  117. self.con.create_function("isblob", 1, func_isblob)
  118. self.con.create_function("islonglong", 1, func_islonglong)
  119. def tearDown(self):
  120. self.con.close()
  121. def CheckFuncErrorOnCreate(self):
  122. try:
  123. self.con.create_function("bla", -100, lambda x: 2*x)
  124. self.fail("should have raised an OperationalError")
  125. except sqlite.OperationalError:
  126. pass
  127. def CheckFuncRefCount(self):
  128. def getfunc():
  129. def f():
  130. return 1
  131. return f
  132. f = getfunc()
  133. globals()["foo"] = f
  134. # self.con.create_function("reftest", 0, getfunc())
  135. self.con.create_function("reftest", 0, f)
  136. cur = self.con.cursor()
  137. cur.execute("select reftest()")
  138. def CheckFuncReturnText(self):
  139. cur = self.con.cursor()
  140. cur.execute("select returntext()")
  141. val = cur.fetchone()[0]
  142. self.assertEqual(type(val), unicode)
  143. self.assertEqual(val, "foo")
  144. def CheckFuncReturnUnicode(self):
  145. cur = self.con.cursor()
  146. cur.execute("select returnunicode()")
  147. val = cur.fetchone()[0]
  148. self.assertEqual(type(val), unicode)
  149. self.assertEqual(val, u"bar")
  150. def CheckFuncReturnInt(self):
  151. cur = self.con.cursor()
  152. cur.execute("select returnint()")
  153. val = cur.fetchone()[0]
  154. self.assertEqual(type(val), int)
  155. self.assertEqual(val, 42)
  156. def CheckFuncReturnFloat(self):
  157. cur = self.con.cursor()
  158. cur.execute("select returnfloat()")
  159. val = cur.fetchone()[0]
  160. self.assertEqual(type(val), float)
  161. if val < 3.139 or val > 3.141:
  162. self.fail("wrong value")
  163. def CheckFuncReturnNull(self):
  164. cur = self.con.cursor()
  165. cur.execute("select returnnull()")
  166. val = cur.fetchone()[0]
  167. self.assertEqual(type(val), type(None))
  168. self.assertEqual(val, None)
  169. def CheckFuncReturnBlob(self):
  170. cur = self.con.cursor()
  171. cur.execute("select returnblob()")
  172. val = cur.fetchone()[0]
  173. with test_support.check_py3k_warnings():
  174. self.assertEqual(type(val), buffer)
  175. self.assertEqual(val, buffer("blob"))
  176. def CheckFuncReturnLongLong(self):
  177. cur = self.con.cursor()
  178. cur.execute("select returnlonglong()")
  179. val = cur.fetchone()[0]
  180. self.assertEqual(val, 1<<31)
  181. def CheckFuncException(self):
  182. cur = self.con.cursor()
  183. try:
  184. cur.execute("select raiseexception()")
  185. cur.fetchone()
  186. self.fail("should have raised OperationalError")
  187. except sqlite.OperationalError, e:
  188. self.assertEqual(e.args[0], 'user-defined function raised exception')
  189. def CheckParamString(self):
  190. cur = self.con.cursor()
  191. cur.execute("select isstring(?)", ("foo",))
  192. val = cur.fetchone()[0]
  193. self.assertEqual(val, 1)
  194. def CheckParamInt(self):
  195. cur = self.con.cursor()
  196. cur.execute("select isint(?)", (42,))
  197. val = cur.fetchone()[0]
  198. self.assertEqual(val, 1)
  199. def CheckParamFloat(self):
  200. cur = self.con.cursor()
  201. cur.execute("select isfloat(?)", (3.14,))
  202. val = cur.fetchone()[0]
  203. self.assertEqual(val, 1)
  204. def CheckParamNone(self):
  205. cur = self.con.cursor()
  206. cur.execute("select isnone(?)", (None,))
  207. val = cur.fetchone()[0]
  208. self.assertEqual(val, 1)
  209. def CheckParamBlob(self):
  210. cur = self.con.cursor()
  211. with test_support.check_py3k_warnings():
  212. cur.execute("select isblob(?)", (buffer("blob"),))
  213. val = cur.fetchone()[0]
  214. self.assertEqual(val, 1)
  215. def CheckParamLongLong(self):
  216. cur = self.con.cursor()
  217. cur.execute("select islonglong(?)", (1<<42,))
  218. val = cur.fetchone()[0]
  219. self.assertEqual(val, 1)
  220. class AggregateTests(unittest.TestCase):
  221. def setUp(self):
  222. self.con = sqlite.connect(":memory:")
  223. cur = self.con.cursor()
  224. cur.execute("""
  225. create table test(
  226. t text,
  227. i integer,
  228. f float,
  229. n,
  230. b blob
  231. )
  232. """)
  233. with test_support.check_py3k_warnings():
  234. cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
  235. ("foo", 5, 3.14, None, buffer("blob"),))
  236. self.con.create_aggregate("nostep", 1, AggrNoStep)
  237. self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
  238. self.con.create_aggregate("excInit", 1, AggrExceptionInInit)
  239. self.con.create_aggregate("excStep", 1, AggrExceptionInStep)
  240. self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize)
  241. self.con.create_aggregate("checkType", 2, AggrCheckType)
  242. self.con.create_aggregate("mysum", 1, AggrSum)
  243. def tearDown(self):
  244. #self.cur.close()
  245. #self.con.close()
  246. pass
  247. def CheckAggrErrorOnCreate(self):
  248. try:
  249. self.con.create_function("bla", -100, AggrSum)
  250. self.fail("should have raised an OperationalError")
  251. except sqlite.OperationalError:
  252. pass
  253. def CheckAggrNoStep(self):
  254. cur = self.con.cursor()
  255. try:
  256. cur.execute("select nostep(t) from test")
  257. self.fail("should have raised an AttributeError")
  258. except AttributeError, e:
  259. self.assertEqual(e.args[0], "AggrNoStep instance has no attribute 'step'")
  260. def CheckAggrNoFinalize(self):
  261. cur = self.con.cursor()
  262. try:
  263. cur.execute("select nofinalize(t) from test")
  264. val = cur.fetchone()[0]
  265. self.fail("should have raised an OperationalError")
  266. except sqlite.OperationalError, e:
  267. self.assertEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error")
  268. def CheckAggrExceptionInInit(self):
  269. cur = self.con.cursor()
  270. try:
  271. cur.execute("select excInit(t) from test")
  272. val = cur.fetchone()[0]
  273. self.fail("should have raised an OperationalError")
  274. except sqlite.OperationalError, e:
  275. self.assertEqual(e.args[0], "user-defined aggregate's '__init__' method raised error")
  276. def CheckAggrExceptionInStep(self):
  277. cur = self.con.cursor()
  278. try:
  279. cur.execute("select excStep(t) from test")
  280. val = cur.fetchone()[0]
  281. self.fail("should have raised an OperationalError")
  282. except sqlite.OperationalError, e:
  283. self.assertEqual(e.args[0], "user-defined aggregate's 'step' method raised error")
  284. def CheckAggrExceptionInFinalize(self):
  285. cur = self.con.cursor()
  286. try:
  287. cur.execute("select excFinalize(t) from test")
  288. val = cur.fetchone()[0]
  289. self.fail("should have raised an OperationalError")
  290. except sqlite.OperationalError, e:
  291. self.assertEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error")
  292. def CheckAggrCheckParamStr(self):
  293. cur = self.con.cursor()
  294. cur.execute("select checkType('str', ?)", ("foo",))
  295. val = cur.fetchone()[0]
  296. self.assertEqual(val, 1)
  297. def CheckAggrCheckParamInt(self):
  298. cur = self.con.cursor()
  299. cur.execute("select checkType('int', ?)", (42,))
  300. val = cur.fetchone()[0]
  301. self.assertEqual(val, 1)
  302. def CheckAggrCheckParamFloat(self):
  303. cur = self.con.cursor()
  304. cur.execute("select checkType('float', ?)", (3.14,))
  305. val = cur.fetchone()[0]
  306. self.assertEqual(val, 1)
  307. def CheckAggrCheckParamNone(self):
  308. cur = self.con.cursor()
  309. cur.execute("select checkType('None', ?)", (None,))
  310. val = cur.fetchone()[0]
  311. self.assertEqual(val, 1)
  312. def CheckAggrCheckParamBlob(self):
  313. cur = self.con.cursor()
  314. with test_support.check_py3k_warnings():
  315. cur.execute("select checkType('blob', ?)", (buffer("blob"),))
  316. val = cur.fetchone()[0]
  317. self.assertEqual(val, 1)
  318. def CheckAggrCheckAggrSum(self):
  319. cur = self.con.cursor()
  320. cur.execute("delete from test")
  321. cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)])
  322. cur.execute("select mysum(i) from test")
  323. val = cur.fetchone()[0]
  324. self.assertEqual(val, 60)
  325. class AuthorizerTests(unittest.TestCase):
  326. @staticmethod
  327. def authorizer_cb(action, arg1, arg2, dbname, source):
  328. if action != sqlite.SQLITE_SELECT:
  329. return sqlite.SQLITE_DENY
  330. if arg2 == 'c2' or arg1 == 't2':
  331. return sqlite.SQLITE_DENY
  332. return sqlite.SQLITE_OK
  333. def setUp(self):
  334. self.con = sqlite.connect(":memory:")
  335. self.con.executescript("""
  336. create table t1 (c1, c2);
  337. create table t2 (c1, c2);
  338. insert into t1 (c1, c2) values (1, 2);
  339. insert into t2 (c1, c2) values (4, 5);
  340. """)
  341. # For our security test:
  342. self.con.execute("select c2 from t2")
  343. self.con.set_authorizer(self.authorizer_cb)
  344. def tearDown(self):
  345. pass
  346. def test_table_access(self):
  347. try:
  348. self.con.execute("select * from t2")
  349. except sqlite.DatabaseError, e:
  350. if not e.args[0].endswith("prohibited"):
  351. self.fail("wrong exception text: %s" % e.args[0])
  352. return
  353. self.fail("should have raised an exception due to missing privileges")
  354. def test_column_access(self):
  355. try:
  356. self.con.execute("select c2 from t1")
  357. except sqlite.DatabaseError, e:
  358. if not e.args[0].endswith("prohibited"):
  359. self.fail("wrong exception text: %s" % e.args[0])
  360. return
  361. self.fail("should have raised an exception due to missing privileges")
  362. class AuthorizerRaiseExceptionTests(AuthorizerTests):
  363. @staticmethod
  364. def authorizer_cb(action, arg1, arg2, dbname, source):
  365. if action != sqlite.SQLITE_SELECT:
  366. raise ValueError
  367. if arg2 == 'c2' or arg1 == 't2':
  368. raise ValueError
  369. return sqlite.SQLITE_OK
  370. class AuthorizerIllegalTypeTests(AuthorizerTests):
  371. @staticmethod
  372. def authorizer_cb(action, arg1, arg2, dbname, source):
  373. if action != sqlite.SQLITE_SELECT:
  374. return 0.0
  375. if arg2 == 'c2' or arg1 == 't2':
  376. return 0.0
  377. return sqlite.SQLITE_OK
  378. class AuthorizerLargeIntegerTests(AuthorizerTests):
  379. @staticmethod
  380. def authorizer_cb(action, arg1, arg2, dbname, source):
  381. if action != sqlite.SQLITE_SELECT:
  382. return 2**32
  383. if arg2 == 'c2' or arg1 == 't2':
  384. return 2**32
  385. return sqlite.SQLITE_OK
  386. def suite():
  387. function_suite = unittest.makeSuite(FunctionTests, "Check")
  388. aggregate_suite = unittest.makeSuite(AggregateTests, "Check")
  389. authorizer_suite = unittest.makeSuite(AuthorizerTests)
  390. return unittest.TestSuite((
  391. function_suite,
  392. aggregate_suite,
  393. authorizer_suite,
  394. unittest.makeSuite(AuthorizerRaiseExceptionTests),
  395. unittest.makeSuite(AuthorizerIllegalTypeTests),
  396. unittest.makeSuite(AuthorizerLargeIntegerTests),
  397. ))
  398. def test():
  399. runner = unittest.TextTestRunner()
  400. runner.run(suite())
  401. if __name__ == "__main__":
  402. test()