Skip to content

Commit

Permalink
[mypyc] Optimize calls to final classes (#17886)
Browse files Browse the repository at this point in the history
Fixes #9612

This change allows to gain more efficiency where classes are annotated
with `@final` bypassing entirely the vtable for method calls and
property accessors.

For example:
In
```python
@Final
class Vector:
    __slots__ = ("_x", "_y")
    def __init__(self, x: i32, y: i32) -> None:
        self._x = x
        self._y = y

    @Property
    def y(self) -> i32:
        return self._y

def test_vector() -> None:
    v3 = Vector(1, 2)
    assert v3.y == 2
```

The call will produce:

```c
...
cpy_r_r6 = CPyDef_Vector___y(cpy_r_r0);
...
```

Instead of:

```c
...
cpy_r_r1 = CPY_GET_ATTR(cpy_r_r0, CPyType_Vector, 2, farm_rush___engine___vectors2___VectorObject, int32_t); /* y */
...
```
(which uses vtable)
  • Loading branch information
jairov4 authored Oct 14, 2024
1 parent 395108d commit cbd96f9
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 26 deletions.
1 change: 1 addition & 0 deletions mypyc/codegen/emitclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,7 @@ def generate_setup_for_class(
emitter.emit_line("}")
else:
emitter.emit_line(f"self->vtable = {vtable_name};")

for i in range(0, len(cl.bitmap_attrs), BITMAP_BITS):
field = emitter.bitmap_field(i)
emitter.emit_line(f"self->{field} = 0;")
Expand Down
42 changes: 24 additions & 18 deletions mypyc/codegen/emitfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from mypyc.ir.pprint import generate_names_for_ir
from mypyc.ir.rtypes import (
RArray,
RInstance,
RStruct,
RTuple,
RType,
Expand Down Expand Up @@ -362,20 +363,23 @@ def visit_get_attr(self, op: GetAttr) -> None:
prefer_method = cl.is_trait and attr_rtype.error_overlap
if cl.get_method(op.attr, prefer_method=prefer_method):
# Properties are essentially methods, so use vtable access for them.
version = "_TRAIT" if cl.is_trait else ""
self.emit_line(
"%s = CPY_GET_ATTR%s(%s, %s, %d, %s, %s); /* %s */"
% (
dest,
version,
obj,
self.emitter.type_struct_name(rtype.class_ir),
rtype.getter_index(op.attr),
rtype.struct_name(self.names),
self.ctype(rtype.attr_type(op.attr)),
op.attr,
if cl.is_method_final(op.attr):
self.emit_method_call(f"{dest} = ", op.obj, op.attr, [])
else:
version = "_TRAIT" if cl.is_trait else ""
self.emit_line(
"%s = CPY_GET_ATTR%s(%s, %s, %d, %s, %s); /* %s */"
% (
dest,
version,
obj,
self.emitter.type_struct_name(rtype.class_ir),
rtype.getter_index(op.attr),
rtype.struct_name(self.names),
self.ctype(rtype.attr_type(op.attr)),
op.attr,
)
)
)
else:
# Otherwise, use direct or offset struct access.
attr_expr = self.get_attr_expr(obj, op, decl_cl)
Expand Down Expand Up @@ -529,11 +533,13 @@ def visit_call(self, op: Call) -> None:
def visit_method_call(self, op: MethodCall) -> None:
"""Call native method."""
dest = self.get_dest_assign(op)
obj = self.reg(op.obj)
self.emit_method_call(dest, op.obj, op.method, op.args)

rtype = op.receiver_type
def emit_method_call(self, dest: str, op_obj: Value, name: str, op_args: list[Value]) -> None:
obj = self.reg(op_obj)
rtype = op_obj.type
assert isinstance(rtype, RInstance)
class_ir = rtype.class_ir
name = op.method
method = rtype.class_ir.get_method(name)
assert method is not None

Expand All @@ -547,7 +553,7 @@ def visit_method_call(self, op: MethodCall) -> None:
if method.decl.kind == FUNC_STATICMETHOD
else [f"(PyObject *)Py_TYPE({obj})"] if method.decl.kind == FUNC_CLASSMETHOD else [obj]
)
args = ", ".join(obj_args + [self.reg(arg) for arg in op.args])
args = ", ".join(obj_args + [self.reg(arg) for arg in op_args])
mtype = native_function_type(method, self.emitter)
version = "_TRAIT" if rtype.class_ir.is_trait else ""
if is_direct:
Expand All @@ -567,7 +573,7 @@ def visit_method_call(self, op: MethodCall) -> None:
rtype.struct_name(self.names),
mtype,
args,
op.method,
name,
)
)

Expand Down
10 changes: 7 additions & 3 deletions mypyc/ir/class_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,15 @@ def __init__(
is_generated: bool = False,
is_abstract: bool = False,
is_ext_class: bool = True,
is_final_class: bool = False,
) -> None:
self.name = name
self.module_name = module_name
self.is_trait = is_trait
self.is_generated = is_generated
self.is_abstract = is_abstract
self.is_ext_class = is_ext_class
self.is_final_class = is_final_class
# An augmented class has additional methods separate from what mypyc generates.
# Right now the only one is dataclasses.
self.is_augmented = False
Expand Down Expand Up @@ -199,7 +201,8 @@ def __repr__(self) -> str:
"ClassIR("
"name={self.name}, module_name={self.module_name}, "
"is_trait={self.is_trait}, is_generated={self.is_generated}, "
"is_abstract={self.is_abstract}, is_ext_class={self.is_ext_class}"
"is_abstract={self.is_abstract}, is_ext_class={self.is_ext_class}, "
"is_final_class={self.is_final_class}"
")".format(self=self)
)

