Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Codegen: parametrized pragmas #17532

Merged
merged 4 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion misc/codegen/generators/qlgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _get_doc(cls: schema.Class, prop: schema.Property, plural=None):
return format.format(**{noun: transform(noun) for noun in nouns})

prop_name = _humanize(prop.name)
class_name = cls.default_doc_name or _humanize(inflection.underscore(cls.name))
class_name = cls.pragmas.get("ql_default_doc_name", _humanize(inflection.underscore(cls.name)))
if prop.is_predicate:
return f"this {class_name} {prop_name}"
if plural is not None:
Expand Down
2 changes: 1 addition & 1 deletion misc/codegen/generators/rusttestgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def generate(opts, renderer):
continue
assert not adding_code, "Unterminated code block in docstring: " + "\n".join(cls.doc)
test_name = inflection.underscore(cls.name)
signature = cls.rust_doc_test_function
signature = cls.pragmas.get("rust_doc_test_signature", "() -> ()")
fn = signature and Function(f"test_{test_name}", signature)
if fn:
indent = 4 * " "
Expand Down
32 changes: 22 additions & 10 deletions misc/codegen/lib/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,15 @@ class Kind(Enum):
name: Optional[str] = None
type: Optional[str] = None
is_child: bool = False
pragmas: List[str] = field(default_factory=list)
pragmas: List[str] | Dict[str, object] = field(default_factory=dict)
doc: Optional[str] = None
description: List[str] = field(default_factory=list)
synth: bool = False

def __post_init__(self):
if not isinstance(self.pragmas, dict):
self.pragmas = dict.fromkeys(self.pragmas, None)

@property
def is_single(self) -> bool:
return self.kind == self.Kind.SINGLE
Expand Down Expand Up @@ -88,14 +92,14 @@ class Class:
derived: Set[str] = field(default_factory=set)
properties: List[Property] = field(default_factory=list)
group: str = ""
pragmas: List[str] = field(default_factory=list)
synth: Optional[Union[SynthInfo, bool]] = None
"""^^^ filled with `True` for non-final classes with only synthesized final descendants """
pragmas: List[str] | Dict[str, object] = field(default_factory=dict)
doc: List[str] = field(default_factory=list)
default_doc_name: Optional[str] = None
hideable: bool = False
test_with: Optional[str] = None
rust_doc_test_function: Optional["FunctionInfo"] = "() -> ()" # TODO: parametrized pragmas

def __post_init__(self):
if not isinstance(self.pragmas, dict):
self.pragmas = dict.fromkeys(self.pragmas, None)

@property
def final(self):
Expand All @@ -108,13 +112,21 @@ def check_types(self, known: typing.Iterable[str]):
_check_type(d, known)
for p in self.properties:
_check_type(p.type, known)
if self.synth is not None:
_check_type(self.synth.from_class, known)
if self.synth.on_arguments is not None:
for t in self.synth.on_arguments.values():
if "synth" in self.pragmas:
synth = self.pragmas["synth"]
_check_type(synth.from_class, known)
if synth.on_arguments is not None:
for t in synth.on_arguments.values():
_check_type(t, known)
_check_type(self.test_with, known)

@property
def synth(self) -> SynthInfo | bool | None:
return self.pragmas.get("synth")

def mark_synth(self):
self.pragmas.setdefault("synth", True)


@dataclass
class Schema:
Expand Down
128 changes: 83 additions & 45 deletions misc/codegen/lib/schemadefs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable as _Callable, List as _List
from typing import Callable as _Callable, Dict as _Dict, ClassVar as _ClassVar
from misc.codegen.lib import schema as _schema
import inspect as _inspect
from dataclasses import dataclass as _dataclass
Expand Down Expand Up @@ -62,11 +62,14 @@ def include(source: str):
_inspect.currentframe().f_back.f_locals.setdefault("includes", []).append(source)


@_dataclass
class _Namespace:
""" simple namespacing mechanism """
name: str

def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def add(self, pragma: "_PragmaBase", key: str | None = None):
self.__dict__[pragma.pragma] = pragma
pragma.pragma = key or f"{self.name}_{pragma.pragma}"


@_dataclass
Expand All @@ -77,51 +80,86 @@ def modify(self, prop: _schema.Property):
prop.synth = self.synth

def negate(self) -> "PropertyModifier":
return _SynthModifier(False)
return _SynthModifier(self.name, False)


qltest = _Namespace("qltest")
ql = _Namespace("ql")
cpp = _Namespace("cpp")
rust = _Namespace("rust")
synth = _SynthModifier("synth")


@_dataclass
class _PragmaBase:
pragma: str


@_dataclass
class _ClassPragma(_PragmaBase):
""" A class pragma.
For schema classes it acts as a python decorator with `@`.
"""
value: object = None

def __call__(self, cls: type) -> type:
""" use this pragma as a decorator on classes """
# not using hasattr as we don't want to land on inherited pragmas
if "_pragmas" not in cls.__dict__:
cls._pragmas = {}
self._apply(cls._pragmas)
return cls

