From adf20323ce99e8829a2ab53f8d487f5704d28320 Mon Sep 17 00:00:00 2001 From: Trygve Aaberge Date: Thu, 12 Jan 2023 22:47:51 +0100 Subject: Don't fetch the same user multiple times --- slack/api.py | 26 +++++++++++------- slack/http.py | 5 ++-- slack/shared.py | 8 +++--- slack/task.py | 67 +++++++++++++++++++++++++++++++---------------- tests/test_task_runner.py | 8 +++--- 5 files changed, 74 insertions(+), 40 deletions(-) diff --git a/slack/api.py b/slack/api.py index 66cd687..4fdef4d 100644 --- a/slack/api.py +++ b/slack/api.py @@ -4,14 +4,14 @@ import json import re import time from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from urllib.parse import urlencode import weechat from slack.http import http_request from slack.shared import shared -from slack.task import create_task, gather +from slack.task import Future, create_task, gather from slack.util import get_callback_name if TYPE_CHECKING: @@ -76,7 +76,7 @@ class SlackWorkspace: self.is_connected = False self.nick = "TODO" # Maybe make private, so you have to use get_user? Maybe make get_user a getter, though don't know if that's a problem since it's async - self.users: Dict[str, SlackUser] = {} + self.users: Dict[str, Future[SlackUser]] = {} self.conversations: Dict[str, SlackConversation] = {} async def connect(self): @@ -104,14 +104,17 @@ class SlackWorkspace: self.is_connected = True weechat.bar_item_update("input_text") - async def get_user(self, id: str) -> SlackUser: - if id in self.users: - return self.users[id] + async def create_user(self, id: str) -> SlackUser: user = SlackUser(self, id) await user.init() - self.users[id] = user return user + async def get_user(self, id: str) -> SlackUser: + if id in self.users: + return await self.users[id] + self.users[id] = create_task(self.create_user(id)) + return await self.users[id] + class SlackUser: def __init__(self, workspace: SlackWorkspace, id: str): @@ -223,10 +226,13 @@ class SlackMessage: async def unfurl_refs(self, message: str): re_user = re.compile("<@([^>]+)>") - user_ids = re_user.findall(message) - await gather(*(self.workspace.get_user(user_id) for user_id in user_ids)) + user_ids: List[str] = re_user.findall(message) + users_list = await gather( + *(self.workspace.get_user(user_id) for user_id in user_ids) + ) + users = dict(zip(user_ids, users_list)) def unfurl_user(user_id: str): - return "@" + self.workspace.users[user_id].name + return "@" + users[user_id].name return re_user.sub(lambda match: unfurl_user(match.group(1)), message) diff --git a/slack/http.py b/slack/http.py index 794164d..ca99526 100644 --- a/slack/http.py +++ b/slack/http.py @@ -42,10 +42,11 @@ async def hook_process_hashtable(command: str, options: Dict[str, str], timeout: return_code = -1 while return_code == -1: - _, return_code, out, err = await future + next_future = FutureProcess(future.id) + _, return_code, out, err = await next_future log( LogLevel.TRACE, - f"hook_process_hashtable intermediary response ({future.id}): command: {command}", + f"hook_process_hashtable intermediary response ({next_future.id}): command: {command}", ) stdout.write(out) stderr.write(err) diff --git a/slack/shared.py b/slack/shared.py index 68400a2..f972c2b 100644 --- a/slack/shared.py +++ b/slack/shared.py @@ -1,11 +1,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Tuple +from collections import defaultdict +from typing import TYPE_CHECKING, Any, Dict, List, Tuple if TYPE_CHECKING: from slack.api import SlackWorkspace from slack.config import SlackConfig - from slack.task import Task + from slack.task import Future, Task class Shared: @@ -15,7 +16,8 @@ class Shared: self.weechat_version: int self.weechat_callbacks: Dict[str, Any] - self.active_tasks: Dict[str, Task[Any]] = {} + self.active_tasks: Dict[str, List[Task[Any]]] = defaultdict(list) + self.active_futures: Dict[str, Future[Any]] = {} self.active_responses: Dict[str, Tuple[Any, ...]] = {} self.workspaces: Dict[str, SlackWorkspace] = {} self.config: SlackConfig diff --git a/slack/task.py b/slack/task.py index fefa21f..0612985 100644 --- a/slack/task.py +++ b/slack/task.py @@ -1,6 +1,16 @@ from __future__ import annotations -from typing import Any, Awaitable, Coroutine, Generator, List, Tuple, TypeVar +from typing import ( + Any, + Awaitable, + Coroutine, + Generator, + List, + Optional, + Tuple, + TypeVar, + Union, +) from uuid import uuid4 import weechat @@ -12,11 +22,16 @@ T = TypeVar("T") class Future(Awaitable[T]): - def __init__(self): - self.id = str(uuid4()) + def __init__(self, future_id: Optional[str] = None): + if future_id is None: + self.id = str(uuid4()) + else: + self.id = future_id + self.result: Optional[T] = None def __await__(self) -> Generator[Future[T], T, T]: - return (yield self) + self.result = yield self + return self.result class FutureProcess(Future[Tuple[str, int, str, str]]): @@ -35,8 +50,11 @@ class Task(Future[T]): def weechat_task_cb(data: str, *args: Any) -> int: - task = shared.active_tasks.pop(data) - task_runner(task, args) + future = shared.active_futures.pop(data) + future.result = args + tasks = shared.active_tasks.pop(data) + for task in tasks: + task_runner(task, args) return weechat.WEECHAT_RC_OK @@ -46,26 +64,28 @@ def task_runner(task: Task[Any], response: Any): future = task.coroutine.send(response) if future.id in shared.active_responses: response = shared.active_responses.pop(future.id) + elif future.result is not None: + response = future.result else: - if future.id in shared.active_tasks: - raise Exception( - f"future.id in active_tasks, {future.id}, {shared.active_tasks}" - ) - shared.active_tasks[future.id] = task + shared.active_tasks[future.id].append(task) + shared.active_futures[future.id] = future break except StopIteration as e: + task.result = e.value if task.id in shared.active_tasks: - task = shared.active_tasks.pop(task.id) - response = e.value - else: - if task.id in shared.active_responses: - raise Exception( # pylint: disable=raise-missing-from - f"task.id in active_responses, {task.id}, {shared.active_responses}" - ) - if not task.final: - shared.active_responses[task.id] = e.value + tasks = shared.active_tasks.pop(task.id) + for active_task in tasks: + task_runner(active_task, e.value) break + if task.id in shared.active_responses: + raise Exception( # pylint: disable=raise-missing-from + f"task.id in active_responses, {task.id}, {shared.active_responses}" + ) + if not task.final: + shared.active_responses[task.id] = e.value + break + def create_task( coroutine: Coroutine[Future[Any], Any, T], final: bool = False @@ -75,9 +95,12 @@ def create_task( return task -async def gather(*requests: Coroutine[Any, Any, T]) -> List[T]: +async def gather(*requests: Union[Future[T], Coroutine[Any, Any, T]]) -> List[T]: # TODO: Should probably propagate first exception - tasks = [create_task(request) for request in requests] + tasks = [ + create_task(request) if isinstance(request, Coroutine) else request + for request in requests + ] return [await task for task in tasks] diff --git a/tests/test_task_runner.py b/tests/test_task_runner.py index 17a3ec8..ca4f78b 100644 --- a/tests/test_task_runner.py +++ b/tests/test_task_runner.py @@ -1,9 +1,11 @@ +from collections import defaultdict + from slack.shared import shared from slack.task import Future, create_task, weechat_task_cb def test_run_single_task(): - shared.active_tasks = {} + shared.active_tasks = defaultdict(list) shared.active_responses = {} future = Future[str]() @@ -19,7 +21,7 @@ def test_run_single_task(): def test_run_nested_task(): - shared.active_tasks = {} + shared.active_tasks = defaultdict(list) shared.active_responses = {} future = Future[str]() @@ -41,7 +43,7 @@ def test_run_nested_task(): def test_run_two_tasks_concurrently(): - shared.active_tasks = {} + shared.active_tasks = defaultdict(list) shared.active_responses = {} future1 = Future[str]() future2 = Future[str]() -- cgit