aboutsummaryrefslogtreecommitdiffstats
path: root/slack
diff options
context:
space:
mode:
Diffstat (limited to 'slack')
-rw-r--r--slack/slack_api.py21
-rw-r--r--slack/slack_conversation.py58
-rw-r--r--slack/slack_workspace.py145
3 files changed, 173 insertions, 51 deletions
diff --git a/slack/slack_api.py b/slack/slack_api.py
index fb6952d..7ce919f 100644
--- a/slack/slack_api.py
+++ b/slack/slack_api.py
@@ -14,6 +14,8 @@ from slack.util import chunked
if TYPE_CHECKING:
from slack_api.slack_bots_info import SlackBotInfoResponse, SlackBotsInfoResponse
+ from slack_api.slack_client_counts import SlackClientCountsResponse
+ from slack_api.slack_client_userboot import SlackClientUserbootResponse
from slack_api.slack_common import SlackGenericResponse
from slack_api.slack_conversations_history import SlackConversationsHistoryResponse
from slack_api.slack_conversations_info import SlackConversationsInfoResponse
@@ -56,10 +58,7 @@ class SlackEdgeApi(SlackApiCommon):
return self.workspace.token_type == "session"
async def _fetch_edgeapi(self, method: str, params: EdgeParams = {}):
- enterprise_id_part = (
- f"{self.workspace.enterprise_id}/" if self.workspace.enterprise_id else ""
- )
- url = f"https://edgeapi.slack.com/cache/{enterprise_id_part}{self.workspace.id}/{method}"
+ url = f"https://edgeapi.slack.com/cache/{self.workspace.id}/{method}"
options = self._get_request_options()
options["postfields"] = json.dumps(params)
options["httpheader"] += "\nContent-Type: application/json"
@@ -284,6 +283,20 @@ class SlackApi(SlackApiCommon):
raise SlackApiError(self.workspace, method, response)
return response
+ async def fetch_client_userboot(self):
+ method = "client.userBoot"
+ response: SlackClientUserbootResponse = await self._fetch(method)
+ if response["ok"] is False:
+ raise SlackApiError(self.workspace, method, response)
+ return response
+
+ async def fetch_client_counts(self):
+ method = "client.counts"
+ response: SlackClientCountsResponse = await self._fetch(method)
+ if response["ok"] is False:
+ raise SlackApiError(self.workspace, method, response)
+ return response
+
async def conversations_close(self, conversation: SlackConversation):
method = "conversations.close"
params: Params = {"channel": conversation.id}
diff --git a/slack/slack_conversation.py b/slack/slack_conversation.py
index 642ed82..e39428f 100644
--- a/slack/slack_conversation.py
+++ b/slack/slack_conversation.py
@@ -35,7 +35,9 @@ from slack.task import Task, gather, run_async
from slack.util import unhtmlescape, with_color
if TYPE_CHECKING:
- from slack_api.slack_conversations_info import SlackConversationsInfo
+ from slack_api.slack_client_userboot import SlackClientUserbootIm
+ from slack_api.slack_conversations_info import SlackConversationsInfo, SlackTopic
+ from slack_api.slack_users_conversations import SlackUsersConversationsNotIm
from slack_rtm.slack_rtm_message import (
SlackMessageChanged,
SlackMessageDeleted,
@@ -48,6 +50,10 @@ if TYPE_CHECKING:
from slack.slack_workspace import SlackWorkspace
+ SlackConversationsInfoInternal = Union[
+ SlackConversationsInfo, SlackUsersConversationsNotIm, SlackClientUserbootIm
+ ]
+
def update_buffer_props():
for workspace in shared.workspaces.values():
@@ -133,7 +139,7 @@ class SlackConversation(SlackBuffer):
async def __new__(
cls,
workspace: SlackWorkspace,
- info: SlackConversationsInfo,
+ info: SlackConversationsInfoInternal,
):
conversation = super().__new__(cls)
conversation.__init__(workspace, info)
@@ -142,7 +148,7 @@ class SlackConversation(SlackBuffer):
def __init__(
self,
workspace: SlackWorkspace,
- info: SlackConversationsInfo,
+ info: SlackConversationsInfoInternal,
):
super().__init__()
self._workspace = workspace
@@ -155,11 +161,27 @@ class SlackConversation(SlackBuffer):
self.nicklist_needs_refresh = True
self.message_hashes = SlackConversationMessageHashes(self)
+ self._last_read = (
+ SlackTs(self._info["last_read"])
+ if "last_read" in self._info
+ else SlackTs("0.0")
+ )
+
+ self._topic: SlackTopic = (
+ self._info["topic"]
+ if "topic" in self._info
+ else {"value": "", "creator": "", "last_set": 0}
+ )
+
async def __init_async(self):
if self._info["is_im"] is True:
self._im_user = await self._workspace.users[self._info["user"]]
elif self.type == "mpim":
- members = await self.load_members(load_all=True)
+ if "members" in self._info:
+ members = self._info["members"]
+ else:
+ members = await self.load_members(load_all=True)
+
self._mpim_users = await gather(
*(
self._workspace.users[user_id]
@@ -225,11 +247,11 @@ class SlackConversation(SlackBuffer):
@property
def last_read(self) -> SlackTs:
- return SlackTs(self._info["last_read"])
+ return self._last_read
@last_read.setter
def last_read(self, value: SlackTs):
- self._info["last_read"] = value
+ self._last_read = value
self.set_unread_and_hotlist()
@property
@@ -295,16 +317,14 @@ class SlackConversation(SlackBuffer):
def buffer_title(self) -> str:
# TODO: unfurl and apply styles
- topic = unhtmlescape(self._info.get("topic", {}).get("value", ""))
+ topic = unhtmlescape(self._topic["value"])
if self._im_user:
status = f"{self._im_user.status_emoji} {self._im_user.status_text}".strip()
return " | ".join(part for part in [status, topic] if part)
return topic
def set_topic(self, title: str):
- if "topic" not in self._info:
- self._info["topic"] = {"value": "", "creator": "", "last_set": 0}
- self._info["topic"]["value"] = title
+ self._topic["value"] = title
self.update_buffer_props()
def get_name_and_buffer_props(self) -> Tuple[str, Dict[str, str]]:
@@ -524,12 +544,18 @@ class SlackConversation(SlackBuffer):
async def nicklist_update(self):
if self.nicklist_needs_refresh and self.type != "im":
self.nicklist_needs_refresh = False
- members = await self.load_members()
- users = await gather(
- *(self.workspace.users[user_id] for user_id in members)
- )
- for user in users:
- self.nicklist_add_user(user)
+ try:
+ members = await self.load_members()
+ except SlackApiError as e:
+ if e.response["error"] == "enterprise_is_restricted":
+ return
+ raise e
+ else:
+ users = await gather(
+ *(self.workspace.users[user_id] for user_id in members)
+ )
+ for user in users:
+ self.nicklist_add_user(user)
def nicklist_add_user(
self, user: Optional[Union[SlackUser, SlackBot]], nick: Optional[str] = None
diff --git a/slack/slack_workspace.py b/slack/slack_workspace.py
index 762009a..9348e1a 100644
--- a/slack/slack_workspace.py
+++ b/slack/slack_workspace.py
@@ -5,7 +5,18 @@ import socket
import ssl
import time
from abc import ABC, abstractmethod
-from typing import TYPE_CHECKING, Dict, Generic, Iterable, Optional, Set, Type, TypeVar
+from typing import (
+ TYPE_CHECKING,
+ Dict,
+ Generic,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Set,
+ Type,
+ TypeVar,
+)
import weechat
from websocket import (
@@ -35,15 +46,16 @@ from slack.util import get_callback_name
if TYPE_CHECKING:
from slack_api.slack_bots_info import SlackBotInfo
- from slack_api.slack_conversations_info import SlackConversationsInfo
from slack_api.slack_usergroups_info import SlackUsergroupInfo
from slack_api.slack_users_conversations import SlackUsersConversations
from slack_api.slack_users_info import SlackUserInfo
from slack_rtm.slack_rtm_message import SlackRtmMessage
from typing_extensions import Literal
+
+ from slack.slack_conversation import SlackConversationsInfoInternal
else:
SlackBotInfo = object
- SlackConversationsInfo = object
+ SlackConversationsInfoInternal = object
SlackUsergroupInfo = object
SlackUserInfo = object
@@ -52,7 +64,7 @@ SlackItemClass = TypeVar(
)
SlackItemInfo = TypeVar(
"SlackItemInfo",
- SlackConversationsInfo,
+ SlackConversationsInfoInternal,
SlackUserInfo,
SlackBotInfo,
SlackUsergroupInfo,
@@ -71,19 +83,37 @@ class SlackItem(
self[key] = create_task(self._create_item(key))
return self[key]
- def initialize_items(self, item_ids: Iterable[str]):
+ def initialize_items(
+ self,
+ item_ids: Iterable[str],
+ items_info_prefetched: Optional[Mapping[str, SlackItemInfo]] = None,
+ ):
item_ids_to_init = set(item_id for item_id in item_ids if item_id not in self)
if item_ids_to_init:
- items_info_task = create_task(self._fetch_items_info(item_ids_to_init))
+ item_ids_to_fetch = (
+ set(
+ item_id
+ for item_id in item_ids_to_init
+ if item_id not in items_info_prefetched
+ )
+ if items_info_prefetched
+ else item_ids_to_init
+ )
+ items_info_task = create_task(self._fetch_items_info(item_ids_to_fetch))
for item_id in item_ids_to_init:
- self[item_id] = create_task(self._create_item(item_id, items_info_task))
+ self[item_id] = create_task(
+ self._create_item(item_id, items_info_task, items_info_prefetched)
+ )
async def _create_item(
self,
item_id: str,
items_info_task: Optional[Future[Dict[str, SlackItemInfo]]] = None,
+ items_info_prefetched: Optional[Mapping[str, SlackItemInfo]] = None,
) -> SlackItemClass:
- if items_info_task:
+ if items_info_prefetched and item_id in items_info_prefetched:
+ return await self._create_item_from_info(items_info_prefetched[item_id])
+ elif items_info_task:
items_info = await items_info_task
item = items_info.get(item_id)
if item is None:
@@ -103,13 +133,13 @@ class SlackItem(
raise NotImplementedError()
-class SlackConversations(SlackItem[SlackConversation, SlackConversationsInfo]):
+class SlackConversations(SlackItem[SlackConversation, SlackConversationsInfoInternal]):
def __init__(self, workspace: SlackWorkspace):
super().__init__(workspace, SlackConversation)
async def _fetch_items_info(
self, item_ids: Iterable[str]
- ) -> Dict[str, SlackConversationsInfo]:
+ ) -> Dict[str, SlackConversationsInfoInternal]:
responses = await gather(
*(
self.workspace.api.fetch_conversations_info(item_id)
@@ -121,7 +151,7 @@ class SlackConversations(SlackItem[SlackConversation, SlackConversationsInfo]):
}
async def _create_item_from_info(
- self, item_info: SlackConversationsInfo
+ self, item_info: SlackConversationsInfoInternal
) -> SlackConversation:
return await self._item_class(self.workspace, item_info)
@@ -224,17 +254,9 @@ class SlackWorkspace:
weechat.prnt("", f"Connected to workspace {self.name}")
self._connect_task = None
- async def _connect(self) -> None:
- try:
- rtm_connect = await self.api.fetch_rtm_connect()
- except SlackApiError as e:
- print_error(
- f'failed connecting to workspace "{self.name}": {e.response["error"]}'
- )
- return
-
+ async def _connect_oauth(self) -> List[SlackConversation]:
+ rtm_connect = await self.api.fetch_rtm_connect()
self.id = rtm_connect["team"]["id"]
- self.enterprise_id = rtm_connect["team"].get("enterprise_id")
self.my_user = await self.users[rtm_connect["self"]["id"]]
await self._connect_ws(rtm_connect["url"])
@@ -242,16 +264,6 @@ class SlackWorkspace:
prefs = await self.api.fetch_users_get_prefs("muted_channels")
self.muted_channels = set(prefs["prefs"]["muted_channels"].split(","))
- custom_emojis_response = await self.api.fetch_emoji_list()
- self.custom_emojis = custom_emojis_response["emoji"]
-
- if not self.api.edgeapi.is_available:
- usergroups = await self.api.fetch_usergroups_list()
- for usergroup in usergroups["usergroups"]:
- future = Future[SlackUsergroup]()
- future.set_result(SlackUsergroup(self, usergroup))
- self.usergroups[usergroup["id"]] = future
-
users_conversations_response = await self.api.fetch_users_conversations(
"public_channel,private_channel,mpim,im"
)
@@ -264,6 +276,77 @@ class SlackWorkspace:
conversations_to_open = [
c for c in conversations_if_should_open if c is not None
]
+ return conversations_to_open
+
+ async def _connect_session(self) -> List[SlackConversation]:
+ user_boot_task = create_task(self.api.fetch_client_userboot())
+ client_counts_task = create_task(self.api.fetch_client_counts())
+ user_boot = await user_boot_task
+ client_counts = await client_counts_task
+
+ self.id = user_boot["team"]["id"]
+ my_user_id = user_boot["self"]["id"]
+ # self.users.initialize_items(my_user_id, {my_user_id: user_boot["self"]})
+ self.my_user = await self.users[my_user_id]
+ self.muted_channels = set(user_boot["prefs"]["muted_channels"].split(","))
+
+ await self._connect_ws(
+ f"wss://wss-primary.slack.com/?token={self.config.api_token.value}&batch_presence_aware=1"
+ )
+
+ conversation_counts = (
+ client_counts["channels"] + client_counts["mpims"] + client_counts["ims"]
+ )
+
+ conversation_ids = set(
+ [
+ channel["id"]
+ for channel in user_boot["channels"]
+ if not channel["is_mpim"]
+ ]
+ + user_boot["is_open"]
+ + [count["id"] for count in conversation_counts if count["has_unreads"]]
+ )
+
+ channel_infos: Dict[str, SlackConversationsInfoInternal] = {
+ channel["id"]: channel for channel in user_boot["channels"]
+ }
+ self.conversations.initialize_items(conversation_ids, channel_infos)
+ conversations = {
+ conversation_id: await self.conversations[conversation_id]
+ for conversation_id in conversation_ids
+ }
+
+ for conversation_count in conversation_counts:
+ if conversation_count["id"] in conversations:
+ conversation = conversations[conversation_count["id"]]
+ # TODO: Update without moving unread marker to the bottom
+ if conversation.last_read == SlackTs("0.0"):
+ conversation.last_read = SlackTs(conversation_count["last_read"])
+
+ return list(conversations.values())
+
+ async def _connect(self) -> None:
+ try:
+ if self.token_type == "session":
+ conversations_to_open = await self._connect_session()
+ else:
+ conversations_to_open = await self._connect_oauth()
+ except SlackApiError as e:
+ print_error(
+ f'failed connecting to workspace "{self.name}": {e.response["error"]}'
+ )
+ return
+
+ custom_emojis_response = await self.api.fetch_emoji_list()
+ self.custom_emojis = custom_emojis_response["emoji"]
+
+ if not self.api.edgeapi.is_available:
+ usergroups = await self.api.fetch_usergroups_list()
+ for usergroup in usergroups["usergroups"]:
+ future = Future[SlackUsergroup]()
+ future.set_result(SlackUsergroup(self, usergroup))
+ self.usergroups[usergroup["id"]] = future
for conversation in sorted(
conversations_to_open, key=lambda conversation: conversation.sort_key()