Skip to content

Commit

Permalink
Added support for python types in plugin annotations (PR #8578)
Browse files Browse the repository at this point in the history
Stage 2 for #8559
  • Loading branch information
wouterdb authored and inmantaci committed Jan 8, 2025
1 parent caf4da6 commit 7a53c98
Show file tree
Hide file tree
Showing 10 changed files with 316 additions and 20 deletions.
5 changes: 5 additions & 0 deletions changelogs/unreleased/native_plugin_types.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
description: Added support for python types in plugin annotations
change-type: minor
destination-branches: [iso7]
sections:
feature: "{{description}}"
13 changes: 7 additions & 6 deletions mypy-baseline.txt
Original file line number Diff line number Diff line change
Expand Up @@ -124,17 +124,11 @@ src/inmanta/ast/__init__.py:0: error: Function is missing a return type annotati
src/inmanta/ast/type.py:0: error: List item 0 has incompatible type "type[float]"; expected "Callable[[object], Number]" [list-item]
src/inmanta/ast/type.py:0: error: List item 0 has incompatible type "type[float]"; expected "Callable[[object], object]" [list-item]
src/inmanta/ast/type.py:0: error: Incompatible types in assignment (expression has type "Sequence[Callable[[object | None], object]]", base class "Number" defined the type as "Sequence[Callable[[object], Number]]") [assignment]
src/inmanta/ast/type.py:0: error: Function is missing a return type annotation [no-untyped-def]
src/inmanta/ast/type.py:0: note: Use "-> None" if function does not return a value
src/inmanta/ast/type.py:0: error: Call to untyped function "__init__" in typed context [no-untyped-call]
src/inmanta/ast/type.py:0: error: Return type "str | None" of "type_string" incompatible with return type "str" in supertype "List" [override]
src/inmanta/ast/type.py:0: error: Function is missing a type annotation for one or more arguments [no-untyped-def]
src/inmanta/ast/type.py:0: error: Function is missing a return type annotation [no-untyped-def]
src/inmanta/ast/type.py:0: error: Function is missing a return type annotation [no-untyped-def]
src/inmanta/ast/type.py:0: error: Function is missing a type annotation [no-untyped-def]
src/inmanta/ast/type.py:0: error: Item "Locatable" of "Locatable | None" has no attribute "name" [union-attr]
src/inmanta/ast/type.py:0: error: Item "None" of "Locatable | None" has no attribute "name" [union-attr]
src/inmanta/ast/type.py:0: error: Call to untyped function "List" in typed context [no-untyped-call]
src/inmanta/ast/type.py:0: error: "List" expects no type arguments, but 1 given [type-arg]
src/inmanta/ast/type.py:0: error: Incompatible types in assignment (expression has type "list[Never]", variable has type "inmanta.ast.type.List") [assignment]
src/inmanta/ast/type.py:0: error: "inmanta.ast.type.List" has no attribute "append" [attr-defined]
Expand Down Expand Up @@ -460,6 +454,7 @@ src/inmanta/protocol/endpoints.py:0: error: Return type "Coroutine[Any, Any, Bas
src/inmanta/parser/cache.py:0: error: Argument 1 to "ASTUnpickler" has incompatible type "BufferedReader"; expected "BytesIO" [arg-type]
src/inmanta/loader.py:0: error: Missing type parameters for generic type Module [type-arg]
src/inmanta/loader.py:0: error: Argument 1 to "getsourcefile" has incompatible type "object"; expected Module | type[Any] | MethodType | FunctionType | TracebackType | FrameType | CodeType | Callable[..., Any] [arg-type]
src/inmanta/loader.py:0: error: Function is missing a type annotation [no-untyped-def]
src/inmanta/loader.py:0: error: Return type "bytes" of "get_source" incompatible with return type "str | None" in supertype "InspectLoader" [override]
src/inmanta/loader.py:0: error: No return value expected [return-value]
src/inmanta/loader.py:0: error: Incompatible types in assignment (expression has type "None", variable has type "PluginModuleFinder") [assignment]
Expand Down Expand Up @@ -756,7 +751,13 @@ src/inmanta/protocol/rest/client.py:0: error: Argument "result" to "Result" has
src/inmanta/plugins.py:0: error: Missing type parameters for generic type "ResultVariable" [type-arg]
src/inmanta/plugins.py:0: error: Argument 1 to "add_function" of "PluginMeta" has incompatible type "PluginMeta"; expected "type[Plugin]" [arg-type]
src/inmanta/plugins.py:0: error: "type[Plugin]" has no attribute "__fq_plugin_name__" [attr-defined]
src/inmanta/plugins.py:0: error: Non-overlapping identity check (left operand type: "type[object]", right operand type: "<typing special form>") [comparison-overlap]
src/inmanta/plugins.py:0: error: Argument 1 to "issubclass" has incompatible type "Any | None"; expected "type" [arg-type]
src/inmanta/plugins.py:0: error: Argument 1 to "issubclass" has incompatible type "Any | None"; expected "type" [arg-type]
src/inmanta/plugins.py:0: error: Incompatible types in assignment (expression has type "tuple[Any, ...]", variable has type "list[type[object]]") [assignment]
src/inmanta/plugins.py:0: error: Argument 1 to "issubclass" has incompatible type "Any | None"; expected "type" [arg-type]
src/inmanta/plugins.py:0: error: Invalid index type "object" for "dict[str | None, Type]"; expected type "str | None" [index]
src/inmanta/plugins.py:0: error: Argument 1 to "to_dsl_type" has incompatible type "object"; expected "type[object]" [arg-type]
src/inmanta/plugins.py:0: error: Incompatible return value type (got "Type | None", expected "Type") [return-value]
src/inmanta/plugins.py:0: error: "type[Plugin]" has no attribute "__function__" [attr-defined]
src/inmanta/plugins.py:0: error: "type[Plugin]" has no attribute "__function__" [attr-defined]
Expand Down
34 changes: 31 additions & 3 deletions src/inmanta/ast/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,15 @@ def with_base_type(self, base_type: "Type") -> "Type":
"""
return base_type

def __eq__(self, other: object) -> bool:
if type(self) != Type: # noqa: E721
# Not for children
return NotImplemented
return type(self) == type(other) # noqa: E721

def __hash__(self) -> int:
return hash(type(self))


class NamedType(Type, Named):
def get_double_defined_exception(self, other: "NamedType") -> "DuplicateException":
Expand All @@ -105,6 +114,14 @@ def get_double_defined_exception(self, other: "NamedType") -> "DuplicateExceptio
def type_string(self) -> str:
return self.get_full_name()

def __eq__(self, other: object) -> bool:
if not isinstance(other, NamedType):
return False
return self.get_full_name() == other.get_full_name()

def __hash__(self) -> int:
return hash(self.get_full_name())


@stable_api
class NullableType(Type):
Expand Down Expand Up @@ -370,6 +387,9 @@ def is_primitive(self) -> bool:
def get_location(self) -> None:
return None

def __eq__(self, other: object) -> bool:
return type(self) == type(other) # noqa: E721


@stable_api
class List(Type):
Expand All @@ -378,7 +398,7 @@ class List(Type):
This class refers to the list type used in plugin annotations. For the list type in the Inmanta DSL, see `LiteralList`.
"""

def __init__(self):
def __init__(self) -> None:
Type.__init__(self)

def validate(self, value: Optional[object]) -> bool:
Expand All @@ -403,6 +423,9 @@ def type_string_internal(self) -> str:
def get_location(self) -> None:
return None

def __eq__(self, other: object) -> bool:
return type(self) == type(other) # noqa: E721


@stable_api
class TypedList(List):
Expand Down Expand Up @@ -544,6 +567,11 @@ def type_string_internal(self) -> str:
def get_location(self) -> None:
return None

def __eq__(self, other: object) -> bool:
if not isinstance(other, TypedDict):
return NotImplemented
return self.element_type == other.element_type


@stable_api
class LiteralDict(TypedDict):
Expand Down Expand Up @@ -622,7 +650,7 @@ def normalize(self) -> None:
assert self.expression is not None
self.expression.normalize()

def set_constraint(self, expression) -> None:
def set_constraint(self, expression: "ExpressionStatement") -> None:
"""
Set the constraint for this type. This baseclass for constraint
types requires the constraint to be set as a regex that can be
Expand All @@ -631,7 +659,7 @@ def set_constraint(self, expression) -> None:
self.expression = expression
self._constraint = create_function(self, expression)

def get_constraint(self):
def get_constraint(self) -> "ExpressionStatement | None":
"""
Get the string representation of the constraint
"""
Expand Down
10 changes: 10 additions & 0 deletions src/inmanta/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,16 @@ class ModuleSource:
source: Optional[bytes] = None
_client: Optional["protocol.SyncClient"] = None

def __lt__(self, other):
if not isinstance(other, ModuleSource):
return NotImplemented
return (self.name, self.hash_value, self.is_byte_code) < (other.name, other.hash_value, other.is_byte_code)

def __eq__(self, other: object) -> bool:
if not isinstance(other, ModuleSource):
return False
return (self.name, self.hash_value, self.is_byte_code) == (other.name, other.hash_value, other.is_byte_code)

def get_source_code(self) -> bytes:
"""Load the source code"""
if self.source is not None:
Expand Down
145 changes: 134 additions & 11 deletions src/inmanta/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,29 @@
import asyncio
import collections.abc
import inspect
import numbers
import os
import subprocess
import typing
import warnings
from collections import abc
from typing import TYPE_CHECKING, Any, Callable, Literal, Mapping, Optional, Sequence, Type, TypeVar
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Type, TypeVar

import typing_inspect

import inmanta.ast.type as inmanta_type
from inmanta import const, protocol, util
from inmanta.ast import LocatableString, Location, Namespace, Range, RuntimeException, TypeNotFoundException, WithComment
from inmanta.ast import (
LocatableString,
Location,
Namespace,
Range,
RuntimeException,
TypeNotFoundException,
TypingException,
WithComment,
)
from inmanta.ast.type import NamedType
from inmanta.config import Config
from inmanta.execute.proxy import DynamicProxy
Expand Down Expand Up @@ -201,6 +215,9 @@ def type_string(self) -> str:
def type_string_internal(self) -> str:
return self.type_string()

def __eq__(self, other: object) -> bool:
return type(self) == type(other) # noqa: E721


# Define some types which are used in the context of plugins.
PLUGIN_TYPES = {
Expand All @@ -210,6 +227,109 @@ def type_string_internal(self) -> str:
None: Null(), # Only NoneValue will pass validation
}

python_to_model = {
str: inmanta_type.String(),
float: inmanta_type.Float(),
numbers.Number: inmanta_type.Number(),
int: inmanta_type.Integer(),
bool: inmanta_type.Bool(),
dict: inmanta_type.TypedDict(inmanta_type.Type()),
typing.Mapping: inmanta_type.TypedDict(inmanta_type.Type()),
Mapping: inmanta_type.TypedDict(inmanta_type.Type()),
list: inmanta_type.List(),
typing.Sequence: inmanta_type.List(),
Sequence: inmanta_type.List(),
object: inmanta_type.Type(),
}


def to_dsl_type(python_type: type[object]) -> inmanta_type.Type:
"""
Convert a python type annotation to an Inmanta DSL type annotation.
:param python_type: The evaluated python type as provided in the Python type annotation.
"""
# Any to any
if python_type is typing.Any:
return inmanta_type.Type()

# None to None
if python_type is type(None) or python_type is None:
return Null()

# Unions and optionals
if typing_inspect.is_union_type(python_type):
# Optional type
if typing_inspect.is_optional_type(python_type):
other_types = [tt for tt in typing.get_args(python_type) if not typing_inspect.is_optional_type(tt)]
if len(other_types) == 0:
# Probably not possible
return Null()
if len(other_types) == 1:
return inmanta_type.NullableType(to_dsl_type(other_types[0]))
# TODO: optional unions
return inmanta_type.Type()
else:
# TODO: unions
return inmanta_type.Type()
# bases: Sequence[inmanta.ast.type.Type] = [to_dsl_type(arg) for arg in typing.get_args(python_type)]
# return inmanta.ast.type.Union(bases)

# Lists and dicts
if typing_inspect.is_generic_type(python_type):
origin = typing.get_origin(python_type)

# dict
if issubclass(origin, Mapping):
if origin in [collections.abc.Mapping, dict, typing.Mapping]:
args = typing_inspect.get_args(python_type)
if not args:
return inmanta_type.TypedDict(inmanta_type.Type())

if not issubclass(args[0], str):
raise TypingException(
None, f"invalid type {python_type}, the keys of any dict should be 'str', got {args[0]} instead"
)

if len(args) == 1:
return inmanta_type.TypedDict(inmanta_type.Type())

return inmanta_type.TypedDict(to_dsl_type(args[1]))
else:
raise TypingException(None, f"invalid type {python_type}, dictionary types should be Mapping or dict")

# List
if issubclass(origin, Sequence):
if origin in [collections.abc.Sequence, list, typing.Sequence]:
args = typing.get_args(python_type)
if not args:
return inmanta_type.List()
return inmanta_type.TypedList(to_dsl_type(args[0]))
else:
raise TypingException(None, f"invalid type {python_type}, list types should be Sequence or list")

# Set
if issubclass(origin, collections.abc.Set):
raise TypingException(None, f"invalid type {python_type}, set is not supported on the plugin boundary")

# TODO annotated types
# if typing.get_origin(t) is typing.Annotated:
# args: Sequence[object] = typing.get_args(python_type)
# inmanta_types: Sequence[plugin_typing.InmantaType] =
# [arg if isinstance(arg, plugin_typing.InmantaType) for arg in args]
# if inmanta_types:
# if len(inmanta_types) > 1:
# # TODO
# raise Exception()
# # TODO
# return parse_dsl_type(inmanta_types[0].dsl_type)
# # the annotation doesn't concern us => use base type
# return to_dsl_type(args[0])
if python_type in python_to_model:
return python_to_model[python_type]

return inmanta_type.Type()


class PluginValue:
"""
Expand Down Expand Up @@ -258,15 +378,18 @@ def resolve_type(self, plugin: "Plugin", resolver: Namespace) -> inmanta_type.Ty
return self._resolved_type

if not isinstance(self.type_expression, str):
raise RuntimeException(
stmt=None,
msg="Bad annotation in plugin %s for %s, expected str but got %s (%s)"
% (plugin.get_full_name(), self.VALUE_NAME, type(self.type_expression).__name__, self.type_expression),
)

plugin_line: Range = Range(plugin.location.file, plugin.location.lnr, 1, plugin.location.lnr + 1, 1)
locatable_type: LocatableString = LocatableString(self.type_expression, plugin_line, 0, resolver)
self._resolved_type = inmanta_type.resolve_type(locatable_type, resolver)
if isinstance(self.type_expression, type) or typing.get_origin(self.type_expression) is not None:
self._resolved_type = to_dsl_type(self.type_expression)
else:
raise RuntimeException(
stmt=None,
msg="Bad annotation in plugin %s for %s, expected str or python type but got %s (%s)"
% (plugin.get_full_name(), self.VALUE_NAME, type(self.type_expression).__name__, self.type_expression),
)
else:
plugin_line: Range = Range(plugin.location.file, plugin.location.lnr, 1, plugin.location.lnr + 1, 1)
locatable_type: LocatableString = LocatableString(self.type_expression, plugin_line, 0, resolver)
self._resolved_type = inmanta_type.resolve_type(locatable_type, resolver)
return self._resolved_type

def validate(self, value: object) -> bool:
Expand Down
Loading

0 comments on commit 7a53c98

Please sign in to comment.