from __future__ import annotations from typing import ( TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Generator, List, Optional, Sequence, Tuple, TypeVar, Union, overload, ) from uuid import uuid4 import weechat from slack.error import format_exception from slack.log import print_error from slack.shared import shared from slack.util import get_callback_name if TYPE_CHECKING: from typing_extensions import Literal, Self T = TypeVar("T") class CancelledError(Exception): pass class InvalidStateError(Exception): pass # Heavily inspired by https://github.com/python/cpython/blob/3.11/Lib/asyncio/futures.py class Future(Awaitable[T]): def __init__(self, future_id: Optional[str] = None): self.id = future_id or str(uuid4()) self._state: Literal["PENDING", "CANCELLED", "FINISHED"] = "PENDING" self._result: T self._exception: Optional[BaseException] = None self._cancel_message = None self._callbacks: List[Callable[[Self], object]] = [] def __repr__(self) -> str: return f"{self.__class__.__name__}('{self.id}')" def __await__(self) -> Generator[Future[T], None, T]: if not self.done(): yield self # This tells Task to wait for completion. if not self.done(): raise RuntimeError("await wasn't used with future") return self.result() # May raise too. def _make_cancelled_error(self): if self._cancel_message is None: return CancelledError() else: return CancelledError(self._cancel_message) def __schedule_callbacks(self): callbacks = self._callbacks[:] if not callbacks: return self._callbacks[:] = [] for callback in callbacks: callback(self) def result(self): exc = self.exception() if exc is not None: raise exc return self._result def set_result(self, result: T): if self.done(): raise InvalidStateError(f"{self._state}: {self!r}") self._result = result self._state = "FINISHED" self.__schedule_callbacks() def set_exception(self, exception: BaseException): if self.done(): raise InvalidStateError(f"{self._state}: {self!r}") if isinstance(exception, type): exception = exception() if type(exception) is StopIteration: raise TypeError( "StopIteration interacts badly with generators " "and cannot be raised into a Future" ) self._exception = exception self._state = "FINISHED" self.__schedule_callbacks() def done(self): return self._state != "PENDING" def cancelled(self): return self._state == "CANCELLED" def add_done_callback(self, callback: Callable[[Self], object]) -> None: if self.done(): callback(self) else: self._callbacks.append(callback) def remove_done_callback(self, callback: Callable[[Self], object]) -> int: filtered_callbacks = [cb for cb in self._callbacks if cb != callback] removed_count = len(self._callbacks) - len(filtered_callbacks) if removed_count: self._callbacks[:] = filtered_callbacks return removed_count def cancel(self, msg: Optional[str] = None): if self._state != "PENDING": return False self._state = "CANCELLED" self._cancel_message = msg self.__schedule_callbacks() return True def exception(self): if self.cancelled(): raise self._make_cancelled_error() elif not self.done(): raise InvalidStateError("Exception is not set.") return self._exception class FutureProcess(Future[Tuple[str, int, str, str]]): pass class FutureTimer(Future[Tuple[int]]): pass class Task(Future[T]): def __init__(self, coroutine: Coroutine[Future[Any], Any, T]): super().__init__() self.coroutine = coroutine def __repr__(self) -> str: return f"{self.__class__.__name__}('{self.id}', coroutine={self.coroutine.__qualname__})" def cancel(self, msg: Optional[str] = None): if not super().cancel(msg): return False self.coroutine.close() return True def weechat_task_cb(data: str, *args: object) -> int: future = shared.active_futures.pop(data) future.set_result(args) tasks = shared.active_tasks.pop(data) for task in tasks: task_runner(task, args) return weechat.WEECHAT_RC_OK def process_ended_task(task: Task[Any], response: object): if isinstance(response, BaseException): task.set_exception(response) else: task.set_result(response) if task.id in shared.active_tasks: tasks = shared.active_tasks.pop(task.id) for active_task in tasks: task_runner(active_task, response) if task.id in shared.active_futures: del shared.active_futures[task.id] def task_runner(task: Task[Any], response: object): while True: if task.cancelled(): return try: future = task.coroutine.send(response) except BaseException as e: result = e.value if isinstance(e, StopIteration) else e process_ended_task(task, result) return if future.done(): response = future.result() else: shared.active_tasks[future.id].append(task) shared.active_futures[future.id] = future break def create_task(coroutine: Coroutine[Future[Any], Any, T]) -> Task[T]: task = Task(coroutine) task_runner(task, None) return task def _async_task_done(task: Task[object]): exception = task.exception() if exception: print_error(f"{task} failed with: {format_exception(exception)}") def run_async(coroutine: Coroutine[Future[Any], Any, Any]) -> None: task = Task(coroutine) task.add_done_callback(_async_task_done) task_runner(task, None) @overload async def gather( *requests: Union[Future[T], Coroutine[Any, Any, T]], return_exceptions: Literal[False] = False, ) -> List[T]: ... @overload async def gather( *requests: Union[Future[T], Coroutine[Any, Any, T]], return_exceptions: Literal[True], ) -> List[Union[T, BaseException]]: ... async def gather( *requests: Union[Future[T], Coroutine[Any, Any, T]], return_exceptions: bool = False ) -> Sequence[Union[T, BaseException]]: # TODO: Should probably propagate first exception tasks_map: Dict[int, Future[T]] = {} results_map: Dict[int, Union[T, BaseException]] = {} for i, request in enumerate(requests): if isinstance(request, Coroutine): try: tasks_map[i] = create_task(request) except BaseException as e: results_map[i] = e else: tasks_map[i] = request for i, task in tasks_map.items(): try: # print(f"waiting for {task}") results_map[i] = await task except BaseException as e: results_map[i] = e results = [results_map[i] for i in sorted(results_map.keys())] if not return_exceptions: for result in results: if isinstance(result, BaseException): raise result return results async def sleep(milliseconds: int): future = FutureTimer() sleep_ms = milliseconds if milliseconds > 0 else 1 weechat.hook_timer(sleep_ms, 0, 1, get_callback_name(weechat_task_cb), future.id) return await future