from __future__ import annotations import json import socket import ssl import time from abc import ABC, abstractmethod from typing import ( TYPE_CHECKING, Dict, Generic, Iterable, List, Mapping, Optional, Set, Type, TypeVar, ) import weechat from websocket import ( ABNF, WebSocket, WebSocketConnectionClosedException, create_connection, ) from slack.error import ( SlackApiError, SlackError, SlackRtmError, store_and_format_exception, ) from slack.log import DebugMessageType, LogLevel, log, print_error from slack.proxy import Proxy from slack.shared import shared from slack.slack_api import SlackApi from slack.slack_buffer import SlackBuffer from slack.slack_conversation import SlackConversation from slack.slack_message import SlackMessage, SlackTs from slack.slack_thread import SlackThread from slack.slack_user import SlackBot, SlackUser, SlackUsergroup from slack.task import Future, Task, create_task, gather, run_async from slack.util import get_callback_name, get_cookies if TYPE_CHECKING: from slack_api.slack_bots_info import SlackBotInfo from slack_api.slack_usergroups_info import SlackUsergroupInfo from slack_api.slack_users_conversations import SlackUsersConversations from slack_api.slack_users_info import SlackUserInfo from slack_rtm.slack_rtm_message import SlackRtmMessage from typing_extensions import Literal from slack.slack_conversation import SlackConversationsInfoInternal else: SlackBotInfo = object SlackConversationsInfoInternal = object SlackUsergroupInfo = object SlackUserInfo = object SlackItemClass = TypeVar( "SlackItemClass", SlackConversation, SlackUser, SlackBot, SlackUsergroup ) SlackItemInfo = TypeVar( "SlackItemInfo", SlackConversationsInfoInternal, SlackUserInfo, SlackBotInfo, SlackUsergroupInfo, ) class SlackItem( ABC, Generic[SlackItemClass, SlackItemInfo], Dict[str, Future[SlackItemClass]] ): def __init__(self, workspace: SlackWorkspace, item_class: Type[SlackItemClass]): super().__init__() self.workspace = workspace self._item_class = item_class def __missing__(self, key: str): self[key] = create_task(self._create_item(key)) return self[key] def initialize_items( self, item_ids: Iterable[str], items_info_prefetched: Optional[Mapping[str, SlackItemInfo]] = None, ): item_ids_to_init = set(item_id for item_id in item_ids if item_id not in self) if item_ids_to_init: item_ids_to_fetch = ( set( item_id for item_id in item_ids_to_init if item_id not in items_info_prefetched ) if items_info_prefetched else item_ids_to_init ) items_info_task = create_task(self._fetch_items_info(item_ids_to_fetch)) for item_id in item_ids_to_init: self[item_id] = create_task( self._create_item(item_id, items_info_task, items_info_prefetched) ) async def _create_item( self, item_id: str, items_info_task: Optional[Future[Dict[str, SlackItemInfo]]] = None, items_info_prefetched: Optional[Mapping[str, SlackItemInfo]] = None, ) -> SlackItemClass: if items_info_prefetched and item_id in items_info_prefetched: return await self._create_item_from_info(items_info_prefetched[item_id]) elif items_info_task: items_info = await items_info_task item = items_info.get(item_id) if item is None: raise SlackError(self.workspace, "item_not_found") return await self._create_item_from_info(item) else: return await self._item_class.create(self.workspace, item_id) @abstractmethod async def _fetch_items_info( self, item_ids: Iterable[str] ) -> Dict[str, SlackItemInfo]: raise NotImplementedError() @abstractmethod async def _create_item_from_info(self, item_info: SlackItemInfo) -> SlackItemClass: raise NotImplementedError() class SlackConversations(SlackItem[SlackConversation, SlackConversationsInfoInternal]): def __init__(self, workspace: SlackWorkspace): super().__init__(workspace, SlackConversation) async def _fetch_items_info( self, item_ids: Iterable[str] ) -> Dict[str, SlackConversationsInfoInternal]: responses = await gather( *( self.workspace.api.fetch_conversations_info(item_id) for item_id in item_ids ) ) return { response["channel"]["id"]: response["channel"] for response in responses } async def _create_item_from_info( self, item_info: SlackConversationsInfoInternal ) -> SlackConversation: return await self._item_class(self.workspace, item_info) class SlackUsers(SlackItem[SlackUser, SlackUserInfo]): def __init__(self, workspace: SlackWorkspace): super().__init__(workspace, SlackUser) async def _fetch_items_info( self, item_ids: Iterable[str] ) -> Dict[str, SlackUserInfo]: response = await self.workspace.api.fetch_users_info(item_ids) return {info["id"]: info for info in response["users"]} async def _create_item_from_info(self, item_info: SlackUserInfo) -> SlackUser: return self._item_class(self.workspace, item_info) class SlackBots(SlackItem[SlackBot, SlackBotInfo]): def __init__(self, workspace: SlackWorkspace): super().__init__(workspace, SlackBot) async def _fetch_items_info( self, item_ids: Iterable[str] ) -> Dict[str, SlackBotInfo]: response = await self.workspace.api.fetch_bots_info(item_ids) return {info["id"]: info for info in response["bots"]} async def _create_item_from_info(self, item_info: SlackBotInfo) -> SlackBot: return self._item_class(self.workspace, item_info) class SlackUsergroups(SlackItem[SlackUsergroup, SlackUsergroupInfo]): def __init__(self, workspace: SlackWorkspace): super().__init__(workspace, SlackUsergroup) async def _fetch_items_info( self, item_ids: Iterable[str] ) -> Dict[str, SlackUsergroupInfo]: response = await self.workspace.api.edgeapi.fetch_usergroups_info( list(item_ids) ) return {info["id"]: info for info in response["results"]} async def _create_item_from_info( self, item_info: SlackUsergroupInfo ) -> SlackUsergroup: return self._item_class(self.workspace, item_info) class SlackWorkspace: def __init__(self, name: str): self.name = name self.config = shared.config.create_workspace_config(self.name) self.api = SlackApi(self) self._is_connected = False self._connect_task: Optional[Task[None]] = None self._ws: Optional[WebSocket] = None self._hook_ws_fd: Optional[str] = None self._debug_ws_buffer_pointer: Optional[str] = None self.conversations = SlackConversations(self) self.open_conversations: Dict[str, SlackConversation] = {} self.users = SlackUsers(self) self.bots = SlackBots(self) self.usergroups = SlackUsergroups(self) self.muted_channels: Set[str] = set() self.custom_emojis: Dict[str, str] = {} self.max_users_per_fetch_request = 512 def __repr__(self): return f"{self.__class__.__name__}({self.name})" @property def token_type(self) -> Literal["oauth", "session", "unknown"]: if self.config.api_token.value.startswith("xoxp-"): return "oauth" elif self.config.api_token.value.startswith("xoxc-"): return "session" else: return "unknown" @property def team_is_org_level(self) -> bool: return self.id.startswith("E") @property def is_connected(self): return self._is_connected @property def is_connecting(self): return self._connect_task is not None @is_connected.setter def is_connected(self, value: bool): self._is_connected = value weechat.bar_item_update("input_text") async def connect(self) -> None: if self.is_connected: return weechat.prnt("", f"Connecting to workspace {self.name}") self._connect_task = create_task(self._connect()) await self._connect_task weechat.prnt("", f"Connected to workspace {self.name}") self._connect_task = None async def _connect_oauth(self) -> List[SlackConversation]: rtm_connect = await self.api.fetch_rtm_connect() self.id = rtm_connect["team"]["id"] self.enterprise_id = rtm_connect["team"].get("enterprise_id") self.domain = rtm_connect["team"]["domain"] self.my_user = await self.users[rtm_connect["self"]["id"]] await self._connect_ws(rtm_connect["url"]) prefs = await self.api.fetch_users_get_prefs("muted_channels") self.muted_channels = set(prefs["prefs"]["muted_channels"].split(",")) users_conversations_response = await self.api.fetch_users_conversations( "public_channel,private_channel,mpim,im" ) channels = users_conversations_response["channels"] self.conversations.initialize_items(channel["id"] for channel in channels) conversations_if_should_open = await gather( *(self._conversation_if_should_open(channel) for channel in channels) ) conversations_to_open = [ c for c in conversations_if_should_open if c is not None ] return conversations_to_open async def _connect_session(self) -> List[SlackConversation]: team_info_task = create_task(self.api.fetch_team_info()) user_boot_task = create_task(self.api.fetch_client_userboot()) client_counts_task = create_task(self.api.fetch_client_counts()) team_info = await team_info_task user_boot = await user_boot_task client_counts = await client_counts_task self.id = team_info["team"]["id"] self.enterprise_id = ( self.id if self.team_is_org_level else team_info["team"]["enterprise_id"] if "enterprise_id" in team_info["team"] else None ) self.domain = team_info["team"]["domain"] my_user_id = user_boot["self"]["id"] # self.users.initialize_items(my_user_id, {my_user_id: user_boot["self"]}) self.my_user = await self.users[my_user_id] self.muted_channels = set(user_boot["prefs"]["muted_channels"].split(",")) await self._connect_ws( f"wss://wss-primary.slack.com/?token={self.config.api_token.value}&slack_client=desktop&batch_presence_aware=1" ) conversation_counts = ( client_counts["channels"] + client_counts["mpims"] + client_counts["ims"] ) conversation_ids = set( [ channel["id"] for channel in user_boot["channels"] if not channel["is_mpim"] and ( self.team_is_org_level or "internal_team_ids" not in channel or self.id in channel["internal_team_ids"] ) ] + user_boot["is_open"] + [count["id"] for count in conversation_counts if count["has_unreads"]] ) conversation_counts_ids = set(count["id"] for count in conversation_counts) if not conversation_ids.issubset(conversation_counts_ids): raise SlackError( self, "Unexpectedly missing some conversations in client.counts", { "conversation_ids": list(conversation_ids), "conversation_counts_ids": list(conversation_counts_ids), }, ) channel_infos: Dict[str, SlackConversationsInfoInternal] = { channel["id"]: channel for channel in user_boot["channels"] } self.conversations.initialize_items(conversation_ids, channel_infos) conversations = { conversation_id: await self.conversations[conversation_id] for conversation_id in conversation_ids } for conversation_count in conversation_counts: if conversation_count["id"] in conversations: conversation = conversations[conversation_count["id"]] # TODO: Update without moving unread marker to the bottom if conversation.last_read == SlackTs("0.0"): conversation.last_read = SlackTs(conversation_count["last_read"]) return list(conversations.values()) async def _connect(self) -> None: try: if self.token_type == "session": conversations_to_open = await self._connect_session() else: conversations_to_open = await self._connect_oauth() except SlackApiError as e: print_error( f'failed connecting to workspace "{self.name}": {e.response["error"]}' ) return custom_emojis_response = await self.api.fetch_emoji_list() self.custom_emojis = custom_emojis_response["emoji"] if not self.api.edgeapi.is_available: usergroups = await self.api.fetch_usergroups_list() for usergroup in usergroups["usergroups"]: future = Future[SlackUsergroup]() future.set_result(SlackUsergroup(self, usergroup)) self.usergroups[usergroup["id"]] = future for conversation in sorted( conversations_to_open, key=lambda conversation: conversation.sort_key() ): await conversation.open_buffer() await gather( *(slack_buffer.set_hotlist() for slack_buffer in shared.buffers.values()) ) self.is_connected = True async def _conversation_if_should_open(self, info: SlackUsersConversations): conversation = await self.conversations[info["id"]] if not conversation.should_open(): if conversation.type != "im" and conversation.type != "mpim": return if conversation.last_read == SlackTs("0.0"): history = await self.api.fetch_conversations_history(conversation) else: history = await self.api.fetch_conversations_history_after( conversation, conversation.last_read ) if not history["messages"]: return return conversation async def _connect_ws(self, url: str): proxy = Proxy() # TODO: Handle errors self._ws = create_connection( url, self.config.network_timeout.value, cookie=get_cookies(self.config.api_cookies.value), proxy_type=proxy.type, http_proxy_host=proxy.address, http_proxy_port=proxy.port, http_proxy_auth=(proxy.username, proxy.password), http_proxy_timeout=self.config.network_timeout.value, ) self._hook_ws_fd = weechat.hook_fd( self._ws.sock.fileno(), 1, 0, 0, get_callback_name(self._ws_read_cb), "", ) self._ws.sock.setblocking(False) def _ws_read_cb(self, data: str, fd: int) -> int: if self._ws is None: raise SlackError(self, "ws_read_cb called while _ws is None") while True: try: opcode, recv_data = self._ws.recv_data(control_frame=True) except ssl.SSLWantReadError: # No more data to read at this time. return weechat.WEECHAT_RC_OK except (WebSocketConnectionClosedException, socket.error) as e: print("lost connection on receive, reconnecting", e) run_async(self.reconnect()) return weechat.WEECHAT_RC_OK if opcode == ABNF.OPCODE_PONG: # TODO: Maybe record last time anything was received instead self.last_pong_time = time.time() return weechat.WEECHAT_RC_OK elif opcode != ABNF.OPCODE_TEXT: return weechat.WEECHAT_RC_OK run_async(self._ws_recv(json.loads(recv_data.decode()))) async def _ws_recv(self, data: SlackRtmMessage): # TODO: Remove old messages log(LogLevel.DEBUG, DebugMessageType.WEBSOCKET_RECV, json.dumps(data)) try: if data["type"] in [ "hello", "file_public", "file_shared", "file_deleted", "dnd_updated_user", ]: return elif data["type"] == "pref_change": if data["name"] == "muted_channels": new_muted_channels = set(data["value"].split(",")) changed_channels = self.muted_channels ^ new_muted_channels self.muted_channels = new_muted_channels for channel_id in changed_channels: channel = self.open_conversations.get(channel_id) if channel: channel.update_buffer_props() return elif data["type"] == "user_status_changed": user_id = data["user"]["id"] if user_id in self.users: user = await self.users[user_id] user.update_info_json(data["user"]) return elif data["type"] == "user_invalidated": user_id = data["user"]["id"] if user_id in self.users: has_dm_conversation = any( conversation.im_user_id == user_id for conversation in self.open_conversations.values() ) if has_dm_conversation: user = await self.users[user_id] user_info = await self.api.fetch_user_info(user_id) user.update_info_json(user_info["user"]) return elif data["type"] == "channel_joined" or data["type"] == "group_joined": channel_id = data["channel"]["id"] elif data["type"] == "reaction_added" or data["type"] == "reaction_removed": channel_id = data["item"]["channel"] elif ( data["type"] == "thread_marked" or data["type"] == "thread_subscribed" or data["type"] == "thread_unsubscribed" ) and data["subscription"]["type"] == "thread": channel_id = data["subscription"]["channel"] elif data["type"] == "sh_room_join" or data["type"] == "sh_room_update": channel_id = data["huddle"]["channel_id"] elif "channel" in data and isinstance(data["channel"], str): channel_id = data["channel"] else: log( LogLevel.DEBUG, DebugMessageType.LOG, f"unknown websocket message type (without channel): {data.get('type')}", ) return channel = self.open_conversations.get(channel_id) if channel is None: if ( data["type"] == "message" or data["type"] == "im_open" or data["type"] == "mpim_open" or data["type"] == "group_open" or data["type"] == "channel_joined" or data["type"] == "group_joined" ): channel = await self.conversations[channel_id] if channel.type in ["im", "mpim"] or data["type"] in [ "channel_joined", "group_joined", ]: await channel.open_buffer() await channel.set_hotlist() else: log( LogLevel.DEBUG, DebugMessageType.LOG, "received websocket message for not open conversation, discarding", ) return if data["type"] == "message": if "subtype" in data and data["subtype"] == "message_changed": await channel.change_message(data) elif "subtype" in data and data["subtype"] == "message_deleted": await channel.delete_message(data) elif "subtype" in data and data["subtype"] == "message_replied": await channel.change_message(data) else: if "subtype" in data and data["subtype"] == "channel_topic": channel.set_topic(data["topic"]) message = SlackMessage(channel, data) await channel.add_new_message(message) elif ( data["type"] == "im_close" or data["type"] == "mpim_close" or data["type"] == "group_close" or data["type"] == "channel_left" or data["type"] == "group_left" ): if channel.buffer_pointer is not None and channel.is_joined: await channel.close_buffer() elif data["type"] == "reaction_added" and data["item"]["type"] == "message": await channel.reaction_add( SlackTs(data["item"]["ts"]), data["reaction"], data["user"] ) elif ( data["type"] == "reaction_removed" and data["item"]["type"] == "message" ): await channel.reaction_remove( SlackTs(data["item"]["ts"]), data["reaction"], data["user"] ) elif ( data["type"] == "channel_marked" or data["type"] == "group_marked" or data["type"] == "mpim_marked" or data["type"] == "im_marked" ): channel.last_read = SlackTs(data["ts"]) elif ( data["type"] == "thread_marked" and data["subscription"]["type"] == "thread" ): message = channel.messages.get( SlackTs(data["subscription"]["thread_ts"]) ) if message: message.last_read = SlackTs(data["subscription"]["last_read"]) elif ( data["type"] == "thread_subscribed" or data["type"] == "thread_unsubscribed" ) and data["subscription"]["type"] == "thread": message = channel.messages.get( SlackTs(data["subscription"]["thread_ts"]) ) if message: subscribed = data["type"] == "thread_subscribed" await message.update_subscribed(subscribed, data["subscription"]) elif data["type"] == "sh_room_join" or data["type"] == "sh_room_update": await channel.update_message_room(data) elif data["type"] == "user_typing": await channel.typing_add_user(data) else: log( LogLevel.DEBUG, DebugMessageType.LOG, f"unknown websocket message type (with channel): {data.get('type')}", ) except Exception as e: slack_error = SlackRtmError(self, e, data) print_error(store_and_format_exception(slack_error)) def ping(self): if not self.is_connected: raise SlackError(self, "Can't ping when not connected") if self._ws is None: raise SlackError(self, "is_connected is True while _ws is None") try: self._ws.ping() # workspace.last_ping_time = time.time() except (WebSocketConnectionClosedException, socket.error): print("lost connection on ping, reconnecting") run_async(self.reconnect()) def send_typing(self, buffer: SlackBuffer): if not self.is_connected: raise SlackError(self, "Can't send typing when not connected") if self._ws is None: raise SlackError(self, "is_connected is True while _ws is None") msg = { "type": "user_typing", "channel": buffer.conversation.id, } if isinstance(buffer, SlackThread): msg["thread_ts"] = buffer.parent.ts self._ws.send(json.dumps(msg)) async def reconnect(self): self.disconnect() await self.connect() def disconnect(self): self.is_connected = False weechat.prnt("", f"Disconnected from workspace {self.name}") if self._connect_task: self._connect_task.cancel() self._connect_task = None if self._hook_ws_fd: weechat.unhook(self._hook_ws_fd) self._hook_ws_fd = None if self._ws: self._ws.close() self._ws = None