aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTrygve Aaberge <trygveaa@gmail.com>2023-01-15 20:53:49 +0100
committerTrygve Aaberge <trygveaa@gmail.com>2024-02-18 11:32:53 +0100
commit432d7ffdaf701df556a6d4e45b8e0b75bd887597 (patch)
treeb723799d688fae2ba2f958b3fdbaece9db1864d5
parent7a57687d9e6724e78f45755b0206f739c504606d (diff)
downloadwee-slack-432d7ffdaf701df556a6d4e45b8e0b75bd887597.tar.gz
Fix async functions returning None and without await never finishing
-rw-r--r--slack/task.py26
-rw-r--r--tests/test_task_runner.py16
2 files changed, 36 insertions, 6 deletions
diff --git a/slack/task.py b/slack/task.py
index 8f4e58d..9ebcad5 100644
--- a/slack/task.py
+++ b/slack/task.py
@@ -27,11 +27,25 @@ class Future(Awaitable[T]):
self.id = str(uuid4())
else:
self.id = future_id
- self.result: Optional[T] = None
+ self._finished = False
+ self._result: Optional[T] = None
def __await__(self) -> Generator[Future[T], T, T]:
- self.result = yield self
- return self.result
+ result = yield self
+ self.set_result(result)
+ return result
+
+ @property
+ def finished(self):
+ return self._finished
+
+ @property
+ def result(self):
+ return self._result
+
+ def set_result(self, result: T):
+ self._result = result
+ self._finished = True
class FutureProcess(Future[Tuple[str, int, str, str]]):
@@ -50,7 +64,7 @@ class Task(Future[T]):
def weechat_task_cb(data: str, *args: Any) -> int:
future = shared.active_futures.pop(data)
- future.result = args
+ future.set_result(args)
tasks = shared.active_tasks.pop(data)
for task in tasks:
task_runner(task, args)
@@ -61,14 +75,14 @@ def task_runner(task: Task[Any], response: Any):
while True:
try:
future = task.coroutine.send(response)
- if future.result is not None:
+ if future.finished:
response = future.result
else:
shared.active_tasks[future.id].append(task)
shared.active_futures[future.id] = future
break
except StopIteration as e:
- task.result = e.value
+ task.set_result(e.value)
if task.id in shared.active_tasks:
tasks = shared.active_tasks.pop(task.id)
for active_task in tasks:
diff --git a/tests/test_task_runner.py b/tests/test_task_runner.py
index 095b013..33b74a6 100644
--- a/tests/test_task_runner.py
+++ b/tests/test_task_runner.py
@@ -61,3 +61,19 @@ def test_run_two_tasks_concurrently():
assert not shared.active_futures
assert task1.result == ("awaitable", ("data1",))
assert task2.result == ("awaitable", ("data2",))
+
+
+def test_task_without_await():
+ shared.active_tasks = defaultdict(list)
+ shared.active_futures = {}
+
+ async def fun_without_await():
+ pass
+
+ async def run():
+ await create_task(fun_without_await())
+
+ create_task(run())
+
+ assert not shared.active_tasks
+ assert not shared.active_futures