From 3b98fbeff044299104f7d00a3c9b56d9a53b0df0 Mon Sep 17 00:00:00 2001 From: Trygve Aaberge Date: Mon, 11 Dec 2023 21:15:51 +0100 Subject: Support API restricted workspaces Some enterprise workspaces are restricted in which API methods they can use, so we have to use some of the APIs the official web client uses (which can't be used by OAuth tokens) instead (mainly to initialize the workspace with client.userBoot and client.counts, and to connect to the websocket). This also has the benefit of being more performant, as the API methods the web client uses are more suited for creating a client than the official API methods. I think which API methods are restricted may be configured by the workspace admins, so it may not be the same for different workspaces, but for me it seems to be at least rtm.connect, users.conversations, conversations.list and conversations.members, so these are the ones I've changed to be conditional on the token type. --- slack/slack_api.py | 21 +++++-- slack/slack_conversation.py | 58 +++++++++++++----- slack/slack_workspace.py | 145 ++++++++++++++++++++++++++++++++++---------- 3 files changed, 173 insertions(+), 51 deletions(-) (limited to 'slack') diff --git a/slack/slack_api.py b/slack/slack_api.py index fb6952d..7ce919f 100644 --- a/slack/slack_api.py +++ b/slack/slack_api.py @@ -14,6 +14,8 @@ from slack.util import chunked if TYPE_CHECKING: from slack_api.slack_bots_info import SlackBotInfoResponse, SlackBotsInfoResponse + from slack_api.slack_client_counts import SlackClientCountsResponse + from slack_api.slack_client_userboot import SlackClientUserbootResponse from slack_api.slack_common import SlackGenericResponse from slack_api.slack_conversations_history import SlackConversationsHistoryResponse from slack_api.slack_conversations_info import SlackConversationsInfoResponse @@ -56,10 +58,7 @@ class SlackEdgeApi(SlackApiCommon): return self.workspace.token_type == "session" async def _fetch_edgeapi(self, method: str, params: EdgeParams = {}): - enterprise_id_part = ( - f"{self.workspace.enterprise_id}/" if self.workspace.enterprise_id else "" - ) - url = f"https://edgeapi.slack.com/cache/{enterprise_id_part}{self.workspace.id}/{method}" + url = f"https://edgeapi.slack.com/cache/{self.workspace.id}/{method}" options = self._get_request_options() options["postfields"] = json.dumps(params) options["httpheader"] += "\nContent-Type: application/json" @@ -284,6 +283,20 @@ class SlackApi(SlackApiCommon): raise SlackApiError(self.workspace, method, response) return response + async def fetch_client_userboot(self): + method = "client.userBoot" + response: SlackClientUserbootResponse = await self._fetch(method) + if response["ok"] is False: + raise SlackApiError(self.workspace, method, response) + return response + + async def fetch_client_counts(self): + method = "client.counts" + response: SlackClientCountsResponse = await self._fetch(method) + if response["ok"] is False: + raise SlackApiError(self.workspace, method, response) + return response + async def conversations_close(self, conversation: SlackConversation): method = "conversations.close" params: Params = {"channel": conversation.id} diff --git a/slack/slack_conversation.py b/slack/slack_conversation.py index 642ed82..e39428f 100644 --- a/slack/slack_conversation.py +++ b/slack/slack_conversation.py @@ -35,7 +35,9 @@ from slack.task import Task, gather, run_async from slack.util import unhtmlescape, with_color if TYPE_CHECKING: - from slack_api.slack_conversations_info import SlackConversationsInfo + from slack_api.slack_client_userboot import SlackClientUserbootIm + from slack_api.slack_conversations_info import SlackConversationsInfo, SlackTopic + from slack_api.slack_users_conversations import SlackUsersConversationsNotIm from slack_rtm.slack_rtm_message import ( SlackMessageChanged, SlackMessageDeleted, @@ -48,6 +50,10 @@ if TYPE_CHECKING: from slack.slack_workspace import SlackWorkspace + SlackConversationsInfoInternal = Union[ + SlackConversationsInfo, SlackUsersConversationsNotIm, SlackClientUserbootIm + ] + def update_buffer_props(): for workspace in shared.workspaces.values(): @@ -133,7 +139,7 @@ class SlackConversation(SlackBuffer): async def __new__( cls, workspace: SlackWorkspace, - info: SlackConversationsInfo, + info: SlackConversationsInfoInternal, ): conversation = super().__new__(cls) conversation.__init__(workspace, info) @@ -142,7 +148,7 @@ class SlackConversation(SlackBuffer): def __init__( self, workspace: SlackWorkspace, - info: SlackConversationsInfo, + info: SlackConversationsInfoInternal, ): super().__init__() self._workspace = workspace @@ -155,11 +161,27 @@ class SlackConversation(SlackBuffer): self.nicklist_needs_refresh = True self.message_hashes = SlackConversationMessageHashes(self) + self._last_read = ( + SlackTs(self._info["last_read"]) + if "last_read" in self._info + else SlackTs("0.0") + ) + + self._topic: SlackTopic = ( + self._info["topic"] + if "topic" in self._info + else {"value": "", "creator": "", "last_set": 0} + ) + async def __init_async(self): if self._info["is_im"] is True: self._im_user = await self._workspace.users[self._info["user"]] elif self.type == "mpim": - members = await self.load_members(load_all=True) + if "members" in self._info: + members = self._info["members"] + else: + members = await self.load_members(load_all=True) + self._mpim_users = await gather( *( self._workspace.users[user_id] @@ -225,11 +247,11 @@ class SlackConversation(SlackBuffer): @property def last_read(self) -> SlackTs: - return SlackTs(self._info["last_read"]) + return self._last_read @last_read.setter def last_read(self, value: SlackTs): - self._info["last_read"] = value + self._last_read = value self.set_unread_and_hotlist() @property @@ -295,16 +317,14 @@ class SlackConversation(SlackBuffer): def buffer_title(self) -> str: # TODO: unfurl and apply styles - topic = unhtmlescape(self._info.get("topic", {}).get("value", "")) + topic = unhtmlescape(self._topic["value"]) if self._im_user: status = f"{self._im_user.status_emoji} {self._im_user.status_text}".strip() return " | ".join(part for part in [status, topic] if part) return topic def set_topic(self, title: str): - if "topic" not in self._info: - self._info["topic"] = {"value": "", "creator": "", "last_set": 0} - self._info["topic"]["value"] = title + self._topic["value"] = title self.update_buffer_props() def get_name_and_buffer_props(self) -> Tuple[str, Dict[str, str]]: @@ -524,12 +544,18 @@ class SlackConversation(SlackBuffer): async def nicklist_update(self): if self.nicklist_needs_refresh and self.type != "im": self.nicklist_needs_refresh = False - members = await self.load_members() - users = await gather( - *(self.workspace.users[user_id] for user_id in members) - ) - for user in users: - self.nicklist_add_user(user) + try: + members = await self.load_members() + except SlackApiError as e: + if e.response["error"] == "enterprise_is_restricted": + return + raise e + else: + users = await gather( + *(self.workspace.users[user_id] for user_id in members) + ) + for user in users: + self.nicklist_add_user(user) def nicklist_add_user( self, user: Optional[Union[SlackUser, SlackBot]], nick: Optional[str] = None diff --git a/slack/slack_workspace.py b/slack/slack_workspace.py index 762009a..9348e1a 100644 --- a/slack/slack_workspace.py +++ b/slack/slack_workspace.py @@ -5,7 +5,18 @@ import socket import ssl import time from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Dict, Generic, Iterable, Optional, Set, Type, TypeVar +from typing import ( + TYPE_CHECKING, + Dict, + Generic, + Iterable, + List, + Mapping, + Optional, + Set, + Type, + TypeVar, +) import weechat from websocket import ( @@ -35,15 +46,16 @@ from slack.util import get_callback_name if TYPE_CHECKING: from slack_api.slack_bots_info import SlackBotInfo - from slack_api.slack_conversations_info import SlackConversationsInfo from slack_api.slack_usergroups_info import SlackUsergroupInfo from slack_api.slack_users_conversations import SlackUsersConversations from slack_api.slack_users_info import SlackUserInfo from slack_rtm.slack_rtm_message import SlackRtmMessage from typing_extensions import Literal + + from slack.slack_conversation import SlackConversationsInfoInternal else: SlackBotInfo = object - SlackConversationsInfo = object + SlackConversationsInfoInternal = object SlackUsergroupInfo = object SlackUserInfo = object @@ -52,7 +64,7 @@ SlackItemClass = TypeVar( ) SlackItemInfo = TypeVar( "SlackItemInfo", - SlackConversationsInfo, + SlackConversationsInfoInternal, SlackUserInfo, SlackBotInfo, SlackUsergroupInfo, @@ -71,19 +83,37 @@ class SlackItem( self[key] = create_task(self._create_item(key)) return self[key] - def initialize_items(self, item_ids: Iterable[str]): + def initialize_items( + self, + item_ids: Iterable[str], + items_info_prefetched: Optional[Mapping[str, SlackItemInfo]] = None, + ): item_ids_to_init = set(item_id for item_id in item_ids if item_id not in self) if item_ids_to_init: - items_info_task = create_task(self._fetch_items_info(item_ids_to_init)) + item_ids_to_fetch = ( + set( + item_id + for item_id in item_ids_to_init + if item_id not in items_info_prefetched + ) + if items_info_prefetched + else item_ids_to_init + ) + items_info_task = create_task(self._fetch_items_info(item_ids_to_fetch)) for item_id in item_ids_to_init: - self[item_id] = create_task(self._create_item(item_id, items_info_task)) + self[item_id] = create_task( + self._create_item(item_id, items_info_task, items_info_prefetched) + ) async def _create_item( self, item_id: str, items_info_task: Optional[Future[Dict[str, SlackItemInfo]]] = None, + items_info_prefetched: Optional[Mapping[str, SlackItemInfo]] = None, ) -> SlackItemClass: - if items_info_task: + if items_info_prefetched and item_id in items_info_prefetched: + return await self._create_item_from_info(items_info_prefetched[item_id]) + elif items_info_task: items_info = await items_info_task item = items_info.get(item_id) if item is None: @@ -103,13 +133,13 @@ class SlackItem( raise NotImplementedError() -class SlackConversations(SlackItem[SlackConversation, SlackConversationsInfo]): +class SlackConversations(SlackItem[SlackConversation, SlackConversationsInfoInternal]): def __init__(self, workspace: SlackWorkspace): super().__init__(workspace, SlackConversation) async def _fetch_items_info( self, item_ids: Iterable[str] - ) -> Dict[str, SlackConversationsInfo]: + ) -> Dict[str, SlackConversationsInfoInternal]: responses = await gather( *( self.workspace.api.fetch_conversations_info(item_id) @@ -121,7 +151,7 @@ class SlackConversations(SlackItem[SlackConversation, SlackConversationsInfo]): } async def _create_item_from_info( - self, item_info: SlackConversationsInfo + self, item_info: SlackConversationsInfoInternal ) -> SlackConversation: return await self._item_class(self.workspace, item_info) @@ -224,17 +254,9 @@ class SlackWorkspace: weechat.prnt("", f"Connected to workspace {self.name}") self._connect_task = None - async def _connect(self) -> None: - try: - rtm_connect = await self.api.fetch_rtm_connect() - except SlackApiError as e: - print_error( - f'failed connecting to workspace "{self.name}": {e.response["error"]}' - ) - return - + async def _connect_oauth(self) -> List[SlackConversation]: + rtm_connect = await self.api.fetch_rtm_connect() self.id = rtm_connect["team"]["id"] - self.enterprise_id = rtm_connect["team"].get("enterprise_id") self.my_user = await self.users[rtm_connect["self"]["id"]] await self._connect_ws(rtm_connect["url"]) @@ -242,16 +264,6 @@ class SlackWorkspace: prefs = await self.api.fetch_users_get_prefs("muted_channels") self.muted_channels = set(prefs["prefs"]["muted_channels"].split(",")) - custom_emojis_response = await self.api.fetch_emoji_list() - self.custom_emojis = custom_emojis_response["emoji"] - - if not self.api.edgeapi.is_available: - usergroups = await self.api.fetch_usergroups_list() - for usergroup in usergroups["usergroups"]: - future = Future[SlackUsergroup]() - future.set_result(SlackUsergroup(self, usergroup)) - self.usergroups[usergroup["id"]] = future - users_conversations_response = await self.api.fetch_users_conversations( "public_channel,private_channel,mpim,im" ) @@ -264,6 +276,77 @@ class SlackWorkspace: conversations_to_open = [ c for c in conversations_if_should_open if c is not None ] + return conversations_to_open + + async def _connect_session(self) -> List[SlackConversation]: + user_boot_task = create_task(self.api.fetch_client_userboot()) + client_counts_task = create_task(self.api.fetch_client_counts()) + user_boot = await user_boot_task + client_counts = await client_counts_task + + self.id = user_boot["team"]["id"] + my_user_id = user_boot["self"]["id"] + # self.users.initialize_items(my_user_id, {my_user_id: user_boot["self"]}) + self.my_user = await self.users[my_user_id] + self.muted_channels = set(user_boot["prefs"]["muted_channels"].split(",")) + + await self._connect_ws( + f"wss://wss-primary.slack.com/?token={self.config.api_token.value}&batch_presence_aware=1" + ) + + conversation_counts = ( + client_counts["channels"] + client_counts["mpims"] + client_counts["ims"] + ) + + conversation_ids = set( + [ + channel["id"] + for channel in user_boot["channels"] + if not channel["is_mpim"] + ] + + user_boot["is_open"] + + [count["id"] for count in conversation_counts if count["has_unreads"]] + ) + + channel_infos: Dict[str, SlackConversationsInfoInternal] = { + channel["id"]: channel for channel in user_boot["channels"] + } + self.conversations.initialize_items(conversation_ids, channel_infos) + conversations = { + conversation_id: await self.conversations[conversation_id] + for conversation_id in conversation_ids + } + + for conversation_count in conversation_counts: + if conversation_count["id"] in conversations: + conversation = conversations[conversation_count["id"]] + # TODO: Update without moving unread marker to the bottom + if conversation.last_read == SlackTs("0.0"): + conversation.last_read = SlackTs(conversation_count["last_read"]) + + return list(conversations.values()) + + async def _connect(self) -> None: + try: + if self.token_type == "session": + conversations_to_open = await self._connect_session() + else: + conversations_to_open = await self._connect_oauth() + except SlackApiError as e: + print_error( + f'failed connecting to workspace "{self.name}": {e.response["error"]}' + ) + return + + custom_emojis_response = await self.api.fetch_emoji_list() + self.custom_emojis = custom_emojis_response["emoji"] + + if not self.api.edgeapi.is_available: + usergroups = await self.api.fetch_usergroups_list() + for usergroup in usergroups["usergroups"]: + future = Future[SlackUsergroup]() + future.set_result(SlackUsergroup(self, usergroup)) + self.usergroups[usergroup["id"]] = future for conversation in sorted( conversations_to_open, key=lambda conversation: conversation.sort_key() -- cgit