Skip to content

Commit

Permalink
Refactoring: xml parsing bound to CustomElement instead of SVG
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Panchenko committed May 29, 2024
1 parent 711d388 commit 2bde1bb
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 78 deletions.
86 changes: 9 additions & 77 deletions src/penai/svg.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from collections import defaultdict
from copy import deepcopy
from enum import Enum
from functools import cached_property
from typing import TYPE_CHECKING, Any, Self, overload
from typing import TYPE_CHECKING, Any, Self

from lxml import etree
from pptree import print_tree

from penai.types import PathLike, RecursiveStrDict
from penai.utils.dict import apply_func_to_nested_keys
from penai.utils.svg import strip_penpot_from_tree, url_to_data_uri, validate_uri
from penai.xml import BetterElement

_CustomElementBaseAnnotationClass: Any = object
if TYPE_CHECKING:
Expand All @@ -23,67 +23,8 @@
_CustomElementBaseAnnotationClass = XMLElement


class BetterElement(etree.ElementBase):
@cached_property
def query_compatible_nsmap(self) -> dict[str, str]:
nsmap = dict(self.nsmap)
nsmap[""] = nsmap.pop(None)
return nsmap

def find(self, path: str, namespaces: dict[str, str] | None = None) -> list[etree._Element]:
return super().find(path, namespaces=namespaces or self.query_compatible_nsmap)

def findall(self, path: str, namespaces: dict[str, str] | None = None) -> list[etree._Element]:
return super().findall(path, namespaces=namespaces or self.query_compatible_nsmap)

def xpath(
self,
path: str,
namespaces: dict[str, str] | None = None,
**kwargs: dict[str, Any],
) -> list[etree._Element]:
return super().xpath(path, namespaces=namespaces or self.query_compatible_nsmap, **kwargs)

@overload
def get_namespaced_key(self, key: str) -> str:
...

@overload
def get_namespaced_key(self, namespace: str, key: str) -> str:
...

# Note: Is there a better way to handle the overload here?
def get_namespaced_key(self, arg1: str, arg2: str | None = None) -> str: # type: ignore[misc]
"""Returns a XML key (tag or attribute name) with the correct namespace.
The key is returned without prefix if no namespace is provided or
no namespace map is attached.
Otherwise the key is returned in the format {namespace}key or a exception
is raised if the namespace can't be found in the namespace map.
"""
if arg2 is None:
key = arg1
namespace = None
else:
namespace, key = arg1, arg2

if not self.nsmap and namespace is None:
return key

if not (namespace_uri := self.nsmap.get(namespace)):
raise ValueError(
f"No namespace with name {namespace}. Known namespaces are {list(self.nsmap)}",
)
return f"{{{namespace_uri}}}{key}"

@cached_property
def localname(self) -> str:
return etree.QName(self).localname


