aboutsummaryrefslogblamecommitdiffstats
path: root/slack/task.py
blob: 3c214159d5f0078ed0840dbbff11d6463494b59d (plain) (tree)
1
2
3
4
5
6
7
8
9

                                  
                    
                  

              
             
              
         


              
             


            
             
 

                      
              
 
                                        
                                 

                                        
 
                 
                                               
 


                








                                                                                        
                           
                                                        
                                           
                                                                            
                       


                                                            
 


                                                        





                                                                 
 













                                                       
 
                     


                              


                                    

                                                               
                             

















































                                                                              










                                                       
                                                                  

                                  
 


                                                                                                 


                                                
                              
                   
 
 
                                                     
                                            
                           


                                         


                                
                                                          



                                           







                                                
                                                   
               

                            
            
                                                  
                                  

                                                                   

                  

                                      


                                                       

                 
 

                                                                      



                           











                                                                         


















                                                                                        
                                                     



























                                                                  

 

                                   

                                                                                     
                       
from __future__ import annotations

from typing import (
    TYPE_CHECKING,
    Any,
    Awaitable,
    Callable,
    Coroutine,
    Dict,
    Generator,
    List,
    Optional,
    Sequence,
    Tuple,
    TypeVar,
    Union,
    overload,
)
from uuid import uuid4

import weechat

from slack.error import format_exception
from slack.log import print_error
from slack.shared import shared
from slack.util import get_callback_name

if TYPE_CHECKING:
    from typing_extensions import Literal, Self

T = TypeVar("T")


class CancelledError(Exception):
    pass


class InvalidStateError(Exception):
    pass


# Heavily inspired by https://github.com/python/cpython/blob/3.11/Lib/asyncio/futures.py
class Future(Awaitable[T]):
    def __init__(self, future_id: Optional[str] = None):
        self.id = future_id or str(uuid4())
        self._state: Literal["PENDING", "CANCELLED", "FINISHED"] = "PENDING"
        self._result: T
        self._exception: Optional[BaseException] = None
        self._cancel_message = None
        self._callbacks: List[Callable[[Self], object]] = []

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}('{self.id}')"

    def __await__(self) -> Generator[Future[T], None, T]:
        if not self.done():
            yield self  # This tells Task to wait for completion.
        if not self.done():
            raise RuntimeError("await wasn't used with future")
        return self.result()  # May raise too.

    def _make_cancelled_error(self):
        if self._cancel_message is None:
            return CancelledError()
        else:
            return CancelledError(self._cancel_message)

    def __schedule_callbacks(self):
        callbacks = self._callbacks[:]
        if not callbacks:
            return

        self._callbacks[:] = []
        for callback in callbacks:
            callback(self)

    def result(self):
        exc = self.exception()
        if exc is not None:
            raise exc
        return self._result

    def set_result(self, result: T):
        if self.done():
            raise InvalidStateError(f"{self._state}: {self!r}")
        self._result = result
        self._state = "FINISHED"
        self.__schedule_callbacks()

    def set_exception(self, exception: BaseException):
        if self.done():
            raise InvalidStateError(f"{self._state}: {self!r}")
        if isinstance(exception, type):
            exception = exception()
        if type(exception) is StopIteration:
            raise TypeError(
                "StopIteration interacts badly with generators "
                "and cannot be raised into a Future"
            )
        self._exception = exception
        self._state = "FINISHED"
        self.__schedule_callbacks()

    def done(self):
        return self._state != "PENDING"

    def cancelled(self):
        return self._state == "CANCELLED"

    def add_done_callback(self, callback: Callable[[Self], object]) -> None:
        if self.done():
            callback(self)
        else:
            self._callbacks.append(callback)

    def remove_done_callback(self, callback: Callable[[Self], object]) -> int:
        filtered_callbacks = [cb for cb in self._callbacks if cb != callback]
        removed_count = len(self._callbacks) - len(filtered_callbacks)
        if removed_count:
            self._callbacks[:] = filtered_callbacks
        return removed_count

    def cancel(self, msg: Optional[str] = None):
        if self._state != "PENDING":
            return False
        self._state = "CANCELLED"
        self._cancel_message = msg
        self.__schedule_callbacks()
        return True

    def exception(self):
        if self.cancelled():
            raise self._make_cancelled_error()
        elif not self.done():
            raise InvalidStateError("Exception is not set.")
        return self._exception


