Skip to content

Commit

Permalink
fix: Stub generation testing and fixing of miscellaneous bugs (#76)
Browse files Browse the repository at this point in the history
### Summary of Changes

This PR contains the generated stub files from the
[Library](https://github.com/Safe-DS/Library) project and is for testing
those and fixing bugs, if they are found.

- [X] Collection[T] and Sequence[T] are recognized as TypeVars if used
as superclasses
- [x] Mapping[] can be handled if used as superclass
- [x] Omit any class completely that inherits directly or transitively
from ```Exception```
- [x] Rework how reexports are handled for the is_public method
(currently not working correctly)
- [x] Fix stub package names: "The shortest public "reexport" should be
used based on the ```__init__.py``` files."

---------

Co-authored-by: megalinter-bot <[email protected]>
Co-authored-by: Lars Reimann <[email protected]>
  • Loading branch information
3 people authored Mar 12, 2024
1 parent 27659ee commit 97b0ab3
Show file tree
Hide file tree
Showing 60 changed files with 2,233 additions and 985 deletions.
5 changes: 5 additions & 0 deletions src/safeds_stubgen/api_analyzer/_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import json
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum as PythonEnum
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -48,6 +49,8 @@ def __init__(self, distribution: str, package: str, version: str) -> None:
self.attributes_: dict[str, Attribute] = {}
self.parameters_: dict[str, Parameter] = {}

self.reexport_map: dict[str, set[Module]] = defaultdict(set)

def add_module(self, module: Module) -> None:
self.modules[module.id] = module

Expand Down Expand Up @@ -170,6 +173,7 @@ class Class:
docstring: ClassDocstring
constructor: Function | None = None
constructor_fulldocstring: str = ""
inherits_from_exception: bool = False
reexported_by: list[Module] = field(default_factory=list)
attributes: list[Attribute] = field(default_factory=list)
methods: list[Function] = field(default_factory=list)
Expand All @@ -184,6 +188,7 @@ def to_dict(self) -> dict[str, Any]:
"is_public": self.is_public,
"superclasses": self.superclasses,
"constructor": self.constructor.to_dict() if self.constructor is not None else None,
"inherits_from_exception": self.inherits_from_exception,
"reexported_by": [module.id for module in self.reexported_by],
"attributes": [attribute.id for attribute in self.attributes],
"methods": [method.id for method in self.methods],
Expand Down
212 changes: 142 additions & 70 deletions src/safeds_stubgen/api_analyzer/_ast_visitor.py

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion src/safeds_stubgen/api_analyzer/_ast_walker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __walk(self, node: MypyFile | ClassDef | Decorator | FuncDef | AssignmentStm
if isinstance(node, Decorator):
node = node.func

if node in visited_nodes:
if node in visited_nodes: # pragma: no cover
raise AssertionError("Node visited twice")
visited_nodes.add(node)

Expand Down Expand Up @@ -71,6 +71,9 @@ def __walk(self, node: MypyFile | ClassDef | Decorator | FuncDef | AssignmentStm
if isinstance(node, FuncDef) and node.name != "__init__":
continue

if isinstance(child_node, FuncDef) and isinstance(node, FuncDef):
continue

self.__walk(child_node, visited_nodes)
self.__leave(node)

Expand Down
64 changes: 24 additions & 40 deletions src/safeds_stubgen/api_analyzer/_get_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING

import mypy.build as mypy_build
import mypy.main as mypy_main
Expand All @@ -16,6 +16,9 @@
from ._ast_walker import ASTWalker
from ._package_metadata import distribution, distribution_version, package_root

if TYPE_CHECKING:
from pathlib import Path


def get_api(
package_name: str,
Expand Down Expand Up @@ -44,7 +47,7 @@ def get_api(
if file_path.parts[-1] == "__init__.py":
# if a directory contains an __init__.py file it's a package
package_paths.append(
file_path.parent,
str(file_path.parent),
)
continue

Expand All @@ -59,7 +62,7 @@ def get_api(

# Get mypy ast and aliases
build_result = _get_mypy_build(walkable_files)
mypy_asts = _get_mypy_asts(build_result, walkable_files, package_paths, root)
mypy_asts = _get_mypy_asts(build_result, walkable_files, package_paths)
aliases = _get_aliases(build_result.types, package_name)

# Setup api walker
Expand Down Expand Up @@ -91,44 +94,25 @@ def _get_mypy_build(files: list[str]) -> mypy_build.BuildResult:
def _get_mypy_asts(
build_result: mypy_build.BuildResult,
files: list[str],
package_paths: list[Path],
root: Path,
package_paths: list[str],
) -> list[mypy_nodes.MypyFile]:
# Check mypy data key root start
parts = root.parts
graph_keys = list(build_result.graph.keys())
root_start_after = -1
for i in range(len(parts)):
if ".".join(parts[i:]) in graph_keys:
root_start_after = i
break

# Create the keys for getting the corresponding data
packages = [
".".join(
package_path.parts[root_start_after:],
).replace(".py", "")
for package_path in package_paths
]

modules = [
".".join(
Path(file).parts[root_start_after:],
).replace(".py", "")
for file in files
]

# Get the needed data from mypy. The packages need to be checked first, since we have
# to get the reexported data first
all_paths = packages + modules

asts = []
for path_key in all_paths:
tree = build_result.graph[path_key].tree
if tree is not None:
asts.append(tree)

return asts
package_ast = []
module_ast = []
for graph_key in build_result.graph:
ast = build_result.graph[graph_key].tree

if ast is None: # pragma: no cover
raise ValueError

if ast.path.endswith("__init__.py"):
ast_package_path = ast.path.split("__init__.py")[0][:-1]
if ast_package_path in package_paths:
package_ast.append(ast)
elif ast.path in files:
module_ast.append(ast)

# The packages need to be checked first, since we have to get the reexported data first
return package_ast + module_ast


def _get_aliases(result_types: dict, package_name: str) -> dict[str, set[str]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
class AbstractDocstringParser(ABC):
@abstractmethod
def get_class_documentation(self, class_node: nodes.ClassDef) -> ClassDocstring:
pass
pass # pragma: no cover

@abstractmethod
def get_function_documentation(self, function_node: nodes.FuncDef) -> FunctionDocstring:
pass
pass # pragma: no cover

@abstractmethod
def get_parameter_documentation(
Expand All @@ -34,16 +34,16 @@ def get_parameter_documentation(
parameter_assigned_by: ParameterAssignment,
parent_class: Class | None,
) -> ParameterDocstring:
pass
pass # pragma: no cover

@abstractmethod
def get_attribute_documentation(
self,
parent_class: Class,
attribute_name: str,
) -> AttributeDocstring:
pass
pass # pragma: no cover

@abstractmethod
def get_result_documentation(self, function_node: nodes.FuncDef) -> ResultDocstring:
pass
pass # pragma: no cover
3 changes: 2 additions & 1 deletion src/safeds_stubgen/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""The entrypoint to the program."""

from __future__ import annotations

import time
Expand All @@ -16,5 +17,5 @@ def main() -> None:
print(f"Program ran in {time.time() - start_time}s") # noqa: T201


if __name__ == "__main__":
if __name__ == "__main__": # pragma: no cover
main()
74 changes: 65 additions & 9 deletions src/safeds_stubgen/stubs_generator/_generate_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def generate_stubs(api: API, out_path: Path, convert_identifiers: bool) -> None:
naming_convention = NamingConvention.SAFE_DS if convert_identifiers else NamingConvention.PYTHON
stubs_generator = StubsStringGenerator(api, naming_convention)
stubs_data = _generate_stubs_data(api, out_path, stubs_generator)
_generate_stubs_files(stubs_data, api, out_path, stubs_generator, naming_convention)
_generate_stubs_files(stubs_data, out_path, stubs_generator, naming_convention)


def _generate_stubs_data(
Expand Down Expand Up @@ -76,13 +76,10 @@ def _generate_stubs_data(

def _generate_stubs_files(
stubs_data: list[tuple[Path, str, str]],
api: API,
out_path: Path,
stubs_generator: StubsStringGenerator,
naming_convention: NamingConvention,
) -> None:
Path(out_path / api.package).mkdir(parents=True, exist_ok=True)

for module_dir, module_name, module_text in stubs_data:
# Create module dir
module_dir.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -179,7 +176,7 @@ def __call__(self, module: Module) -> str:

def _create_module_string(self, module: Module) -> str:
# Create package info
package_info = module.id.replace("/", ".")
package_info = self._get_shortest_public_reexport()
package_info_camel_case = _convert_name_to_convention(package_info, self.naming_convention)
module_name_info = ""
module_text = ""
Expand All @@ -194,7 +191,7 @@ def _create_module_string(self, module: Module) -> str:

# Create classes, class attr. & class methods
for class_ in module.classes:
if class_.is_public:
if class_.is_public and not class_.inherits_from_exception:
module_text += f"\n{self._create_class_string(class_)}\n"

# Create enums & enum instances
Expand Down Expand Up @@ -280,14 +277,18 @@ def _create_class_string(self, class_: Class, class_indentation: str = "") -> st
# Type parameters
constraints_info = ""
variance_info = ""
if class_.type_parameters:

constructor_type_vars = None
if class_.constructor:
constructor_type_vars = class_.constructor.type_var_types

if class_.type_parameters or constructor_type_vars:
# We collect the class generics for the methods later
self.class_generics = []
out = "out "
for variance in class_.type_parameters:
variance_direction = {
VarianceKind.INVARIANT.name: "",
VarianceKind.COVARIANT.name: out,
VarianceKind.COVARIANT.name: "out ",
VarianceKind.CONTRAVARIANT.name: "in ",
}[variance.variance.name]

Expand All @@ -300,6 +301,11 @@ def _create_class_string(self, class_: Class, class_indentation: str = "") -> st
variance_item = f"{variance_item} sub {self._create_type_string(variance.type.to_dict())}"
self.class_generics.append(variance_item)

if constructor_type_vars:
for constructor_type_var in constructor_type_vars:
if constructor_type_var.name not in self.class_generics:
self.class_generics.append(constructor_type_var.name)

if self.class_generics:
variance_info = f"<{', '.join(self.class_generics)}>"

Expand Down Expand Up @@ -729,6 +735,8 @@ def _create_type_string(self, type_data: dict | None) -> str:
if len(types) == 2 and none_type_name in types:
# if None is at least one of the two possible types, we can remove the None and just return the
# other type with a question mark
if types[0] == none_type_name:
return f"{types[1]}?"
return f"{types[0]}?"

# If the union contains only one type, return the type instead of creating a union
Expand Down Expand Up @@ -861,6 +869,54 @@ def _get_class_in_module(self, class_name: str) -> Class:

raise LookupError(f"Expected finding class '{class_name}' in module '{self.module.id}'.") # pragma: no cover

def _get_shortest_public_reexport(self) -> str:
module_qname = self.module.id.replace("/", ".")
module_name = self.module.name
reexports = self.api.reexport_map

def _module_name_check(name: str, string: str) -> bool:
return (
string == name
or (f".{name}" in string and (string.endswith(f".{name}") or f"{name}." in string))
or (f"{name}." in string and (string.startswith(f"{name}.") or f".{name}" in string))
)

keys = [reexport_key for reexport_key in reexports if _module_name_check(module_name, reexport_key)]

module_ids = set()
for key in keys:
for module in reexports[key]:
added_module_id = False

for qualified_import in module.qualified_imports:
if _module_name_check(module_name, qualified_import.qualified_name):
module_ids.add(module.id)
added_module_id = True
break

if added_module_id:
continue

for wildcard_import in module.wildcard_imports:
if _module_name_check(module_name, wildcard_import.module_name):
module_ids.add(module.id)
break

# Adjust all ids
fixed_module_ids_parts = [module_id.split("/") for module_id in module_ids]

shortest_id = None
for fixed_module_id_parts in fixed_module_ids_parts:
if shortest_id is None or len(fixed_module_id_parts) < len(shortest_id):
shortest_id = fixed_module_id_parts

if len(shortest_id) == 1:
break

if shortest_id is None:
return module_qname
return ".".join(shortest_id)


def _callable_type_name_generator() -> Generator:
"""Generate a name for callable type parameters starting from 'a' until 'zz'."""
Expand Down
8 changes: 8 additions & 0 deletions tests/data/various_modules_package/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,22 @@
from ._reexport_module_3 import *
from ._reexport_module_4 import FourthReexportClass
from ._reexport_module_4 import _reexported_function_4
from ._reexport_module_4 import _reexported_function_4_alias as reexported_function_4_alias
from ._reexport_module_4 import _two_times_reexported
from ._reexport_module_4 import _two_times_reexported as two_times_reexported
from .enum_module import _ReexportedEmptyEnum
from file_creation._module_3 import Reexported

__all__ = [
"reex_1",
"ReexportClass",
"reexported_function_2",
"reexported_function_3",
"_reexported_function_4",
"reexported_function_4_alias",
"_two_times_reexported",
"two_times_reexported",
"FourthReexportClass",
"_ReexportedEmptyEnum",
"Reexported",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from tests.data.various_modules_package._reexport_module_5 import reexported_in_an_internal_package
8 changes: 8 additions & 0 deletions tests/data/various_modules_package/_reexport_module_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,11 @@ def _unreexported_function() -> None:

def _reexported_function_4() -> None:
pass


def _reexported_function_4_alias() -> None:
pass


def _two_times_reexported() -> None:
pass
1 change: 1 addition & 0 deletions tests/data/various_modules_package/_reexport_module_5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
def reexported_in_an_internal_package(): ...
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,6 @@ class AliasingModuleClassC(_some_alias_a):
typed_alias_attr2: AliasModule2
infer_alias_attr2 = AliasModule2

infer_alias_attr3 = _some_alias_a

alias_list: list[_some_alias_a | some_alias_b, AliasModule2, ImportMeAlias]
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
def reexported_in_another_package_function() -> str: ...


class ReexportedInAnotherPackageClass:
pass


def _still_internal_function(): ...
Loading

0 comments on commit 97b0ab3

Please sign in to comment.