diff options
-rw-r--r-- | slack/commands.py | 2 | ||||
-rw-r--r-- | slack/init.py | 2 | ||||
-rw-r--r-- | slack/shared.py | 3 | ||||
-rw-r--r-- | slack/task.py | 23 | ||||
-rw-r--r-- | tests/test_task_runner.py | 21 |
5 files changed, 19 insertions, 32 deletions
diff --git a/slack/commands.py b/slack/commands.py index c9c62ee..735ba35 100644 --- a/slack/commands.py +++ b/slack/commands.py @@ -115,7 +115,7 @@ def command_slack_connect( for workspace in shared.workspaces.values(): await workspace.connect() - create_task(connect(), final=True) + create_task(connect()) @weechat_command() diff --git a/slack/init.py b/slack/init.py index 68da138..b126a24 100644 --- a/slack/init.py +++ b/slack/init.py @@ -88,4 +88,4 @@ def main(): "", ) - create_task(init(), final=True) + create_task(init()) diff --git a/slack/shared.py b/slack/shared.py index f972c2b..eab2c92 100644 --- a/slack/shared.py +++ b/slack/shared.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, List, Tuple +from typing import TYPE_CHECKING, Any, Dict, List if TYPE_CHECKING: from slack.api import SlackWorkspace @@ -18,7 +18,6 @@ class Shared: self.weechat_callbacks: Dict[str, Any] self.active_tasks: Dict[str, List[Task[Any]]] = defaultdict(list) self.active_futures: Dict[str, Future[Any]] = {} - self.active_responses: Dict[str, Tuple[Any, ...]] = {} self.workspaces: Dict[str, SlackWorkspace] = {} self.config: SlackConfig diff --git a/slack/task.py b/slack/task.py index 0612985..8f4e58d 100644 --- a/slack/task.py +++ b/slack/task.py @@ -43,10 +43,9 @@ class FutureTimer(Future[Tuple[int]]): class Task(Future[T]): - def __init__(self, coroutine: Coroutine[Future[Any], Any, T], final: bool): + def __init__(self, coroutine: Coroutine[Future[Any], Any, T]): super().__init__() self.coroutine = coroutine - self.final = final def weechat_task_cb(data: str, *args: Any) -> int: @@ -62,9 +61,7 @@ def task_runner(task: Task[Any], response: Any): while True: try: future = task.coroutine.send(response) - if future.id in shared.active_responses: - response = shared.active_responses.pop(future.id) - elif future.result is not None: + if future.result is not None: response = future.result else: shared.active_tasks[future.id].append(task) @@ -76,21 +73,13 @@ def task_runner(task: Task[Any], response: Any): tasks = shared.active_tasks.pop(task.id) for active_task in tasks: task_runner(active_task, e.value) - break - - if task.id in shared.active_responses: - raise Exception( # pylint: disable=raise-missing-from - f"task.id in active_responses, {task.id}, {shared.active_responses}" - ) - if not task.final: - shared.active_responses[task.id] = e.value + if task.id in shared.active_futures: + del shared.active_futures[task.id] break -def create_task( - coroutine: Coroutine[Future[Any], Any, T], final: bool = False -) -> Task[T]: - task = Task(coroutine, final) +def create_task(coroutine: Coroutine[Future[Any], Any, T]) -> Task[T]: + task = Task(coroutine) task_runner(task, None) return task diff --git a/tests/test_task_runner.py b/tests/test_task_runner.py index ca4f78b..095b013 100644 --- a/tests/test_task_runner.py +++ b/tests/test_task_runner.py @@ -6,7 +6,7 @@ from slack.task import Future, create_task, weechat_task_cb def test_run_single_task(): shared.active_tasks = defaultdict(list) - shared.active_responses = {} + shared.active_futures = {} future = Future[str]() async def awaitable(): @@ -17,12 +17,13 @@ def test_run_single_task(): weechat_task_cb(future.id, "data") assert not shared.active_tasks - assert shared.active_responses == {task.id: ("awaitable", ("data",))} + assert not shared.active_futures + assert task.result == ("awaitable", ("data",)) def test_run_nested_task(): shared.active_tasks = defaultdict(list) - shared.active_responses = {} + shared.active_futures = {} future = Future[str]() async def awaitable1(): @@ -37,14 +38,13 @@ def test_run_nested_task(): weechat_task_cb(future.id, "data") assert not shared.active_tasks - assert shared.active_responses == { - task.id: ("awaitable2", ("awaitable1", ("data",))) - } + assert not shared.active_futures + assert task.result == ("awaitable2", ("awaitable1", ("data",))) def test_run_two_tasks_concurrently(): shared.active_tasks = defaultdict(list) - shared.active_responses = {} + shared.active_futures = {} future1 = Future[str]() future2 = Future[str]() @@ -58,7 +58,6 @@ def test_run_two_tasks_concurrently(): weechat_task_cb(future2.id, "data2") assert not shared.active_tasks - assert shared.active_responses == { - task1.id: ("awaitable", ("data1",)), - task2.id: ("awaitable", ("data2",)), - } + assert not shared.active_futures + assert task1.result == ("awaitable", ("data1",)) + assert task2.result == ("awaitable", ("data2",)) |