class SVG:
"""A simple wrapper around the `ElementTree` class for now.
"""A simple wrapper around an `ElementTree` that is based on `BetterElement` as nodes in the tree.
In the long-term (lol never), we might extend this a full-fledged SVG implementation.
"""
Expand All @@ -95,7 +36,7 @@ def __init__(self, dom: etree.ElementTree):
def from_root_element(
cls,
element: BetterElement,
namespace_map: dict | None = None,
nsmap: dict | None = None,
svg_attribs: dict[str, str] | None = None,
) -> Self:
"""Create an SVG object from a given root element.
Expand All @@ -107,14 +48,14 @@ def from_root_element(
if not isinstance(element, BetterElement):
raise TypeError(f"Expected an BetterElement, got {type(element)}")

namespace_map = namespace_map or dict(element.nsmap)
namespace_map = {None: "http://www.w3.org/2000/svg", **namespace_map}
nsmap = nsmap or dict(element.nsmap)
nsmap = {None: "http://www.w3.org/2000/svg", **nsmap}

# As recommended in https://lxml.de/tutorial.html, create a deep copy of the element.
element = deepcopy(element)

if element.localname != "svg":
root = BetterElement("svg", nsmap=namespace_map)
root = BetterElement("svg", nsmap=nsmap)
root.append(element)
else:
root = element
Expand All @@ -124,22 +65,13 @@ def from_root_element(

return cls(etree.ElementTree(root))

@staticmethod
def get_parser() -> etree.XMLParser:
parser_lookup = etree.ElementDefaultClassLookup(element=BetterElement)
parser = etree.XMLParser()
parser.set_element_class_lookup(parser_lookup)
return parser

@classmethod
def from_file(cls, path: PathLike) -> Self:
parser = cls.get_parser()
return cls(etree.parse(path, parser))
return cls(dom=BetterElement.parse_file(path))

@classmethod
def from_string(cls, string: str) -> Self:
parser = cls.get_parser()
return cls(etree.ElementTree(etree.fromstring(string, parser)))
return cls(dom=BetterElement.parse_string(string))

def strip_penpot_tags(self) -> None:
"""Strip all Penpot-specific nodes from the SVG tree.
Expand Down
122 changes: 122 additions & 0 deletions src/penai/xml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from functools import cached_property
from typing import TYPE_CHECKING, Any, Self, overload

from lxml import etree
from lxml.etree import ElementBase as Element
from lxml.etree import ElementTree
from overrides import override

from penai.types import PathLike

if TYPE_CHECKING:
# Trick to get proper type hints and docstrings for lxlml etree stuff
# It has the same api as etree, but the latter is in python and properly typed and documented,
# whereas the former is a stub but much faster. So at type-checking time, we use the python version of etree.
from xml import etree
from xml.etree.ElementTree import Element, ElementTree


class CustomElement(Element):
"""Customizing the Element class to allow for custom element classes.
The associated parser and parsing methods will produce trees with elements of this custom class.
Therefore, for such trees all find methods will return (collections of) elements of this class.
"""

nsmap: dict

@classmethod
def get_parser(cls) -> etree.XMLParser:
parser_lookup = etree.ElementDefaultClassLookup(element=cls)
parser = etree.XMLParser()
parser.set_element_class_lookup(parser_lookup)
return parser

@classmethod
def parse_file(cls, path: PathLike) -> ElementTree:
"""Parses an XML file into an ElementTree which contains elements of the custom element class."""
parser = cls.get_parser()
return etree.parse(path, parser)

@classmethod
def parse_string(cls, string: str) -> ElementTree:
parser = cls.get_parser()
return etree.ElementTree(etree.fromstring(string, parser))

def find(self, path: str, namespaces: dict[str, str] | None = None) -> Self | None:
return super().find(path, namespaces=namespaces)

def findall(self, path: str, namespaces: dict[str, str] | None = None) -> list[Self]:
return super().findall(path, namespaces=namespaces)

def xpath(
self,
path: str,
namespaces: dict[str, str] | None = None,
**kwargs: dict[str, Any],
) -> list[Self]:
return super().xpath(path, namespaces=namespaces, **kwargs)


class BetterElement(CustomElement):
"""Simplifies handling of namespaces in ElementTree."""

@cached_property
def query_compatible_nsmap(self) -> dict[str, str]:
nsmap = dict(self.nsmap)
nsmap[""] = nsmap.pop(None)
return nsmap

@override
def find(self, path: str, namespaces: dict[str, str] | None = None) -> Self | None:
return super().find(path, namespaces=namespaces or self.query_compatible_nsmap)

@override
def findall(self, path: str, namespaces: dict[str, str] | None = None) -> list[Self]:
return super().findall(path, namespaces=namespaces or self.query_compatible_nsmap)

@override
def xpath(
self,
path: str,
namespaces: dict[str, str] | None = None,
**kwargs: dict[str, Any],
) -> list[Self]:
return super().xpath(path, namespaces=namespaces or self.query_compatible_nsmap, **kwargs)

@overload
def get_namespaced_key(self, key: str) -> str:
...

@overload
def get_namespaced_key(self, namespace: str, key: str) -> str:
...

# Note: Is there a better way to handle the overload here?
def get_namespaced_key(self, arg1: str, arg2: str | None = None) -> str: # type: ignore[misc]
"""Returns a XML key (tag or attribute name) with the correct namespace.
The key is returned without prefix if no namespace is provided or
no namespace map is attached.
Otherwise the key is returned in the format {namespace}key or a exception
is raised if the namespace can't be found in the namespace map.
"""
if arg2 is None:
key = arg1
namespace = None
else:
namespace, key = arg1, arg2

if not self.nsmap and namespace is None:
return key

if not (namespace_uri := self.nsmap.get(namespace)):
raise ValueError(
f"No namespace with name {namespace}. Known namespaces are {list(self.nsmap)}",
)
return f"{{{namespace_uri}}}{key}"

@cached_property
def localname(self) -> str:
return etree.QName(self).localname
2 changes: 1 addition & 1 deletion test/penai/test_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _test_svg_renderer(
)


class TestChromeSVGRenderers:
class TestSVGRenderers:
@pytest.mark.parametrize("renderer", [ChromeSVGRenderer(), ResvgRenderer()])
def test_rendering_works(
self,
Expand Down

0 comments on commit 2bde1bb

Please sign in to comment.