pool.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744
  1. #
  2. # Module providing the `Pool` class for managing a process pool
  3. #
  4. # multiprocessing/pool.py
  5. #
  6. # Copyright (c) 2006-2008, R Oudkerk
  7. # All rights reserved.
  8. #
  9. # Redistribution and use in source and binary forms, with or without
  10. # modification, are permitted provided that the following conditions
  11. # are met:
  12. #
  13. # 1. Redistributions of source code must retain the above copyright
  14. # notice, this list of conditions and the following disclaimer.
  15. # 2. Redistributions in binary form must reproduce the above copyright
  16. # notice, this list of conditions and the following disclaimer in the
  17. # documentation and/or other materials provided with the distribution.
  18. # 3. Neither the name of author nor the names of any contributors may be
  19. # used to endorse or promote products derived from this software
  20. # without specific prior written permission.
  21. #
  22. # THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS "AS IS" AND
  23. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  24. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  25. # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
  26. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  27. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
  28. # OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
  29. # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
  30. # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
  31. # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
  32. # SUCH DAMAGE.
  33. #
  34. __all__ = ['Pool']
  35. #
  36. # Imports
  37. #
  38. import threading
  39. import Queue
  40. import itertools
  41. import collections
  42. import time
  43. from multiprocessing import Process, cpu_count, TimeoutError
  44. from multiprocessing.util import Finalize, debug
  45. #
  46. # Constants representing the state of a pool
  47. #
  48. RUN = 0
  49. CLOSE = 1
  50. TERMINATE = 2
  51. #
  52. # Miscellaneous
  53. #
  54. job_counter = itertools.count()
  55. def mapstar(args):
  56. return map(*args)
  57. #
  58. # Code run by worker processes
  59. #
  60. class MaybeEncodingError(Exception):
  61. """Wraps possible unpickleable errors, so they can be
  62. safely sent through the socket."""
  63. def __init__(self, exc, value):
  64. self.exc = repr(exc)
  65. self.value = repr(value)
  66. super(MaybeEncodingError, self).__init__(self.exc, self.value)
  67. def __str__(self):
  68. return "Error sending result: '%s'. Reason: '%s'" % (self.value,
  69. self.exc)
  70. def __repr__(self):
  71. return "<MaybeEncodingError: %s>" % str(self)
  72. def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
  73. assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0)
  74. put = outqueue.put
  75. get = inqueue.get
  76. if hasattr(inqueue, '_writer'):
  77. inqueue._writer.close()
  78. outqueue._reader.close()
  79. if initializer is not None:
  80. initializer(*initargs)
  81. completed = 0
  82. while maxtasks is None or (maxtasks and completed < maxtasks):
  83. try:
  84. task = get()
  85. except (EOFError, IOError):
  86. debug('worker got EOFError or IOError -- exiting')
  87. break
  88. if task is None:
  89. debug('worker got sentinel -- exiting')
  90. break
  91. job, i, func, args, kwds = task
  92. try:
  93. result = (True, func(*args, **kwds))
  94. except Exception, e:
  95. result = (False, e)
  96. try:
  97. put((job, i, result))
  98. except Exception as e:
  99. wrapped = MaybeEncodingError(e, result[1])
  100. debug("Possible encoding error while sending result: %s" % (
  101. wrapped))
  102. put((job, i, (False, wrapped)))
  103. completed += 1
  104. debug('worker exiting after %d tasks' % completed)
  105. #
  106. # Class representing a process pool
  107. #
  108. class Pool(object):
  109. '''
  110. Class which supports an async version of the `apply()` builtin
  111. '''
  112. Process = Process
  113. def __init__(self, processes=None, initializer=None, initargs=(),
  114. maxtasksperchild=None):
  115. self._setup_queues()
  116. self._taskqueue = Queue.Queue()
  117. self._cache = {}
  118. self._state = RUN
  119. self._maxtasksperchild = maxtasksperchild
  120. self._initializer = initializer
  121. self._initargs = initargs
  122. if processes is None:
  123. try:
  124. processes = cpu_count()
  125. except NotImplementedError:
  126. processes = 1
  127. if processes < 1:
  128. raise ValueError("Number of processes must be at least 1")
  129. if initializer is not None and not hasattr(initializer, '__call__'):
  130. raise TypeError('initializer must be a callable')
  131. self._processes = processes
  132. self._pool = []
  133. self._repopulate_pool()
  134. self._worker_handler = threading.Thread(
  135. target=Pool._handle_workers,
  136. args=(self, )
  137. )
  138. self._worker_handler.daemon = True
  139. self._worker_handler._state = RUN
  140. self._worker_handler.start()
  141. self._task_handler = threading.Thread(
  142. target=Pool._handle_tasks,
  143. args=(self._taskqueue, self._quick_put, self._outqueue,
  144. self._pool, self._cache)
  145. )
  146. self._task_handler.daemon = True
  147. self._task_handler._state = RUN
  148. self._task_handler.start()
  149. self._result_handler = threading.Thread(
  150. target=Pool._handle_results,
  151. args=(self._outqueue, self._quick_get, self._cache)
  152. )
  153. self._result_handler.daemon = True
  154. self._result_handler._state = RUN
  155. self._result_handler.start()
  156. self._terminate = Finalize(
  157. self, self._terminate_pool,
  158. args=(self._taskqueue, self._inqueue, self._outqueue, self._pool,
  159. self._worker_handler, self._task_handler,
  160. self._result_handler, self._cache),
  161. exitpriority=15
  162. )
  163. def _join_exited_workers(self):
  164. """Cleanup after any worker processes which have exited due to reaching
  165. their specified lifetime. Returns True if any workers were cleaned up.
  166. """
  167. cleaned = False
  168. for i in reversed(range(len(self._pool))):
  169. worker = self._pool[i]
  170. if worker.exitcode is not None:
  171. # worker exited
  172. debug('cleaning up worker %d' % i)
  173. worker.join()
  174. cleaned = True
  175. del self._pool[i]
  176. return cleaned
  177. def _repopulate_pool(self):
  178. """Bring the number of pool processes up to the specified number,
  179. for use after reaping workers which have exited.
  180. """
  181. for i in range(self._processes - len(self._pool)):
  182. w = self.Process(target=worker,
  183. args=(self._inqueue, self._outqueue,
  184. self._initializer,
  185. self._initargs, self._maxtasksperchild)
  186. )
  187. self._pool.append(w)
  188. w.name = w.name.replace('Process', 'PoolWorker')
  189. w.daemon = True
  190. w.start()
  191. debug('added worker')
  192. def _maintain_pool(self):
  193. """Clean up any exited workers and start replacements for them.
  194. """
  195. if self._join_exited_workers():
  196. self._repopulate_pool()
  197. def _setup_queues(self):
  198. from .queues import SimpleQueue
  199. self._inqueue = SimpleQueue()
  200. self._outqueue = SimpleQueue()
  201. self._quick_put = self._inqueue._writer.send
  202. self._quick_get = self._outqueue._reader.recv
  203. def apply(self, func, args=(), kwds={}):
  204. '''
  205. Equivalent of `apply()` builtin
  206. '''
  207. assert self._state == RUN
  208. return self.apply_async(func, args, kwds).get()
  209. def map(self, func, iterable, chunksize=None):
  210. '''
  211. Equivalent of `map()` builtin
  212. '''
  213. assert self._state == RUN
  214. return self.map_async(func, iterable, chunksize).get()
  215. def imap(self, func, iterable, chunksize=1):
  216. '''
  217. Equivalent of `itertools.imap()` -- can be MUCH slower than `Pool.map()`
  218. '''
  219. assert self._state == RUN
  220. if chunksize == 1:
  221. result = IMapIterator(self._cache)
  222. self._taskqueue.put((((result._job, i, func, (x,), {})
  223. for i, x in enumerate(iterable)), result._set_length))
  224. return result
  225. else:
  226. assert chunksize > 1
  227. task_batches = Pool._get_tasks(func, iterable, chunksize)
  228. result = IMapIterator(self._cache)
  229. self._taskqueue.put((((result._job, i, mapstar, (x,), {})
  230. for i, x in enumerate(task_batches)), result._set_length))
  231. return (item for chunk in result for item in chunk)
  232. def imap_unordered(self, func, iterable, chunksize=1):
  233. '''
  234. Like `imap()` method but ordering of results is arbitrary
  235. '''
  236. assert self._state == RUN
  237. if chunksize == 1:
  238. result = IMapUnorderedIterator(self._cache)
  239. self._taskqueue.put((((result._job, i, func, (x,), {})
  240. for i, x in enumerate(iterable)), result._set_length))
  241. return result
  242. else:
  243. assert chunksize > 1
  244. task_batches = Pool._get_tasks(func, iterable, chunksize)
  245. result = IMapUnorderedIterator(self._cache)
  246. self._taskqueue.put((((result._job, i, mapstar, (x,), {})
  247. for i, x in enumerate(task_batches)), result._set_length))
  248. return (item for chunk in result for item in chunk)
  249. def apply_async(self, func, args=(), kwds={}, callback=None):
  250. '''
  251. Asynchronous equivalent of `apply()` builtin
  252. '''
  253. assert self._state == RUN
  254. result = ApplyResult(self._cache, callback)
  255. self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
  256. return result
  257. def map_async(self, func, iterable, chunksize=None, callback=None):
  258. '''
  259. Asynchronous equivalent of `map()` builtin
  260. '''
  261. assert self._state == RUN
  262. if not hasattr(iterable, '__len__'):
  263. iterable = list(iterable)
  264. if chunksize is None:
  265. chunksize, extra = divmod(len(iterable), len(self._pool) * 4)
  266. if extra:
  267. chunksize += 1
  268. if len(iterable) == 0:
  269. chunksize = 0
  270. task_batches = Pool._get_tasks(func, iterable, chunksize)
  271. result = MapResult(self._cache, chunksize, len(iterable), callback)
  272. self._taskqueue.put((((result._job, i, mapstar, (x,), {})
  273. for i, x in enumerate(task_batches)), None))
  274. return result
  275. @staticmethod
  276. def _handle_workers(pool):
  277. thread = threading.current_thread()
  278. # Keep maintaining workers until the cache gets drained, unless the pool
  279. # is terminated.
  280. while thread._state == RUN or (pool._cache and thread._state != TERMINATE):
  281. pool._maintain_pool()
  282. time.sleep(0.1)
  283. # send sentinel to stop workers
  284. pool._taskqueue.put(None)
  285. debug('worker handler exiting')
  286. @staticmethod
  287. def _handle_tasks(taskqueue, put, outqueue, pool, cache):
  288. thread = threading.current_thread()
  289. for taskseq, set_length in iter(taskqueue.get, None):
  290. task = None
  291. i = -1
  292. try:
  293. for i, task in enumerate(taskseq):
  294. if thread._state:
  295. debug('task handler found thread._state != RUN')
  296. break
  297. try:
  298. put(task)
  299. except Exception as e:
  300. job, ind = task[:2]
  301. try:
  302. cache[job]._set(ind, (False, e))
  303. except KeyError:
  304. pass
  305. else:
  306. if set_length:
  307. debug('doing set_length()')
  308. set_length(i+1)
  309. continue
  310. break
  311. except Exception as ex:
  312. job, ind = task[:2] if task else (0, 0)
  313. if job in cache:
  314. cache[job]._set(ind + 1, (False, ex))
  315. if set_length:
  316. debug('doing set_length()')
  317. set_length(i+1)
  318. else:
  319. debug('task handler got sentinel')
  320. try:
  321. # tell result handler to finish when cache is empty
  322. debug('task handler sending sentinel to result handler')
  323. outqueue.put(None)
  324. # tell workers there is no more work
  325. debug('task handler sending sentinel to workers')
  326. for p in pool:
  327. put(None)
  328. except IOError:
  329. debug('task handler got IOError when sending sentinels')
  330. debug('task handler exiting')
  331. @staticmethod
  332. def _handle_results(outqueue, get, cache):
  333. thread = threading.current_thread()
  334. while 1:
  335. try:
  336. task = get()
  337. except (IOError, EOFError):
  338. debug('result handler got EOFError/IOError -- exiting')
  339. return
  340. if thread._state:
  341. assert thread._state == TERMINATE
  342. debug('result handler found thread._state=TERMINATE')
  343. break
  344. if task is None:
  345. debug('result handler got sentinel')
  346. break
  347. job, i, obj = task
  348. try:
  349. cache[job]._set(i, obj)
  350. except KeyError:
  351. pass
  352. while cache and thread._state != TERMINATE:
  353. try:
  354. task = get()
  355. except (IOError, EOFError):
  356. debug('result handler got EOFError/IOError -- exiting')
  357. return
  358. if task is None:
  359. debug('result handler ignoring extra sentinel')
  360. continue
  361. job, i, obj = task
  362. try:
  363. cache[job]._set(i, obj)
  364. except KeyError:
  365. pass
  366. if hasattr(outqueue, '_reader'):
  367. debug('ensuring that outqueue is not full')
  368. # If we don't make room available in outqueue then
  369. # attempts to add the sentinel (None) to outqueue may
  370. # block. There is guaranteed to be no more than 2 sentinels.
  371. try:
  372. for i in range(10):
  373. if not outqueue._reader.poll():
  374. break
  375. get()
  376. except (IOError, EOFError):
  377. pass
  378. debug('result handler exiting: len(cache)=%s, thread._state=%s',
  379. len(cache), thread._state)
  380. @staticmethod
  381. def _get_tasks(func, it, size):
  382. it = iter(it)
  383. while 1:
  384. x = tuple(itertools.islice(it, size))
  385. if not x:
  386. return
  387. yield (func, x)
  388. def __reduce__(self):
  389. raise NotImplementedError(
  390. 'pool objects cannot be passed between processes or pickled'
  391. )
  392. def close(self):
  393. debug('closing pool')
  394. if self._state == RUN:
  395. self._state = CLOSE
  396. self._worker_handler._state = CLOSE
  397. def terminate(self):
  398. debug('terminating pool')
  399. self._state = TERMINATE
  400. self._worker_handler._state = TERMINATE
  401. self._terminate()
  402. def join(self):
  403. debug('joining pool')
  404. assert self._state in (CLOSE, TERMINATE)
  405. self._worker_handler.join()
  406. self._task_handler.join()
  407. self._result_handler.join()
  408. for p in self._pool:
  409. p.join()
  410. @staticmethod
  411. def _help_stuff_finish(inqueue, task_handler, size):
  412. # task_handler may be blocked trying to put items on inqueue
  413. debug('removing tasks from inqueue until task handler finished')
  414. inqueue._rlock.acquire()
  415. while task_handler.is_alive() and inqueue._reader.poll():
  416. inqueue._reader.recv()
  417. time.sleep(0)
  418. @classmethod
  419. def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool,
  420. worker_handler, task_handler, result_handler, cache):
  421. # this is guaranteed to only be called once
  422. debug('finalizing pool')
  423. worker_handler._state = TERMINATE
  424. task_handler._state = TERMINATE
  425. debug('helping task handler/workers to finish')
  426. cls._help_stuff_finish(inqueue, task_handler, len(pool))
  427. assert result_handler.is_alive() or len(cache) == 0
  428. result_handler._state = TERMINATE
  429. outqueue.put(None) # sentinel
  430. # We must wait for the worker handler to exit before terminating
  431. # workers because we don't want workers to be restarted behind our back.
  432. debug('joining worker handler')
  433. if threading.current_thread() is not worker_handler:
  434. worker_handler.join(1e100)
  435. # Terminate workers which haven't already finished.
  436. if pool and hasattr(pool[0], 'terminate'):
  437. debug('terminating workers')
  438. for p in pool:
  439. if p.exitcode is None:
  440. p.terminate()
  441. debug('joining task handler')
  442. if threading.current_thread() is not task_handler:
  443. task_handler.join(1e100)
  444. debug('joining result handler')
  445. if threading.current_thread() is not result_handler:
  446. result_handler.join(1e100)
  447. if pool and hasattr(pool[0], 'terminate'):
  448. debug('joining pool workers')
  449. for p in pool:
  450. if p.is_alive():
  451. # worker has not yet exited
  452. debug('cleaning up worker %d' % p.pid)
  453. p.join()
  454. #
  455. # Class whose instances are returned by `Pool.apply_async()`
  456. #
  457. class ApplyResult(object):
  458. def __init__(self, cache, callback):
  459. self._cond = threading.Condition(threading.Lock())
  460. self._job = job_counter.next()
  461. self._cache = cache
  462. self._ready = False
  463. self._callback = callback
  464. cache[self._job] = self
  465. def ready(self):
  466. return self._ready
  467. def successful(self):
  468. assert self._ready
  469. return self._success
  470. def wait(self, timeout=None):
  471. self._cond.acquire()
  472. try:
  473. if not self._ready:
  474. self._cond.wait(timeout)
  475. finally:
  476. self._cond.release()
  477. def get(self, timeout=None):
  478. self.wait(timeout)
  479. if not self._ready:
  480. raise TimeoutError
  481. if self._success:
  482. return self._value
  483. else:
  484. raise self._value
  485. def _set(self, i, obj):
  486. self._success, self._value = obj
  487. if self._callback and self._success:
  488. self._callback(self._value)
  489. self._cond.acquire()
  490. try:
  491. self._ready = True
  492. self._cond.notify()
  493. finally:
  494. self._cond.release()
  495. del self._cache[self._job]
  496. AsyncResult = ApplyResult # create alias -- see #17805
  497. #
  498. # Class whose instances are returned by `Pool.map_async()`
  499. #
  500. class MapResult(ApplyResult):
  501. def __init__(self, cache, chunksize, length, callback):
  502. ApplyResult.__init__(self, cache, callback)
  503. self._success = True
  504. self._value = [None] * length
  505. self._chunksize = chunksize
  506. if chunksize <= 0:
  507. self._number_left = 0
  508. self._ready = True
  509. del cache[self._job]
  510. else:
  511. self._number_left = length//chunksize + bool(length % chunksize)
  512. def _set(self, i, success_result):
  513. success, result = success_result
  514. if success:
  515. self._value[i*self._chunksize:(i+1)*self._chunksize] = result
  516. self._number_left -= 1
  517. if self._number_left == 0:
  518. if self._callback:
  519. self._callback(self._value)
  520. del self._cache[self._job]
  521. self._cond.acquire()
  522. try:
  523. self._ready = True
  524. self._cond.notify()
  525. finally:
  526. self._cond.release()
  527. else:
  528. self._success = False
  529. self._value = result
  530. del self._cache[self._job]
  531. self._cond.acquire()
  532. try:
  533. self._ready = True
  534. self._cond.notify()
  535. finally:
  536. self._cond.release()
  537. #
  538. # Class whose instances are returned by `Pool.imap()`
  539. #
  540. class IMapIterator(object):
  541. def __init__(self, cache):
  542. self._cond = threading.Condition(threading.Lock())
  543. self._job = job_counter.next()
  544. self._cache = cache
  545. self._items = collections.deque()
  546. self._index = 0
  547. self._length = None
  548. self._unsorted = {}
  549. cache[self._job] = self
  550. def __iter__(self):
  551. return self
  552. def next(self, timeout=None):
  553. self._cond.acquire()
  554. try:
  555. try:
  556. item = self._items.popleft()
  557. except IndexError:
  558. if self._index == self._length:
  559. raise StopIteration
  560. self._cond.wait(timeout)
  561. try:
  562. item = self._items.popleft()
  563. except IndexError:
  564. if self._index == self._length:
  565. raise StopIteration
  566. raise TimeoutError
  567. finally:
  568. self._cond.release()
  569. success, value = item
  570. if success:
  571. return value
  572. raise value
  573. __next__ = next # XXX
  574. def _set(self, i, obj):
  575. self._cond.acquire()
  576. try:
  577. if self._index == i:
  578. self._items.append(obj)
  579. self._index += 1
  580. while self._index in self._unsorted:
  581. obj = self._unsorted.pop(self._index)
  582. self._items.append(obj)
  583. self._index += 1
  584. self._cond.notify()
  585. else:
  586. self._unsorted[i] = obj
  587. if self._index == self._length:
  588. del self._cache[self._job]
  589. finally:
  590. self._cond.release()
  591. def _set_length(self, length):
  592. self._cond.acquire()
  593. try:
  594. self._length = length
  595. if self._index == self._length:
  596. self._cond.notify()
  597. del self._cache[self._job]
  598. finally:
  599. self._cond.release()
  600. #
  601. # Class whose instances are returned by `Pool.imap_unordered()`
  602. #
  603. class IMapUnorderedIterator(IMapIterator):
  604. def _set(self, i, obj):
  605. self._cond.acquire()
  606. try:
  607. self._items.append(obj)
  608. self._index += 1
  609. self._cond.notify()
  610. if self._index == self._length:
  611. del self._cache[self._job]
  612. finally:
  613. self._cond.release()
  614. #
  615. #
  616. #
  617. class ThreadPool(Pool):
  618. from .dummy import Process
  619. def __init__(self, processes=None, initializer=None, initargs=()):
  620. Pool.__init__(self, processes, initializer, initargs)
  621. def _setup_queues(self):
  622. self._inqueue = Queue.Queue()
  623. self._outqueue = Queue.Queue()
  624. self._quick_put = self._inqueue.put
  625. self._quick_get = self._outqueue.get
  626. @staticmethod
  627. def _help_stuff_finish(inqueue, task_handler, size):
  628. # put sentinels at head of inqueue to make workers finish
  629. inqueue.not_empty.acquire()
  630. try:
  631. inqueue.queue.clear()
  632. inqueue.queue.extend([None] * size)
  633. inqueue.not_empty.notify_all()
  634. finally:
  635. inqueue.not_empty.release()