aboutsummaryrefslogtreecommitdiffstats
path: root/slack/slack_workspace.py
diff options
context:
space:
mode:
Diffstat (limited to 'slack/slack_workspace.py')
-rw-r--r--slack/slack_workspace.py145
1 files changed, 114 insertions, 31 deletions
diff --git a/slack/slack_workspace.py b/slack/slack_workspace.py
index 762009a..9348e1a 100644
--- a/slack/slack_workspace.py
+++ b/slack/slack_workspace.py
@@ -5,7 +5,18 @@ import socket
import ssl
import time
from abc import ABC, abstractmethod
-from typing import TYPE_CHECKING, Dict, Generic, Iterable, Optional, Set, Type, TypeVar
+from typing import (
+ TYPE_CHECKING,
+ Dict,
+ Generic,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Set,
+ Type,
+ TypeVar,
+)
import weechat
from websocket import (
@@ -35,15 +46,16 @@ from slack.util import get_callback_name
if TYPE_CHECKING:
from slack_api.slack_bots_info import SlackBotInfo
- from slack_api.slack_conversations_info import SlackConversationsInfo
from slack_api.slack_usergroups_info import SlackUsergroupInfo
from slack_api.slack_users_conversations import SlackUsersConversations
from slack_api.slack_users_info import SlackUserInfo
from slack_rtm.slack_rtm_message import SlackRtmMessage
from typing_extensions import Literal
+
+ from slack.slack_conversation import SlackConversationsInfoInternal
else:
SlackBotInfo = object
- SlackConversationsInfo = object
+ SlackConversationsInfoInternal = object
SlackUsergroupInfo = object
SlackUserInfo = object
@@ -52,7 +64,7 @@ SlackItemClass = TypeVar(
)
SlackItemInfo = TypeVar(
"SlackItemInfo",
- SlackConversationsInfo,
+ SlackConversationsInfoInternal,
SlackUserInfo,
SlackBotInfo,
SlackUsergroupInfo,
@@ -71,19 +83,37 @@ class SlackItem(
self[key] = create_task(self._create_item(key))
return self[key]
- def initialize_items(self, item_ids: Iterable[str]):
+ def initialize_items(
+ self,
+ item_ids: Iterable[str],
+ items_info_prefetched: Optional[Mapping[str, SlackItemInfo]] = None,
+ ):
item_ids_to_init = set(item_id for item_id in item_ids if item_id not in self)
if item_ids_to_init:
- items_info_task = create_task(self._fetch_items_info(item_ids_to_init))
+ item_ids_to_fetch = (
+ set(
+ item_id
+ for item_id in item_ids_to_init
+ if item_id not in items_info_prefetched
+ )
+ if items_info_prefetched
+ else item_ids_to_init
+ )
+ items_info_task = create_task(self._fetch_items_info(item_ids_to_fetch))
for item_id in item_ids_to_init:
- self[item_id] = create_task(self._create_item(item_id, items_info_task))
+ self[item_id] = create_task(
+ self._create_item(item_id, items_info_task, items_info_prefetched)
+ )
async def _create_item(
self,
item_id: str,
items_info_task: Optional[Future[Dict[str, SlackItemInfo]]] = None,
+ items_info_prefetched: Optional[Mapping[str, SlackItemInfo]] = None,
) -> SlackItemClass:
- if items_info_task:
+ if items_info_prefetched and item_id in items_info_prefetched:
+ return await self._create_item_from_info(items_info_prefetched[item_id])
+ elif items_info_task:
items_info = await items_info_task
item = items_info.get(item_id)
if item is None:
@@ -103,13 +133,13 @@ class SlackItem(
raise NotImplementedError()
-class SlackConversations(SlackItem[SlackConversation, SlackConversationsInfo]):
+class SlackConversations(SlackItem[SlackConversation, SlackConversationsInfoInternal]):
def __init__(self, workspace: SlackWorkspace):
super().__init__(workspace, SlackConversation)
async def _fetch_items_info(
self, item_ids: Iterable[str]
- ) -> Dict[str, SlackConversationsInfo]:
+ ) -> Dict[str, SlackConversationsInfoInternal]:
responses = await gather(
*(
self.workspace.api.fetch_conversations_info(item_id)
@@ -121,7 +151,7 @@ class SlackConversations(SlackItem[SlackConversation, SlackConversationsInfo]):
}
async def _create_item_from_info(
- self, item_info: SlackConversationsInfo
+ self, item_info: SlackConversationsInfoInternal
) -> SlackConversation:
return await self._item_class(self.workspace, item_info)
@@ -224,17 +254,9 @@ class SlackWorkspace:
weechat.prnt("", f"Connected to workspace {self.name}")
self._connect_task = None
- async def _connect(self) -> None:
- try:
- rtm_connect = await self.api.fetch_rtm_connect()
- except SlackApiError as e:
- print_error(
- f'failed connecting to workspace "{self.name}": {e.response["error"]}'
- )
- return
-
+ async def _connect_oauth(self) -> List[SlackConversation]:
+ rtm_connect = await self.api.fetch_rtm_connect()
self.id = rtm_connect["team"]["id"]
- self.enterprise_id = rtm_connect["team"].get("enterprise_id")
self.my_user = await self.users[rtm_connect["self"]["id"]]
await self._connect_ws(rtm_connect["url"])
@@ -242,16 +264,6 @@ class SlackWorkspace:
prefs = await self.api.fetch_users_get_prefs("muted_channels")
self.muted_channels = set(prefs["prefs"]["muted_channels"].split(","))
- custom_emojis_response = await self.api.fetch_emoji_list()
- self.custom_emojis = custom_emojis_response["emoji"]
-
- if not self.api.edgeapi.is_available:
- usergroups = await self.api.fetch_usergroups_list()
- for usergroup in usergroups["usergroups"]:
- future = Future[SlackUsergroup]()
- future.set_result(SlackUsergroup(self, usergroup))
- self.usergroups[usergroup["id"]] = future
-
users_conversations_response = await self.api.fetch_users_conversations(
"public_channel,private_channel,mpim,im"
)
@@ -264,6 +276,77 @@ class SlackWorkspace:
conversations_to_open = [
c for c in conversations_if_should_open if c is not None
]
+ return conversations_to_open
+
+ async def _connect_session(self) -> List[SlackConversation]:
+ user_boot_task = create_task(self.api.fetch_client_userboot())
+ client_counts_task = create_task(self.api.fetch_client_counts())
+ user_boot = await user_boot_task
+ client_counts = await client_counts_task
+
+ self.id = user_boot["team"]["id"]
+ my_user_id = user_boot["self"]["id"]
+ # self.users.initialize_items(my_user_id, {my_user_id: user_boot["self"]})
+ self.my_user = await self.users[my_user_id]
+ self.muted_channels = set(user_boot["prefs"]["muted_channels"].split(","))
+
+ await self._connect_ws(
+ f"wss://wss-primary.slack.com/?token={self.config.api_token.value}&batch_presence_aware=1"
+ )
+
+ conversation_counts = (
+ client_counts["channels"] + client_counts["mpims"] + client_counts["ims"]
+ )
+
+ conversation_ids = set(
+ [
+ channel["id"]
+ for channel in user_boot["channels"]
+ if not channel["is_mpim"]
+ ]
+ + user_boot["is_open"]
+ + [count["id"] for count in conversation_counts if count["has_unreads"]]
+ )
+
+ channel_infos: Dict[str, SlackConversationsInfoInternal] = {
+ channel["id"]: channel for channel in user_boot["channels"]
+ }
+ self.conversations.initialize_items(conversation_ids, channel_infos)
+ conversations = {
+ conversation_id: await self.conversations[conversation_id]
+ for conversation_id in conversation_ids
+ }
+
+ for conversation_count in conversation_counts:
+ if conversation_count["id"] in conversations:
+ conversation = conversations[conversation_count["id"]]
+ # TODO: Update without moving unread marker to the bottom
+ if conversation.last_read == SlackTs("0.0"):
+ conversation.last_read = SlackTs(conversation_count["last_read"])
+
+ return list(conversations.values())
+
+ async def _connect(self) -> None:
+ try:
+ if self.token_type == "session":
+ conversations_to_open = await self._connect_session()
+ else:
+ conversations_to_open = await self._connect_oauth()
+ except SlackApiError as e:
+ print_error(
+ f'failed connecting to workspace "{self.name}": {e.response["error"]}'
+ )
+ return
+
+ custom_emojis_response = await self.api.fetch_emoji_list()
+ self.custom_emojis = custom_emojis_response["emoji"]
+
+ if not self.api.edgeapi.is_available:
+ usergroups = await self.api.fetch_usergroups_list()
+ for usergroup in usergroups["usergroups"]:
+ future = Future[SlackUsergroup]()
+ future.set_result(SlackUsergroup(self, usergroup))
+ self.usergroups[usergroup["id"]] = future
for conversation in sorted(
conversations_to_open, key=lambda conversation: conversation.sort_key()