Skip to content

Commit

Permalink
Type Hinting and other work
Browse files Browse the repository at this point in the history
  • Loading branch information
olijeffers0n committed Apr 14, 2023
1 parent 0aba780 commit 22cca9a
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 94 deletions.
63 changes: 26 additions & 37 deletions rustplus/api/remote/camera/camera_manager.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,40 @@
import time
import traceback
from typing import Iterable, Union, List, Coroutine
from typing import Iterable, Union, List, Coroutine, TypeVar, Set

from PIL import Image

from .camera_parser import Parser
from ..events import EventLoopManager, EventHandler
from ..rustplus_proto import AppCameraInput, Vector2, AppEmpty
from ...structures import Vector
from .structures import CameraInfo, LimitedQueue, Entity
from .structures import CameraInfo, Entity, RayPacket

RS = TypeVar("RS", bound="RustSocket")


class CameraManager:
def __init__(self, rust_socket, cam_id, cam_info_message) -> None:
self.rust_socket = rust_socket
self._cam_id = cam_id
self._last_packets: LimitedQueue = LimitedQueue(6)
def __init__(self, rust_socket: RS, cam_id: str, cam_info_message: CameraInfo) -> None:
self.rust_socket: RS = rust_socket
self._cam_id: str = cam_id
self._last_packet: RayPacket = None
self._cam_info_message: CameraInfo = CameraInfo(cam_info_message)
self._open = True
self.parser = Parser(
self._open: bool = True
self.parser: Parser = Parser(
self._cam_info_message.width, self._cam_info_message.height
)
self.time_since_last_subscribe = time.time()
self.frame_callbacks = set()
self.time_since_last_subscribe: float = time.time()
self.frame_callbacks: Set[Coroutine] = set()

def add_packet(self, packet) -> None:
self._last_packets.add(packet)
def add_packet(self, packet: RayPacket) -> None:
self._last_packet = packet

self.parser.handle_camera_ray_data(packet)
self.parser.step()

if len(self.frame_callbacks) == 0:
return

try:
frame = self._create_frame()
except Exception:
traceback.print_exc()
return
frame = self._create_frame()

for callback in self.frame_callbacks:
EventHandler.schedule_event(
Expand All @@ -46,29 +43,27 @@ def add_packet(self, packet) -> None:
frame,
)

def on_frame_received(self, coro) -> Coroutine:
def on_frame_received(self, coro: Coroutine) -> Coroutine:
self.frame_callbacks.add(coro)
return coro

def has_frame_data(self) -> bool:
return len(self._last_packets) > 0
return self._last_packet is not None

def _create_frame(self, render_entities: bool = True, entity_render_distance: float = float("inf"), max_entity_amount: int = float("inf")) -> Union[Image.Image, None]:
if self._last_packets is None:
if self._last_packet is None:
return None

if not self._open:
raise Exception("Camera is closed")

last_packet = self._last_packets.get_last()

return self.parser.render(
render_entities,
last_packet.entities,
last_packet.vertical_fov,
self._last_packet.entities,
self._last_packet.vertical_fov,
self._cam_info_message.far_plane,
entity_render_distance,
max_entity_amount if max_entity_amount is not None else len(last_packet.entities),
max_entity_amount if max_entity_amount is not None else len(self._last_packet.entities),
)

async def get_frame(self, render_entities: bool = True, entity_render_distance: float = float("inf"), max_entity_amount: int = float("inf")) -> Union[Image.Image, None]:
Expand Down Expand Up @@ -123,7 +118,7 @@ async def exit_camera(self) -> None:
self.rust_socket.remote.ignored_responses.append(app_request.seq)

self._open = False
self._last_packets.clear()
self._last_packet = None

async def resubscribe(self) -> None:
await self.rust_socket.remote.subscribe_to_camera(self._cam_id, True)
Expand All @@ -132,22 +127,16 @@ async def resubscribe(self) -> None:
self.rust_socket.remote.camera_manager = self

async def get_entities_in_frame(self) -> List[Entity]:
if self._last_packets is None:
if self._last_packet is None:
return []

if len(self._last_packets) == 0:
return []

return self._last_packets.get_last().entities
return self._last_packet.entities

async def get_distance_from_player(self) -> float:
if self._last_packets is None:
return float("inf")

if len(self._last_packets) == 0:
if self._last_packet is None:
return float("inf")

return self._last_packets.get_last().distance
return self._last_packet.distance

async def get_max_distance(self) -> float:
return self._cam_info_message.far_plane
36 changes: 18 additions & 18 deletions rustplus/api/remote/camera/camera_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from importlib import resources
from math import radians, tan
import random
from typing import Union, Tuple, List, Any
from typing import Union, Tuple, List, Any, Dict
import numpy as np
from scipy.spatial import ConvexHull
from PIL import Image, ImageDraw, ImageFont

from .camera_constants import LOOKUP_CONSTANTS
from .structures import Entity
from .structures import Entity, Vector3

SCIENTIST_COLOUR = "#3098f2"
PLAYER_COLOUR = "#fa2828"
Expand All @@ -17,7 +17,7 @@


class Parser:
def __init__(self, width, height) -> None:
def __init__(self, width: int, height: int) -> None:
self.width = width
self.height = height
self.data_pointer = 0
Expand Down Expand Up @@ -306,7 +306,7 @@ def handle_entity(
cam_fov,
aspect_ratio,
text,
):
) -> None:

