Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
jmaupetit committed Oct 15, 2024
1 parent 9e4d3be commit ec221f1
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 33 deletions.
6 changes: 3 additions & 3 deletions src/api/qualicharge/auth/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sqlmodel import Field, Relationship, SQLModel

from qualicharge.conf import settings
from qualicharge.schemas import BaseTimestampedSQLModel
from qualicharge.schemas import BaseAuditableSQLModel
from qualicharge.schemas.core import OperationalUnit


Expand Down Expand Up @@ -53,7 +53,7 @@ class ScopesEnum(StrEnum):


# -- Core schemas
class User(BaseTimestampedSQLModel, table=True):
class User(BaseAuditableSQLModel, table=True):
"""QualiCharge User."""

id: UUID = Field(default_factory=uuid4, primary_key=True)
Expand Down Expand Up @@ -103,7 +103,7 @@ def check_password(self, password: str) -> bool:
return settings.PASSWORD_CONTEXT.verify(password, self.password)


class Group(BaseTimestampedSQLModel, table=True):
class Group(BaseAuditableSQLModel, table=True):
"""QualiCharge Group."""

id: UUID = Field(default_factory=uuid4, primary_key=True)
Expand Down
79 changes: 74 additions & 5 deletions src/api/qualicharge/schemas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,35 @@
"""QualiCharge schemas."""

from datetime import datetime, timezone
from typing import Any
from enum import StrEnum
from typing import Any, Optional
from uuid import UUID

from pydantic import PastDatetime
from sqlalchemy import CheckConstraint
from sqlalchemy.types import DateTime
from sqlalchemy import CheckConstraint, insert, inspect
from sqlalchemy.types import DateTime, JSON
from sqlmodel import Field, SQLModel


