You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
131 lines
4.5 KiB
Python
131 lines
4.5 KiB
Python
2 years ago
|
"""
|
||
|
Thin wrappers around `concurrent.futures`.
|
||
|
"""
|
||
|
from __future__ import absolute_import
|
||
|
|
||
|
from contextlib import contextmanager
|
||
|
|
||
|
from ..auto import tqdm as tqdm_auto
|
||
|
from ..std import TqdmWarning
|
||
|
|
||
|
try:
|
||
|
from operator import length_hint
|
||
|
except ImportError:
|
||
|
def length_hint(it, default=0):
|
||
|
"""Returns `len(it)`, falling back to `default`"""
|
||
|
try:
|
||
|
return len(it)
|
||
|
except TypeError:
|
||
|
return default
|
||
|
try:
|
||
|
from os import cpu_count
|
||
|
except ImportError:
|
||
|
try:
|
||
|
from multiprocessing import cpu_count
|
||
|
except ImportError:
|
||
|
def cpu_count():
|
||
|
return 4
|
||
|
import sys
|
||
|
|
||
|
__author__ = {"github.com/": ["casperdcl"]}
|
||
|
__all__ = ['thread_map', 'process_map']
|
||
|
|
||
|
|
||
|
@contextmanager
|
||
|
def ensure_lock(tqdm_class, lock_name=""):
|
||
|
"""get (create if necessary) and then restore `tqdm_class`'s lock"""
|
||
|
old_lock = getattr(tqdm_class, '_lock', None) # don't create a new lock
|
||
|
lock = old_lock or tqdm_class.get_lock() # maybe create a new lock
|
||
|
lock = getattr(lock, lock_name, lock) # maybe subtype
|
||
|
tqdm_class.set_lock(lock)
|
||
|
yield lock
|
||
|
if old_lock is None:
|
||
|
del tqdm_class._lock
|
||
|
else:
|
||
|
tqdm_class.set_lock(old_lock)
|
||
|
|
||
|
|
||
|
def _executor_map(PoolExecutor, fn, *iterables, **tqdm_kwargs):
|
||
|
"""
|
||
|
Implementation of `thread_map` and `process_map`.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
tqdm_class : [default: tqdm.auto.tqdm].
|
||
|
max_workers : [default: min(32, cpu_count() + 4)].
|
||
|
chunksize : [default: 1].
|
||
|
lock_name : [default: "":str].
|
||
|
"""
|
||
|
kwargs = tqdm_kwargs.copy()
|
||
|
if "total" not in kwargs:
|
||
|
kwargs["total"] = length_hint(iterables[0])
|
||
|
tqdm_class = kwargs.pop("tqdm_class", tqdm_auto)
|
||
|
max_workers = kwargs.pop("max_workers", min(32, cpu_count() + 4))
|
||
|
chunksize = kwargs.pop("chunksize", 1)
|
||
|
lock_name = kwargs.pop("lock_name", "")
|
||
|
with ensure_lock(tqdm_class, lock_name=lock_name) as lk:
|
||
|
pool_kwargs = {'max_workers': max_workers}
|
||
|
sys_version = sys.version_info[:2]
|
||
|
if sys_version >= (3, 7):
|
||
|
# share lock in case workers are already using `tqdm`
|
||
|
pool_kwargs.update(initializer=tqdm_class.set_lock, initargs=(lk,))
|
||
|
map_args = {}
|
||
|
if not (3, 0) < sys_version < (3, 5):
|
||
|
map_args.update(chunksize=chunksize)
|
||
|
with PoolExecutor(**pool_kwargs) as ex:
|
||
|
return list(tqdm_class(ex.map(fn, *iterables, **map_args), **kwargs))
|
||
|
|
||
|
|
||
|
def thread_map(fn, *iterables, **tqdm_kwargs):
|
||
|
"""
|
||
|
Equivalent of `list(map(fn, *iterables))`
|
||
|
driven by `concurrent.futures.ThreadPoolExecutor`.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
tqdm_class : optional
|
||
|
`tqdm` class to use for bars [default: tqdm.auto.tqdm].
|
||
|
max_workers : int, optional
|
||
|
Maximum number of workers to spawn; passed to
|
||
|
`concurrent.futures.ThreadPoolExecutor.__init__`.
|
||
|
[default: max(32, cpu_count() + 4)].
|
||
|
"""
|
||
|
from concurrent.futures import ThreadPoolExecutor
|
||
|
return _executor_map(ThreadPoolExecutor, fn, *iterables, **tqdm_kwargs)
|
||
|
|
||
|
|
||
|
def process_map(fn, *iterables, **tqdm_kwargs):
|
||
|
"""
|
||
|
Equivalent of `list(map(fn, *iterables))`
|
||
|
driven by `concurrent.futures.ProcessPoolExecutor`.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
tqdm_class : optional
|
||
|
`tqdm` class to use for bars [default: tqdm.auto.tqdm].
|
||
|
max_workers : int, optional
|
||
|
Maximum number of workers to spawn; passed to
|
||
|
`concurrent.futures.ProcessPoolExecutor.__init__`.
|
||
|
[default: min(32, cpu_count() + 4)].
|
||
|
chunksize : int, optional
|
||
|
Size of chunks sent to worker processes; passed to
|
||
|
`concurrent.futures.ProcessPoolExecutor.map`. [default: 1].
|
||
|
lock_name : str, optional
|
||
|
Member of `tqdm_class.get_lock()` to use [default: mp_lock].
|
||
|
"""
|
||
|
from concurrent.futures import ProcessPoolExecutor
|
||
|
if iterables and "chunksize" not in tqdm_kwargs:
|
||
|
# default `chunksize=1` has poor performance for large iterables
|
||
|
# (most time spent dispatching items to workers).
|
||
|
longest_iterable_len = max(map(length_hint, iterables))
|
||
|
if longest_iterable_len > 1000:
|
||
|
from warnings import warn
|
||
|
warn("Iterable length %d > 1000 but `chunksize` is not set."
|
||
|
" This may seriously degrade multiprocess performance."
|
||
|
" Set `chunksize=1` or more." % longest_iterable_len,
|
||
|
TqdmWarning, stacklevel=2)
|
||
|
if "lock_name" not in tqdm_kwargs:
|
||
|
tqdm_kwargs = tqdm_kwargs.copy()
|
||
|
tqdm_kwargs["lock_name"] = "mp_lock"
|
||
|
return _executor_map(ProcessPoolExecutor, fn, *iterables, **tqdm_kwargs)
|