entity.size.x = min(entity.size.x, 5)
entity.size.y = min(entity.size.y, 5)
Expand Down Expand Up @@ -461,24 +461,24 @@ def render(


class MathUtils:
VERTEX_CACHE = {}
COLOUR_CACHE = {}
VERTEX_CACHE: Dict[Vector3, np.ndarray] = {}
COLOUR_CACHE: Dict[Tuple[float, float, float, float], Tuple[int, int, int]] = {}

@staticmethod
def camera_matrix(position, rotation):
def camera_matrix(position, rotation) -> np.ndarray:
matrix = np.matmul(
MathUtils.rotation_matrix(rotation), MathUtils.translation_matrix(-position)
)
return np.linalg.inv(matrix)

@staticmethod
def scale_matrix(size):
def scale_matrix(size) -> np.ndarray:
return np.array(
[[size[0], 0, 0, 0], [0, size[1], 0, 0], [0, 0, size[2], 0], [0, 0, 0, 1]]
)

@staticmethod
def rotation_matrix(rotation):
def rotation_matrix(rotation) -> np.ndarray:
pitch = rotation[0]
yaw = rotation[1]
roll = rotation[2]
Expand Down Expand Up @@ -513,7 +513,7 @@ def rotation_matrix(rotation):
return np.matmul(np.matmul(rotation_x, rotation_y), rotation_z)

@staticmethod
def translation_matrix(position):
def translation_matrix(position) -> np.ndarray:
return np.array(
[
[1, 0, 0, position[0]],
Expand All @@ -524,7 +524,7 @@ def translation_matrix(position):
)

@staticmethod
def perspective_matrix(fov, aspect_ratio, near, far):
def perspective_matrix(fov, aspect_ratio, near, far) -> np.ndarray:
f = 1 / tan(radians(fov) / 2)
return np.array(
[
Expand All @@ -536,7 +536,7 @@ def perspective_matrix(fov, aspect_ratio, near, far):
)

@staticmethod
def gift_wrap_algorithm(vertices):
def gift_wrap_algorithm(vertices) -> List[Tuple[float, float]]:
data = np.array(vertices)

# Check that the min and max are not the same
Expand Down Expand Up @@ -588,7 +588,7 @@ def solve_quadratic(a: float, b: float, c: float, larger: bool) -> float:
return (-b - math.sqrt(discriminant)) / (2 * a)

@classmethod
def get_tree_vertices(cls, size):
def get_tree_vertices(cls, size) -> np.ndarray:

if size in cls.VERTEX_CACHE:
return cls.VERTEX_CACHE[size]
Expand All @@ -615,7 +615,7 @@ def get_tree_vertices(cls, size):
return vertices

@classmethod
def get_player_vertices(cls, size):
def get_player_vertices(cls, size) -> np.ndarray:

if size in cls.VERTEX_CACHE:
return cls.VERTEX_CACHE[size]
Expand Down Expand Up @@ -687,7 +687,7 @@ def get_slightly_random_colour(colour: str, entity_id: int) -> str:
return f"#{r}{g}{b}"

@staticmethod
def get_font_size(distance, font_size, near, far, aspect_ratio, fov):
def get_font_size(distance, font_size, near, far, aspect_ratio, fov) -> int:
"""
Given the distance from the screen, uses perspective projection matrix to return a font size for some text.
Expand Down Expand Up @@ -737,7 +737,7 @@ def get_font_size(distance, font_size, near, far, aspect_ratio, fov):
@staticmethod
def set_polygon_with_depth(
vertices, image_data, depth_data, depth, colour, width, height, far_plane
):
) -> None:
if len(vertices) <= 1:
return

Expand All @@ -764,7 +764,7 @@ def set_polygon_with_depth(
image_data[pixels[:, 0], pixels[:, 1]] = colour

@staticmethod
def convert_colour_to_tuple(colour):
def convert_colour_to_tuple(colour) -> Tuple[int, int, int]:
"""
Converts a colour in the form
#RRGGBB
Expand All @@ -774,7 +774,7 @@ def convert_colour_to_tuple(colour):
return int(colour[1:3], 16), int(colour[3:5], 16), int(colour[5:7], 16)

@staticmethod
def get_vertices_in_polygon(vertices, width, height):
def get_vertices_in_polygon(vertices, width, height) -> np.ndarray:
"""
Takes a list of vertices and returns the vertices that are inside the polygon defined by the vertices
vertices is a list of vertices in the form [x, y]
Expand Down
29 changes: 0 additions & 29 deletions rustplus/api/remote/camera/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,32 +70,3 @@ def __str__(self) -> str:
f"RayPacket(vertical_fov={self.vertical_fov}, sample_offset={self.sample_offset}, "
f"ray_data={self.ray_data}, distance={self.distance}, entities={self.entities})"
)


class LimitedQueue:
def __init__(self, length) -> None:
self._length = length
self._queue = []

def add(self, item) -> None:
self._queue.append(item)
if len(self._queue) > self._length:
self._queue.pop(0)

def get(self, index=0) -> Any:
if index >= len(self._queue) or index < 0:
return None

return self._queue[index]

def get_last(self) -> Any:
return self._queue[-1]

def pop(self) -> Any:
return self._queue.pop(0)

def clear(self) -> None:
self._queue.clear()

def __len__(self) -> int:
return len(self._queue)
14 changes: 8 additions & 6 deletions rustplus/api/remote/events/event_handler.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
import asyncio
import logging
from asyncio.futures import Future
from typing import Set
from typing import Set, Coroutine, Any

from ....utils import ServerID
from .events import EntityEvent, TeamEvent, ChatEvent, ProtobufEvent
from .registered_listener import RegisteredListener
from .event_loop_manager import EventLoopManager
from ..rustplus_proto import AppMessage


class EventHandler:
@staticmethod
def schedule_event(loop, coro, arg) -> None:
def schedule_event(loop: asyncio.AbstractEventLoop, coro: Coroutine, arg: Any) -> None:
def callback(inner_future: Future):
if inner_future.exception() is not None:
logging.getLogger("rustplus.py").exception(inner_future.exception())

future: Future = asyncio.run_coroutine_threadsafe(coro(arg), loop)
future.add_done_callback(callback)

def run_entity_event(self, name, app_message, server_id) -> None:
def run_entity_event(self, name: str, app_message: AppMessage, server_id: ServerID) -> None:

handlers: Set[RegisteredListener] = EntityEvent.handlers.get_handlers(
server_id
Expand All @@ -36,7 +38,7 @@ def run_entity_event(self, name, app_message, server_id) -> None:
EntityEvent(app_message, event_type),
)

def run_team_event(self, app_message, server_id) -> None:
def run_team_event(self, app_message: AppMessage, server_id: ServerID) -> None:

handlers: Set[RegisteredListener] = TeamEvent.handlers.get_handlers(server_id)
for handler in handlers.copy():
Expand All @@ -46,7 +48,7 @@ def run_team_event(self, app_message, server_id) -> None:
EventLoopManager.get_loop(server_id), coro, TeamEvent(app_message)
)

def run_chat_event(self, app_message, server_id) -> None:
def run_chat_event(self, app_message: AppMessage, server_id: ServerID) -> None:

handlers: Set[RegisteredListener] = ChatEvent.handlers.get_handlers(server_id)
for handler in handlers.copy():
Expand All @@ -56,7 +58,7 @@ def run_chat_event(self, app_message, server_id) -> None:
EventLoopManager.get_loop(server_id), coro, ChatEvent(app_message)
)

def run_proto_event(self, byte_data: bytes, server_id) -> None:
def run_proto_event(self, byte_data: bytes, server_id: ServerID) -> None:

handlers: Set[RegisteredListener] = ProtobufEvent.handlers.get_handlers(
server_id
Expand Down
4 changes: 3 additions & 1 deletion rustplus/api/remote/events/event_loop_manager.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import asyncio
from typing import Dict

from ....utils import ServerID


class EventLoopManager:

_loop = {}
_loop: Dict[ServerID, asyncio.AbstractEventLoop] = {}

@staticmethod
def get_loop(server_id: ServerID) -> asyncio.AbstractEventLoop:
Expand Down
Loading

0 comments on commit 22cca9a

Please sign in to comment.