aboutsummaryrefslogtreecommitdiffstats
path: root/slack/task.py
blob: 27abd489fdfd68319f92fcb7bda4c4c3b9e122a8 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from __future__ import annotations

from typing import Any, Awaitable, Coroutine, Generator, List, Tuple, TypeVar
from uuid import uuid4

import weechat

from slack.shared import shared
from slack.util import get_callback_name

T = TypeVar("T")


class Future(Awaitable[T]):
    def __init__(self):
        self.id = str(uuid4())

    def __await__(self) -> Generator[Future[T], T, T]:
        return (yield self)


class FutureProcess(Future[Tuple[str, int, str, str]]):
    pass


class FutureTimer(Future[Tuple[int]]):
    pass


class Task(Future[T]):
    def __init__(self, coroutine: Coroutine[Future[Any], Any, T], final: bool):
        super().__init__()
        self.coroutine = coroutine
        self.final = final


def weechat_task_cb(data: str, *args: Any) -> int:
    task = shared.active_tasks.pop(data)
    task_runner(task, args)
    return weechat.WEECHAT_RC_OK


def task_runner(task: Task[Any], response: Any):
    while True:
        try:
            future = task.coroutine.send(response)
            if future.id in shared.active_responses:
                response = shared.active_responses.pop(future.id)
            else:
                if future.id in shared.active_tasks:
                    raise Exception(
                        f"future.id in active_tasks, {future.id}, {shared.active_tasks}"
                    )
                shared.active_tasks[future.id] = task
                break
        except StopIteration as e:
            if task.id in shared.active_tasks:
                task = shared.active_tasks.pop(task.id)
                response = e.value
            else:
                if task.id in shared.active_responses:
                    raise Exception(  # pylint: disable=raise-missing-from
                        f"task.id in active_responses, {task.id}, {shared.active_responses}"
                    )
                if not task.final:
                    shared.active_responses[task.id] = e.value
                break


def create_task(
    coroutine: Coroutine[Future[Any], Any, T], final: bool = False
) -> Task[T]:
    task = Task(coroutine, final)
    task_runner(task, None)
    return task


async def await_all_concurrent(requests: List[Coroutine[Any, Any, T]]) -> List[T]:
    tasks = [create_task(request) for request in requests]
    return [await task for task in tasks]


async def sleep(milliseconds: int):
    future = FutureTimer()
    weechat.hook_timer(
        milliseconds, 0, 1, get_callback_name(weechat_task_cb), future.id
    )
    return await future