You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
83 lines
2.6 KiB
Python
83 lines
2.6 KiB
Python
from collections import defaultdict
|
|
|
|
import attr
|
|
from async_generator import asynccontextmanager
|
|
|
|
from .. import _core
|
|
from .. import _util
|
|
from .. import Event
|
|
|
|
if False:
|
|
from typing import DefaultDict, Set
|
|
|
|
|
|
@attr.s(eq=False, hash=False)
|
|
class Sequencer(metaclass=_util.Final):
|
|
"""A convenience class for forcing code in different tasks to run in an
|
|
explicit linear order.
|
|
|
|
Instances of this class implement a ``__call__`` method which returns an
|
|
async context manager. The idea is that you pass a sequence number to
|
|
``__call__`` to say where this block of code should go in the linear
|
|
sequence. Block 0 starts immediately, and then block N doesn't start until
|
|
block N-1 has finished.
|
|
|
|
Example:
|
|
An extremely elaborate way to print the numbers 0-5, in order::
|
|
|
|
async def worker1(seq):
|
|
async with seq(0):
|
|
print(0)
|
|
async with seq(4):
|
|
print(4)
|
|
|
|
async def worker2(seq):
|
|
async with seq(2):
|
|
print(2)
|
|
async with seq(5):
|
|
print(5)
|
|
|
|
async def worker3(seq):
|
|
async with seq(1):
|
|
print(1)
|
|
async with seq(3):
|
|
print(3)
|
|
|
|
async def main():
|
|
seq = trio.testing.Sequencer()
|
|
async with trio.open_nursery() as nursery:
|
|
nursery.start_soon(worker1, seq)
|
|
nursery.start_soon(worker2, seq)
|
|
nursery.start_soon(worker3, seq)
|
|
|
|
"""
|
|
|
|
_sequence_points = attr.ib(
|
|
factory=lambda: defaultdict(Event), init=False
|
|
) # type: DefaultDict[int, Event]
|
|
_claimed = attr.ib(factory=set, init=False) # type: Set[int]
|
|
_broken = attr.ib(default=False, init=False)
|
|
|
|
@asynccontextmanager
|
|
async def __call__(self, position: int):
|
|
if position in self._claimed:
|
|
raise RuntimeError("Attempted to re-use sequence point {}".format(position))
|
|
if self._broken:
|
|
raise RuntimeError("sequence broken!")
|
|
self._claimed.add(position)
|
|
if position != 0:
|
|
try:
|
|
await self._sequence_points[position].wait()
|
|
except _core.Cancelled:
|
|
self._broken = True
|
|
for event in self._sequence_points.values():
|
|
event.set()
|
|
raise RuntimeError("Sequencer wait cancelled -- sequence broken")
|
|
else:
|
|
if self._broken:
|
|
raise RuntimeError("sequence broken!")
|
|
try:
|
|
yield
|
|
finally:
|
|
self._sequence_points[position + 1].set()
|