import asyncio
|
import contextlib
|
|
# PY3.9: Import Callable from typing until we drop Python 3.9 support
|
# https://github.com/python/cpython/issues/87131
|
from typing import (
|
TYPE_CHECKING,
|
Any,
|
Awaitable,
|
Callable,
|
Iterable,
|
List,
|
Optional,
|
Set,
|
Tuple,
|
TypeVar,
|
Union,
|
)
|
|
_T = TypeVar("_T")
|
|
RE_RAISE_EXCEPTIONS = (SystemExit, KeyboardInterrupt)
|
|
|
def _set_result(wait_next: "asyncio.Future[None]") -> None:
|
"""Set the result of a future if it is not already done."""
|
if not wait_next.done():
|
wait_next.set_result(None)
|
|
|
async def _wait_one(
|
futures: "Iterable[asyncio.Future[Any]]",
|
loop: asyncio.AbstractEventLoop,
|
) -> _T:
|
"""Wait for the first future to complete."""
|
wait_next = loop.create_future()
|
|
def _on_completion(fut: "asyncio.Future[Any]") -> None:
|
if not wait_next.done():
|
wait_next.set_result(fut)
|
|
for f in futures:
|
f.add_done_callback(_on_completion)
|
|
try:
|
return await wait_next
|
finally:
|
for f in futures:
|
f.remove_done_callback(_on_completion)
|
|
|
async def staggered_race(
|
coro_fns: Iterable[Callable[[], Awaitable[_T]]],
|
delay: Optional[float],
|
*,
|
loop: Optional[asyncio.AbstractEventLoop] = None,
|
) -> Tuple[Optional[_T], Optional[int], List[Optional[BaseException]]]:
|
"""
|
Run coroutines with staggered start times and take the first to finish.
|
|
This method takes an iterable of coroutine functions. The first one is
|
started immediately. From then on, whenever the immediately preceding one
|
fails (raises an exception), or when *delay* seconds has passed, the next
|
coroutine is started. This continues until one of the coroutines complete
|
successfully, in which case all others are cancelled, or until all
|
coroutines fail.
|
|
The coroutines provided should be well-behaved in the following way:
|
|
* They should only ``return`` if completed successfully.
|
|
* They should always raise an exception if they did not complete
|
successfully. In particular, if they handle cancellation, they should
|
probably reraise, like this::
|
|
try:
|
# do work
|
except asyncio.CancelledError:
|
# undo partially completed work
|
raise
|
|
Args:
|
----
|
coro_fns: an iterable of coroutine functions, i.e. callables that
|
return a coroutine object when called. Use ``functools.partial`` or
|
lambdas to pass arguments.
|
|
delay: amount of time, in seconds, between starting coroutines. If
|
``None``, the coroutines will run sequentially.
|
|
loop: the event loop to use. If ``None``, the running loop is used.
|
|
Returns:
|
-------
|
tuple *(winner_result, winner_index, exceptions)* where
|
|
- *winner_result*: the result of the winning coroutine, or ``None``
|
if no coroutines won.
|
|
- *winner_index*: the index of the winning coroutine in
|
``coro_fns``, or ``None`` if no coroutines won. If the winning
|
coroutine may return None on success, *winner_index* can be used
|
to definitively determine whether any coroutine won.
|
|
- *exceptions*: list of exceptions returned by the coroutines.
|
``len(exceptions)`` is equal to the number of coroutines actually
|
started, and the order is the same as in ``coro_fns``. The winning
|
coroutine's entry is ``None``.
|
|
"""
|
loop = loop or asyncio.get_running_loop()
|
exceptions: List[Optional[BaseException]] = []
|
tasks: Set[asyncio.Task[Optional[Tuple[_T, int]]]] = set()
|
|
async def run_one_coro(
|
coro_fn: Callable[[], Awaitable[_T]],
|
this_index: int,
|
start_next: "asyncio.Future[None]",
|
) -> Optional[Tuple[_T, int]]:
|
"""
|
Run a single coroutine.
|
|
If the coroutine fails, set the exception in the exceptions list and
|
start the next coroutine by setting the result of the start_next.
|
|
If the coroutine succeeds, return the result and the index of the
|
coroutine in the coro_fns list.
|
|
If SystemExit or KeyboardInterrupt is raised, re-raise it.
|
"""
|
try:
|
result = await coro_fn()
|
except RE_RAISE_EXCEPTIONS:
|
raise
|
except BaseException as e:
|
exceptions[this_index] = e
|
_set_result(start_next) # Kickstart the next coroutine
|
return None
|
|
return result, this_index
|
|
start_next_timer: Optional[asyncio.TimerHandle] = None
|
start_next: Optional[asyncio.Future[None]]
|
task: asyncio.Task[Optional[Tuple[_T, int]]]
|
done: Union[asyncio.Future[None], asyncio.Task[Optional[Tuple[_T, int]]]]
|
coro_iter = iter(coro_fns)
|
this_index = -1
|
try:
|
while True:
|
if coro_fn := next(coro_iter, None):
|
this_index += 1
|
exceptions.append(None)
|
start_next = loop.create_future()
|
task = loop.create_task(run_one_coro(coro_fn, this_index, start_next))
|
tasks.add(task)
|
start_next_timer = (
|
loop.call_later(delay, _set_result, start_next) if delay else None
|
)
|
elif not tasks:
|
# We exhausted the coro_fns list and no tasks are running
|
# so we have no winner and all coroutines failed.
|
break
|
|
while tasks or start_next:
|
done = await _wait_one(
|
(*tasks, start_next) if start_next else tasks, loop
|
)
|
if done is start_next:
|
# The current task has failed or the timer has expired
|
# so we need to start the next task.
|
start_next = None
|
if start_next_timer:
|
start_next_timer.cancel()
|
start_next_timer = None
|
|
# Break out of the task waiting loop to start the next
|
# task.
|
break
|
|
if TYPE_CHECKING:
|
assert isinstance(done, asyncio.Task)
|
|
tasks.remove(done)
|
if winner := done.result():
|
return *winner, exceptions
|
finally:
|
# We either have:
|
# - a winner
|
# - all tasks failed
|
# - a KeyboardInterrupt or SystemExit.
|
|
#
|
# If the timer is still running, cancel it.
|
#
|
if start_next_timer:
|
start_next_timer.cancel()
|
|
#
|
# If there are any tasks left, cancel them and than
|
# wait them so they fill the exceptions list.
|
#
|
for task in tasks:
|
task.cancel()
|
with contextlib.suppress(asyncio.CancelledError):
|
await task
|
|
return None, None, exceptions
|