diff options
author | Trygve Aaberge <trygveaa@gmail.com> | 2023-01-28 16:33:34 +0100 |
---|---|---|
committer | Trygve Aaberge <trygveaa@gmail.com> | 2024-02-18 11:32:53 +0100 |
commit | 26d3c52e775a806756bfbfc4d8b39537ed94f6a6 (patch) | |
tree | 6fa0a45548c3bc442faec9c018c5bf2117c7ce93 | |
parent | 73cf96863d6529e98cfa9541da4473029479eae0 (diff) | |
download | wee-slack-26d3c52e775a806756bfbfc4d8b39537ed94f6a6.tar.gz |
Support return_exceptions in gather
-rw-r--r-- | slack/task.py | 57 |
1 files changed, 51 insertions, 6 deletions
diff --git a/slack/task.py b/slack/task.py index c35e42f..facc0e2 100644 --- a/slack/task.py +++ b/slack/task.py @@ -4,12 +4,16 @@ from typing import ( Any, Awaitable, Coroutine, + Dict, Generator, List, + Literal, Optional, + Sequence, Tuple, TypeVar, Union, + overload, ) from uuid import uuid4 @@ -120,13 +124,54 @@ def create_task(coroutine: Coroutine[Future[Any], Any, T]) -> Task[T]: return task -async def gather(*requests: Union[Future[T], Coroutine[Any, Any, T]]) -> List[T]: +@overload +async def gather( + *requests: Union[Future[T], Coroutine[Any, Any, T]], + return_exceptions: Literal[False] = False, +) -> List[T]: + ... + + +@overload +async def gather( + *requests: Union[Future[T], Coroutine[Any, Any, T]], + return_exceptions: Literal[True], +) -> List[Union[T, BaseException]]: + ... + + +async def gather( + *requests: Union[Future[T], Coroutine[Any, Any, T]], return_exceptions: bool = False +) -> Sequence[Union[T, BaseException]]: # TODO: Should probably propagate first exception - tasks = [ - create_task(request) if isinstance(request, Coroutine) else request - for request in requests - ] - return [await task for task in tasks] + + tasks_map: Dict[int, Future[T]] = {} + results_map: Dict[int, Union[T, BaseException]] = {} + + for i, request in enumerate(requests): + if isinstance(request, Coroutine): + try: + tasks_map[i] = create_task(request) + except BaseException as e: + results_map[i] = e + else: + tasks_map[i] = request + + for i, task in tasks_map.items(): + try: + # print(f"waiting for {task}") + results_map[i] = await task + except BaseException as e: + results_map[i] = e + + results = [results_map[i] for i in sorted(results_map.keys())] + + if not return_exceptions: + for result in results: + if isinstance(result, BaseException): + raise result + + return results async def sleep(milliseconds: int): |