aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTrygve Aaberge <trygveaa@gmail.com>2023-01-28 16:33:34 +0100
committerTrygve Aaberge <trygveaa@gmail.com>2024-02-18 11:32:53 +0100
commit26d3c52e775a806756bfbfc4d8b39537ed94f6a6 (patch)
tree6fa0a45548c3bc442faec9c018c5bf2117c7ce93
parent73cf96863d6529e98cfa9541da4473029479eae0 (diff)
downloadwee-slack-26d3c52e775a806756bfbfc4d8b39537ed94f6a6.tar.gz
Support return_exceptions in gather
-rw-r--r--slack/task.py57
1 files changed, 51 insertions, 6 deletions
diff --git a/slack/task.py b/slack/task.py
index c35e42f..facc0e2 100644
--- a/slack/task.py
+++ b/slack/task.py
@@ -4,12 +4,16 @@ from typing import (
Any,
Awaitable,
Coroutine,
+ Dict,
Generator,
List,
+ Literal,
Optional,
+ Sequence,
Tuple,
TypeVar,
Union,
+ overload,
)
from uuid import uuid4
@@ -120,13 +124,54 @@ def create_task(coroutine: Coroutine[Future[Any], Any, T]) -> Task[T]:
return task
-async def gather(*requests: Union[Future[T], Coroutine[Any, Any, T]]) -> List[T]:
+@overload
+async def gather(
+ *requests: Union[Future[T], Coroutine[Any, Any, T]],
+ return_exceptions: Literal[False] = False,
+) -> List[T]:
+ ...
+
+
+@overload
+async def gather(
+ *requests: Union[Future[T], Coroutine[Any, Any, T]],
+ return_exceptions: Literal[True],
+) -> List[Union[T, BaseException]]:
+ ...
+
+
+async def gather(
+ *requests: Union[Future[T], Coroutine[Any, Any, T]], return_exceptions: bool = False
+) -> Sequence[Union[T, BaseException]]:
# TODO: Should probably propagate first exception
- tasks = [
- create_task(request) if isinstance(request, Coroutine) else request
- for request in requests
- ]
- return [await task for task in tasks]
+
+ tasks_map: Dict[int, Future[T]] = {}
+ results_map: Dict[int, Union[T, BaseException]] = {}
+
+ for i, request in enumerate(requests):
+ if isinstance(request, Coroutine):
+ try:
+ tasks_map[i] = create_task(request)
+ except BaseException as e:
+ results_map[i] = e
+ else:
+ tasks_map[i] = request
+
+ for i, task in tasks_map.items():
+ try:
+ # print(f"waiting for {task}")
+ results_map[i] = await task
+ except BaseException as e:
+ results_map[i] = e
+
+ results = [results_map[i] for i in sorted(results_map.keys())]
+
+ if not return_exceptions:
+ for result in results:
+ if isinstance(result, BaseException):
+ raise result
+
+ return results
async def sleep(milliseconds: int):