diff options
Diffstat (limited to 'slack')
-rw-r--r-- | slack/slack_api.py | 21 | ||||
-rw-r--r-- | slack/slack_conversation.py | 58 | ||||
-rw-r--r-- | slack/slack_workspace.py | 145 |
3 files changed, 173 insertions, 51 deletions
diff --git a/slack/slack_api.py b/slack/slack_api.py index fb6952d..7ce919f 100644 --- a/slack/slack_api.py +++ b/slack/slack_api.py @@ -14,6 +14,8 @@ from slack.util import chunked if TYPE_CHECKING: from slack_api.slack_bots_info import SlackBotInfoResponse, SlackBotsInfoResponse + from slack_api.slack_client_counts import SlackClientCountsResponse + from slack_api.slack_client_userboot import SlackClientUserbootResponse from slack_api.slack_common import SlackGenericResponse from slack_api.slack_conversations_history import SlackConversationsHistoryResponse from slack_api.slack_conversations_info import SlackConversationsInfoResponse @@ -56,10 +58,7 @@ class SlackEdgeApi(SlackApiCommon): return self.workspace.token_type == "session" async def _fetch_edgeapi(self, method: str, params: EdgeParams = {}): - enterprise_id_part = ( - f"{self.workspace.enterprise_id}/" if self.workspace.enterprise_id else "" - ) - url = f"https://edgeapi.slack.com/cache/{enterprise_id_part}{self.workspace.id}/{method}" + url = f"https://edgeapi.slack.com/cache/{self.workspace.id}/{method}" options = self._get_request_options() options["postfields"] = json.dumps(params) options["httpheader"] += "\nContent-Type: application/json" @@ -284,6 +283,20 @@ class SlackApi(SlackApiCommon): raise SlackApiError(self.workspace, method, response) return response + async def fetch_client_userboot(self): + method = "client.userBoot" + response: SlackClientUserbootResponse = await self._fetch(method) + if response["ok"] is False: + raise SlackApiError(self.workspace, method, response) + return response + + async def fetch_client_counts(self): + method = "client.counts" + response: SlackClientCountsResponse = await self._fetch(method) + if response["ok"] is False: + raise SlackApiError(self.workspace, method, response) + return response + async def conversations_close(self, conversation: SlackConversation): method = "conversations.close" params: Params = {"channel": conversation.id} diff --git a/slack/slack_conversation.py b/slack/slack_conversation.py index 642ed82..e39428f 100644 --- a/slack/slack_conversation.py +++ b/slack/slack_conversation.py @@ -35,7 +35,9 @@ from slack.task import Task, gather, run_async from slack.util import unhtmlescape, with_color if TYPE_CHECKING: - from slack_api.slack_conversations_info import SlackConversationsInfo + from slack_api.slack_client_userboot import SlackClientUserbootIm + from slack_api.slack_conversations_info import SlackConversationsInfo, SlackTopic + from slack_api.slack_users_conversations import SlackUsersConversationsNotIm from slack_rtm.slack_rtm_message import ( SlackMessageChanged, SlackMessageDeleted, @@ -48,6 +50,10 @@ if TYPE_CHECKING: from slack.slack_workspace import SlackWorkspace + SlackConversationsInfoInternal = Union[ + SlackConversationsInfo, SlackUsersConversationsNotIm, SlackClientUserbootIm + ] + def update_buffer_props(): for workspace in shared.workspaces.values(): @@ -133,7 +139,7 @@ class SlackConversation(SlackBuffer): async def __new__( cls, workspace: SlackWorkspace, - info: SlackConversationsInfo, + info: SlackConversationsInfoInternal, ): conversation = super().__new__(cls) conversation.__init__(workspace, info) @@ -142,7 +148,7 @@ class SlackConversation(SlackBuffer): def __init__( self, workspace: SlackWorkspace, - info: SlackConversationsInfo, + info: SlackConversationsInfoInternal, ): super().__init__() self._workspace = workspace @@ -155,11 +161,27 @@ class SlackConversation(SlackBuffer): self.nicklist_needs_refresh = True self.message_hashes = SlackConversationMessageHashes(self) + self._last_read = ( + SlackTs(self._info["last_read"]) + if "last_read" in self._info + else SlackTs("0.0") + ) + + self._topic: SlackTopic = ( + self._info["topic"] + if "topic" in self._info + else {"value": "", "creator": "", "last_set": 0} + ) + async def __init_async(self): if self._info["is_im"] is True: self._im_user = await self._workspace.users[self._info["user"]] elif self.type == "mpim": - members = await self.load_members(load_all=True) + if "members" in self._info: + members = self._info["members"] + else: + members = await self.load_members(load_all=True) + self._mpim_users = await gather( *( self._workspace.users[user_id] @@ -225,11 +247,11 @@ class SlackConversation(SlackBuffer): @property def last_read(self) -> SlackTs: - return SlackTs(self._info["last_read"]) + return self._last_read @last_read.setter def last_read(self, value: SlackTs): - self._info["last_read"] = value + self._last_read = value self.set_unread_and_hotlist() @property @@ -295,16 +317,14 @@ class SlackConversation(SlackBuffer): def buffer_title(self) -> str: # TODO: unfurl and apply styles - topic = unhtmlescape(self._info.get("topic", {}).get("value", "")) + topic = unhtmlescape(self._topic["value"]) if self._im_user: status = f"{self._im_user.status_emoji} {self._im_user.status_text}".strip() return " | ".join(part for part in [status, topic] if part) return topic def set_topic(self, title: str): - if "topic" not in self._info: - self._info["topic"] = {"value": "", "creator": "", "last_set": 0} - self._info["topic"]["value"] = title + self._topic["value"] = title self.update_buffer_props() def get_name_and_buffer_props(self) -> Tuple[str, Dict[str, str]]: @@ -524,12 +544,18 @@ class SlackConversation(SlackBuffer): async def nicklist_update(self): if self.nicklist_needs_refresh and self.type != "im": self.nicklist_needs_refresh = False - members = await self.load_members() - users = await gather( - *(self.workspace.users[user_id] for user_id in members) - ) - for user in users: - self.nicklist_add_user(user) + try: + members = await self.load_members() + except SlackApiError as e: + if e.response["error"] == "enterprise_is_restricted": + return + raise e + else: + users = await gather( + *(self.workspace.users[user_id] for user_id in members) + ) + for user in users: + self.nicklist_add_user(user) def nicklist_add_user( self, user: Optional[Union[SlackUser, SlackBot]], nick: Optional[str] = None diff --git a/slack/slack_workspace.py b/slack/slack_workspace.py index 762009a..9348e1a 100644 --- a/slack/slack_workspace.py +++ b/slack/slack_workspace.py @@ -5,7 +5,18 @@ import socket import ssl import time from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Dict, Generic, Iterable, Optional, Set, Type, TypeVar +from typing import ( + TYPE_CHECKING, + Dict, + Generic, + Iterable, + List, + Mapping, + Optional, + Set, + Type, + TypeVar, +) import weechat from websocket import ( @@ -35,15 +46,16 @@ from slack.util import get_callback_name if TYPE_CHECKING: from slack_api.slack_bots_info import SlackBotInfo - from slack_api.slack_conversations_info import SlackConversationsInfo from slack_api.slack_usergroups_info import SlackUsergroupInfo from slack_api.slack_users_conversations import SlackUsersConversations from slack_api.slack_users_info import SlackUserInfo from slack_rtm.slack_rtm_message import SlackRtmMessage from typing_extensions import Literal + + from slack.slack_conversation import SlackConversationsInfoInternal else: SlackBotInfo = object - SlackConversationsInfo = object + SlackConversationsInfoInternal = object SlackUsergroupInfo = object SlackUserInfo = object @@ -52,7 +64,7 @@ SlackItemClass = TypeVar( ) SlackItemInfo = TypeVar( "SlackItemInfo", - SlackConversationsInfo, + SlackConversationsInfoInternal, SlackUserInfo, SlackBotInfo, SlackUsergroupInfo, @@ -71,19 +83,37 @@ class SlackItem( self[key] = create_task(self._create_item(key)) return self[key] - def initialize_items(self, item_ids: Iterable[str]): + def initialize_items( + self, + item_ids: Iterable[str], + items_info_prefetched: Optional[Mapping[str, SlackItemInfo]] = None, + ): item_ids_to_init = set(item_id for item_id in item_ids if item_id not in self) if item_ids_to_init: - items_info_task = create_task(self._fetch_items_info(item_ids_to_init)) + item_ids_to_fetch = ( + set( + item_id + for item_id in item_ids_to_init + if item_id not in items_info_prefetched + ) + if items_info_prefetched + else item_ids_to_init + ) + items_info_task = create_task(self._fetch_items_info(item_ids_to_fetch)) for item_id in item_ids_to_init: - self[item_id] = create_task(self._create_item(item_id, items_info_task)) + self[item_id] = create_task( + self._create_item(item_id, items_info_task, items_info_prefetched) + ) async def _create_item( self, item_id: str, items_info_task: Optional[Future[Dict[str, SlackItemInfo]]] = None, + items_info_prefetched: Optional[Mapping[str, SlackItemInfo]] = None, ) -> SlackItemClass: - if items_info_task: + if items_info_prefetched and item_id in items_info_prefetched: + return await self._create_item_from_info(items_info_prefetched[item_id]) + elif items_info_task: items_info = await items_info_task item = items_info.get(item_id) if item is None: @@ -103,13 +133,13 @@ class SlackItem( raise NotImplementedError() -class SlackConversations(SlackItem[SlackConversation, SlackConversationsInfo]): +class SlackConversations(SlackItem[SlackConversation, SlackConversationsInfoInternal]): def __init__(self, workspace: SlackWorkspace): super().__init__(workspace, SlackConversation) async def _fetch_items_info( self, item_ids: Iterable[str] - ) -> Dict[str, SlackConversationsInfo]: + ) -> Dict[str, SlackConversationsInfoInternal]: responses = await gather( *( self.workspace.api.fetch_conversations_info(item_id) @@ -121,7 +151,7 @@ class SlackConversations(SlackItem[SlackConversation, SlackConversationsInfo]): } async def _create_item_from_info( - self, item_info: SlackConversationsInfo + self, item_info: SlackConversationsInfoInternal ) -> SlackConversation: return await self._item_class(self.workspace, item_info) @@ -224,17 +254,9 @@ class SlackWorkspace: weechat.prnt("", f"Connected to workspace {self.name}") self._connect_task = None - async def _connect(self) -> None: - try: - rtm_connect = await self.api.fetch_rtm_connect() - except SlackApiError as e: - print_error( - f'failed connecting to workspace "{self.name}": {e.response["error"]}' - ) - return - + async def _connect_oauth(self) -> List[SlackConversation]: + rtm_connect = await self.api.fetch_rtm_connect() self.id = rtm_connect["team"]["id"] - self.enterprise_id = rtm_connect["team"].get("enterprise_id") self.my_user = await self.users[rtm_connect["self"]["id"]] await self._connect_ws(rtm_connect["url"]) @@ -242,16 +264,6 @@ class SlackWorkspace: prefs = await self.api.fetch_users_get_prefs("muted_channels") self.muted_channels = set(prefs["prefs"]["muted_channels"].split(",")) - custom_emojis_response = await self.api.fetch_emoji_list() - self.custom_emojis = custom_emojis_response["emoji"] - - if not self.api.edgeapi.is_available: - usergroups = await self.api.fetch_usergroups_list() - for usergroup in usergroups["usergroups"]: - future = Future[SlackUsergroup]() - future.set_result(SlackUsergroup(self, usergroup)) - self.usergroups[usergroup["id"]] = future - users_conversations_response = await self.api.fetch_users_conversations( "public_channel,private_channel,mpim,im" ) @@ -264,6 +276,77 @@ class SlackWorkspace: conversations_to_open = [ c for c in conversations_if_should_open if c is not None ] + return conversations_to_open + + async def _connect_session(self) -> List[SlackConversation]: + user_boot_task = create_task(self.api.fetch_client_userboot()) + client_counts_task = create_task(self.api.fetch_client_counts()) + user_boot = await user_boot_task + client_counts = await client_counts_task + + self.id = user_boot["team"]["id"] + my_user_id = user_boot["self"]["id"] + # self.users.initialize_items(my_user_id, {my_user_id: user_boot["self"]}) + self.my_user = await self.users[my_user_id] + self.muted_channels = set(user_boot["prefs"]["muted_channels"].split(",")) + + await self._connect_ws( + f"wss://wss-primary.slack.com/?token={self.config.api_token.value}&batch_presence_aware=1" + ) + + conversation_counts = ( + client_counts["channels"] + client_counts["mpims"] + client_counts["ims"] + ) + + conversation_ids = set( + [ + channel["id"] + for channel in user_boot["channels"] + if not channel["is_mpim"] + ] + + user_boot["is_open"] + + [count["id"] for count in conversation_counts if count["has_unreads"]] + ) + + channel_infos: Dict[str, SlackConversationsInfoInternal] = { + channel["id"]: channel for channel in user_boot["channels"] + } + self.conversations.initialize_items(conversation_ids, channel_infos) + conversations = { + conversation_id: await self.conversations[conversation_id] + for conversation_id in conversation_ids + } + + for conversation_count in conversation_counts: + if conversation_count["id"] in conversations: + conversation = conversations[conversation_count["id"]] + # TODO: Update without moving unread marker to the bottom + if conversation.last_read == SlackTs("0.0"): + conversation.last_read = SlackTs(conversation_count["last_read"]) + + return list(conversations.values()) + + async def _connect(self) -> None: + try: + if self.token_type == "session": + conversations_to_open = await self._connect_session() + else: + conversations_to_open = await self._connect_oauth() + except SlackApiError as e: + print_error( + f'failed connecting to workspace "{self.name}": {e.response["error"]}' + ) + return + + custom_emojis_response = await self.api.fetch_emoji_list() + self.custom_emojis = custom_emojis_response["emoji"] + + if not self.api.edgeapi.is_available: + usergroups = await self.api.fetch_usergroups_list() + for usergroup in usergroups["usergroups"]: + future = Future[SlackUsergroup]() + future.set_result(SlackUsergroup(self, usergroup)) + self.usergroups[usergroup["id"]] = future for conversation in sorted( conversations_to_open, key=lambda conversation: conversation.sort_key() |