diff --git a/rustplus/api/base_rust_api.py b/rustplus/api/base_rust_api.py index 57444c9..b48f90e 100644 --- a/rustplus/api/base_rust_api.py +++ b/rustplus/api/base_rust_api.py @@ -1,5 +1,5 @@ import asyncio -from typing import List, Callable, Union +from typing import List, Union, Coroutine, Callable, Dict, Tuple from PIL import Image from .remote.events.event_loop_manager import EventLoopManager @@ -113,8 +113,10 @@ async def connect( self, retries: int = float("inf"), delay: int = 20, - on_failure=None, - on_success=None, + on_failure: Union[Coroutine, Callable[[], None], None] = None, + on_success: Union[Coroutine, Callable[[], None], None] = None, + on_success_args_kwargs: Tuple[List, Dict] = ([], {}), + on_failure_args_kwargs: Tuple[List, Dict] = ([], {}), ) -> None: """ Attempts to open a connection to the rust game server specified in the constructor @@ -123,7 +125,10 @@ async def connect( :param delay: The delay (in seconds) between reconnection attempts. :param on_failure: Optional function to be called when connecting fails. :param on_success: Optional function to be called when connecting succeeds. - + :param on_success_args_kwargs: Optional tuple holding keyword and regular arguments + for on_success in this format (args, kwargs) + :param on_failure_args_kwargs: Optional tuple holding keyword and regular arguments + for on_failure in this format (args, kwargs) :return: None """ @@ -151,6 +156,8 @@ async def connect( delay=delay, on_failure=on_failure, on_success=on_success, + on_success_args_kwargs=on_success_args_kwargs, + on_failure_args_kwargs=on_failure_args_kwargs, ) await self.heartbeat.start_beat() except ConnectionRefusedError: diff --git a/rustplus/api/remote/rust_remote_interface.py b/rustplus/api/remote/rust_remote_interface.py index 80e3829..0df46ba 100644 --- a/rustplus/api/remote/rust_remote_interface.py +++ b/rustplus/api/remote/rust_remote_interface.py @@ -66,7 +66,15 @@ def __init__( self.pending_entity_subscriptions = [] self.camera_manager: Union[CameraManager, None] = None - async def connect(self, retries, delay, on_failure=None, on_success=None) -> None: + async def connect( + self, + retries, + delay, + on_failure, + on_success, + on_success_args_kwargs, + on_failure_args_kwargs, + ) -> None: self.ws = RustWebsocket( server_id=self.server_id, remote=self, @@ -76,6 +84,8 @@ async def connect(self, retries, delay, on_failure=None, on_success=None) -> Non on_failure=on_failure, on_success=on_success, delay=delay, + on_success_args_kwargs=on_success_args_kwargs, + on_failure_args_kwargs=on_failure_args_kwargs, ) await self.ws.connect(retries=retries) diff --git a/rustplus/api/remote/rustws.py b/rustplus/api/remote/rustws.py index 10e4bbf..703c967 100644 --- a/rustplus/api/remote/rustws.py +++ b/rustplus/api/remote/rustws.py @@ -34,6 +34,8 @@ def __init__( on_failure, on_success, delay, + on_success_args_kwargs, + on_failure_args_kwargs, ): self.connection: Union[WebSocketClientProtocol, None] = None self.task: Union[Task, None] = None @@ -49,6 +51,8 @@ def __init__( self.on_failure = on_failure self.on_success = on_success self.delay = delay + self.on_success_args_kwargs = on_success_args_kwargs + self.on_failure_args_kwargs = on_failure_args_kwargs async def connect( self, retries=float("inf"), ignore_open_value: bool = False @@ -90,9 +94,15 @@ async def connect( if self.on_success is not None: try: if asyncio.iscoroutinefunction(self.on_success): - await self.on_success() + await self.on_success( + *self.on_success_args_kwargs[0], + **self.on_success_args_kwargs[1], + ) else: - self.on_success() + self.on_success( + *self.on_success_args_kwargs[0], + **self.on_success_args_kwargs[1], + ) except Exception as e: self.logger.warning(e) break @@ -105,9 +115,15 @@ async def connect( if self.on_failure is not None: try: if asyncio.iscoroutinefunction(self.on_failure): - val = await self.on_failure() + val = await self.on_failure( + *self.on_failure_args_kwargs[0], + **self.on_failure_args_kwargs[1], + ) else: - val = self.on_failure() + val = self.on_failure( + *self.on_failure_args_kwargs[0], + **self.on_failure_args_kwargs[1], + ) if val is not None: print_error = val