From 73381844070149f54ed53fc75fe7ae940e4ed569 Mon Sep 17 00:00:00 2001 From: wouter Date: Tue, 7 Jan 2025 16:14:06 +0100 Subject: [PATCH] Added support for python types in plugin annotations (PR #8559) Basic native python types For the documentation, I would add a separate ticket Strike through any lines that are not applicable (`~~line~~`) then check the box - [x] Attached issue to pull request - [x] Changelog entry - [x] Type annotations are present - [x] Code is clear and sufficiently documented - [x] No (preventable) type errors (check using make mypy or make mypy-diff) - [x] Sufficient test cases (reproduces the bug/tests the requested feature) - [x] Correct, in line with design - [ ] End user documentation is included or an issue is created for end-user documentation (add ref to issue here: ) - [ ] If this PR fixes a race condition in the test suite, also push the fix to the relevant stable branche(s) (see [test-fixes](https://internal.inmanta.com/development/core/tasks/build-master.html#test-fixes) for more info) --- changelogs/unreleased/native_plugin_types.yml | 5 + mypy-baseline.txt | 13 +- src/inmanta/ast/type.py | 34 +++- src/inmanta/loader.py | 10 ++ src/inmanta/plugins.py | 146 ++++++++++++++++-- tests/compiler/test_plugin_types.py | 66 ++++++++ tests/compiler/test_plugins.py | 22 +++ .../plugin_native_types/model/_init.cf | 0 .../modules/plugin_native_types/module.yml | 3 + .../plugin_native_types/plugins/__init__.py | 34 ++++ 10 files changed, 313 insertions(+), 20 deletions(-) create mode 100644 changelogs/unreleased/native_plugin_types.yml create mode 100644 tests/compiler/test_plugin_types.py create mode 100644 tests/data/modules/plugin_native_types/model/_init.cf create mode 100644 tests/data/modules/plugin_native_types/module.yml create mode 100644 tests/data/modules/plugin_native_types/plugins/__init__.py diff --git a/changelogs/unreleased/native_plugin_types.yml b/changelogs/unreleased/native_plugin_types.yml new file mode 100644 index 0000000000..4665ea777b --- /dev/null +++ b/changelogs/unreleased/native_plugin_types.yml @@ -0,0 +1,5 @@ +description: Added support for python types in plugin annotations +change-type: minor +destination-branches: [master, iso7] +sections: + feature: "{{description}}" diff --git a/mypy-baseline.txt b/mypy-baseline.txt index b4f21d2538..9a7b372afe 100644 --- a/mypy-baseline.txt +++ b/mypy-baseline.txt @@ -124,17 +124,11 @@ src/inmanta/ast/__init__.py:0: error: Function is missing a return type annotati src/inmanta/ast/type.py:0: error: List item 0 has incompatible type "type[float]"; expected "Callable[[object], Number]" [list-item] src/inmanta/ast/type.py:0: error: List item 0 has incompatible type "type[float]"; expected "Callable[[object], object]" [list-item] src/inmanta/ast/type.py:0: error: Incompatible types in assignment (expression has type "Sequence[Callable[[object | None], object]]", base class "Number" defined the type as "Sequence[Callable[[object], Number]]") [assignment] -src/inmanta/ast/type.py:0: error: Function is missing a return type annotation [no-untyped-def] -src/inmanta/ast/type.py:0: note: Use "-> None" if function does not return a value -src/inmanta/ast/type.py:0: error: Call to untyped function "__init__" in typed context [no-untyped-call] src/inmanta/ast/type.py:0: error: Return type "str | None" of "type_string" incompatible with return type "str" in supertype "List" [override] -src/inmanta/ast/type.py:0: error: Function is missing a type annotation for one or more arguments [no-untyped-def] -src/inmanta/ast/type.py:0: error: Function is missing a return type annotation [no-untyped-def] src/inmanta/ast/type.py:0: error: Function is missing a return type annotation [no-untyped-def] src/inmanta/ast/type.py:0: error: Function is missing a type annotation [no-untyped-def] src/inmanta/ast/type.py:0: error: Item "Locatable" of "Locatable | None" has no attribute "name" [union-attr] src/inmanta/ast/type.py:0: error: Item "None" of "Locatable | None" has no attribute "name" [union-attr] -src/inmanta/ast/type.py:0: error: Call to untyped function "List" in typed context [no-untyped-call] src/inmanta/ast/type.py:0: error: "List" expects no type arguments, but 1 given [type-arg] src/inmanta/ast/type.py:0: error: Incompatible types in assignment (expression has type "list[Never]", variable has type "inmanta.ast.type.List") [assignment] src/inmanta/ast/type.py:0: error: "inmanta.ast.type.List" has no attribute "append" [attr-defined] @@ -460,6 +454,7 @@ src/inmanta/protocol/endpoints.py:0: error: Return type "Coroutine[Any, Any, Bas src/inmanta/parser/cache.py:0: error: Argument 1 to "ASTUnpickler" has incompatible type "BufferedReader"; expected "BytesIO" [arg-type] src/inmanta/loader.py:0: error: Missing type parameters for generic type Module [type-arg] src/inmanta/loader.py:0: error: Argument 1 to "getsourcefile" has incompatible type "object"; expected Module | type[Any] | MethodType | FunctionType | TracebackType | FrameType | CodeType | Callable[..., Any] [arg-type] +src/inmanta/loader.py:0: error: Function is missing a type annotation [no-untyped-def] src/inmanta/loader.py:0: error: Return type "bytes" of "get_source" incompatible with return type "str | None" in supertype "InspectLoader" [override] src/inmanta/loader.py:0: error: No return value expected [return-value] src/inmanta/loader.py:0: error: Incompatible types in assignment (expression has type "None", variable has type "PluginModuleFinder") [assignment] @@ -756,7 +751,13 @@ src/inmanta/protocol/rest/client.py:0: error: Argument "result" to "Result" has src/inmanta/plugins.py:0: error: Missing type parameters for generic type "ResultVariable" [type-arg] src/inmanta/plugins.py:0: error: Argument 1 to "add_function" of "PluginMeta" has incompatible type "PluginMeta"; expected "type[Plugin]" [arg-type] src/inmanta/plugins.py:0: error: "type[Plugin]" has no attribute "__fq_plugin_name__" [attr-defined] +src/inmanta/plugins.py:0: error: Non-overlapping identity check (left operand type: "type[object]", right operand type: "") [comparison-overlap] +src/inmanta/plugins.py:0: error: Argument 1 to "issubclass" has incompatible type "Any | None"; expected "type" [arg-type] +src/inmanta/plugins.py:0: error: Argument 1 to "issubclass" has incompatible type "Any | None"; expected "type" [arg-type] +src/inmanta/plugins.py:0: error: Incompatible types in assignment (expression has type "tuple[Any, ...]", variable has type "list[type[object]]") [assignment] +src/inmanta/plugins.py:0: error: Argument 1 to "issubclass" has incompatible type "Any | None"; expected "type" [arg-type] src/inmanta/plugins.py:0: error: Invalid index type "object" for "dict[str | None, Type]"; expected type "str | None" [index] +src/inmanta/plugins.py:0: error: Argument 1 to "to_dsl_type" has incompatible type "object"; expected "type[object]" [arg-type] src/inmanta/plugins.py:0: error: Incompatible return value type (got "Type | None", expected "Type") [return-value] src/inmanta/plugins.py:0: error: "type[Plugin]" has no attribute "__function__" [attr-defined] src/inmanta/plugins.py:0: error: "type[Plugin]" has no attribute "__function__" [attr-defined] diff --git a/src/inmanta/ast/type.py b/src/inmanta/ast/type.py index 1659f71349..30b6d06abf 100644 --- a/src/inmanta/ast/type.py +++ b/src/inmanta/ast/type.py @@ -96,6 +96,15 @@ def with_base_type(self, base_type: "Type") -> "Type": """ return base_type + def __eq__(self, other: object) -> bool: + if type(self) != Type: # noqa: E721 + # Not for children + return NotImplemented + return type(self) == type(other) # noqa: E721 + + def __hash__(self) -> int: + return hash(type(self)) + class NamedType(Type, Named): def get_double_defined_exception(self, other: "NamedType") -> "DuplicateException": @@ -105,6 +114,14 @@ def get_double_defined_exception(self, other: "NamedType") -> "DuplicateExceptio def type_string(self) -> str: return self.get_full_name() + def __eq__(self, other: object) -> bool: + if not isinstance(other, NamedType): + return False + return self.get_full_name() == other.get_full_name() + + def __hash__(self) -> int: + return hash(self.get_full_name()) + @stable_api class NullableType(Type): @@ -370,6 +387,9 @@ def is_primitive(self) -> bool: def get_location(self) -> None: return None + def __eq__(self, other: object) -> bool: + return type(self) == type(other) # noqa: E721 + @stable_api class List(Type): @@ -378,7 +398,7 @@ class List(Type): This class refers to the list type used in plugin annotations. For the list type in the Inmanta DSL, see `LiteralList`. """ - def __init__(self): + def __init__(self) -> None: Type.__init__(self) def validate(self, value: Optional[object]) -> bool: @@ -403,6 +423,9 @@ def type_string_internal(self) -> str: def get_location(self) -> None: return None + def __eq__(self, other: object) -> bool: + return type(self) == type(other) # noqa: E721 + @stable_api class TypedList(List): @@ -544,6 +567,11 @@ def type_string_internal(self) -> str: def get_location(self) -> None: return None + def __eq__(self, other: object) -> bool: + if not isinstance(other, TypedDict): + return NotImplemented + return self.element_type == other.element_type + @stable_api class LiteralDict(TypedDict): @@ -622,7 +650,7 @@ def normalize(self) -> None: assert self.expression is not None self.expression.normalize() - def set_constraint(self, expression) -> None: + def set_constraint(self, expression: "ExpressionStatement") -> None: """ Set the constraint for this type. This baseclass for constraint types requires the constraint to be set as a regex that can be @@ -631,7 +659,7 @@ def set_constraint(self, expression) -> None: self.expression = expression self._constraint = create_function(self, expression) - def get_constraint(self): + def get_constraint(self) -> "ExpressionStatement | None": """ Get the string representation of the constraint """ diff --git a/src/inmanta/loader.py b/src/inmanta/loader.py index 708843c4b6..523ea602b1 100644 --- a/src/inmanta/loader.py +++ b/src/inmanta/loader.py @@ -214,6 +214,16 @@ class ModuleSource: source: Optional[bytes] = None _client: Optional["protocol.SyncClient"] = None + def __lt__(self, other): + if not isinstance(other, ModuleSource): + return NotImplemented + return (self.name, self.hash_value, self.is_byte_code) < (other.name, other.hash_value, other.is_byte_code) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ModuleSource): + return False + return (self.name, self.hash_value, self.is_byte_code) == (other.name, other.hash_value, other.is_byte_code) + def get_source_code(self) -> bytes: """Load the source code""" if self.source is not None: diff --git a/src/inmanta/plugins.py b/src/inmanta/plugins.py index 58d986a887..47560067e0 100644 --- a/src/inmanta/plugins.py +++ b/src/inmanta/plugins.py @@ -19,15 +19,30 @@ import asyncio import collections.abc import inspect +import numbers import os import subprocess +import typing import warnings from collections import abc -from typing import TYPE_CHECKING, Any, Callable, Literal, Mapping, Optional, Sequence, Type, TypeVar +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Type, TypeVar + +import typing_inspect import inmanta.ast.type as inmanta_type from inmanta import const, protocol, util -from inmanta.ast import LocatableString, Location, Namespace, Range, RuntimeException, TypeNotFoundException, WithComment +from inmanta.ast import ( # noqa: F401 Plugin exception is part of the stable api + LocatableString, + Location, + Namespace, + PluginException, + Range, + RuntimeException, + TypeNotFoundException, + TypingException, + WithComment, +) from inmanta.ast.type import NamedType from inmanta.config import Config from inmanta.execute.proxy import DynamicProxy @@ -201,6 +216,9 @@ def type_string(self) -> str: def type_string_internal(self) -> str: return self.type_string() + def __eq__(self, other: object) -> bool: + return type(self) == type(other) # noqa: E721 + # Define some types which are used in the context of plugins. PLUGIN_TYPES = { @@ -210,6 +228,109 @@ def type_string_internal(self) -> str: None: Null(), # Only NoneValue will pass validation } +python_to_model = { + str: inmanta_type.String(), + float: inmanta_type.Float(), + numbers.Number: inmanta_type.Number(), + int: inmanta_type.Integer(), + bool: inmanta_type.Bool(), + dict: inmanta_type.TypedDict(inmanta_type.Type()), + typing.Mapping: inmanta_type.TypedDict(inmanta_type.Type()), + Mapping: inmanta_type.TypedDict(inmanta_type.Type()), + list: inmanta_type.List(), + typing.Sequence: inmanta_type.List(), + Sequence: inmanta_type.List(), + object: inmanta_type.Type(), +} + + +def to_dsl_type(python_type: type[object]) -> inmanta_type.Type: + """ + Convert a python type annotation to an Inmanta DSL type annotation. + + :param python_type: The evaluated python type as provided in the Python type annotation. + """ + # Any to any + if python_type is typing.Any: + return inmanta_type.Type() + + # None to None + if python_type is type(None) or python_type is None: + return Null() + + # Unions and optionals + if typing_inspect.is_union_type(python_type): + # Optional type + if typing_inspect.is_optional_type(python_type): + other_types = [tt for tt in typing.get_args(python_type) if not typing_inspect.is_optional_type(tt)] + if len(other_types) == 0: + # Probably not possible + return Null() + if len(other_types) == 1: + return inmanta_type.NullableType(to_dsl_type(other_types[0])) + # TODO: optional unions + return inmanta_type.Type() + else: + # TODO: unions + return inmanta_type.Type() + # bases: Sequence[inmanta.ast.type.Type] = [to_dsl_type(arg) for arg in typing.get_args(python_type)] + # return inmanta.ast.type.Union(bases) + + # Lists and dicts + if typing_inspect.is_generic_type(python_type): + origin = typing.get_origin(python_type) + + # dict + if issubclass(origin, Mapping): + if origin in [collections.abc.Mapping, dict, typing.Mapping]: + args = typing_inspect.get_args(python_type) + if not args: + return inmanta_type.TypedDict(inmanta_type.Type()) + + if not issubclass(args[0], str): + raise TypingException( + None, f"invalid type {python_type}, the keys of any dict should be 'str', got {args[0]} instead" + ) + + if len(args) == 1: + return inmanta_type.TypedDict(inmanta_type.Type()) + + return inmanta_type.TypedDict(to_dsl_type(args[1])) + else: + raise TypingException(None, f"invalid type {python_type}, dictionary types should be Mapping or dict") + + # List + if issubclass(origin, Sequence): + if origin in [collections.abc.Sequence, list, typing.Sequence]: + args = typing.get_args(python_type) + if not args: + return inmanta_type.List() + return inmanta_type.TypedList(to_dsl_type(args[0])) + else: + raise TypingException(None, f"invalid type {python_type}, list types should be Sequence or list") + + # Set + if issubclass(origin, collections.abc.Set): + raise TypingException(None, f"invalid type {python_type}, set is not supported on the plugin boundary") + + # TODO annotated types + # if typing.get_origin(t) is typing.Annotated: + # args: Sequence[object] = typing.get_args(python_type) + # inmanta_types: Sequence[plugin_typing.InmantaType] = + # [arg if isinstance(arg, plugin_typing.InmantaType) for arg in args] + # if inmanta_types: + # if len(inmanta_types) > 1: + # # TODO + # raise Exception() + # # TODO + # return parse_dsl_type(inmanta_types[0].dsl_type) + # # the annotation doesn't concern us => use base type + # return to_dsl_type(args[0]) + if python_type in python_to_model: + return python_to_model[python_type] + + return inmanta_type.Type() + class PluginValue: """ @@ -258,15 +379,18 @@ def resolve_type(self, plugin: "Plugin", resolver: Namespace) -> inmanta_type.Ty return self._resolved_type if not isinstance(self.type_expression, str): - raise RuntimeException( - stmt=None, - msg="Bad annotation in plugin %s for %s, expected str but got %s (%s)" - % (plugin.get_full_name(), self.VALUE_NAME, type(self.type_expression).__name__, self.type_expression), - ) - - plugin_line: Range = Range(plugin.location.file, plugin.location.lnr, 1, plugin.location.lnr + 1, 1) - locatable_type: LocatableString = LocatableString(self.type_expression, plugin_line, 0, resolver) - self._resolved_type = inmanta_type.resolve_type(locatable_type, resolver) + if isinstance(self.type_expression, type) or typing.get_origin(self.type_expression) is not None: + self._resolved_type = to_dsl_type(self.type_expression) + else: + raise RuntimeException( + stmt=None, + msg="Bad annotation in plugin %s for %s, expected str or python type but got %s (%s)" + % (plugin.get_full_name(), self.VALUE_NAME, type(self.type_expression).__name__, self.type_expression), + ) + else: + plugin_line: Range = Range(plugin.location.file, plugin.location.lnr, 1, plugin.location.lnr + 1, 1) + locatable_type: LocatableString = LocatableString(self.type_expression, plugin_line, 0, resolver) + self._resolved_type = inmanta_type.resolve_type(locatable_type, resolver) return self._resolved_type def validate(self, value: object) -> bool: diff --git a/tests/compiler/test_plugin_types.py b/tests/compiler/test_plugin_types.py new file mode 100644 index 0000000000..8f12a1da5d --- /dev/null +++ b/tests/compiler/test_plugin_types.py @@ -0,0 +1,66 @@ +""" + Copyright 2025 Inmanta + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Contact: code@inmanta.com +""" + +import collections.abc +from typing import Any, Mapping, Sequence, Union + +import pytest + +import inmanta.ast.type as inmanta_type +from inmanta.ast import RuntimeException +from inmanta.plugins import Null, to_dsl_type + + +def test_conversion(): + assert inmanta_type.Integer() == to_dsl_type(int) + assert inmanta_type.Float() == to_dsl_type(float) + assert inmanta_type.NullableType(inmanta_type.Float()) == to_dsl_type(float | None) + assert inmanta_type.List() == to_dsl_type(list) + assert inmanta_type.TypedList(inmanta_type.String()) == to_dsl_type(list[str]) + assert inmanta_type.TypedList(inmanta_type.String()) == to_dsl_type(Sequence[str]) + assert inmanta_type.List() == to_dsl_type(Sequence) + assert inmanta_type.List() == to_dsl_type(collections.abc.Sequence) + assert inmanta_type.TypedList(inmanta_type.String()) == to_dsl_type(collections.abc.Sequence[str]) + assert inmanta_type.TypedDict(inmanta_type.Type()) == to_dsl_type(dict) + assert inmanta_type.TypedDict(inmanta_type.Type()) == to_dsl_type(Mapping) + assert inmanta_type.TypedDict(inmanta_type.String()) == to_dsl_type(dict[str, str]) + assert inmanta_type.TypedDict(inmanta_type.String()) == to_dsl_type(Mapping[str, str]) + + assert inmanta_type.TypedDict(inmanta_type.String()) == to_dsl_type(collections.abc.Mapping[str, str]) + + assert Null() == to_dsl_type(Union[None]) + + assert isinstance(to_dsl_type(Any), inmanta_type.Type) + + with pytest.raises(RuntimeException): + to_dsl_type(dict[int, int]) + + with pytest.raises(RuntimeException): + to_dsl_type(set[str]) + + class CustomList[T](list[T]): + pass + + class CustomDict[K, V](Mapping[K, V]): + pass + + with pytest.raises(RuntimeException): + to_dsl_type(CustomList[str]) + + with pytest.raises(RuntimeException): + to_dsl_type(CustomDict[str, str]) diff --git a/tests/compiler/test_plugins.py b/tests/compiler/test_plugins.py index 102b8624bb..23c5667103 100644 --- a/tests/compiler/test_plugins.py +++ b/tests/compiler/test_plugins.py @@ -416,3 +416,25 @@ def test_context_and_defaults(snippetcompiler: "SnippetCompilationTest") -> None """ ) compiler.do_compile() + + +def test_native_types(snippetcompiler: "SnippetCompilationTest") -> None: + """ + test the use of python types + """ + snippetcompiler.setup_for_snippet( + """ +import plugin_native_types + +a = "b" +a = plugin_native_types::get_from_dict({"a":"b"}, "a") + +none = null +none = plugin_native_types::get_from_dict({"a":"b"}, "B") + +a = plugin_native_types::many_arguments(["a","c","b"], 1) + +none = plugin_native_types::as_none("a") + """ + ) + compiler.do_compile() diff --git a/tests/data/modules/plugin_native_types/model/_init.cf b/tests/data/modules/plugin_native_types/model/_init.cf new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/data/modules/plugin_native_types/module.yml b/tests/data/modules/plugin_native_types/module.yml new file mode 100644 index 0000000000..09d4dcce59 --- /dev/null +++ b/tests/data/modules/plugin_native_types/module.yml @@ -0,0 +1,3 @@ +name: plugin_native_types +license: Apache 2.0 +version: 1.0.0 diff --git a/tests/data/modules/plugin_native_types/plugins/__init__.py b/tests/data/modules/plugin_native_types/plugins/__init__.py new file mode 100644 index 0000000000..a668fcd6d4 --- /dev/null +++ b/tests/data/modules/plugin_native_types/plugins/__init__.py @@ -0,0 +1,34 @@ +""" + Copyright 2023 Inmanta + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + Contact: code@inmanta.com +""" + +from inmanta.plugins import plugin + + +@plugin +def get_from_dict(value: dict[str, str], key: str) -> str | None: + return value.get(key) + + +@plugin +def many_arguments(il: list[str], idx: int) -> str: + return sorted(il)[idx] + + +@plugin +def as_none(value: str) -> None: + pass