From ec221f1948cb864077b238515843adfbb2358704 Mon Sep 17 00:00:00 2001 From: Julien Maupetit Date: Tue, 15 Oct 2024 18:44:28 +0200 Subject: [PATCH] WIP --- src/api/qualicharge/auth/schemas.py | 6 +- src/api/qualicharge/schemas/__init__.py | 79 +++++++++++++++++++++++-- src/api/qualicharge/schemas/core.py | 28 ++++----- src/api/qualicharge/schemas/geo.py | 4 +- src/api/qualicharge/schemas/sql.py | 18 +++--- 5 files changed, 102 insertions(+), 33 deletions(-) diff --git a/src/api/qualicharge/auth/schemas.py b/src/api/qualicharge/auth/schemas.py index a72d7a36..cfbaf31a 100644 --- a/src/api/qualicharge/auth/schemas.py +++ b/src/api/qualicharge/auth/schemas.py @@ -9,7 +9,7 @@ from sqlmodel import Field, Relationship, SQLModel from qualicharge.conf import settings -from qualicharge.schemas import BaseTimestampedSQLModel +from qualicharge.schemas import BaseAuditableSQLModel from qualicharge.schemas.core import OperationalUnit @@ -53,7 +53,7 @@ class ScopesEnum(StrEnum): # -- Core schemas -class User(BaseTimestampedSQLModel, table=True): +class User(BaseAuditableSQLModel, table=True): """QualiCharge User.""" id: UUID = Field(default_factory=uuid4, primary_key=True) @@ -103,7 +103,7 @@ def check_password(self, password: str) -> bool: return settings.PASSWORD_CONTEXT.verify(password, self.password) -class Group(BaseTimestampedSQLModel, table=True): +class Group(BaseAuditableSQLModel, table=True): """QualiCharge Group.""" id: UUID = Field(default_factory=uuid4, primary_key=True) diff --git a/src/api/qualicharge/schemas/__init__.py b/src/api/qualicharge/schemas/__init__.py index 4290e078..a014d2f4 100644 --- a/src/api/qualicharge/schemas/__init__.py +++ b/src/api/qualicharge/schemas/__init__.py @@ -1,21 +1,35 @@ """QualiCharge schemas.""" from datetime import datetime, timezone -from typing import Any +from enum import StrEnum +from typing import Any, Optional +from uuid import UUID from pydantic import PastDatetime -from sqlalchemy import CheckConstraint -from sqlalchemy.types import DateTime +from sqlalchemy import CheckConstraint, insert, inspect +from sqlalchemy.types import DateTime, JSON from sqlmodel import Field, SQLModel -class BaseTimestampedSQLModel(SQLModel): - """A base class for SQL models with timestamp fields. +class AuditableFieldBlackListEnum(StrEnum): + """Fields black listed for auditability.""" + + PASSWORD = "password" + + +class BaseAuditableSQLModel(SQLModel): + """A base class for auditable SQL models. This class provides two timestamp fields, `created_at` and `updated_at`, which are automatically managed. The `created_at` field is set to the current UTC time when a new record is created, and the `updated_at` field is updated to the current UTC time whenever the record is modified. + + The two `created_by_id`, `updated_by_id` foreign keys points to the auth.User model. + Both keys are optional and need to be explicitly set by your code. + + To fully track changes, you need to connect the `track_model_changes` utility (see + below) to SQLAlchemy events. """ __table_args__: Any = ( @@ -33,3 +47,58 @@ class BaseTimestampedSQLModel(SQLModel): default_factory=lambda: datetime.now(timezone.utc), description="The timestamp indicating when the record was last updated.", ) # type: ignore + created_by_id: Optional[UUID] = Field(default=None, foreign_key="user.id") + updated_by_id: Optional[UUID] = Field(default=None, foreign_key="user.id") + + +class Audit(SQLModel, table=True): + """Model changes record for auditability.""" + + table: str + author_id: UUID = Field(default=None, foreign_key="user.id") + updated_at: PastDatetime = Field(sa_type=DateTime(timezone=True)) + changes: dict = Field(sa_type=JSON) + + +def track_model_changes(mapper, connection, target): + """Track model changes for auditability. + + Models to track are supposed to inherit from the `BaseAuditableSQLModel`. You + should listen to "after_update" events on your model to add full auditability + support, e.g.: + + ```python + from sqlalchemy import event + + from myapp.schemas import MyModel + from ..schemas.audit import track_model_changes + + + event.listen(MyModel, "after_update", track_model_changes) + ``` + + For each changed field, the previous value is stored along with the modification + date and the author. For fields with sensitive information (_e.g._ passwords or + tokens), a null value is stored. + """ + state = inspect(target) + + # Get changes + changes = {} + for attr in state.attrs: + if attr.key in AuditableFieldBlackListEnum: + continue + history = attr.load_history() + if not history.has_changes(): + continue + changes[attr.key] = [history.deleted, history.added] + + # Log changes + connection.execute( + insert(Audit).values( + table="", + author_id=target.updated_by_id, + updated_at=target.updated_at, + changes=changes, + ) + ) diff --git a/src/api/qualicharge/schemas/core.py b/src/api/qualicharge/schemas/core.py index e6272f11..55418f33 100644 --- a/src/api/qualicharge/schemas/core.py +++ b/src/api/qualicharge/schemas/core.py @@ -37,7 +37,7 @@ ImplantationStationEnum, RaccordementEnum, ) -from . import BaseTimestampedSQLModel +from . import BaseAuditableSQLModel if TYPE_CHECKING: from qualicharge.auth.schemas import Group @@ -50,10 +50,10 @@ class OperationalUnitTypeEnum(IntEnum): MOBILITY = 2 -class Amenageur(BaseTimestampedSQLModel, table=True): +class Amenageur(BaseAuditableSQLModel, table=True): """Amenageur table.""" - __table_args__ = BaseTimestampedSQLModel.__table_args__ + ( + __table_args__ = BaseAuditableSQLModel.__table_args__ + ( UniqueConstraint("nom_amenageur", "siren_amenageur", "contact_amenageur"), ) @@ -73,10 +73,10 @@ def __eq__(self, other) -> bool: return all(getattr(self, field) == getattr(other, field) for field in fields) -class Operateur(BaseTimestampedSQLModel, table=True): +class Operateur(BaseAuditableSQLModel, table=True): """Operateur table.""" - __table_args__ = BaseTimestampedSQLModel.__table_args__ + ( + __table_args__ = BaseAuditableSQLModel.__table_args__ + ( UniqueConstraint("nom_operateur", "contact_operateur", "telephone_operateur"), ) @@ -96,7 +96,7 @@ def __eq__(self, other) -> bool: return all(getattr(self, field) == getattr(other, field) for field in fields) -class Enseigne(BaseTimestampedSQLModel, table=True): +class Enseigne(BaseAuditableSQLModel, table=True): """Enseigne table.""" model_config = SQLModelConfig(validate_assignment=True) @@ -113,7 +113,7 @@ def __eq__(self, other) -> bool: return all(getattr(self, field) == getattr(other, field) for field in fields) -class Localisation(BaseTimestampedSQLModel, table=True): +class Localisation(BaseAuditableSQLModel, table=True): """Localisation table.""" model_config = SQLModelConfig( @@ -178,7 +178,7 @@ def serialize_coordonneesXY( return self._wkb_to_coordinates(value) -class OperationalUnit(BaseTimestampedSQLModel, table=True): +class OperationalUnit(BaseAuditableSQLModel, table=True): """OperationalUnit table.""" id: UUID = Field(default_factory=uuid4, primary_key=True) @@ -217,7 +217,7 @@ def create_stations_fk(self, session: SMSession): session.commit() -class Station(BaseTimestampedSQLModel, table=True): +class Station(BaseAuditableSQLModel, table=True): """Station table.""" model_config = SQLModelConfig(validate_assignment=True) @@ -281,7 +281,7 @@ def link_station_to_operational_unit(mapper, connection, target): target.operational_unit_id = operational_unit.id -class PointDeCharge(BaseTimestampedSQLModel, table=True): +class PointDeCharge(BaseAuditableSQLModel, table=True): """PointDeCharge table.""" model_config = SQLModelConfig(validate_assignment=True) @@ -322,10 +322,10 @@ def __eq__(self, other) -> bool: statuses: List["Status"] = Relationship(back_populates="point_de_charge") -class Session(BaseTimestampedSQLModel, SessionBase, table=True): +class Session(BaseAuditableSQLModel, SessionBase, table=True): """IRVE recharge session.""" - __table_args__ = BaseTimestampedSQLModel.__table_args__ + ( + __table_args__ = BaseAuditableSQLModel.__table_args__ + ( {"timescaledb_hypertable": {"time_column_name": "start"}}, ) @@ -346,10 +346,10 @@ class Session(BaseTimestampedSQLModel, SessionBase, table=True): point_de_charge: PointDeCharge = Relationship(back_populates="sessions") -class Status(BaseTimestampedSQLModel, StatusBase, table=True): +class Status(BaseAuditableSQLModel, StatusBase, table=True): """IRVE recharge session.""" - __table_args__ = BaseTimestampedSQLModel.__table_args__ + ( + __table_args__ = BaseAuditableSQLModel.__table_args__ + ( {"timescaledb_hypertable": {"time_column_name": "horodatage"}}, ) diff --git a/src/api/qualicharge/schemas/geo.py b/src/api/qualicharge/schemas/geo.py index 50fcc024..8d46be42 100644 --- a/src/api/qualicharge/schemas/geo.py +++ b/src/api/qualicharge/schemas/geo.py @@ -7,10 +7,10 @@ from sqlmodel import Field, Relationship from sqlmodel.main import SQLModelConfig -from . import BaseTimestampedSQLModel +from . import BaseAuditableSQLModel -class BaseAdministrativeBoundaries(BaseTimestampedSQLModel): +class BaseAdministrativeBoundaries(BaseAuditableSQLModel): """Base administrative boundaries model.""" model_config = SQLModelConfig( diff --git a/src/api/qualicharge/schemas/sql.py b/src/api/qualicharge/schemas/sql.py index 674669de..e7d61c59 100644 --- a/src/api/qualicharge/schemas/sql.py +++ b/src/api/qualicharge/schemas/sql.py @@ -19,7 +19,7 @@ from ..exceptions import ObjectDoesNotExist, ProgrammingError from ..models.static import Statique -from . import BaseTimestampedSQLModel +from . import BaseAuditableSQLModel from .core import ( AccessibilitePMREnum, Amenageur, @@ -45,7 +45,7 @@ def __init__(self, df: pd.DataFrame, connection: Connection): self._statique: pd.DataFrame = self._fix_enums(df) self._statique_with_fk: pd.DataFrame = self._statique.copy() - self._saved_schemas: list[type[BaseTimestampedSQLModel]] = [] + self._saved_schemas: list[type[BaseAuditableSQLModel]] = [] self._amenageur: Optional[pd.DataFrame] = None self._enseigne: Optional[pd.DataFrame] = None @@ -64,7 +64,7 @@ def __len__(self): @staticmethod def _add_timestamped_model_fields(df: pd.DataFrame): - """Add required fields for a BaseTimestampedSQLModel.""" + """Add required fields for a BaseAuditableSQLModel.""" df["id"] = df.apply(lambda x: uuid.uuid4(), axis=1) now = pd.Timestamp.now(tz="utc") df["created_at"] = now @@ -72,12 +72,12 @@ def _add_timestamped_model_fields(df: pd.DataFrame): return df @staticmethod - def _schema_fk(schema: type[BaseTimestampedSQLModel]) -> str: + def _schema_fk(schema: type[BaseAuditableSQLModel]) -> str: """Get expected schema foreign key name.""" return f"{schema.__table__.name}_id" # type: ignore[attr-defined] @staticmethod - def _get_schema_fks(schema: type[BaseTimestampedSQLModel]) -> list[str]: + def _get_schema_fks(schema: type[BaseAuditableSQLModel]) -> list[str]: """Get foreign key field names from a schema.""" return [ fk.parent.name @@ -103,7 +103,7 @@ def _fix_enums(self, df: pd.DataFrame) -> pd.DataFrame: return df.replace(to_replace=src, value=target) def _get_fields_for_schema( - self, schema: type[BaseTimestampedSQLModel], with_fk: bool = False + self, schema: type[BaseAuditableSQLModel], with_fk: bool = False ) -> list[str]: """Get Statique fields from a core schema.""" fields = list( @@ -115,7 +115,7 @@ def _get_fields_for_schema( def _get_dataframe_for_schema( self, - schema: type[BaseTimestampedSQLModel], + schema: type[BaseAuditableSQLModel], subset: Optional[str] = None, with_fk: bool = False, ): @@ -127,7 +127,7 @@ def _get_dataframe_for_schema( return df def _add_fk_from_saved_schema( - self, saved: pd.DataFrame, schema: type[BaseTimestampedSQLModel] + self, saved: pd.DataFrame, schema: type[BaseAuditableSQLModel] ): """Add foreign keys to the statique DataFrame using saved schema.""" fields = self._get_fields_for_schema(schema) @@ -214,7 +214,7 @@ def station(self) -> pd.DataFrame: def _save_schema( self, df: pd.DataFrame, - schema: type[BaseTimestampedSQLModel], + schema: type[BaseAuditableSQLModel], constraint: Optional[str] = None, index_elements: Optional[list[str]] = None, chunksize: int = 1000,