qltest = _Namespace()
ql = _Namespace()
cpp = _Namespace()
rust = _Namespace()
synth = _SynthModifier()
def _apply(self, pragmas: _Dict[str, object]) -> None:
pragmas[self.pragma] = self.value


@_dataclass
class _Pragma(_schema.PropertyModifier):
class _Pragma(_ClassPragma, _schema.PropertyModifier):
""" A class or property pragma.
For properties, it functions similarly to a `_PropertyModifier` with `|`, adding the pragma.
For schema classes it acts as a python decorator with `@`.
"""
pragma: str
remove: bool = False

def __post_init__(self):
namespace, _, name = self.pragma.partition('_')
setattr(globals()[namespace], name, self)

def modify(self, prop: _schema.Property):
self._apply(prop.pragmas)

def negate(self) -> "PropertyModifier":
return _Pragma(self.pragma, remove=True)

def __call__(self, cls: type) -> type:
""" use this pragma as a decorator on classes """
if "_pragmas" in cls.__dict__: # not using hasattr as we don't want to land on inherited pragmas
self._apply(cls._pragmas)
elif not self.remove:
cls._pragmas = [self.pragma]
return cls

def _apply(self, pragmas: _List[str]) -> None:
def _apply(self, pragmas: _Dict[str, object]) -> None:
if self.remove:
try:
pragmas.remove(self.pragma)
except ValueError:
pass
pragmas.pop(self.pragma, None)
else:
pragmas.append(self.pragma)
super()._apply(pragmas)


@_dataclass
class _ParametrizedClassPragma(_PragmaBase):
""" A class parametrized pragma.
Needs to be applied to a parameter to give a class pragma.
"""
_pragma_class: _ClassVar[type] = _ClassPragma

function: _Callable[..., object] = None

def __post_init__(self):
self.__signature__ = _inspect.signature(self.function).replace(return_annotation=self._pragma_class)

def __call__(self, *args, **kwargs) -> _pragma_class:
return self._pragma_class(self.pragma, value=self.function(*args, **kwargs))


@_dataclass
class _ParametrizedPragma(_ParametrizedClassPragma):
""" A class or property parametrized pragma.
Needs to be applied to a parameter to give a pragma.
"""
_pragma_class: _ClassVar[type] = _Pragma

def __invert__(self) -> _Pragma:
return _Pragma(self.pragma, remove=True)


class _Optionalizer(_schema.PropertyModifier):
Expand Down Expand Up @@ -190,30 +228,30 @@ def f(cls: type) -> type:

use_for_null = _annotate(null=True)

_Pragma("qltest_skip")
_Pragma("qltest_collapse_hierarchy")
_Pragma("qltest_uncollapse_hierarchy")
qltest.test_with = lambda cls: _annotate(test_with=cls)
qltest.add(_Pragma("skip"))
qltest.add(_ClassPragma("collapse_hierarchy"))
qltest.add(_ClassPragma("uncollapse_hierarchy"))
qltest.test_with = lambda cls: _annotate(test_with=cls) # inheritable

ql.default_doc_name = lambda doc: _annotate(doc_name=doc)
ql.hideable = _annotate(hideable=True)
_Pragma("ql_internal")
ql.add(_ParametrizedClassPragma("default_doc_name", lambda doc: doc))
ql.hideable = _annotate(hideable=True) # inheritable
ql.add(_Pragma("internal"))

_Pragma("cpp_skip")
cpp.add(_Pragma("skip"))

_Pragma("rust_skip_doc_test")
rust.add(_Pragma("skip_doc_test"))

rust.doc_test_signature = lambda signature: _annotate(rust_doc_test_function=signature)
rust.add(_ParametrizedClassPragma("doc_test_signature", lambda signature: signature))


def group(name: str = "") -> _ClassDecorator:
return _annotate(group=name)


synth.from_class = lambda ref: _annotate(synth=_schema.SynthInfo(
from_class=_schema.get_type_name(ref)))
synth.on_arguments = lambda **kwargs: _annotate(
synth=_schema.SynthInfo(on_arguments={k: _schema.get_type_name(t) for k, t in kwargs.items()}))
synth.add(_ParametrizedClassPragma("from_class", lambda ref: _schema.SynthInfo(
from_class=_schema.get_type_name(ref))), key="synth")
synth.add(_ParametrizedClassPragma("on_arguments", lambda **kwargs:
_schema.SynthInfo(on_arguments={k: _schema.get_type_name(t) for k, t in kwargs.items()})), key="synth")


