Skip to content

Commit

Permalink
query and find methods + proper loading of templates from db (#124)
Browse files Browse the repository at this point in the history
  • Loading branch information
viseshrp authored Sep 13, 2024
1 parent 36db6dd commit 95d63de
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 68 deletions.
27 changes: 18 additions & 9 deletions src/ansys/dynamicreporting/core/serverless/adr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pathlib import Path
import platform
import sys
from typing import Any, Optional, Type
from typing import Any, Optional, Type, Union

import django
from django.core import management
Expand Down Expand Up @@ -32,6 +32,7 @@ def __init__(
opts: dict = None,
request: HttpRequest = None,
logfile: str = None,
debug: bool = None,
) -> None:
self._db_directory = None
self._media_directory = None
Expand Down Expand Up @@ -68,6 +69,13 @@ def __init__(
if "CEI_NEXUS_LOCAL_STATIC_DIR" in os.environ:
self._static_directory = self._check_dir(os.environ["CEI_NEXUS_LOCAL_STATIC_DIR"])

if debug is not None:
self._debug = debug
os.environ["CEI_NEXUS_DEBUG"] = str(int(debug))
else:
if "CEI_NEXUS_DEBUG" in os.environ:
self._debug = bool(int(os.environ["CEI_NEXUS_DEBUG"]))

self._request = request # passed when used in the context of a webserver.
self._session = None
self._dataset = None
Expand Down Expand Up @@ -243,11 +251,12 @@ def render_report(self, context=None, query=None, **kwargs):
self._logger.error(f"{e}")
raise e

def query(self, query_type: str = Item, filter: Optional[str] = "") -> list:
...

def create(self, objects: list) -> None:
...

def delete(self, objects: list) -> None:
...
def query(
self,
query_type: Union[Session, Dataset, Type[Item], Type[Template]],
filter: Optional[str] = "",
) -> list:
if not issubclass(query_type, (Item, Template, Session, Dataset)):
self._logger.error(f"{query_type} is not valid")
raise TypeError(f"{query_type} is not valid")
return list(query_type.find(query=filter))
109 changes: 92 additions & 17 deletions src/ansys/dynamicreporting/core/serverless/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from abc import ABC, ABCMeta, abstractmethod
from collections.abc import Iterable
from dataclasses import dataclass, field
from dataclasses import fields as dataclass_fields
import importlib
import inspect
from itertools import chain
import shlex
from typing import Any
from typing import Any, get_args, get_origin
import uuid
from uuid import UUID

Expand All @@ -19,6 +20,7 @@
from django.db import DatabaseError
from django.db.models import Model, QuerySet
from django.db.models.base import subclass_exception
from django.db.models.manager import Manager

from ..exceptions import (
ADRException,
Expand Down Expand Up @@ -46,6 +48,10 @@ def wrapper(*args, **kwargs):
return wrapper


def is_generic_class(cls):
return not isinstance(cls, type) or get_origin(cls) is not None


class BaseMeta(ABCMeta):
_cls_registry: dict[str, type["BaseModel"]] = {}
_model_cls_registry: dict[str, type[Model]] = {}
Expand Down Expand Up @@ -138,15 +144,39 @@ def __post_init__(self):

def _validate_field_types(self):
for field_name, field_type in self._get_field_names(with_types=True):
value = getattr(self, field_name, None)
if value is None:
continue
# Type inference
# convert strings to classes
if isinstance(field_type, str):
type_ = self.__class__._cls_registry[field_type]
type_cls = self.__class__._cls_registry[field_type]
else:
type_ = field_type
if issubclass(type_, Validator):
type_cls = field_type
# Validators will validate by themselves, so this can be ignored.
# Will only work when the type is a proper class
if not is_generic_class(type_cls) and issubclass(type_cls, Validator):
continue
value = getattr(self, field_name, None)
if value is not None and not isinstance(value, type_):
raise TypeError(f"Expected {field_name} to be of type {type_}.")
# 'Generic' class types
if get_origin(type_cls) is not None:
# get any args
args = get_args(type_cls)
# update with the origin type
type_cls = get_origin(type_cls)
# validate with the 'arg' type:
# eg: 'Template' in list['Template']
if args:
content_type = args[0]
if isinstance(content_type, str):
content_type = self.__class__._cls_registry[content_type]
if isinstance(value, Iterable):
for elem in value:
if not isinstance(elem, content_type):
raise TypeError(
f"Expected '{field_name}' to contain items of type '{content_type}'."
)
if not isinstance(value, type_cls):
raise TypeError(f"Expected '{field_name}' to be of type '{type_cls}'.")

@staticmethod
def _add_quotes(input_str):
Expand Down Expand Up @@ -189,7 +219,7 @@ def _get_all_field_names(cls):
return tuple(property_fields) + cls._get_field_names()

@classmethod
def serialize_from_orm(cls, orm_instance):
def from_db(cls, orm_instance, parent=None):
cls_fields = dict(cls._get_field_names(with_types=True, include_private=True))
model_fields = cls._get_orm_field_names(orm_instance)
obj = cls()
Expand All @@ -201,12 +231,49 @@ def serialize_from_orm(cls, orm_instance):
attr = f"_{field_}"
else:
continue
value = getattr(orm_instance, field_, None)
# don't check for None here, we need everything as-is
value = getattr(orm_instance, field_, None)
field_type = cls_fields[attr]
# We must also serialize 'related' fields
if isinstance(value, Model):
type_ = cls_fields[attr]
value = type_.serialize_from_orm(value)
# convert the value to a type supported by the proxy
# for string definitions of the dataclass type, example - parent: 'Template'
if isinstance(field_type, str):
type_ = cls._cls_registry[field_type]
else:
type_ = field_type
if issubclass(type_, cls):
value = parent
else:
value = type_.from_db(value)
elif isinstance(value, Manager):
type_ = get_origin(field_type)
args = get_args(field_type)
if type_ is None or not issubclass(type_, Iterable) or len(args) != 1:
raise TypeError(
f"The field '{attr}' in the dataclass must be a generic iterable"
f" class containing exactly one type argument. For example: "
f"list['Template'] or tuple['Template']."
)
content_type = args[0]
if isinstance(content_type, str):
content_type = cls._cls_registry[content_type]
qs = value.all()
# content_type must match orm model class
if content_type._orm_model_cls != qs.model:
raise TypeError(
f"The field '{attr}' is of '{field_type}' but the "
f"actual content is of type '{qs.model}'"
)
if qs:
obj_set = ObjectSet(
_model=content_type, _orm_model=qs.model, _orm_queryset=qs, _parent=obj
)
value = type_(obj_set)
else:
value = type_()

# set the orm value on the proxy object
setattr(obj, attr, value)

obj._orm_instance = orm_instance
Expand Down Expand Up @@ -262,14 +329,20 @@ def get(cls, **kwargs):
except MultipleObjectsReturned:
raise cls.MultipleObjectsReturned

return cls.serialize_from_orm(orm_instance)
return cls.from_db(orm_instance)

@classmethod
@handle_field_errors
def filter(cls, **kwargs):
qs = cls._orm_model_cls.objects.filter(**kwargs)
return ObjectSet(_model=cls, _orm_model=cls._orm_model_cls, _orm_queryset=qs)

@classmethod
@handle_field_errors
def find(cls, query="", reverse=False, sort_tag="date"):
qs = cls._orm_model_cls.find(query=query, reverse=reverse, sort_tag=sort_tag)
return ObjectSet(_model=cls, _orm_model=cls._orm_model_cls, _orm_queryset=qs)

def get_tags(self):
return self.tags

Expand Down Expand Up @@ -303,13 +376,15 @@ class ObjectSet:
_saved: bool = field(init=False, compare=False, default=False)
_orm_model: type[Model] = field(compare=False, default=None)
_orm_queryset: QuerySet = field(compare=False, default=None)
_parent: BaseModel = field(compare=False, default=None)

def __post_init__(self):
if self._orm_queryset is not None:
self._saved = True
self._obj_set = [
self._model.serialize_from_orm(instance) for instance in self._orm_queryset
]
if self._orm_queryset is None:
return
self._saved = True
self._obj_set = [
self._model.from_db(instance, parent=self._parent) for instance in self._orm_queryset
]

def __repr__(self):
return f"<{self.__class__.__name__} {self._obj_set}>"
Expand Down
35 changes: 23 additions & 12 deletions src/ansys/dynamicreporting/core/serverless/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def process(self, value, obj):

# check file type
file_ext = Path(file.name).suffix.lower()
if file_ext.replace(".", "") not in self.ALLOWED_EXT:
if self.ALLOWED_EXT is not None and file_ext.replace(".", "") not in self.ALLOWED_EXT:
raise ValueError(f"File type {file_ext} is not supported by {obj.__class__}")
# check for empty files
if file.size == 0:
Expand Down Expand Up @@ -169,15 +169,15 @@ class SceneContent(FileValidator):


class FileContent(FileValidator):
ALLOWED_EXT = ("ens", "enc", "evsn")
ALLOWED_EXT = None


class SimplePayloadMixin:
@classmethod
def serialize_from_orm(cls, orm_instance):
def from_db(cls, orm_instance, **kwargs):
from data.extremely_ugly_hacks import safe_unpickle

obj = super().serialize_from_orm(orm_instance)
obj = super().from_db(orm_instance)
obj.content = safe_unpickle(obj._orm_instance.payloaddata)
return obj

Expand All @@ -190,8 +190,8 @@ class FilePayloadMixin:
_file: DjangoFile = field(init=False, compare=False, default=None)

@classmethod
def serialize_from_orm(cls, orm_instance):
obj = super().serialize_from_orm(orm_instance)
def from_db(cls, orm_instance, **kwargs):
obj = super().from_db(orm_instance)
obj.content = obj._orm_instance.payloadfile.path
return obj

Expand All @@ -205,7 +205,6 @@ def save(self, **kwargs):
super().save(**kwargs)


# todo: prevent instantiation
class Item(BaseModel):
name: str = field(compare=False, kw_only=True, default="")
date: datetime = field(compare=False, kw_only=True, default_factory=timezone.now)
Expand Down Expand Up @@ -237,15 +236,27 @@ def delete(self, **kwargs):
delete_item_media(self._orm_instance.guid)
return super().delete(**kwargs)

@classmethod
def get(cls, **kwargs):
new_kwargs = {"type": cls.type, **kwargs} if cls.type != "none" else kwargs
return super().get(**new_kwargs)

@classmethod
def filter(cls, **kwargs):
new_kwargs = {"type": cls.type, **kwargs} if cls.type != "none" else kwargs
return super().filter(**new_kwargs)

@classmethod
def get(cls, **kwargs):
new_kwargs = {"type": cls.type, **kwargs} if cls.type != "none" else kwargs
return super().get(**new_kwargs)
def find(cls, **kwargs):
if cls.type == "none":
return super().find(**kwargs)
query = kwargs.pop("query", "")
if "i_type|cont" in query:
raise ADRException(
extra_detail="The 'i_type' filter is not required if using a subclass of Item"
)
new_kwargs = {**kwargs, "query": f"A|i_type|cont|{cls.type};{query}"}
return super().find(**new_kwargs)

def render(self, context=None, request=None) -> Optional[str]:
if context is None:
Expand Down Expand Up @@ -281,10 +292,10 @@ class Table(Item):
_properties: tuple = table_attr

@classmethod
def serialize_from_orm(cls, orm_instance):
def from_db(cls, orm_instance, **kwargs):
from data.extremely_ugly_hacks import safe_unpickle

obj = super().serialize_from_orm(orm_instance)
obj = super().from_db(orm_instance)
payload = safe_unpickle(obj._orm_instance.payloaddata)
obj.content = payload.pop("array", None)
for prop in cls._properties:
Expand Down
Loading

0 comments on commit 95d63de

Please sign in to comment.