From 3dbfeb9d9cff42fb2b9111a56770bf3daa0528a3 Mon Sep 17 00:00:00 2001 From: Abigail Emery Date: Wed, 3 Aug 2022 08:42:13 +0100 Subject: [PATCH] Add cast of message type in command interpreter (#75) * #74 Add cast of message type in command interpreter Co-authored-by: Garry O'Donnell * Add type hint patch to command interpreter tests Co-authored-by: Callum Forrester <29771545+callumforrester@users.noreply.github.com> Co-authored-by: Garry O'Donnell Co-authored-by: Callum Forrester <29771545+callumforrester@users.noreply.github.com> --- .../command/test_command_interpreter.py | 22 ++++++++++++++++++- .../command/command_interpreter.py | 6 ++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/tests/adapters/interpreters/command/test_command_interpreter.py b/tests/adapters/interpreters/command/test_command_interpreter.py index d6f6e49bb..f622938ce 100644 --- a/tests/adapters/interpreters/command/test_command_interpreter.py +++ b/tests/adapters/interpreters/command/test_command_interpreter.py @@ -2,12 +2,16 @@ from typing import Callable, Optional, Sequence import pytest -from mock import AsyncMock, MagicMock +from mock import AsyncMock, MagicMock, patch from tickit.adapters.interpreters.command import CommandInterpreter from tickit.adapters.interpreters.command.command_interpreter import Command from tickit.core.adapter import Adapter +_GET_TYPE_HINTS = ( + "tickit.adapters.interpreters.command.command_interpreter.get_type_hints" +) + @pytest.fixture def command_interpreter(): @@ -33,6 +37,10 @@ def parse(self, data: bytes) -> Optional[Sequence[str]]: return TestCommand +@patch( + _GET_TYPE_HINTS, + lambda _: {"arg1": str, "arg2": str}, +) @pytest.mark.asyncio async def test_command_interpreter_handle_calls_func_with_args( command_interpreter: CommandInterpreter, @@ -51,6 +59,10 @@ async def test_command_interpreter_handle_calls_func_with_args( test_adapter.test_method.assert_awaited_once_with("arg1", "arg2") +@patch( + _GET_TYPE_HINTS, + lambda _: {"arg1": str, "arg2": str}, +) @pytest.mark.asyncio async def test_command_interpreter_handle_returns_iterable_reply( command_interpreter: CommandInterpreter, @@ -70,6 +82,10 @@ async def test_command_interpreter_handle_returns_iterable_reply( assert reply == (await command_interpreter.handle(test_adapter, b"\x01"))[0] +@patch( + _GET_TYPE_HINTS, + lambda _: {"arg1": str, "arg2": str}, +) @pytest.mark.asyncio async def test_command_interpreter_handle_wraps_non_iterable_reply( command_interpreter: CommandInterpreter, @@ -94,6 +110,10 @@ async def test_command_interpreter_handle_wraps_non_iterable_reply( ) +@patch( + _GET_TYPE_HINTS, + lambda _: {"arg1": str, "arg2": str}, +) @pytest.mark.asyncio @pytest.mark.parametrize("interrupt", [True, False]) async def test_command_interpreter_handle_returns_interupt( diff --git a/tickit/adapters/interpreters/command/command_interpreter.py b/tickit/adapters/interpreters/command/command_interpreter.py index f7548f951..935d9d749 100644 --- a/tickit/adapters/interpreters/command/command_interpreter.py +++ b/tickit/adapters/interpreters/command/command_interpreter.py @@ -1,6 +1,6 @@ from abc import abstractmethod from inspect import getmembers -from typing import AnyStr, AsyncIterable, Optional, Sequence, Tuple +from typing import AnyStr, AsyncIterable, Optional, Sequence, Tuple, get_type_hints from tickit.core.adapter import Adapter, Interpreter from tickit.utils.compat.typing_compat import Protocol, runtime_checkable @@ -90,6 +90,10 @@ async def handle( args = command.parse(message) if args is None: continue + args = ( + argtype(arg) + for arg, argtype in zip(args, get_type_hints(method).values()) + ) resp = await method(*args) if not isinstance(resp, AsyncIterable): resp = CommandInterpreter._wrap(resp)