diff --git a/docs/changelog/index.md b/docs/changelog/index.md index ec348389..f49f80c7 100644 --- a/docs/changelog/index.md +++ b/docs/changelog/index.md @@ -9,15 +9,17 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ### Added -* Row actions by [@jowilf](https://github.com/jowilf) in [#348](https://github.com/jowilf/starlette-admin/pull/348) +* Add Before and After Hooks for Create, Edit, and Delete Operations by [@jowilf](https://github.com/jowilf) + in [#327](https://github.com/jowilf/starlette-admin/pull/327) +* Row actions by [@jowilf](https://github.com/jowilf) in [#348](https://github.com/jowilf/starlette-admin/pull/348) * Add Support for Custom Sortable Field Mapping in SQLAlchemy ModelView by [@jowilf](https://github.com/jowilf) in [#328](https://github.com/jowilf/starlette-admin/pull/328) -??? usage +???+ usage - ```python + ```python hl_lines="12" class Post(Base): __tablename__ = "post" @@ -35,7 +37,7 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). * Add support for datatables [state saving](https://datatables.net/examples/basic_init/state_save.html) -??? usage +???+ usage ```python class MyModelView(ModelView): diff --git a/docs/tutorial/configurations/modelview/index.md b/docs/tutorial/configurations/modelview/index.md index 89340cae..1ecbe15a 100644 --- a/docs/tutorial/configurations/modelview/index.md +++ b/docs/tutorial/configurations/modelview/index.md @@ -204,3 +204,32 @@ fields. ``` ![Custom Select2 rendering](../../../images/tutorial/configurations/modelview/select2_customization.png){ width="300" } + +## Hooks + +Hooks are callback functions that give you an easy way to customize and extend the default CRUD functions. You can use +hooks to perform actions before or after specific operations such as item creation, editing, or deletion. + +The following hooks are available: + +- [before_create(request, data, obj)][starlette_admin.views.BaseModelView.before_create]: Called before a new object is + created + +- [after_create(request, obj)][starlette_admin.views.BaseModelView.after_create]: Called after a new object is created + +- [before_edit(request, data, obj)][starlette_admin.views.BaseModelView.before_edit]: Called before an existing object is + updated + +- [after_edit(request, obj)][starlette_admin.views.BaseModelView.after_edit]: Called after an existing object is updated + +- [before_delete(request, obj)][starlette_admin.views.BaseModelView.before_delete]: Called before an object is deleted + +- [after_delete(request, obj)][starlette_admin.views.BaseModelView.after_delete]: Called after an object is deleted + +### Example + +```python +class OrderView(ModelView): + async def after_create(self, request: Request, order: Order): + analytics.track_order_created(order) +``` diff --git a/starlette_admin/contrib/mongoengine/view.py b/starlette_admin/contrib/mongoengine/view.py index a2abe2d4..f7ee5c43 100644 --- a/starlette_admin/contrib/mongoengine/view.py +++ b/starlette_admin/contrib/mongoengine/view.py @@ -96,16 +96,24 @@ async def find_by_pks( ) -> Sequence[me.Document]: return self.document.objects(id__in=pks) - async def create(self, request: Request, data: Dict[str, Any]) -> None: + async def create(self, request: Request, data: Dict[str, Any]) -> Any: try: - return (await self._populate_obj(request, self.document(), data)).save() + obj = await self._populate_obj(request, self.document(), data) + await self.before_create(request, data, obj) + obj.save() + await self.after_create(request, obj) + return obj except Exception as e: self.handle_exception(e) async def edit(self, request: Request, pk: Any, data: Dict[str, Any]) -> Any: try: obj = await self.find_by_pk(request, pk) - return (await self._populate_obj(request, obj, data, True)).save() + obj = await self._populate_obj(request, obj, data, True) + await self.before_edit(request, data, obj) + obj.save() + await self.after_edit(request, obj) + return obj except Exception as e: self.handle_exception(e) @@ -201,7 +209,13 @@ async def _populate_obj( # noqa: C901 return obj async def delete(self, request: Request, pks: List[Any]) -> Optional[int]: - return self.document.objects(id__in=pks).delete() + objs = self.document.objects(id__in=pks) + for obj in objs: + await self.before_delete(request, obj) + deleted_count = objs.delete() + for obj in objs: + await self.after_delete(request, obj) + return deleted_count def handle_exception(self, exc: Exception) -> None: if isinstance(exc, ValidationError): diff --git a/starlette_admin/contrib/odmantic/view.py b/starlette_admin/contrib/odmantic/view.py index e35c231b..21c7eff6 100644 --- a/starlette_admin/contrib/odmantic/view.py +++ b/starlette_admin/contrib/odmantic/view.py @@ -134,9 +134,14 @@ async def create(self, request: Request, data: Dict) -> Any: session: Union[AIOSession, SyncSession] = request.state.session data = await self._arrange_data(request, data) try: + obj = self.model(**data) + await self.before_create(request, data, obj) if isinstance(session, AIOSession): - return await session.save(self.model(**data)) - return await anyio.to_thread.run_sync(session.save, self.model(**data)) + await session.save(obj) + else: + await anyio.to_thread.run_sync(session.save, obj) + await self.after_create(request, obj) + return obj except Exception as e: self.handle_exception(e) @@ -144,20 +149,31 @@ async def edit(self, request: Request, pk: Any, data: Dict[str, Any]) -> Any: session: Union[AIOSession, SyncSession] = request.state.session data = await self._arrange_data(request, data, is_edit=True) try: - instance = await self.find_by_pk(request, pk) - instance.update(data) + obj = await self.find_by_pk(request, pk) + obj.update(data) + await self.before_edit(request, data, obj) if isinstance(session, AIOSession): - return await session.save(instance) - return await anyio.to_thread.run_sync(session.save, instance) + obj = await session.save(obj) + else: + obj = await anyio.to_thread.run_sync(session.save, obj) + await self.after_edit(request, obj) + return obj except Exception as e: self.handle_exception(e) async def delete(self, request: Request, pks: List[Any]) -> Optional[int]: + objs = await self.find_by_pks(request, pks) pks = list(map(ObjectId, pks)) + for obj in objs: + await self.before_delete(request, obj) session: Union[AIOSession, SyncSession] = request.state.session if isinstance(session, AIOSession): - return await session.remove(self.model, self.model.id.in_(pks)) # type: ignore - return await anyio.to_thread.run_sync(session.remove, self.model, self.model.id.in_(pks)) # type: ignore + deleted_count = await session.remove(self.model, self.model.id.in_(pks)) # type: ignore + else: + deleted_count = await anyio.to_thread.run_sync(session.remove, self.model, self.model.id.in_(pks)) # type: ignore + for obj in objs: + await self.after_delete(request, obj) + return deleted_count def handle_exception(self, exc: Exception) -> None: if isinstance(exc, ValidationError): diff --git a/starlette_admin/contrib/sqla/view.py b/starlette_admin/contrib/sqla/view.py index a86bee1a..f1382672 100644 --- a/starlette_admin/contrib/sqla/view.py +++ b/starlette_admin/contrib/sqla/view.py @@ -360,12 +360,14 @@ async def create(self, request: Request, data: Dict[str, Any]) -> Any: session: Union[Session, AsyncSession] = request.state.session obj = await self._populate_obj(request, self.model(), data) session.add(obj) + await self.before_create(request, data, obj) if isinstance(session, AsyncSession): await session.commit() await session.refresh(obj) else: await anyio.to_thread.run_sync(session.commit) await anyio.to_thread.run_sync(session.refresh, obj) + await self.after_create(request, obj) return obj except Exception as e: return self.handle_exception(e) @@ -376,13 +378,16 @@ async def edit(self, request: Request, pk: Any, data: Dict[str, Any]) -> Any: await self.validate(request, data) session: Union[Session, AsyncSession] = request.state.session obj = await self.find_by_pk(request, pk) - session.add(await self._populate_obj(request, obj, data, True)) + await self._populate_obj(request, obj, data, True) + session.add(obj) + await self.before_edit(request, data, obj) if isinstance(session, AsyncSession): await session.commit() await session.refresh(obj) else: await anyio.to_thread.run_sync(session.commit) await anyio.to_thread.run_sync(session.refresh, obj) + await self.after_edit(request, obj) return obj except Exception as e: self.handle_exception(e) @@ -439,12 +444,16 @@ async def delete(self, request: Request, pks: List[Any]) -> Optional[int]: objs = await self.find_by_pks(request, pks) if isinstance(session, AsyncSession): for obj in objs: + await self.before_delete(request, obj) await session.delete(obj) await session.commit() else: for obj in objs: + await self.before_delete(request, obj) await anyio.to_thread.run_sync(session.delete, obj) await anyio.to_thread.run_sync(session.commit) + for obj in objs: + await self.after_delete(request, obj) return len(objs) async def build_full_text_search_query( @@ -468,8 +477,9 @@ def build_order_clauses( def handle_exception(self, exc: Exception) -> None: try: """Automatically handle sqlalchemy_file error""" - sqlalchemy_file = __import__("sqlalchemy_file") - if isinstance(exc, sqlalchemy_file.exceptions.ValidationError): + from sqlalchemy_file.exceptions import ValidationError + + if isinstance(exc, ValidationError): raise FormValidationError({exc.key: exc.msg}) except ImportError: # pragma: no cover pass diff --git a/starlette_admin/views.py b/starlette_admin/views.py index c1b3a36f..0f92e6a9 100644 --- a/starlette_admin/views.py +++ b/starlette_admin/views.py @@ -548,16 +548,6 @@ async def count( """ raise NotImplementedError() - @abstractmethod - async def delete(self, request: Request, pks: List[Any]) -> Optional[int]: - """ - Bulk delete items - Parameters: - request: The request being processed - pks: List of primary keys - """ - raise NotImplementedError() - @abstractmethod async def find_by_pk(self, request: Request, pk: Any) -> Any: """ @@ -578,6 +568,18 @@ async def find_by_pks(self, request: Request, pks: List[Any]) -> Sequence[Any]: """ raise NotImplementedError() + async def before_create( + self, request: Request, data: Dict[str, Any], obj: Any + ) -> None: + """ + This hook is called before a new item is created. + + Args: + request: The request being processed. + data: Dict values contained converted form data. + obj: The object about to be created. + """ + @abstractmethod async def create(self, request: Request, data: Dict) -> Any: """ @@ -590,6 +592,27 @@ async def create(self, request: Request, data: Dict) -> Any: """ raise NotImplementedError() + async def after_create(self, request: Request, obj: Any) -> None: + """ + This hook is called after a new item is successfully created. + + Args: + request: The request being processed. + obj: The newly created object. + """ + + async def before_edit( + self, request: Request, data: Dict[str, Any], obj: Any + ) -> None: + """ + This hook is called before an item is edited. + + Args: + request: The request being processed. + data: Dict values contained converted form data + obj: The object about to be edited. + """ + @abstractmethod async def edit(self, request: Request, pk: Any, data: Dict[str, Any]) -> Any: """ @@ -603,6 +626,43 @@ async def edit(self, request: Request, pk: Any, data: Dict[str, Any]) -> Any: """ raise NotImplementedError() + async def after_edit(self, request: Request, obj: Any) -> None: + """ + This hook is called after an item is successfully edited. + + Args: + request: The request being processed. + obj: The edited object. + """ + + async def before_delete(self, request: Request, obj: Any) -> None: + """ + This hook is called before an item is deleted. + + Args: + request: The request being processed. + obj: The object about to be deleted. + """ + + @abstractmethod + async def delete(self, request: Request, pks: List[Any]) -> Optional[int]: + """ + Bulk delete items + Parameters: + request: The request being processed + pks: List of primary keys + """ + raise NotImplementedError() + + async def after_delete(self, request: Request, obj: Any) -> None: + """ + This hook is called after an item is successfully deleted. + + Args: + request: The request being processed. + obj: The deleted object. + """ + def can_view_details(self, request: Request) -> bool: """Permission for viewing full details of Item. Return True by default""" return True diff --git a/tests/mongoengine/test_view.py b/tests/mongoengine/test_view.py index fda1fb87..d8342fc6 100644 --- a/tests/mongoengine/test_view.py +++ b/tests/mongoengine/test_view.py @@ -3,10 +3,12 @@ import json import tempfile from enum import Enum +from typing import Any, Dict import mongoengine as me import pytest from mongoengine import connect, disconnect +from requests import Request from starlette.applications import Starlette from starlette.testclient import TestClient from starlette_admin.contrib.mongoengine import Admin, ModelView @@ -42,6 +44,36 @@ class User(me.Document): store = me.ReferenceField("Store") +class ProductView(ModelView): + async def before_create( + self, request: Request, data: Dict[str, Any], obj: Any + ) -> None: + assert isinstance(obj, Product) + assert obj.id is None + + async def after_create(self, request: Request, obj: Any) -> None: + assert isinstance(obj, Product) + assert obj.id is not None + + async def before_edit( + self, request: Request, data: Dict[str, Any], obj: Any + ) -> None: + assert isinstance(obj, Product) + assert obj.id is not None + + async def after_edit(self, request: Request, obj: Any) -> None: + assert isinstance(obj, Product) + assert obj.id is not None + + async def before_delete(self, request: Request, obj: Any) -> None: + assert isinstance(obj, Product) + assert obj.id is not None + + async def after_delete(self, request: Request, obj: Any) -> None: + assert isinstance(obj, Product) + assert obj.id is not None + + class TestMongoBasic: def setup_method(self, method): connect(host=MONGO_URL, uuidRepresentation="standard") @@ -59,7 +91,7 @@ def teardown_method(self, method): def admin(self): admin = Admin() admin.add_view(ModelView(Store)) - admin.add_view(ModelView(Product)) + admin.add_view(ProductView(Product)) admin.add_view(ModelView(User)) return admin diff --git a/tests/odmantic/test_async_engine.py b/tests/odmantic/test_async_engine.py index c150be31..66238cbb 100644 --- a/tests/odmantic/test_async_engine.py +++ b/tests/odmantic/test_async_engine.py @@ -1,12 +1,13 @@ import json import sys from datetime import datetime -from typing import List, Optional +from typing import Any, Dict, List, Optional import pytest import pytest_asyncio from httpx import AsyncClient from odmantic import AIOEngine, EmbeddedModel, Field, Model +from requests import Request from starlette.applications import Starlette from starlette_admin.contrib.odmantic import Admin, ModelView @@ -37,6 +38,36 @@ class User(Model): birthday: Optional[datetime] +class UserView(ModelView): + async def before_create( + self, request: Request, data: Dict[str, Any], obj: Any + ) -> None: + assert isinstance(obj, User) + assert obj.id is not None + + async def after_create(self, request: Request, obj: Any) -> None: + assert isinstance(obj, User) + assert obj.id is not None + + async def before_edit( + self, request: Request, data: Dict[str, Any], obj: Any + ) -> None: + assert isinstance(obj, User) + assert obj.id is not None + + async def after_edit(self, request: Request, obj: Any) -> None: + assert isinstance(obj, User) + assert obj.id is not None + + async def before_delete(self, request: Request, obj: Any) -> None: + assert isinstance(obj, User) + assert obj.id is not None + + async def after_delete(self, request: Request, obj: Any) -> None: + assert isinstance(obj, User) + assert obj.id is not None + + @pytest_asyncio.fixture() async def prepare_database(aio_engine: AIOEngine): await aio_engine.remove(User) @@ -68,7 +99,7 @@ async def prepare_database(aio_engine: AIOEngine): async def client(prepare_database, aio_engine: AIOEngine): admin = Admin(aio_engine) app = Starlette() - admin.add_view(ModelView(User)) + admin.add_view(UserView(User)) admin.mount_to(app) async with AsyncClient(app=app, base_url="http://testserver") as c: yield c diff --git a/tests/sqla/test_async_engine.py b/tests/sqla/test_async_engine.py index 558c5714..1cc66d12 100644 --- a/tests/sqla/test_async_engine.py +++ b/tests/sqla/test_async_engine.py @@ -1,3 +1,5 @@ +from typing import Any, Dict + import pytest import pytest_asyncio from httpx import AsyncClient @@ -5,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession from sqlalchemy.orm import declarative_base from starlette.applications import Starlette +from starlette.requests import Request from starlette_admin.contrib.sqla import Admin, ModelView from tests.sqla.utils import get_async_test_engine @@ -21,6 +24,37 @@ class Product(Base): title = Column(String(100)) +class ProductView(ModelView): + async def before_create( + self, request: Request, data: Dict[str, Any], obj: Any + ) -> None: + assert isinstance(obj, Product) + assert obj.id is None + assert obj.title == "Infinix INBOOK" + + async def after_create(self, request: Request, obj: Any) -> None: + assert isinstance(obj, Product) + assert obj.id is not None + + async def before_edit( + self, request: Request, data: Dict[str, Any], obj: Any + ) -> None: + assert isinstance(obj, Product) + assert obj.id is not None + + async def after_edit(self, request: Request, obj: Any) -> None: + assert isinstance(obj, Product) + assert obj.id is not None + + async def before_delete(self, request: Request, obj: Any) -> None: + assert isinstance(obj, Product) + assert obj.id is not None + + async def after_delete(self, request: Request, obj: Any) -> None: + assert isinstance(obj, Product) + assert obj.id is not None + + @pytest_asyncio.fixture() async def engine(): _engine = get_async_test_engine() @@ -41,7 +75,7 @@ async def session(engine: AsyncEngine) -> AsyncSession: @pytest_asyncio.fixture async def client(engine: AsyncEngine): admin = Admin(engine) - admin.add_view(ModelView(Product)) + admin.add_view(ProductView(Product)) app = Starlette() admin.mount_to(app) async with AsyncClient(app=app, base_url="http://testserver") as c: diff --git a/tests/sqla/test_sync_engine.py b/tests/sqla/test_sync_engine.py index c0ec77aa..f83617ca 100644 --- a/tests/sqla/test_sync_engine.py +++ b/tests/sqla/test_sync_engine.py @@ -1,5 +1,6 @@ import enum import json +from typing import Any, Dict import pytest import pytest_asyncio @@ -21,6 +22,7 @@ from sqlalchemy.orm import Session, declarative_base, relationship from sqlalchemy_file.storage import StorageManager from starlette.applications import Starlette +from starlette.requests import Request from starlette_admin.contrib.sqla import Admin from starlette_admin.contrib.sqla.view import ModelView @@ -59,6 +61,36 @@ class User(Base): products = relationship("Product", back_populates="user") +class ProductView(ModelView): + async def before_create( + self, request: Request, data: Dict[str, Any], obj: Any + ) -> None: + assert isinstance(obj, Product) + assert obj.id is None + + async def after_create(self, request: Request, obj: Any) -> None: + assert isinstance(obj, Product) + assert obj.id is not None + + async def before_edit( + self, request: Request, data: Dict[str, Any], obj: Any + ) -> None: + assert isinstance(obj, Product) + assert obj.id is not None + + async def after_edit(self, request: Request, obj: Any) -> None: + assert isinstance(obj, Product) + assert obj.id is not None + + async def before_delete(self, request: Request, obj: Any) -> None: + assert isinstance(obj, Product) + assert obj.id is not None + + async def after_delete(self, request: Request, obj: Any) -> None: + assert isinstance(obj, Product) + assert obj.id is not None + + class UserView(ModelView): form_include_pk = True @@ -98,7 +130,7 @@ def session(engine: Engine) -> Session: def admin(engine: Engine): admin = Admin(engine) admin.add_view(UserView(User)) - admin.add_view(ModelView(Product)) + admin.add_view(ProductView(Product)) return admin