Source code for lenstronomy.Sampling.Pool.multiprocessing

"""
this file is taken from schwimmbad (https://github.com/adrn/schwimmbad) and an explicit fork by Aymeric Galan
to replace the multiprocessing with the multiprocess dependence as for multi-threading, multiprocessing is
not supporting dill (only pickle) which is required.

The class also extends with a ``is_master()`` definition

"""


# Standard library

from __future__ import division, print_function, absolute_import, unicode_literals
import signal
import functools
import multiprocess
from multiprocess.pool import Pool

__all__ = ['MultiPool']


def _initializer_wrapper(actual_initializer, *rest):
    """
    We ignore SIGINT. It's up to our parent to kill us in the typical condition of this arising from ``^C`` on a
    terminal. If someone is manually killing us with that signal, well... nothing will happen.
    """
    signal.signal(signal.SIGINT, signal.SIG_IGN)
    if actual_initializer is not None:
        actual_initializer(*rest)


class CallbackWrapper(object):

    def __init__(self, callback):
        self.callback = callback

    def __call__(self, tasks):
        for task in tasks:
            self.callback(task)


[docs]class MultiPool(Pool): """ A modified version of :class:`multiprocessing.pool.Pool` that has better behavior with regard to ``KeyboardInterrupts`` in the :func:`map` method. (Original author: `Peter K. G. Williams <peter@newton.cx>`_) Parameters ---------- processes : int, optional The number of worker processes to use; defaults to the number of CPUs. initializer : callable, optional If specified, a callable that will be invoked by each worker process when it starts. initargs : iterable, optional Arguments for ``initializer``; it will be called as ``initializer(*initargs)``. kwargs: Extra arguments passed to the :class:`multiprocessing.pool.Pool` superclass. """ wait_timeout = 3600 def __init__(self, processes=None, initializer=None, initargs=(), **kwargs): new_initializer = functools.partial(_initializer_wrapper, initializer) super(MultiPool, self).__init__(processes, new_initializer, initargs, **kwargs) self.size = self._processes self.rank = 0
[docs] def is_master(self): return self.rank == 0
[docs] def is_worker(self): return self.rank != 0
[docs] @staticmethod def enabled(): return True
[docs] def map(self, func, iterable, chunksize=None, callback=None): """ Equivalent to the built-in ``map()`` function and :meth:`multiprocessing.pool.Pool.map()`, without catching ``KeyboardInterrupt``. Parameters ---------- func : callable A function or callable object that is executed on each element of the specified ``tasks`` iterable. This object must be picklable (i.e. it can't be a function scoped within a function or a ``lambda`` function). This should accept a single positional argument and return a single object. iterable : iterable A list or iterable of tasks. Each task can be itself an iterable (e.g., tuple) of values or data to pass in to the worker function. callback : callable, optional An optional callback function (or callable) that is called with the result from each worker run and is executed on the master process. This is useful for, e.g., saving results to a file, since the callback is only called on the master thread. Returns ------- results : list A list of results from the output of each ``worker()`` call. """ if callback is None: callbackwrapper = None else: callbackwrapper = CallbackWrapper(callback) # The key magic is that we must call r.get() with a timeout, because # a Condition.wait() without a timeout swallows KeyboardInterrupts. r = self.map_async(func, iterable, chunksize=chunksize, callback=callbackwrapper) while True: try: return r.get(self.wait_timeout) except multiprocess.TimeoutError: pass except KeyboardInterrupt: self.terminate() self.join() raise