aboutsummaryrefslogtreecommitdiffstats
path: root/slack/task.py
diff options
context:
space:
mode:
authorTrygve Aaberge <trygveaa@gmail.com>2023-01-31 02:17:54 +0100
committerTrygve Aaberge <trygveaa@gmail.com>2024-02-18 11:32:53 +0100
commitfb7bc866464427f402d4855e200d84a32eba59e9 (patch)
tree196f037dc43cc52b301f8c6b9b6ecae7b1e2773d /slack/task.py
parent9a4daa3435254b38133fc9c6cb92c7fb181e7a0e (diff)
downloadwee-slack-fb7bc866464427f402d4855e200d84a32eba59e9.tar.gz
Make Future very similar to asyncio.Future
Diffstat (limited to 'slack/task.py')
-rw-r--r--slack/task.py118
1 files changed, 97 insertions, 21 deletions
diff --git a/slack/task.py b/slack/task.py
index 622e35f..74c5fc1 100644
--- a/slack/task.py
+++ b/slack/task.py
@@ -5,6 +5,7 @@ from typing import (
TYPE_CHECKING,
Any,
Awaitable,
+ Callable,
Coroutine,
Dict,
Generator,
@@ -26,41 +27,120 @@ from slack.shared import shared
from slack.util import get_callback_name
if TYPE_CHECKING:
- from typing_extensions import Literal
+ from typing_extensions import Literal, Self
T = TypeVar("T")
+class CancelledError(Exception):
+ pass
+
+
+class InvalidStateError(Exception):
+ pass
+
+
+# Heavily inspired by https://github.com/python/cpython/blob/3.11/Lib/asyncio/futures.py
class Future(Awaitable[T]):
def __init__(self, future_id: Optional[str] = None):
if future_id is None:
self.id = str(uuid4())
else:
self.id = future_id
- self._finished = False
+ self._state: Literal["PENDING", "CANCELLED", "FINISHED"] = "PENDING"
self._result: Optional[T] = None
+ self._exception: Optional[BaseException] = None
+ self._cancel_message = None
+ self._callbacks: List[Callable[[Self], object]] = []
def __repr__(self) -> str:
return f"{self.__class__.__name__}('{self.id}')"
def __await__(self) -> Generator[Future[T], T, T]:
+ if self.cancelled():
+ raise self._make_cancelled_error()
result = yield self
if isinstance(result, BaseException):
+ self.set_exception(result)
raise result
self.set_result(result)
return result
- @property
- def finished(self):
- return self._finished
+ def _make_cancelled_error(self):
+ if self._cancel_message is None:
+ return CancelledError()
+ else:
+ return CancelledError(self._cancel_message)
+
+ def __schedule_callbacks(self):
+ callbacks = self._callbacks[:]
+ if not callbacks:
+ return
+
+ self._callbacks[:] = []
+ for callback in callbacks:
+ callback(self)
- @property
def result(self):
+ exc = self.exception()
+ if exc is not None:
+ raise exc
return self._result
def set_result(self, result: T):
+ if self.done():
+ raise InvalidStateError(f"{self._state}: {self!r}")
self._result = result
- self._finished = True
+ self._state = "FINISHED"
+ self.__schedule_callbacks()
+
+ def set_exception(self, exception: BaseException):
+ if self.done():
+ raise InvalidStateError(f"{self._state}: {self!r}")
+ if isinstance(exception, type):
+ exception = exception()
+ if type(exception) is StopIteration:
+ raise TypeError(
+ "StopIteration interacts badly with generators "
+ "and cannot be raised into a Future"
+ )
+ self._exception = exception
+ self._state = "FINISHED"
+ self.__schedule_callbacks()
+
+ def done(self):
+ return self._state != "PENDING"
+
+ def cancelled(self):
+ return self._state == "CANCELLED"
+
+ def add_done_callback(self, callback: Callable[[Self], object]) -> None:
+ if self.done():
+ callback(self)
+ else:
+ self._callbacks.append(callback)
+
+ def remove_done_callback(self, callback: Callable[[Self], object]) -> int:
+ filtered_callbacks = [cb for cb in self._callbacks if cb != callback]
+ removed_count = len(self._callbacks) - len(filtered_callbacks)
+ if removed_count:
+ self._callbacks[:] = filtered_callbacks
+ return removed_count
+
+ def cancel(self, msg: Optional[str] = None):
+ if self._state != "PENDING":
+ return False
+ self._state = "CANCELLED"
+ self._cancel_message = msg
+ self.__schedule_callbacks()
+ return True
+
+ def exception(self):
+ if self.cancelled():
+ raise self._make_cancelled_error()
+ elif not self.done():
+ raise InvalidStateError("Exception is not set.")
+ return self._exception
class FutureProcess(Future[Tuple[str, int, str, str]]):
@@ -75,22 +155,15 @@ class Task(Future[T]):
def __init__(self, coroutine: Coroutine[Future[Any], Any, T]):
super().__init__()
self.coroutine = coroutine
- self._cancelled = False
def __repr__(self) -> str:
return f"{self.__class__.__name__}('{self.id}', coroutine={self.coroutine.__qualname__})"
- def __await__(self) -> Generator[Future[T], T, T]:
- if self.cancelled():
- raise RuntimeError("cannot await a cancelled task")
- return super().__await__()
-
- def cancel(self):
- self._cancelled = True
+ def cancel(self, msg: Optional[str] = None):
+ if not super().cancel(msg):
+ return False
self.coroutine.close()
-
- def cancelled(self):
- return self._cancelled
+ return True
def weechat_task_cb(data: str, *args: object) -> int:
@@ -103,7 +176,10 @@ def weechat_task_cb(data: str, *args: object) -> int:
def process_ended_task(task: Task[Any], response: object):
- task.set_result(response)
+ if isinstance(response, BaseException):
+ task.set_exception(response)
+ else:
+ task.set_result(response)
if task.id in shared.active_tasks:
tasks = shared.active_tasks.pop(task.id)
for active_task in tasks:
@@ -147,8 +223,8 @@ def task_runner(task: Task[Any], response: object):
return
- if future.finished:
- response = future.result
+ if future.done():
+ response = future.result()
else:
shared.active_tasks[future.id].append(task)
shared.active_futures[future.id] = future