aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--slack/api.py26
-rw-r--r--slack/http.py5
-rw-r--r--slack/shared.py8
-rw-r--r--slack/task.py67
-rw-r--r--tests/test_task_runner.py8
5 files changed, 74 insertions, 40 deletions
diff --git a/slack/api.py b/slack/api.py
index 66cd687..4fdef4d 100644
--- a/slack/api.py
+++ b/slack/api.py
@@ -4,14 +4,14 @@ import json
import re
import time
from contextlib import contextmanager
-from typing import TYPE_CHECKING, Any, Dict, Optional, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from urllib.parse import urlencode
import weechat
from slack.http import http_request
from slack.shared import shared
-from slack.task import create_task, gather
+from slack.task import Future, create_task, gather
from slack.util import get_callback_name
if TYPE_CHECKING:
@@ -76,7 +76,7 @@ class SlackWorkspace:
self.is_connected = False
self.nick = "TODO"
# Maybe make private, so you have to use get_user? Maybe make get_user a getter, though don't know if that's a problem since it's async
- self.users: Dict[str, SlackUser] = {}
+ self.users: Dict[str, Future[SlackUser]] = {}
self.conversations: Dict[str, SlackConversation] = {}
async def connect(self):
@@ -104,14 +104,17 @@ class SlackWorkspace:
self.is_connected = True
weechat.bar_item_update("input_text")
- async def get_user(self, id: str) -> SlackUser:
- if id in self.users:
- return self.users[id]
+ async def create_user(self, id: str) -> SlackUser:
user = SlackUser(self, id)
await user.init()
- self.users[id] = user
return user
+ async def get_user(self, id: str) -> SlackUser:
+ if id in self.users:
+ return await self.users[id]
+ self.users[id] = create_task(self.create_user(id))
+ return await self.users[id]
+
class SlackUser:
def __init__(self, workspace: SlackWorkspace, id: str):
@@ -223,10 +226,13 @@ class SlackMessage:
async def unfurl_refs(self, message: str):
re_user = re.compile("<@([^>]+)>")
- user_ids = re_user.findall(message)
- await gather(*(self.workspace.get_user(user_id) for user_id in user_ids))
+ user_ids: List[str] = re_user.findall(message)
+ users_list = await gather(
+ *(self.workspace.get_user(user_id) for user_id in user_ids)
+ )
+ users = dict(zip(user_ids, users_list))
def unfurl_user(user_id: str):
- return "@" + self.workspace.users[user_id].name
+ return "@" + users[user_id].name
return re_user.sub(lambda match: unfurl_user(match.group(1)), message)
diff --git a/slack/http.py b/slack/http.py
index 794164d..ca99526 100644
--- a/slack/http.py
+++ b/slack/http.py
@@ -42,10 +42,11 @@ async def hook_process_hashtable(command: str, options: Dict[str, str], timeout:
return_code = -1
while return_code == -1:
- _, return_code, out, err = await future
+ next_future = FutureProcess(future.id)
+ _, return_code, out, err = await next_future
log(
LogLevel.TRACE,
- f"hook_process_hashtable intermediary response ({future.id}): command: {command}",
+ f"hook_process_hashtable intermediary response ({next_future.id}): command: {command}",
)
stdout.write(out)
stderr.write(err)
diff --git a/slack/shared.py b/slack/shared.py
index 68400a2..f972c2b 100644
--- a/slack/shared.py
+++ b/slack/shared.py
@@ -1,11 +1,12 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Any, Dict, Tuple
+from collections import defaultdict
+from typing import TYPE_CHECKING, Any, Dict, List, Tuple
if TYPE_CHECKING:
from slack.api import SlackWorkspace
from slack.config import SlackConfig
- from slack.task import Task
+ from slack.task import Future, Task
class Shared:
@@ -15,7 +16,8 @@ class Shared:
self.weechat_version: int
self.weechat_callbacks: Dict[str, Any]
- self.active_tasks: Dict[str, Task[Any]] = {}
+ self.active_tasks: Dict[str, List[Task[Any]]] = defaultdict(list)
+ self.active_futures: Dict[str, Future[Any]] = {}
self.active_responses: Dict[str, Tuple[Any, ...]] = {}
self.workspaces: Dict[str, SlackWorkspace] = {}
self.config: SlackConfig
diff --git a/slack/task.py b/slack/task.py
index fefa21f..0612985 100644
--- a/slack/task.py
+++ b/slack/task.py
@@ -1,6 +1,16 @@
from __future__ import annotations
-from typing import Any, Awaitable, Coroutine, Generator, List, Tuple, TypeVar
+from typing import (
+ Any,
+ Awaitable,
+ Coroutine,
+ Generator,
+ List,
+ Optional,
+ Tuple,
+ TypeVar,
+ Union,
+)
from uuid import uuid4
import weechat
@@ -12,11 +22,16 @@ T = TypeVar("T")
class Future(Awaitable[T]):
- def __init__(self):
- self.id = str(uuid4())
+ def __init__(self, future_id: Optional[str] = None):
+ if future_id is None:
+ self.id = str(uuid4())
+ else:
+ self.id = future_id
+ self.result: Optional[T] = None
def __await__(self) -> Generator[Future[T], T, T]:
- return (yield self)
+ self.result = yield self
+ return self.result
class FutureProcess(Future[Tuple[str, int, str, str]]):
@@ -35,8 +50,11 @@ class Task(Future[T]):
def weechat_task_cb(data: str, *args: Any) -> int:
- task = shared.active_tasks.pop(data)
- task_runner(task, args)
+ future = shared.active_futures.pop(data)
+ future.result = args
+ tasks = shared.active_tasks.pop(data)
+ for task in tasks:
+ task_runner(task, args)
return weechat.WEECHAT_RC_OK
@@ -46,26 +64,28 @@ def task_runner(task: Task[Any], response: Any):
future = task.coroutine.send(response)
if future.id in shared.active_responses:
response = shared.active_responses.pop(future.id)
+ elif future.result is not None:
+ response = future.result
else:
- if future.id in shared.active_tasks:
- raise Exception(
- f"future.id in active_tasks, {future.id}, {shared.active_tasks}"
- )
- shared.active_tasks[future.id] = task
+ shared.active_tasks[future.id].append(task)
+ shared.active_futures[future.id] = future
break
except StopIteration as e:
+ task.result = e.value
if task.id in shared.active_tasks:
- task = shared.active_tasks.pop(task.id)
- response = e.value
- else:
- if task.id in shared.active_responses:
- raise Exception( # pylint: disable=raise-missing-from
- f"task.id in active_responses, {task.id}, {shared.active_responses}"
- )
- if not task.final:
- shared.active_responses[task.id] = e.value
+ tasks = shared.active_tasks.pop(task.id)
+ for active_task in tasks:
+ task_runner(active_task, e.value)
break
+ if task.id in shared.active_responses:
+ raise Exception( # pylint: disable=raise-missing-from
+ f"task.id in active_responses, {task.id}, {shared.active_responses}"
+ )
+ if not task.final:
+ shared.active_responses[task.id] = e.value
+ break
+
def create_task(
coroutine: Coroutine[Future[Any], Any, T], final: bool = False
@@ -75,9 +95,12 @@ def create_task(
return task
-async def gather(*requests: Coroutine[Any, Any, T]) -> List[T]:
+async def gather(*requests: Union[Future[T], Coroutine[Any, Any, T]]) -> List[T]:
# TODO: Should probably propagate first exception
- tasks = [create_task(request) for request in requests]
+ tasks = [
+ create_task(request) if isinstance(request, Coroutine) else request
+ for request in requests
+ ]
return [await task for task in tasks]
diff --git a/tests/test_task_runner.py b/tests/test_task_runner.py
index 17a3ec8..ca4f78b 100644
--- a/tests/test_task_runner.py
+++ b/tests/test_task_runner.py
@@ -1,9 +1,11 @@
+from collections import defaultdict
+
from slack.shared import shared
from slack.task import Future, create_task, weechat_task_cb
def test_run_single_task():
- shared.active_tasks = {}
+ shared.active_tasks = defaultdict(list)
shared.active_responses = {}
future = Future[str]()
@@ -19,7 +21,7 @@ def test_run_single_task():
def test_run_nested_task():
- shared.active_tasks = {}
+ shared.active_tasks = defaultdict(list)
shared.active_responses = {}
future = Future[str]()
@@ -41,7 +43,7 @@ def test_run_nested_task():
def test_run_two_tasks_concurrently():
- shared.active_tasks = {}
+ shared.active_tasks = defaultdict(list)
shared.active_responses = {}
future1 = Future[str]()
future2 = Future[str]()