class BaseTimestampedSQLModel(SQLModel):
"""A base class for SQL models with timestamp fields.
class AuditableFieldBlackListEnum(StrEnum):
"""Fields black listed for auditability."""

PASSWORD = "password"


class BaseAuditableSQLModel(SQLModel):
"""A base class for auditable SQL models.
This class provides two timestamp fields, `created_at` and `updated_at`, which are
automatically managed. The `created_at` field is set to the current UTC time when
a new record is created, and the `updated_at` field is updated to the current UTC
time whenever the record is modified.
The two `created_by_id`, `updated_by_id` foreign keys points to the auth.User model.
Both keys are optional and need to be explicitly set by your code.
To fully track changes, you need to connect the `track_model_changes` utility (see
below) to SQLAlchemy events.
"""

__table_args__: Any = (
Expand All @@ -33,3 +47,58 @@ class BaseTimestampedSQLModel(SQLModel):
default_factory=lambda: datetime.now(timezone.utc),
description="The timestamp indicating when the record was last updated.",
) # type: ignore
created_by_id: Optional[UUID] = Field(default=None, foreign_key="user.id")
updated_by_id: Optional[UUID] = Field(default=None, foreign_key="user.id")


class Audit(SQLModel, table=True):
"""Model changes record for auditability."""

table: str
author_id: UUID = Field(default=None, foreign_key="user.id")
updated_at: PastDatetime = Field(sa_type=DateTime(timezone=True))
changes: dict = Field(sa_type=JSON)


def track_model_changes(mapper, connection, target):
"""Track model changes for auditability.
Models to track are supposed to inherit from the `BaseAuditableSQLModel`. You
should listen to "after_update" events on your model to add full auditability
support, e.g.:
```python
from sqlalchemy import event
from myapp.schemas import MyModel
from ..schemas.audit import track_model_changes
event.listen(MyModel, "after_update", track_model_changes)
```
For each changed field, the previous value is stored along with the modification
date and the author. For fields with sensitive information (_e.g._ passwords or
tokens), a null value is stored.
"""
state = inspect(target)

# Get changes
changes = {}
for attr in state.attrs:
if attr.key in AuditableFieldBlackListEnum:
continue
history = attr.load_history()
if not history.has_changes():
continue
changes[attr.key] = [history.deleted, history.added]

# Log changes
connection.execute(
insert(Audit).values(
table="",
author_id=target.updated_by_id,
updated_at=target.updated_at,
changes=changes,
)
)
28 changes: 14 additions & 14 deletions src/api/qualicharge/schemas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
ImplantationStationEnum,
RaccordementEnum,
)
from . import BaseTimestampedSQLModel
from . import BaseAuditableSQLModel

if TYPE_CHECKING:
from qualicharge.auth.schemas import Group
Expand All @@ -50,10 +50,10 @@ class OperationalUnitTypeEnum(IntEnum):
MOBILITY = 2


class Amenageur(BaseTimestampedSQLModel, table=True):
class Amenageur(BaseAuditableSQLModel, table=True):
"""Amenageur table."""

__table_args__ = BaseTimestampedSQLModel.__table_args__ + (
__table_args__ = BaseAuditableSQLModel.__table_args__ + (
UniqueConstraint("nom_amenageur", "siren_amenageur", "contact_amenageur"),
)

Expand All @@ -73,10 +73,10 @@ def __eq__(self, other) -> bool:
return all(getattr(self, field) == getattr(other, field) for field in fields)


class Operateur(BaseTimestampedSQLModel, table=True):
class Operateur(BaseAuditableSQLModel, table=True):
"""Operateur table."""

__table_args__ = BaseTimestampedSQLModel.__table_args__ + (
__table_args__ = BaseAuditableSQLModel.__table_args__ + (
UniqueConstraint("nom_operateur", "contact_operateur", "telephone_operateur"),
)

Expand All @@ -96,7 +96,7 @@ def __eq__(self, other) -> bool:
return all(getattr(self, field) == getattr(other, field) for field in fields)


class Enseigne(BaseTimestampedSQLModel, table=True):
class Enseigne(BaseAuditableSQLModel, table=True):
"""Enseigne table."""

model_config = SQLModelConfig(validate_assignment=True)
Expand All @@ -113,7 +113,7 @@ def __eq__(self, other) -> bool:
return all(getattr(self, field) == getattr(other, field) for field in fields)


class Localisation(BaseTimestampedSQLModel, table=True):
class Localisation(BaseAuditableSQLModel, table=True):
"""Localisation table."""

model_config = SQLModelConfig(
Expand Down Expand Up @@ -178,7 +178,7 @@ def serialize_coordonneesXY(
return self._wkb_to_coordinates(value)


class OperationalUnit(BaseTimestampedSQLModel, table=True):
class OperationalUnit(BaseAuditableSQLModel, table=True):
"""OperationalUnit table."""

id: UUID = Field(default_factory=uuid4, primary_key=True)
Expand Down Expand Up @@ -217,7 +217,7 @@ def create_stations_fk(self, session: SMSession):
session.commit()


class Station(BaseTimestampedSQLModel, table=True):
class Station(BaseAuditableSQLModel, table=True):
"""Station table."""

model_config = SQLModelConfig(validate_assignment=True)
Expand Down Expand Up @@ -281,7 +281,7 @@ def link_station_to_operational_unit(mapper, connection, target):
target.operational_unit_id = operational_unit.id


class PointDeCharge(BaseTimestampedSQLModel, table=True):
class PointDeCharge(BaseAuditableSQLModel, table=True):
"""PointDeCharge table."""

model_config = SQLModelConfig(validate_assignment=True)
Expand Down Expand Up @@ -322,10 +322,10 @@ def __eq__(self, other) -> bool:
statuses: List["Status"] = Relationship(back_populates="point_de_charge")


class Session(BaseTimestampedSQLModel, SessionBase, table=True):
class Session(BaseAuditableSQLModel, SessionBase, table=True):
"""IRVE recharge session."""

__table_args__ = BaseTimestampedSQLModel.__table_args__ + (
__table_args__ = BaseAuditableSQLModel.__table_args__ + (
{"timescaledb_hypertable": {"time_column_name": "start"}},
)

Expand All @@ -346,10 +346,10 @@ class Session(BaseTimestampedSQLModel, SessionBase, table=True):
point_de_charge: PointDeCharge = Relationship(back_populates="sessions")


class Status(BaseTimestampedSQLModel, StatusBase, table=True):
class Status(BaseAuditableSQLModel, StatusBase, table=True):
"""IRVE recharge session."""

__table_args__ = BaseTimestampedSQLModel.__table_args__ + (
__table_args__ = BaseAuditableSQLModel.__table_args__ + (
{"timescaledb_hypertable": {"time_column_name": "horodatage"}},
)

Expand Down
4 changes: 2 additions & 2 deletions src/api/qualicharge/schemas/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from sqlmodel import Field, Relationship
from sqlmodel.main import SQLModelConfig

from . import BaseTimestampedSQLModel
from . import BaseAuditableSQLModel


class BaseAdministrativeBoundaries(BaseTimestampedSQLModel):
class BaseAdministrativeBoundaries(BaseAuditableSQLModel):
"""Base administrative boundaries model."""

model_config = SQLModelConfig(
Expand Down
18 changes: 9 additions & 9 deletions src/api/qualicharge/schemas/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from ..exceptions import ObjectDoesNotExist, ProgrammingError
from ..models.static import Statique
from . import BaseTimestampedSQLModel
from . import BaseAuditableSQLModel
from .core import (
AccessibilitePMREnum,
Amenageur,
Expand All @@ -45,7 +45,7 @@ def __init__(self, df: pd.DataFrame, connection: Connection):

self._statique: pd.DataFrame = self._fix_enums(df)
self._statique_with_fk: pd.DataFrame = self._statique.copy()
self._saved_schemas: list[type[BaseTimestampedSQLModel]] = []
self._saved_schemas: list[type[BaseAuditableSQLModel]] = []

self._amenageur: Optional[pd.DataFrame] = None
self._enseigne: Optional[pd.DataFrame] = None
Expand All @@ -64,20 +64,20 @@ def __len__(self):

@staticmethod
def _add_timestamped_model_fields(df: pd.DataFrame):
"""Add required fields for a BaseTimestampedSQLModel."""
"""Add required fields for a BaseAuditableSQLModel."""
df["id"] = df.apply(lambda x: uuid.uuid4(), axis=1)
now = pd.Timestamp.now(tz="utc")
df["created_at"] = now
df["updated_at"] = now
return df

@staticmethod
def _schema_fk(schema: type[BaseTimestampedSQLModel]) -> str:
def _schema_fk(schema: type[BaseAuditableSQLModel]) -> str:
"""Get expected schema foreign key name."""
return f"{schema.__table__.name}_id" # type: ignore[attr-defined]

@staticmethod
def _get_schema_fks(schema: type[BaseTimestampedSQLModel]) -> list[str]:
def _get_schema_fks(schema: type[BaseAuditableSQLModel]) -> list[str]:
"""Get foreign key field names from a schema."""
return [
fk.parent.name
Expand All @@ -103,7 +103,7 @@ def _fix_enums(self, df: pd.DataFrame) -> pd.DataFrame:
return df.replace(to_replace=src, value=target)

def _get_fields_for_schema(
self, schema: type[BaseTimestampedSQLModel], with_fk: bool = False
self, schema: type[BaseAuditableSQLModel], with_fk: bool = False
) -> list[str]:
"""Get Statique fields from a core schema."""
fields = list(
Expand All @@ -115,7 +115,7 @@ def _get_fields_for_schema(

def _get_dataframe_for_schema(
self,
schema: type[BaseTimestampedSQLModel],
schema: type[BaseAuditableSQLModel],
subset: Optional[str] = None,
with_fk: bool = False,
):
Expand All @@ -127,7 +127,7 @@ def _get_dataframe_for_schema(
return df

def _add_fk_from_saved_schema(
self, saved: pd.DataFrame, schema: type[BaseTimestampedSQLModel]
self, saved: pd.DataFrame, schema: type[BaseAuditableSQLModel]
):
"""Add foreign keys to the statique DataFrame using saved schema."""
fields = self._get_fields_for_schema(schema)
Expand Down Expand Up @@ -214,7 +214,7 @@ def station(self) -> pd.DataFrame:
def _save_schema(
self,
df: pd.DataFrame,
schema: type[BaseTimestampedSQLModel],
schema: type[BaseAuditableSQLModel],
constraint: Optional[str] = None,
index_elements: Optional[list[str]] = None,
chunksize: int = 1000,
Expand Down

0 comments on commit ec221f1

Please sign in to comment.