aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--slack/commands.py2
-rw-r--r--slack/init.py2
-rw-r--r--slack/shared.py3
-rw-r--r--slack/task.py23
-rw-r--r--tests/test_task_runner.py21
5 files changed, 19 insertions, 32 deletions
diff --git a/slack/commands.py b/slack/commands.py
index c9c62ee..735ba35 100644
--- a/slack/commands.py
+++ b/slack/commands.py
@@ -115,7 +115,7 @@ def command_slack_connect(
for workspace in shared.workspaces.values():
await workspace.connect()
- create_task(connect(), final=True)
+ create_task(connect())
@weechat_command()
diff --git a/slack/init.py b/slack/init.py
index 68da138..b126a24 100644
--- a/slack/init.py
+++ b/slack/init.py
@@ -88,4 +88,4 @@ def main():
"",
)
- create_task(init(), final=True)
+ create_task(init())
diff --git a/slack/shared.py b/slack/shared.py
index f972c2b..eab2c92 100644
--- a/slack/shared.py
+++ b/slack/shared.py
@@ -1,7 +1,7 @@
from __future__ import annotations
from collections import defaultdict
-from typing import TYPE_CHECKING, Any, Dict, List, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List
if TYPE_CHECKING:
from slack.api import SlackWorkspace
@@ -18,7 +18,6 @@ class Shared:
self.weechat_callbacks: Dict[str, Any]
self.active_tasks: Dict[str, List[Task[Any]]] = defaultdict(list)
self.active_futures: Dict[str, Future[Any]] = {}
- self.active_responses: Dict[str, Tuple[Any, ...]] = {}
self.workspaces: Dict[str, SlackWorkspace] = {}
self.config: SlackConfig
diff --git a/slack/task.py b/slack/task.py
index 0612985..8f4e58d 100644
--- a/slack/task.py
+++ b/slack/task.py
@@ -43,10 +43,9 @@ class FutureTimer(Future[Tuple[int]]):
class Task(Future[T]):
- def __init__(self, coroutine: Coroutine[Future[Any], Any, T], final: bool):
+ def __init__(self, coroutine: Coroutine[Future[Any], Any, T]):
super().__init__()
self.coroutine = coroutine
- self.final = final
def weechat_task_cb(data: str, *args: Any) -> int:
@@ -62,9 +61,7 @@ def task_runner(task: Task[Any], response: Any):
while True:
try:
future = task.coroutine.send(response)
- if future.id in shared.active_responses:
- response = shared.active_responses.pop(future.id)
- elif future.result is not None:
+ if future.result is not None:
response = future.result
else:
shared.active_tasks[future.id].append(task)
@@ -76,21 +73,13 @@ def task_runner(task: Task[Any], response: Any):
tasks = shared.active_tasks.pop(task.id)
for active_task in tasks:
task_runner(active_task, e.value)
- break
-
- if task.id in shared.active_responses:
- raise Exception( # pylint: disable=raise-missing-from
- f"task.id in active_responses, {task.id}, {shared.active_responses}"
- )
- if not task.final:
- shared.active_responses[task.id] = e.value
+ if task.id in shared.active_futures:
+ del shared.active_futures[task.id]
break
-def create_task(
- coroutine: Coroutine[Future[Any], Any, T], final: bool = False
-) -> Task[T]:
- task = Task(coroutine, final)
+def create_task(coroutine: Coroutine[Future[Any], Any, T]) -> Task[T]:
+ task = Task(coroutine)
task_runner(task, None)
return task
diff --git a/tests/test_task_runner.py b/tests/test_task_runner.py
index ca4f78b..095b013 100644
--- a/tests/test_task_runner.py
+++ b/tests/test_task_runner.py
@@ -6,7 +6,7 @@ from slack.task import Future, create_task, weechat_task_cb
def test_run_single_task():
shared.active_tasks = defaultdict(list)
- shared.active_responses = {}
+ shared.active_futures = {}
future = Future[str]()
async def awaitable():
@@ -17,12 +17,13 @@ def test_run_single_task():
weechat_task_cb(future.id, "data")
assert not shared.active_tasks
- assert shared.active_responses == {task.id: ("awaitable", ("data",))}
+ assert not shared.active_futures
+ assert task.result == ("awaitable", ("data",))
def test_run_nested_task():
shared.active_tasks = defaultdict(list)
- shared.active_responses = {}
+ shared.active_futures = {}
future = Future[str]()
async def awaitable1():
@@ -37,14 +38,13 @@ def test_run_nested_task():
weechat_task_cb(future.id, "data")
assert not shared.active_tasks
- assert shared.active_responses == {
- task.id: ("awaitable2", ("awaitable1", ("data",)))
- }
+ assert not shared.active_futures
+ assert task.result == ("awaitable2", ("awaitable1", ("data",)))
def test_run_two_tasks_concurrently():
shared.active_tasks = defaultdict(list)
- shared.active_responses = {}
+ shared.active_futures = {}
future1 = Future[str]()
future2 = Future[str]()
@@ -58,7 +58,6 @@ def test_run_two_tasks_concurrently():
weechat_task_cb(future2.id, "data2")
assert not shared.active_tasks
- assert shared.active_responses == {
- task1.id: ("awaitable", ("data1",)),
- task2.id: ("awaitable", ("data2",)),
- }
+ assert not shared.active_futures
+ assert task1.result == ("awaitable", ("data1",))
+ assert task2.result == ("awaitable", ("data2",))