class FutureProcess(Future[Tuple[str, int, str, str]]):
    pass


class FutureTimer(Future[Tuple[int]]):
    pass


class Task(Future[T]):
    def __init__(self, coroutine: Coroutine[Future[Any], Any, T]):
        super().__init__()
        self.coroutine = coroutine

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}('{self.id}', coroutine={self.coroutine.__qualname__})"

    def cancel(self, msg: Optional[str] = None):
        if not super().cancel(msg):
            return False
        self.coroutine.close()
        return True


def weechat_task_cb(data: str, *args: object) -> int:
    future = shared.active_futures.pop(data)
    future.set_result(args)
    tasks = shared.active_tasks.pop(data)
    for task in tasks:
        task_runner(task, args)
    return weechat.WEECHAT_RC_OK


def process_ended_task(task: Task[Any], response: object):
    if isinstance(response, BaseException):
        task.set_exception(response)
    else:
        task.set_result(response)
    if task.id in shared.active_tasks:
        tasks = shared.active_tasks.pop(task.id)
        for active_task in tasks:
            task_runner(active_task, response)
    if task.id in shared.active_futures:
        del shared.active_futures[task.id]


def task_runner(task: Task[Any], response: object):
    while True:
        if task.cancelled():
            return
        try:
            future = task.coroutine.send(response)
        except BaseException as e:
            result = e.value if isinstance(e, StopIteration) else e
            process_ended_task(task, result)
            return

        if future.done():
            response = future.result()
        else:
            shared.active_tasks[future.id].append(task)
            shared.active_futures[future.id] = future
            break


def create_task(coroutine: Coroutine[Future[Any], Any, T]) -> Task[T]:
    task = Task(coroutine)
    task_runner(task, None)
    return task


def _async_task_done(task: Task[object]):
    exception = task.exception()
    if exception:
        print_error(f"{task} failed with: {format_exception(exception)}")


def run_async(coroutine: Coroutine[Future[Any], Any, Any]) -> None:
    task = Task(coroutine)
    task.add_done_callback(_async_task_done)
    task_runner(task, None)


@overload
async def gather(
    *requests: Union[Future[T], Coroutine[Any, Any, T]],
    return_exceptions: Literal[False] = False,
) -> List[T]:
    ...


@overload
async def gather(
    *requests: Union[Future[T], Coroutine[Any, Any, T]],
    return_exceptions: Literal[True],
) -> List[Union[T, BaseException]]:
    ...


async def gather(
    *requests: Union[Future[T], Coroutine[Any, Any, T]], return_exceptions: bool = False
) -> Sequence[Union[T, BaseException]]:
    # TODO: Should probably propagate first exception

    tasks_map: Dict[int, Future[T]] = {}
    results_map: Dict[int, Union[T, BaseException]] = {}

    for i, request in enumerate(requests):
        if isinstance(request, Coroutine):
            try:
                tasks_map[i] = create_task(request)
            except BaseException as e:
                results_map[i] = e
        else:
            tasks_map[i] = request

    for i, task in tasks_map.items():
        try:
            # print(f"waiting for {task}")
            results_map[i] = await task
        except BaseException as e:
            results_map[i] = e

    results = [results_map[i] for i in sorted(results_map.keys())]

    if not return_exceptions:
        for result in results:
            if isinstance(result, BaseException):
                raise result

    return results


async def sleep(milliseconds: int):
    future = FutureTimer()
    sleep_ms = milliseconds if milliseconds > 0 else 1
    weechat.hook_timer(sleep_ms, 0, 1, get_callback_name(weechat_task_cb), future.id)
    return await future