test_wsgiref.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571
  1. from unittest import TestCase
  2. from wsgiref.util import setup_testing_defaults
  3. from wsgiref.headers import Headers
  4. from wsgiref.handlers import BaseHandler, BaseCGIHandler
  5. from wsgiref import util
  6. from wsgiref.validate import validator
  7. from wsgiref.simple_server import WSGIServer, WSGIRequestHandler
  8. from wsgiref.simple_server import make_server
  9. from StringIO import StringIO
  10. from SocketServer import BaseServer
  11. import os
  12. import re
  13. import sys
  14. from test import test_support
  15. class MockServer(WSGIServer):
  16. """Non-socket HTTP server"""
  17. def __init__(self, server_address, RequestHandlerClass):
  18. BaseServer.__init__(self, server_address, RequestHandlerClass)
  19. self.server_bind()
  20. def server_bind(self):
  21. host, port = self.server_address
  22. self.server_name = host
  23. self.server_port = port
  24. self.setup_environ()
  25. class MockHandler(WSGIRequestHandler):
  26. """Non-socket HTTP handler"""
  27. def setup(self):
  28. self.connection = self.request
  29. self.rfile, self.wfile = self.connection
  30. def finish(self):
  31. pass
  32. def hello_app(environ,start_response):
  33. start_response("200 OK", [
  34. ('Content-Type','text/plain'),
  35. ('Date','Mon, 05 Jun 2006 18:49:54 GMT')
  36. ])
  37. return ["Hello, world!"]
  38. def run_amock(app=hello_app, data="GET / HTTP/1.0\n\n"):
  39. server = make_server("", 80, app, MockServer, MockHandler)
  40. inp, out, err, olderr = StringIO(data), StringIO(), StringIO(), sys.stderr
  41. sys.stderr = err
  42. try:
  43. server.finish_request((inp,out), ("127.0.0.1",8888))
  44. finally:
  45. sys.stderr = olderr
  46. return out.getvalue(), err.getvalue()
  47. def compare_generic_iter(make_it,match):
  48. """Utility to compare a generic 2.1/2.2+ iterator with an iterable
  49. If running under Python 2.2+, this tests the iterator using iter()/next(),
  50. as well as __getitem__. 'make_it' must be a function returning a fresh
  51. iterator to be tested (since this may test the iterator twice)."""
  52. it = make_it()
  53. n = 0
  54. for item in match:
  55. if not it[n]==item: raise AssertionError
  56. n+=1
  57. try:
  58. it[n]
  59. except IndexError:
  60. pass
  61. else:
  62. raise AssertionError("Too many items from __getitem__",it)
  63. try:
  64. iter, StopIteration
  65. except NameError:
  66. pass
  67. else:
  68. # Only test iter mode under 2.2+
  69. it = make_it()
  70. if not iter(it) is it: raise AssertionError
  71. for item in match:
  72. if not it.next()==item: raise AssertionError
  73. try:
  74. it.next()
  75. except StopIteration:
  76. pass
  77. else:
  78. raise AssertionError("Too many items from .next()",it)
  79. class IntegrationTests(TestCase):
  80. def check_hello(self, out, has_length=True):
  81. self.assertEqual(out,
  82. "HTTP/1.0 200 OK\r\n"
  83. "Server: WSGIServer/0.1 Python/"+sys.version.split()[0]+"\r\n"
  84. "Content-Type: text/plain\r\n"
  85. "Date: Mon, 05 Jun 2006 18:49:54 GMT\r\n" +
  86. (has_length and "Content-Length: 13\r\n" or "") +
  87. "\r\n"
  88. "Hello, world!"
  89. )
  90. def test_plain_hello(self):
  91. out, err = run_amock()
  92. self.check_hello(out)
  93. def test_request_length(self):
  94. out, err = run_amock(data="GET " + ("x" * 65537) + " HTTP/1.0\n\n")
  95. self.assertEqual(out.splitlines()[0],
  96. "HTTP/1.0 414 Request-URI Too Long")
  97. def test_validated_hello(self):
  98. out, err = run_amock(validator(hello_app))
  99. # the middleware doesn't support len(), so content-length isn't there
  100. self.check_hello(out, has_length=False)
  101. def test_simple_validation_error(self):
  102. def bad_app(environ,start_response):
  103. start_response("200 OK", ('Content-Type','text/plain'))
  104. return ["Hello, world!"]
  105. out, err = run_amock(validator(bad_app))
  106. self.assertTrue(out.endswith(
  107. "A server error occurred. Please contact the administrator."
  108. ))
  109. self.assertEqual(
  110. err.splitlines()[-2],
  111. "AssertionError: Headers (('Content-Type', 'text/plain')) must"
  112. " be of type list: <type 'tuple'>"
  113. )
  114. class UtilityTests(TestCase):
  115. def checkShift(self,sn_in,pi_in,part,sn_out,pi_out):
  116. env = {'SCRIPT_NAME':sn_in,'PATH_INFO':pi_in}
  117. util.setup_testing_defaults(env)
  118. self.assertEqual(util.shift_path_info(env),part)
  119. self.assertEqual(env['PATH_INFO'],pi_out)
  120. self.assertEqual(env['SCRIPT_NAME'],sn_out)
  121. return env
  122. def checkDefault(self, key, value, alt=None):
  123. # Check defaulting when empty
  124. env = {}
  125. util.setup_testing_defaults(env)
  126. if isinstance(value, StringIO):
  127. self.assertIsInstance(env[key], StringIO)
  128. else:
  129. self.assertEqual(env[key], value)
  130. # Check existing value
  131. env = {key:alt}
  132. util.setup_testing_defaults(env)
  133. self.assertIs(env[key], alt)
  134. def checkCrossDefault(self,key,value,**kw):
  135. util.setup_testing_defaults(kw)
  136. self.assertEqual(kw[key],value)
  137. def checkAppURI(self,uri,**kw):
  138. util.setup_testing_defaults(kw)
  139. self.assertEqual(util.application_uri(kw),uri)
  140. def checkReqURI(self,uri,query=1,**kw):
  141. util.setup_testing_defaults(kw)
  142. self.assertEqual(util.request_uri(kw,query),uri)
  143. def checkFW(self,text,size,match):
  144. def make_it(text=text,size=size):
  145. return util.FileWrapper(StringIO(text),size)
  146. compare_generic_iter(make_it,match)
  147. it = make_it()
  148. self.assertFalse(it.filelike.closed)
  149. for item in it:
  150. pass
  151. self.assertFalse(it.filelike.closed)
  152. it.close()
  153. self.assertTrue(it.filelike.closed)
  154. def testSimpleShifts(self):
  155. self.checkShift('','/', '', '/', '')
  156. self.checkShift('','/x', 'x', '/x', '')
  157. self.checkShift('/','', None, '/', '')
  158. self.checkShift('/a','/x/y', 'x', '/a/x', '/y')
  159. self.checkShift('/a','/x/', 'x', '/a/x', '/')
  160. def testNormalizedShifts(self):
  161. self.checkShift('/a/b', '/../y', '..', '/a', '/y')
  162. self.checkShift('', '/../y', '..', '', '/y')
  163. self.checkShift('/a/b', '//y', 'y', '/a/b/y', '')
  164. self.checkShift('/a/b', '//y/', 'y', '/a/b/y', '/')
  165. self.checkShift('/a/b', '/./y', 'y', '/a/b/y', '')
  166. self.checkShift('/a/b', '/./y/', 'y', '/a/b/y', '/')
  167. self.checkShift('/a/b', '///./..//y/.//', '..', '/a', '/y/')
  168. self.checkShift('/a/b', '///', '', '/a/b/', '')
  169. self.checkShift('/a/b', '/.//', '', '/a/b/', '')
  170. self.checkShift('/a/b', '/x//', 'x', '/a/b/x', '/')
  171. self.checkShift('/a/b', '/.', None, '/a/b', '')
  172. def testDefaults(self):
  173. for key, value in [
  174. ('SERVER_NAME','127.0.0.1'),
  175. ('SERVER_PORT', '80'),
  176. ('SERVER_PROTOCOL','HTTP/1.0'),
  177. ('HTTP_HOST','127.0.0.1'),
  178. ('REQUEST_METHOD','GET'),
  179. ('SCRIPT_NAME',''),
  180. ('PATH_INFO','/'),
  181. ('wsgi.version', (1,0)),
  182. ('wsgi.run_once', 0),
  183. ('wsgi.multithread', 0),
  184. ('wsgi.multiprocess', 0),
  185. ('wsgi.input', StringIO("")),
  186. ('wsgi.errors', StringIO()),
  187. ('wsgi.url_scheme','http'),
  188. ]:
  189. self.checkDefault(key,value)
  190. def testCrossDefaults(self):
  191. self.checkCrossDefault('HTTP_HOST',"foo.bar",SERVER_NAME="foo.bar")
  192. self.checkCrossDefault('wsgi.url_scheme',"https",HTTPS="on")
  193. self.checkCrossDefault('wsgi.url_scheme',"https",HTTPS="1")
  194. self.checkCrossDefault('wsgi.url_scheme',"https",HTTPS="yes")
  195. self.checkCrossDefault('wsgi.url_scheme',"http",HTTPS="foo")
  196. self.checkCrossDefault('SERVER_PORT',"80",HTTPS="foo")
  197. self.checkCrossDefault('SERVER_PORT',"443",HTTPS="on")
  198. def testGuessScheme(self):
  199. self.assertEqual(util.guess_scheme({}), "http")
  200. self.assertEqual(util.guess_scheme({'HTTPS':"foo"}), "http")
  201. self.assertEqual(util.guess_scheme({'HTTPS':"on"}), "https")
  202. self.assertEqual(util.guess_scheme({'HTTPS':"yes"}), "https")
  203. self.assertEqual(util.guess_scheme({'HTTPS':"1"}), "https")
  204. def testAppURIs(self):
  205. self.checkAppURI("http://127.0.0.1/")
  206. self.checkAppURI("http://127.0.0.1/spam", SCRIPT_NAME="/spam")
  207. self.checkAppURI("http://127.0.0.1/sp%E4m", SCRIPT_NAME="/sp\xe4m")
  208. self.checkAppURI("http://spam.example.com:2071/",
  209. HTTP_HOST="spam.example.com:2071", SERVER_PORT="2071")
  210. self.checkAppURI("http://spam.example.com/",
  211. SERVER_NAME="spam.example.com")
  212. self.checkAppURI("http://127.0.0.1/",
  213. HTTP_HOST="127.0.0.1", SERVER_NAME="spam.example.com")
  214. self.checkAppURI("https://127.0.0.1/", HTTPS="on")
  215. self.checkAppURI("http://127.0.0.1:8000/", SERVER_PORT="8000",
  216. HTTP_HOST=None)
  217. def testReqURIs(self):
  218. self.checkReqURI("http://127.0.0.1/")
  219. self.checkReqURI("http://127.0.0.1/spam", SCRIPT_NAME="/spam")
  220. self.checkReqURI("http://127.0.0.1/sp%E4m", SCRIPT_NAME="/sp\xe4m")
  221. self.checkReqURI("http://127.0.0.1/spammity/spam",
  222. SCRIPT_NAME="/spammity", PATH_INFO="/spam")
  223. self.checkReqURI("http://127.0.0.1/spammity/sp%E4m",
  224. SCRIPT_NAME="/spammity", PATH_INFO="/sp\xe4m")
  225. self.checkReqURI("http://127.0.0.1/spammity/spam;ham",
  226. SCRIPT_NAME="/spammity", PATH_INFO="/spam;ham")
  227. self.checkReqURI("http://127.0.0.1/spammity/spam;cookie=1234,5678",
  228. SCRIPT_NAME="/spammity", PATH_INFO="/spam;cookie=1234,5678")
  229. self.checkReqURI("http://127.0.0.1/spammity/spam?say=ni",
  230. SCRIPT_NAME="/spammity", PATH_INFO="/spam",QUERY_STRING="say=ni")
  231. self.checkReqURI("http://127.0.0.1/spammity/spam?s%E4y=ni",
  232. SCRIPT_NAME="/spammity", PATH_INFO="/spam",QUERY_STRING="s%E4y=ni")
  233. self.checkReqURI("http://127.0.0.1/spammity/spam", 0,
  234. SCRIPT_NAME="/spammity", PATH_INFO="/spam",QUERY_STRING="say=ni")
  235. def testFileWrapper(self):
  236. self.checkFW("xyz"*50, 120, ["xyz"*40,"xyz"*10])
  237. def testHopByHop(self):
  238. for hop in (
  239. "Connection Keep-Alive Proxy-Authenticate Proxy-Authorization "
  240. "TE Trailers Transfer-Encoding Upgrade"
  241. ).split():
  242. for alt in hop, hop.title(), hop.upper(), hop.lower():
  243. self.assertTrue(util.is_hop_by_hop(alt))
  244. # Not comprehensive, just a few random header names
  245. for hop in (
  246. "Accept Cache-Control Date Pragma Trailer Via Warning"
  247. ).split():
  248. for alt in hop, hop.title(), hop.upper(), hop.lower():
  249. self.assertFalse(util.is_hop_by_hop(alt))
  250. class HeaderTests(TestCase):
  251. def testMappingInterface(self):
  252. test = [('x','y')]
  253. self.assertEqual(len(Headers([])),0)
  254. self.assertEqual(len(Headers(test[:])),1)
  255. self.assertEqual(Headers(test[:]).keys(), ['x'])
  256. self.assertEqual(Headers(test[:]).values(), ['y'])
  257. self.assertEqual(Headers(test[:]).items(), test)
  258. self.assertIsNot(Headers(test).items(), test) # must be copy!
  259. h=Headers([])
  260. del h['foo'] # should not raise an error
  261. h['Foo'] = 'bar'
  262. for m in h.has_key, h.__contains__, h.get, h.get_all, h.__getitem__:
  263. self.assertTrue(m('foo'))
  264. self.assertTrue(m('Foo'))
  265. self.assertTrue(m('FOO'))
  266. self.assertFalse(m('bar'))
  267. self.assertEqual(h['foo'],'bar')
  268. h['foo'] = 'baz'
  269. self.assertEqual(h['FOO'],'baz')
  270. self.assertEqual(h.get_all('foo'),['baz'])
  271. self.assertEqual(h.get("foo","whee"), "baz")
  272. self.assertEqual(h.get("zoo","whee"), "whee")
  273. self.assertEqual(h.setdefault("foo","whee"), "baz")
  274. self.assertEqual(h.setdefault("zoo","whee"), "whee")
  275. self.assertEqual(h["foo"],"baz")
  276. self.assertEqual(h["zoo"],"whee")
  277. def testRequireList(self):
  278. self.assertRaises(TypeError, Headers, "foo")
  279. def testExtras(self):
  280. h = Headers([])
  281. self.assertEqual(str(h),'\r\n')
  282. h.add_header('foo','bar',baz="spam")
  283. self.assertEqual(h['foo'], 'bar; baz="spam"')
  284. self.assertEqual(str(h),'foo: bar; baz="spam"\r\n\r\n')
  285. h.add_header('Foo','bar',cheese=None)
  286. self.assertEqual(h.get_all('foo'),
  287. ['bar; baz="spam"', 'bar; cheese'])
  288. self.assertEqual(str(h),
  289. 'foo: bar; baz="spam"\r\n'
  290. 'Foo: bar; cheese\r\n'
  291. '\r\n'
  292. )
  293. class ErrorHandler(BaseCGIHandler):
  294. """Simple handler subclass for testing BaseHandler"""
  295. # BaseHandler records the OS environment at import time, but envvars
  296. # might have been changed later by other tests, which trips up
  297. # HandlerTests.testEnviron().
  298. os_environ = dict(os.environ.items())
  299. def __init__(self,**kw):
  300. setup_testing_defaults(kw)
  301. BaseCGIHandler.__init__(
  302. self, StringIO(''), StringIO(), StringIO(), kw,
  303. multithread=True, multiprocess=True
  304. )
  305. class TestHandler(ErrorHandler):
  306. """Simple handler subclass for testing BaseHandler, w/error passthru"""
  307. def handle_error(self):
  308. raise # for testing, we want to see what's happening
  309. class HandlerTests(TestCase):
  310. def checkEnvironAttrs(self, handler):
  311. env = handler.environ
  312. for attr in [
  313. 'version','multithread','multiprocess','run_once','file_wrapper'
  314. ]:
  315. if attr=='file_wrapper' and handler.wsgi_file_wrapper is None:
  316. continue
  317. self.assertEqual(getattr(handler,'wsgi_'+attr),env['wsgi.'+attr])
  318. def checkOSEnviron(self,handler):
  319. empty = {}; setup_testing_defaults(empty)
  320. env = handler.environ
  321. from os import environ
  322. for k,v in environ.items():
  323. if k not in empty:
  324. self.assertEqual(env[k],v)
  325. for k,v in empty.items():
  326. self.assertIn(k, env)
  327. def testEnviron(self):
  328. h = TestHandler(X="Y")
  329. h.setup_environ()
  330. self.checkEnvironAttrs(h)
  331. self.checkOSEnviron(h)
  332. self.assertEqual(h.environ["X"],"Y")
  333. def testCGIEnviron(self):
  334. h = BaseCGIHandler(None,None,None,{})
  335. h.setup_environ()
  336. for key in 'wsgi.url_scheme', 'wsgi.input', 'wsgi.errors':
  337. self.assertIn(key, h.environ)
  338. def testScheme(self):
  339. h=TestHandler(HTTPS="on"); h.setup_environ()
  340. self.assertEqual(h.environ['wsgi.url_scheme'],'https')
  341. h=TestHandler(); h.setup_environ()
  342. self.assertEqual(h.environ['wsgi.url_scheme'],'http')
  343. def testAbstractMethods(self):
  344. h = BaseHandler()
  345. for name in [
  346. '_flush','get_stdin','get_stderr','add_cgi_vars'
  347. ]:
  348. self.assertRaises(NotImplementedError, getattr(h,name))
  349. self.assertRaises(NotImplementedError, h._write, "test")
  350. def testContentLength(self):
  351. # Demo one reason iteration is better than write()... ;)
  352. def trivial_app1(e,s):
  353. s('200 OK',[])
  354. return [e['wsgi.url_scheme']]
  355. def trivial_app2(e,s):
  356. s('200 OK',[])(e['wsgi.url_scheme'])
  357. return []
  358. def trivial_app4(e,s):
  359. # Simulate a response to a HEAD request
  360. s('200 OK',[('Content-Length', '12345')])
  361. return []
  362. h = TestHandler()
  363. h.run(trivial_app1)
  364. self.assertEqual(h.stdout.getvalue(),
  365. "Status: 200 OK\r\n"
  366. "Content-Length: 4\r\n"
  367. "\r\n"
  368. "http")
  369. h = TestHandler()
  370. h.run(trivial_app2)
  371. self.assertEqual(h.stdout.getvalue(),
  372. "Status: 200 OK\r\n"
  373. "\r\n"
  374. "http")
  375. h = TestHandler()
  376. h.run(trivial_app4)
  377. self.assertEqual(h.stdout.getvalue(),
  378. b'Status: 200 OK\r\n'
  379. b'Content-Length: 12345\r\n'
  380. b'\r\n')
  381. def testBasicErrorOutput(self):
  382. def non_error_app(e,s):
  383. s('200 OK',[])
  384. return []
  385. def error_app(e,s):
  386. raise AssertionError("This should be caught by handler")
  387. h = ErrorHandler()
  388. h.run(non_error_app)
  389. self.assertEqual(h.stdout.getvalue(),
  390. "Status: 200 OK\r\n"
  391. "Content-Length: 0\r\n"
  392. "\r\n")
  393. self.assertEqual(h.stderr.getvalue(),"")
  394. h = ErrorHandler()
  395. h.run(error_app)
  396. self.assertEqual(h.stdout.getvalue(),
  397. "Status: %s\r\n"
  398. "Content-Type: text/plain\r\n"
  399. "Content-Length: %d\r\n"
  400. "\r\n%s" % (h.error_status,len(h.error_body),h.error_body))
  401. self.assertNotEqual(h.stderr.getvalue().find("AssertionError"), -1)
  402. def testErrorAfterOutput(self):
  403. MSG = "Some output has been sent"
  404. def error_app(e,s):
  405. s("200 OK",[])(MSG)
  406. raise AssertionError("This should be caught by handler")
  407. h = ErrorHandler()
  408. h.run(error_app)
  409. self.assertEqual(h.stdout.getvalue(),
  410. "Status: 200 OK\r\n"
  411. "\r\n"+MSG)
  412. self.assertNotEqual(h.stderr.getvalue().find("AssertionError"), -1)
  413. def testHeaderFormats(self):
  414. def non_error_app(e,s):
  415. s('200 OK',[])
  416. return []
  417. stdpat = (
  418. r"HTTP/%s 200 OK\r\n"
  419. r"Date: \w{3}, [ 0123]\d \w{3} \d{4} \d\d:\d\d:\d\d GMT\r\n"
  420. r"%s" r"Content-Length: 0\r\n" r"\r\n"
  421. )
  422. shortpat = (
  423. "Status: 200 OK\r\n" "Content-Length: 0\r\n" "\r\n"
  424. )
  425. for ssw in "FooBar/1.0", None:
  426. sw = ssw and "Server: %s\r\n" % ssw or ""
  427. for version in "1.0", "1.1":
  428. for proto in "HTTP/0.9", "HTTP/1.0", "HTTP/1.1":
  429. h = TestHandler(SERVER_PROTOCOL=proto)
  430. h.origin_server = False
  431. h.http_version = version
  432. h.server_software = ssw
  433. h.run(non_error_app)
  434. self.assertEqual(shortpat,h.stdout.getvalue())
  435. h = TestHandler(SERVER_PROTOCOL=proto)
  436. h.origin_server = True
  437. h.http_version = version
  438. h.server_software = ssw
  439. h.run(non_error_app)
  440. if proto=="HTTP/0.9":
  441. self.assertEqual(h.stdout.getvalue(),"")
  442. else:
  443. self.assertTrue(
  444. re.match(stdpat%(version,sw), h.stdout.getvalue()),
  445. (stdpat%(version,sw), h.stdout.getvalue())
  446. )
  447. def testCloseOnError(self):
  448. side_effects = {'close_called': False}
  449. MSG = b"Some output has been sent"
  450. def error_app(e,s):
  451. s("200 OK",[])(MSG)
  452. class CrashyIterable(object):
  453. def __iter__(self):
  454. while True:
  455. yield b'blah'
  456. raise AssertionError("This should be caught by handler")
  457. def close(self):
  458. side_effects['close_called'] = True
  459. return CrashyIterable()
  460. h = ErrorHandler()
  461. h.run(error_app)
  462. self.assertEqual(side_effects['close_called'], True)
  463. def test_main():
  464. test_support.run_unittest(__name__)
  465. if __name__ == "__main__":
  466. test_main()