Expand Down Expand Up @@ -248,8 +251,7 @@ def has_method(self, name: str) -> bool:
def is_method_final(self, name: str) -> bool:
subs = self.subclasses()
if subs is None:
# TODO: Look at the final attribute!
return False
return self.is_final_class

if self.has_method(name):
method_decl = self.method_decl(name)
Expand Down Expand Up @@ -349,6 +351,7 @@ def serialize(self) -> JsonDict:
"is_abstract": self.is_abstract,
"is_generated": self.is_generated,
"is_augmented": self.is_augmented,
"is_final_class": self.is_final_class,
"inherits_python": self.inherits_python,
"has_dict": self.has_dict,
"allow_interpreted_subclasses": self.allow_interpreted_subclasses,
Expand Down Expand Up @@ -404,6 +407,7 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> ClassIR:
ir.is_abstract = data["is_abstract"]
ir.is_ext_class = data["is_ext_class"]
ir.is_augmented = data["is_augmented"]
ir.is_final_class = data["is_final_class"]
ir.inherits_python = data["inherits_python"]
ir.has_dict = data["has_dict"]
ir.allow_interpreted_subclasses = data["allow_interpreted_subclasses"]
Expand Down
2 changes: 1 addition & 1 deletion mypyc/ir/rtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class RType:

@abstractmethod
def accept(self, visitor: RTypeVisitor[T]) -> T:
raise NotImplementedError
raise NotImplementedError()

def short_name(self) -> str:
return short_name(self.name)
Expand Down
4 changes: 2 additions & 2 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1889,7 +1889,7 @@ def primitive_op(
# Does this primitive map into calling a Python C API
# or an internal mypyc C API function?
if desc.c_function_name:
# TODO: Generate PrimitiOps here and transform them into CallC
# TODO: Generate PrimitiveOps here and transform them into CallC
# ops only later in the lowering pass
c_desc = CFunctionDescription(
desc.name,
Expand All @@ -1908,7 +1908,7 @@ def primitive_op(
)
return self.call_c(c_desc, args, line, result_type)

# This primitve gets transformed in a lowering pass to
# This primitive gets transformed in a lowering pass to
# lower-level IR ops using a custom transform function.

coerced = []
Expand Down
6 changes: 5 additions & 1 deletion mypyc/irbuild/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,11 @@ def build_type_map(
# references even if there are import cycles.
for module, cdef in classes:
class_ir = ClassIR(
cdef.name, module.fullname, is_trait(cdef), is_abstract=cdef.info.is_abstract
cdef.name,
module.fullname,
is_trait(cdef),
is_abstract=cdef.info.is_abstract,
is_final_class=cdef.info.is_final,
)
class_ir.is_ext_class = is_extension_class(cdef)
if class_ir.is_ext_class:
Expand Down
11 changes: 10 additions & 1 deletion mypyc/irbuild/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,16 @@
UnaryExpr,
Var,
)
from mypy.semanal import refers_to_fullname
from mypy.types import FINAL_DECORATOR_NAMES

DATACLASS_DECORATORS = {"dataclasses.dataclass", "attr.s", "attr.attrs"}


def is_final_decorator(d: Expression) -> bool:
return refers_to_fullname(d, FINAL_DECORATOR_NAMES)


def is_trait_decorator(d: Expression) -> bool:
return isinstance(d, RefExpr) and d.fullname == "mypy_extensions.trait"

Expand Down Expand Up @@ -119,7 +125,10 @@ def get_mypyc_attrs(stmt: ClassDef | Decorator) -> dict[str, Any]:

def is_extension_class(cdef: ClassDef) -> bool:
if any(
not is_trait_decorator(d) and not is_dataclass_decorator(d) and not get_mypyc_attr_call(d)
not is_trait_decorator(d)
and not is_dataclass_decorator(d)
and not get_mypyc_attr_call(d)
and not is_final_decorator(d)
for d in cdef.decorators
):
return False
Expand Down
75 changes: 75 additions & 0 deletions mypyc/test-data/run-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -2519,3 +2519,78 @@ class C:
def test_final_attribute() -> None:
assert C.A == -1
assert C.a == [-1]

[case testClassWithFinalDecorator]
from typing import final

@final
class C:
def a(self) -> int:
return 1

def test_class_final_attribute() -> None:
assert C().a() == 1


[case testClassWithFinalDecoratorCtor]
from typing import final

@final
class C:
def __init__(self) -> None:
self.a = 1

def b(self) -> int:
return 2

@property
def c(self) -> int:
return 3

def test_class_final_attribute() -> None:
assert C().a == 1
assert C().b() == 2
assert C().c == 3

[case testClassWithFinalDecoratorInheritedWithProperties]
from typing import final

class B:
def a(self) -> int:
return 2

@property
def b(self) -> int:
return self.a() + 2

@property
def c(self) -> int:
return 3

def test_class_final_attribute_basic() -> None:
assert B().a() == 2
assert B().b == 4
assert B().c == 3

@final
class C(B):
def a(self) -> int:
return 1

@property
def b(self) -> int:
return self.a() + 1

def fn(cl: B) -> int:
return cl.a()

def test_class_final_attribute_inherited() -> None:
assert C().a() == 1
assert fn(C()) == 1
assert B().a() == 2
assert fn(B()) == 2

assert B().b == 4
assert C().b == 2
assert B().c == 3
assert C().c == 3

0 comments on commit cbd96f9

Please sign in to comment.