diff options
-rw-r--r-- | slack/task.py | 26 | ||||
-rw-r--r-- | tests/test_task_runner.py | 16 |
2 files changed, 36 insertions, 6 deletions
diff --git a/slack/task.py b/slack/task.py index 8f4e58d..9ebcad5 100644 --- a/slack/task.py +++ b/slack/task.py @@ -27,11 +27,25 @@ class Future(Awaitable[T]): self.id = str(uuid4()) else: self.id = future_id - self.result: Optional[T] = None + self._finished = False + self._result: Optional[T] = None def __await__(self) -> Generator[Future[T], T, T]: - self.result = yield self - return self.result + result = yield self + self.set_result(result) + return result + + @property + def finished(self): + return self._finished + + @property + def result(self): + return self._result + + def set_result(self, result: T): + self._result = result + self._finished = True class FutureProcess(Future[Tuple[str, int, str, str]]): @@ -50,7 +64,7 @@ class Task(Future[T]): def weechat_task_cb(data: str, *args: Any) -> int: future = shared.active_futures.pop(data) - future.result = args + future.set_result(args) tasks = shared.active_tasks.pop(data) for task in tasks: task_runner(task, args) @@ -61,14 +75,14 @@ def task_runner(task: Task[Any], response: Any): while True: try: future = task.coroutine.send(response) - if future.result is not None: + if future.finished: response = future.result else: shared.active_tasks[future.id].append(task) shared.active_futures[future.id] = future break except StopIteration as e: - task.result = e.value + task.set_result(e.value) if task.id in shared.active_tasks: tasks = shared.active_tasks.pop(task.id) for active_task in tasks: diff --git a/tests/test_task_runner.py b/tests/test_task_runner.py index 095b013..33b74a6 100644 --- a/tests/test_task_runner.py +++ b/tests/test_task_runner.py @@ -61,3 +61,19 @@ def test_run_two_tasks_concurrently(): assert not shared.active_futures assert task1.result == ("awaitable", ("data1",)) assert task2.result == ("awaitable", ("data2",)) + + +def test_task_without_await(): + shared.active_tasks = defaultdict(list) + shared.active_futures = {} + + async def fun_without_await(): + pass + + async def run(): + await create_task(fun_without_await()) + + create_task(run()) + + assert not shared.active_tasks + assert not shared.active_futures |