You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
94 lines
2.7 KiB
Python
94 lines
2.7 KiB
Python
2 years ago
|
"""
|
||
|
Asynchronous progressbar decorator for iterators.
|
||
|
Includes a default `range` iterator printing to `stderr`.
|
||
|
|
||
|
Usage:
|
||
|
>>> from tqdm.asyncio import trange, tqdm
|
||
|
>>> async for i in trange(10):
|
||
|
... ...
|
||
|
"""
|
||
|
import asyncio
|
||
|
from sys import version_info
|
||
|
|
||
|
from .std import tqdm as std_tqdm
|
||
|
|
||
|
__author__ = {"github.com/": ["casperdcl"]}
|
||
|
__all__ = ['tqdm_asyncio', 'tarange', 'tqdm', 'trange']
|
||
|
|
||
|
|
||
|
class tqdm_asyncio(std_tqdm):
|
||
|
"""
|
||
|
Asynchronous-friendly version of tqdm (Python 3.6+).
|
||
|
"""
|
||
|
def __init__(self, iterable=None, *args, **kwargs):
|
||
|
super(tqdm_asyncio, self).__init__(iterable, *args, **kwargs)
|
||
|
self.iterable_awaitable = False
|
||
|
if iterable is not None:
|
||
|
if hasattr(iterable, "__anext__"):
|
||
|
self.iterable_next = iterable.__anext__
|
||
|
self.iterable_awaitable = True
|
||
|
elif hasattr(iterable, "__next__"):
|
||
|
self.iterable_next = iterable.__next__
|
||
|
else:
|
||
|
self.iterable_iterator = iter(iterable)
|
||
|
self.iterable_next = self.iterable_iterator.__next__
|
||
|
|
||
|
def __aiter__(self):
|
||
|
return self
|
||
|
|
||
|
async def __anext__(self):
|
||
|
try:
|
||
|
if self.iterable_awaitable:
|
||
|
res = await self.iterable_next()
|
||
|
else:
|
||
|
res = self.iterable_next()
|
||
|
self.update()
|
||
|
return res
|
||
|
except StopIteration:
|
||
|
self.close()
|
||
|
raise StopAsyncIteration
|
||
|
except BaseException:
|
||
|
self.close()
|
||
|
raise
|
||
|
|
||
|
def send(self, *args, **kwargs):
|
||
|
return self.iterable.send(*args, **kwargs)
|
||
|
|
||
|
@classmethod
|
||
|
def as_completed(cls, fs, *, loop=None, timeout=None, total=None, **tqdm_kwargs):
|
||
|
"""
|
||
|
Wrapper for `asyncio.as_completed`.
|
||
|
"""
|
||
|
if total is None:
|
||
|
total = len(fs)
|
||
|
kwargs = {}
|
||
|
if version_info[:2] < (3, 10):
|
||
|
kwargs['loop'] = loop
|
||
|
yield from cls(asyncio.as_completed(fs, timeout=timeout, **kwargs),
|
||
|
total=total, **tqdm_kwargs)
|
||
|
|
||
|
@classmethod
|
||
|
async def gather(cls, *fs, loop=None, timeout=None, total=None, **tqdm_kwargs):
|
||
|
"""
|
||
|
Wrapper for `asyncio.gather`.
|
||
|
"""
|
||
|
async def wrap_awaitable(i, f):
|
||
|
return i, await f
|
||
|
|
||
|
ifs = [wrap_awaitable(i, f) for i, f in enumerate(fs)]
|
||
|
res = [await f for f in cls.as_completed(ifs, loop=loop, timeout=timeout,
|
||
|
total=total, **tqdm_kwargs)]
|
||
|
return [i for _, i in sorted(res)]
|
||
|
|
||
|
|
||
|
def tarange(*args, **kwargs):
|
||
|
"""
|
||
|
A shortcut for `tqdm.asyncio.tqdm(range(*args), **kwargs)`.
|
||
|
"""
|
||
|
return tqdm_asyncio(range(*args), **kwargs)
|
||
|
|
||
|
|
||
|
# Aliases
|
||
|
tqdm = tqdm_asyncio
|
||
|
trange = tarange
|