diff options
Diffstat (limited to 'slack/slack_workspace.py')
-rw-r--r-- | slack/slack_workspace.py | 145 |
1 files changed, 114 insertions, 31 deletions
diff --git a/slack/slack_workspace.py b/slack/slack_workspace.py index 762009a..9348e1a 100644 --- a/slack/slack_workspace.py +++ b/slack/slack_workspace.py @@ -5,7 +5,18 @@ import socket import ssl import time from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Dict, Generic, Iterable, Optional, Set, Type, TypeVar +from typing import ( + TYPE_CHECKING, + Dict, + Generic, + Iterable, + List, + Mapping, + Optional, + Set, + Type, + TypeVar, +) import weechat from websocket import ( @@ -35,15 +46,16 @@ from slack.util import get_callback_name if TYPE_CHECKING: from slack_api.slack_bots_info import SlackBotInfo - from slack_api.slack_conversations_info import SlackConversationsInfo from slack_api.slack_usergroups_info import SlackUsergroupInfo from slack_api.slack_users_conversations import SlackUsersConversations from slack_api.slack_users_info import SlackUserInfo from slack_rtm.slack_rtm_message import SlackRtmMessage from typing_extensions import Literal + + from slack.slack_conversation import SlackConversationsInfoInternal else: SlackBotInfo = object - SlackConversationsInfo = object + SlackConversationsInfoInternal = object SlackUsergroupInfo = object SlackUserInfo = object @@ -52,7 +64,7 @@ SlackItemClass = TypeVar( ) SlackItemInfo = TypeVar( "SlackItemInfo", - SlackConversationsInfo, + SlackConversationsInfoInternal, SlackUserInfo, SlackBotInfo, SlackUsergroupInfo, @@ -71,19 +83,37 @@ class SlackItem( self[key] = create_task(self._create_item(key)) return self[key] - def initialize_items(self, item_ids: Iterable[str]): + def initialize_items( + self, + item_ids: Iterable[str], + items_info_prefetched: Optional[Mapping[str, SlackItemInfo]] = None, + ): item_ids_to_init = set(item_id for item_id in item_ids if item_id not in self) if item_ids_to_init: - items_info_task = create_task(self._fetch_items_info(item_ids_to_init)) + item_ids_to_fetch = ( + set( + item_id + for item_id in item_ids_to_init + if item_id not in items_info_prefetched + ) + if items_info_prefetched + else item_ids_to_init + ) + items_info_task = create_task(self._fetch_items_info(item_ids_to_fetch)) for item_id in item_ids_to_init: - self[item_id] = create_task(self._create_item(item_id, items_info_task)) + self[item_id] = create_task( + self._create_item(item_id, items_info_task, items_info_prefetched) + ) async def _create_item( self, item_id: str, items_info_task: Optional[Future[Dict[str, SlackItemInfo]]] = None, + items_info_prefetched: Optional[Mapping[str, SlackItemInfo]] = None, ) -> SlackItemClass: - if items_info_task: + if items_info_prefetched and item_id in items_info_prefetched: + return await self._create_item_from_info(items_info_prefetched[item_id]) + elif items_info_task: items_info = await items_info_task item = items_info.get(item_id) if item is None: @@ -103,13 +133,13 @@ class SlackItem( raise NotImplementedError() -class SlackConversations(SlackItem[SlackConversation, SlackConversationsInfo]): +class SlackConversations(SlackItem[SlackConversation, SlackConversationsInfoInternal]): def __init__(self, workspace: SlackWorkspace): super().__init__(workspace, SlackConversation) async def _fetch_items_info( self, item_ids: Iterable[str] - ) -> Dict[str, SlackConversationsInfo]: + ) -> Dict[str, SlackConversationsInfoInternal]: responses = await gather( *( self.workspace.api.fetch_conversations_info(item_id) @@ -121,7 +151,7 @@ class SlackConversations(SlackItem[SlackConversation, SlackConversationsInfo]): } async def _create_item_from_info( - self, item_info: SlackConversationsInfo + self, item_info: SlackConversationsInfoInternal ) -> SlackConversation: return await self._item_class(self.workspace, item_info) @@ -224,17 +254,9 @@ class SlackWorkspace: weechat.prnt("", f"Connected to workspace {self.name}") self._connect_task = None - async def _connect(self) -> None: - try: - rtm_connect = await self.api.fetch_rtm_connect() - except SlackApiError as e: - print_error( - f'failed connecting to workspace "{self.name}": {e.response["error"]}' - ) - return - + async def _connect_oauth(self) -> List[SlackConversation]: + rtm_connect = await self.api.fetch_rtm_connect() self.id = rtm_connect["team"]["id"] - self.enterprise_id = rtm_connect["team"].get("enterprise_id") self.my_user = await self.users[rtm_connect["self"]["id"]] await self._connect_ws(rtm_connect["url"]) @@ -242,16 +264,6 @@ class SlackWorkspace: prefs = await self.api.fetch_users_get_prefs("muted_channels") self.muted_channels = set(prefs["prefs"]["muted_channels"].split(",")) - custom_emojis_response = await self.api.fetch_emoji_list() - self.custom_emojis = custom_emojis_response["emoji"] - - if not self.api.edgeapi.is_available: - usergroups = await self.api.fetch_usergroups_list() - for usergroup in usergroups["usergroups"]: - future = Future[SlackUsergroup]() - future.set_result(SlackUsergroup(self, usergroup)) - self.usergroups[usergroup["id"]] = future - users_conversations_response = await self.api.fetch_users_conversations( "public_channel,private_channel,mpim,im" ) @@ -264,6 +276,77 @@ class SlackWorkspace: conversations_to_open = [ c for c in conversations_if_should_open if c is not None ] + return conversations_to_open + + async def _connect_session(self) -> List[SlackConversation]: + user_boot_task = create_task(self.api.fetch_client_userboot()) + client_counts_task = create_task(self.api.fetch_client_counts()) + user_boot = await user_boot_task + client_counts = await client_counts_task + + self.id = user_boot["team"]["id"] + my_user_id = user_boot["self"]["id"] + # self.users.initialize_items(my_user_id, {my_user_id: user_boot["self"]}) + self.my_user = await self.users[my_user_id] + self.muted_channels = set(user_boot["prefs"]["muted_channels"].split(",")) + + await self._connect_ws( + f"wss://wss-primary.slack.com/?token={self.config.api_token.value}&batch_presence_aware=1" + ) + + conversation_counts = ( + client_counts["channels"] + client_counts["mpims"] + client_counts["ims"] + ) + + conversation_ids = set( + [ + channel["id"] + for channel in user_boot["channels"] + if not channel["is_mpim"] + ] + + user_boot["is_open"] + + [count["id"] for count in conversation_counts if count["has_unreads"]] + ) + + channel_infos: Dict[str, SlackConversationsInfoInternal] = { + channel["id"]: channel for channel in user_boot["channels"] + } + self.conversations.initialize_items(conversation_ids, channel_infos) + conversations = { + conversation_id: await self.conversations[conversation_id] + for conversation_id in conversation_ids + } + + for conversation_count in conversation_counts: + if conversation_count["id"] in conversations: + conversation = conversations[conversation_count["id"]] + # TODO: Update without moving unread marker to the bottom + if conversation.last_read == SlackTs("0.0"): + conversation.last_read = SlackTs(conversation_count["last_read"]) + + return list(conversations.values()) + + async def _connect(self) -> None: + try: + if self.token_type == "session": + conversations_to_open = await self._connect_session() + else: + conversations_to_open = await self._connect_oauth() + except SlackApiError as e: + print_error( + f'failed connecting to workspace "{self.name}": {e.response["error"]}' + ) + return + + custom_emojis_response = await self.api.fetch_emoji_list() + self.custom_emojis = custom_emojis_response["emoji"] + + if not self.api.edgeapi.is_available: + usergroups = await self.api.fetch_usergroups_list() + for usergroup in usergroups["usergroups"]: + future = Future[SlackUsergroup]() + future.set_result(SlackUsergroup(self, usergroup)) + self.usergroups[usergroup["id"]] = future for conversation in sorted( conversations_to_open, key=lambda conversation: conversation.sort_key() |