class _PropertyModifierList(_schema.PropertyModifier):
Expand Down Expand Up @@ -251,9 +289,9 @@ def decorator(cls: type) -> _PropertyAnnotation:
if cls.__doc__ is not None:
annotated_cls.__doc__ = cls.__doc__
old_pragmas = getattr(annotated_cls, "_pragmas", None)
new_pragmas = getattr(cls, "_pragmas", [])
new_pragmas = getattr(cls, "_pragmas", {})
if old_pragmas:
old_pragmas.extend(new_pragmas)
old_pragmas.update(new_pragmas)
else:
annotated_cls._pragmas = new_pragmas
for a, v in cls.__dict__.items():
Expand Down
10 changes: 3 additions & 7 deletions misc/codegen/loaders/schemaloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,12 @@ def _get_class(cls: type) -> schema.Class:
hideable=getattr(cls, "_hideable", False),
test_with=_get_name(getattr(cls, "_test_with", None)),
# in the following we don't use `getattr` to avoid inheriting
pragmas=cls.__dict__.get("_pragmas", []),
synth=cls.__dict__.get("_synth", None),
pragmas=cls.__dict__.get("_pragmas", {}),
properties=[
a | _PropertyNamer(n)
for n, a in cls.__dict__.get("__annotations__", {}).items()
],
doc=schema.split_doc(cls.__doc__),
default_doc_name=cls.__dict__.get("_doc_name"),
rust_doc_test_function=cls.__dict__.get("_rust_doc_test_function",
schema.Class.rust_doc_test_function)
)


Expand Down Expand Up @@ -103,8 +99,8 @@ def fill_is_synth(name: str):
fill_is_synth(root)

for name, cls in classes.items():
if cls.synth is None and is_synth[name]:
cls.synth = True
if is_synth[name]:
cls.mark_synth()


def _fill_hideable_information(classes: typing.Dict[str, schema.Class]):
Expand Down
6 changes: 3 additions & 3 deletions misc/codegen/test/test_cppgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,15 @@ def test_synth_classes_ignored(generate):
assert generate([
schema.Class(
name="W",
synth=schema.SynthInfo(),
pragmas={"synth": schema.SynthInfo()},
),
schema.Class(
name="X",
synth=schema.SynthInfo(from_class="A"),
pragmas={"synth": schema.SynthInfo(from_class="A")},
),
schema.Class(
name="Y",
synth=schema.SynthInfo(on_arguments={"a": "A", "b": "int"}),
pragmas={"synth": schema.SynthInfo(on_arguments={"a": "A", "b": "int"})},
),
schema.Class(
name="Z",
Expand Down
8 changes: 4 additions & 4 deletions misc/codegen/test/test_dbschemegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,9 +536,9 @@ def test_null_class(generate):

def test_synth_classes_ignored(generate):
assert generate([
schema.Class(name="A", synth=schema.SynthInfo()),
schema.Class(name="B", synth=schema.SynthInfo(from_class="A")),
schema.Class(name="C", synth=schema.SynthInfo(on_arguments={"x": "A"})),
schema.Class(name="A", pragmas={"synth": schema.SynthInfo()}),
schema.Class(name="B", pragmas={"synth": schema.SynthInfo(from_class="A")}),
schema.Class(name="C", pragmas={"synth": schema.SynthInfo(on_arguments={"x": "A"})}),
]) == dbscheme.Scheme(
src=schema_file.name,
includes=[],
Expand All @@ -549,7 +549,7 @@ def test_synth_classes_ignored(generate):
def test_synth_derived_classes_ignored(generate):
assert generate([
schema.Class(name="A", derived={"B", "C"}),
schema.Class(name="B", bases=["A"], synth=schema.SynthInfo()),
schema.Class(name="B", bases=["A"], pragmas={"synth": schema.SynthInfo()}),
schema.Class(name="C", bases=["A"]),
]) == dbscheme.Scheme(
src=schema_file.name,
Expand Down
6 changes: 3 additions & 3 deletions misc/codegen/test/test_qlgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,7 @@ def test_property_on_class_with_default_doc_name(generate_classes):
assert generate_classes([
schema.Class("MyObject", properties=[
schema.SingleProperty("foo", "bar")],
default_doc_name="baz"),
pragmas={"ql_default_doc_name": "baz"}),
]) == {
"MyObject.qll": (a_ql_class_public(name="MyObject"),
a_ql_stub(name="MyObject"),
Expand All @@ -937,7 +937,7 @@ def test_property_on_class_with_default_doc_name(generate_classes):

def test_stub_on_class_with_synth_from_class(generate_classes):
assert generate_classes([
schema.Class("MyObject", synth=schema.SynthInfo(from_class="A"),
schema.Class("MyObject", pragmas={"synth": schema.SynthInfo(from_class="A")},
properties=[schema.SingleProperty("foo", "bar")]),
]) == {
"MyObject.qll": (a_ql_class_public(name="MyObject"), a_ql_stub(name="MyObject", synth_accessors=[
Expand All @@ -952,7 +952,7 @@ def test_stub_on_class_with_synth_from_class(generate_classes):

def test_stub_on_class_with_synth_on_arguments(generate_classes):
assert generate_classes([
schema.Class("MyObject", synth=schema.SynthInfo(on_arguments={"base": "A", "index": "int", "label": "string"}),
schema.Class("MyObject", pragmas={"synth": schema.SynthInfo(on_arguments={"base": "A", "index": "int", "label": "string"})},
properties=[schema.SingleProperty("foo", "bar")]),
]) == {
"MyObject.qll": (a_ql_class_public(name="MyObject"), a_ql_stub(name="MyObject", synth_accessors=[
Expand Down
Loading
Loading