diff options
-rw-r--r-- | slack/task.py | 118 | ||||
-rw-r--r-- | tests/test_task_runner.py | 8 |
2 files changed, 101 insertions, 25 deletions
diff --git a/slack/task.py b/slack/task.py index 622e35f..74c5fc1 100644 --- a/slack/task.py +++ b/slack/task.py @@ -5,6 +5,7 @@ from typing import ( TYPE_CHECKING, Any, Awaitable, + Callable, Coroutine, Dict, Generator, @@ -26,41 +27,120 @@ from slack.shared import shared from slack.util import get_callback_name if TYPE_CHECKING: - from typing_extensions import Literal + from typing_extensions import Literal, Self T = TypeVar("T") +class CancelledError(Exception): + pass + + +class InvalidStateError(Exception): + pass + + +# Heavily inspired by https://github.com/python/cpython/blob/3.11/Lib/asyncio/futures.py class Future(Awaitable[T]): def __init__(self, future_id: Optional[str] = None): if future_id is None: self.id = str(uuid4()) else: self.id = future_id - self._finished = False + self._state: Literal["PENDING", "CANCELLED", "FINISHED"] = "PENDING" self._result: Optional[T] = None + self._exception: Optional[BaseException] = None + self._cancel_message = None + self._callbacks: List[Callable[[Self], object]] = [] def __repr__(self) -> str: return f"{self.__class__.__name__}('{self.id}')" def __await__(self) -> Generator[Future[T], T, T]: + if self.cancelled(): + raise self._make_cancelled_error() result = yield self if isinstance(result, BaseException): + self.set_exception(result) raise result self.set_result(result) return result - @property - def finished(self): - return self._finished + def _make_cancelled_error(self): + if self._cancel_message is None: + return CancelledError() + else: + return CancelledError(self._cancel_message) + + def __schedule_callbacks(self): + callbacks = self._callbacks[:] + if not callbacks: + return + + self._callbacks[:] = [] + for callback in callbacks: + callback(self) - @property def result(self): + exc = self.exception() + if exc is not None: + raise exc return self._result def set_result(self, result: T): + if self.done(): + raise InvalidStateError(f"{self._state}: {self!r}") self._result = result - self._finished = True + self._state = "FINISHED" + self.__schedule_callbacks() + + def set_exception(self, exception: BaseException): + if self.done(): + raise InvalidStateError(f"{self._state}: {self!r}") + if isinstance(exception, type): + exception = exception() + if type(exception) is StopIteration: + raise TypeError( + "StopIteration interacts badly with generators " + "and cannot be raised into a Future" + ) + self._exception = exception + self._state = "FINISHED" + self.__schedule_callbacks() + + def done(self): + return self._state != "PENDING" + + def cancelled(self): + return self._state == "CANCELLED" + + def add_done_callback(self, callback: Callable[[Self], object]) -> None: + if self.done(): + callback(self) + else: + self._callbacks.append(callback) + + def remove_done_callback(self, callback: Callable[[Self], object]) -> int: + filtered_callbacks = [cb for cb in self._callbacks if cb != callback] + removed_count = len(self._callbacks) - len(filtered_callbacks) + if removed_count: + self._callbacks[:] = filtered_callbacks + return removed_count + + def cancel(self, msg: Optional[str] = None): + if self._state != "PENDING": + return False + self._state = "CANCELLED" + self._cancel_message = msg + self.__schedule_callbacks() + return True + + def exception(self): + if self.cancelled(): + raise self._make_cancelled_error() + elif not self.done(): + raise InvalidStateError("Exception is not set.") + return self._exception class FutureProcess(Future[Tuple[str, int, str, str]]): @@ -75,22 +155,15 @@ class Task(Future[T]): def __init__(self, coroutine: Coroutine[Future[Any], Any, T]): super().__init__() self.coroutine = coroutine - self._cancelled = False def __repr__(self) -> str: return f"{self.__class__.__name__}('{self.id}', coroutine={self.coroutine.__qualname__})" - def __await__(self) -> Generator[Future[T], T, T]: - if self.cancelled(): - raise RuntimeError("cannot await a cancelled task") - return super().__await__() - - def cancel(self): - self._cancelled = True + def cancel(self, msg: Optional[str] = None): + if not super().cancel(msg): + return False self.coroutine.close() - - def cancelled(self): - return self._cancelled + return True def weechat_task_cb(data: str, *args: object) -> int: @@ -103,7 +176,10 @@ def weechat_task_cb(data: str, *args: object) -> int: def process_ended_task(task: Task[Any], response: object): - task.set_result(response) + if isinstance(response, BaseException): + task.set_exception(response) + else: + task.set_result(response) if task.id in shared.active_tasks: tasks = shared.active_tasks.pop(task.id) for active_task in tasks: @@ -147,8 +223,8 @@ def task_runner(task: Task[Any], response: object): return - if future.finished: - response = future.result + if future.done(): + response = future.result() else: shared.active_tasks[future.id].append(task) shared.active_futures[future.id] = future diff --git a/tests/test_task_runner.py b/tests/test_task_runner.py index 33b74a6..2aaa9f9 100644 --- a/tests/test_task_runner.py +++ b/tests/test_task_runner.py @@ -18,7 +18,7 @@ def test_run_single_task(): assert not shared.active_tasks assert not shared.active_futures - assert task.result == ("awaitable", ("data",)) + assert task.result() == ("awaitable", ("data",)) def test_run_nested_task(): @@ -39,7 +39,7 @@ def test_run_nested_task(): assert not shared.active_tasks assert not shared.active_futures - assert task.result == ("awaitable2", ("awaitable1", ("data",))) + assert task.result() == ("awaitable2", ("awaitable1", ("data",))) def test_run_two_tasks_concurrently(): @@ -59,8 +59,8 @@ def test_run_two_tasks_concurrently(): assert not shared.active_tasks assert not shared.active_futures - assert task1.result == ("awaitable", ("data1",)) - assert task2.result == ("awaitable", ("data2",)) + assert task1.result() == ("awaitable", ("data1",)) + assert task2.result() == ("awaitable", ("data2",)) def test_task_without_await(): |