1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import asyncio
import contextlib
 
# PY3.9: Import Callable from typing until we drop Python 3.9 support
# https://github.com/python/cpython/issues/87131
from typing import (
    TYPE_CHECKING,
    Any,
    Awaitable,
    Callable,
    Iterable,
    List,
    Optional,
    Set,
    Tuple,
    TypeVar,
    Union,
)
 
_T = TypeVar("_T")
 
RE_RAISE_EXCEPTIONS = (SystemExit, KeyboardInterrupt)
 
 
def _set_result(wait_next: "asyncio.Future[None]") -> None:
    """Set the result of a future if it is not already done."""
    if not wait_next.done():
        wait_next.set_result(None)
 
 
async def _wait_one(
    futures: "Iterable[asyncio.Future[Any]]",
    loop: asyncio.AbstractEventLoop,
) -> _T:
    """Wait for the first future to complete."""
    wait_next = loop.create_future()
 
    def _on_completion(fut: "asyncio.Future[Any]") -> None:
        if not wait_next.done():
            wait_next.set_result(fut)
 
    for f in futures:
        f.add_done_callback(_on_completion)
 
    try:
        return await wait_next
    finally:
        for f in futures:
            f.remove_done_callback(_on_completion)
 
 
async def staggered_race(
    coro_fns: Iterable[Callable[[], Awaitable[_T]]],
    delay: Optional[float],
    *,
    loop: Optional[asyncio.AbstractEventLoop] = None,
) -> Tuple[Optional[_T], Optional[int], List[Optional[BaseException]]]:
    """
    Run coroutines with staggered start times and take the first to finish.
 
    This method takes an iterable of coroutine functions. The first one is
    started immediately. From then on, whenever the immediately preceding one
    fails (raises an exception), or when *delay* seconds has passed, the next
    coroutine is started. This continues until one of the coroutines complete
    successfully, in which case all others are cancelled, or until all
    coroutines fail.
 
    The coroutines provided should be well-behaved in the following way:
 
    * They should only ``return`` if completed successfully.
 
    * They should always raise an exception if they did not complete
      successfully. In particular, if they handle cancellation, they should
      probably reraise, like this::
 
        try:
            # do work
        except asyncio.CancelledError:
            # undo partially completed work
            raise
 
    Args:
    ----
        coro_fns: an iterable of coroutine functions, i.e. callables that
            return a coroutine object when called. Use ``functools.partial`` or
            lambdas to pass arguments.
 
        delay: amount of time, in seconds, between starting coroutines. If
            ``None``, the coroutines will run sequentially.
 
        loop: the event loop to use. If ``None``, the running loop is used.
 
    Returns:
    -------
        tuple *(winner_result, winner_index, exceptions)* where
 
        - *winner_result*: the result of the winning coroutine, or ``None``
          if no coroutines won.
 
        - *winner_index*: the index of the winning coroutine in
          ``coro_fns``, or ``None`` if no coroutines won. If the winning
          coroutine may return None on success, *winner_index* can be used
          to definitively determine whether any coroutine won.
 
        - *exceptions*: list of exceptions returned by the coroutines.
          ``len(exceptions)`` is equal to the number of coroutines actually
          started, and the order is the same as in ``coro_fns``. The winning
          coroutine's entry is ``None``.
 
    """
    loop = loop or asyncio.get_running_loop()
    exceptions: List[Optional[BaseException]] = []
    tasks: Set[asyncio.Task[Optional[Tuple[_T, int]]]] = set()
 
    async def run_one_coro(
        coro_fn: Callable[[], Awaitable[_T]],
        this_index: int,
        start_next: "asyncio.Future[None]",
    ) -> Optional[Tuple[_T, int]]:
        """
        Run a single coroutine.
 
        If the coroutine fails, set the exception in the exceptions list and
        start the next coroutine by setting the result of the start_next.
 
        If the coroutine succeeds, return the result and the index of the
        coroutine in the coro_fns list.
 
        If SystemExit or KeyboardInterrupt is raised, re-raise it.
        """
        try:
            result = await coro_fn()
        except RE_RAISE_EXCEPTIONS:
            raise
        except BaseException as e:
            exceptions[this_index] = e
            _set_result(start_next)  # Kickstart the next coroutine
            return None
 
        return result, this_index
 
    start_next_timer: Optional[asyncio.TimerHandle] = None
    start_next: Optional[asyncio.Future[None]]
    task: asyncio.Task[Optional[Tuple[_T, int]]]
    done: Union[asyncio.Future[None], asyncio.Task[Optional[Tuple[_T, int]]]]
    coro_iter = iter(coro_fns)
    this_index = -1
    try:
        while True:
            if coro_fn := next(coro_iter, None):
                this_index += 1
                exceptions.append(None)
                start_next = loop.create_future()
                task = loop.create_task(run_one_coro(coro_fn, this_index, start_next))
                tasks.add(task)
                start_next_timer = (
                    loop.call_later(delay, _set_result, start_next) if delay else None
                )
            elif not tasks:
                # We exhausted the coro_fns list and no tasks are running
                # so we have no winner and all coroutines failed.
                break
 
            while tasks or start_next:
                done = await _wait_one(
                    (*tasks, start_next) if start_next else tasks, loop
                )
                if done is start_next:
                    # The current task has failed or the timer has expired
                    # so we need to start the next task.
                    start_next = None
                    if start_next_timer:
                        start_next_timer.cancel()
                        start_next_timer = None
 
                    # Break out of the task waiting loop to start the next
                    # task.
                    break
 
                if TYPE_CHECKING:
                    assert isinstance(done, asyncio.Task)
 
                tasks.remove(done)
                if winner := done.result():
                    return *winner, exceptions
    finally:
        # We either have:
        #  - a winner
        #  - all tasks failed
        #  - a KeyboardInterrupt or SystemExit.
 
        #
        # If the timer is still running, cancel it.
        #
        if start_next_timer:
            start_next_timer.cancel()
 
        #
        # If there are any tasks left, cancel them and than
        # wait them so they fill the exceptions list.
        #
        for task in tasks:
            task.cancel()
            with contextlib.suppress(asyncio.CancelledError):
                await task
 
    return None, None, exceptions