Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Before and After Hooks for Create, Edit, and Delete Operations #327

Merged
merged 6 commits into from
Oct 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions docs/changelog/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

### Added

* Row actions by [@jowilf](https://github.com/jowilf) in [#348](https://github.com/jowilf/starlette-admin/pull/348)
* Add Before and After Hooks for Create, Edit, and Delete Operations by [@jowilf](https://github.com/jowilf)
in [#327](https://github.com/jowilf/starlette-admin/pull/327)

* Row actions by [@jowilf](https://github.com/jowilf) in [#348](https://github.com/jowilf/starlette-admin/pull/348)

* Add Support for Custom Sortable Field Mapping in SQLAlchemy ModelView by [@jowilf](https://github.com/jowilf)
in [#328](https://github.com/jowilf/starlette-admin/pull/328)

??? usage
???+ usage

```python
```python hl_lines="12"
class Post(Base):
__tablename__ = "post"

Expand All @@ -35,7 +37,7 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

* Add support for datatables [state saving](https://datatables.net/examples/basic_init/state_save.html)

??? usage
???+ usage

```python
class MyModelView(ModelView):
Expand Down
29 changes: 29 additions & 0 deletions docs/tutorial/configurations/modelview/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,32 @@ fields.
```

![Custom Select2 rendering](../../../images/tutorial/configurations/modelview/select2_customization.png){ width="300" }

## Hooks

Hooks are callback functions that give you an easy way to customize and extend the default CRUD functions. You can use
hooks to perform actions before or after specific operations such as item creation, editing, or deletion.

The following hooks are available:

- [before_create(request, data, obj)][starlette_admin.views.BaseModelView.before_create]: Called before a new object is
created

- [after_create(request, obj)][starlette_admin.views.BaseModelView.after_create]: Called after a new object is created

- [before_edit(request, data, obj)][starlette_admin.views.BaseModelView.before_edit]: Called before an existing object is
updated

- [after_edit(request, obj)][starlette_admin.views.BaseModelView.after_edit]: Called after an existing object is updated

- [before_delete(request, obj)][starlette_admin.views.BaseModelView.before_delete]: Called before an object is deleted

- [after_delete(request, obj)][starlette_admin.views.BaseModelView.after_delete]: Called after an object is deleted

### Example

```python
class OrderView(ModelView):
async def after_create(self, request: Request, order: Order):
analytics.track_order_created(order)
```
22 changes: 18 additions & 4 deletions starlette_admin/contrib/mongoengine/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,24 @@ async def find_by_pks(
) -> Sequence[me.Document]:
return self.document.objects(id__in=pks)

async def create(self, request: Request, data: Dict[str, Any]) -> None:
async def create(self, request: Request, data: Dict[str, Any]) -> Any:
try:
return (await self._populate_obj(request, self.document(), data)).save()
obj = await self._populate_obj(request, self.document(), data)
await self.before_create(request, data, obj)
obj.save()
await self.after_create(request, obj)
return obj
except Exception as e:
self.handle_exception(e)

async def edit(self, request: Request, pk: Any, data: Dict[str, Any]) -> Any:
try:
obj = await self.find_by_pk(request, pk)
return (await self._populate_obj(request, obj, data, True)).save()
obj = await self._populate_obj(request, obj, data, True)
await self.before_edit(request, data, obj)
obj.save()
await self.after_edit(request, obj)
return obj
except Exception as e:
self.handle_exception(e)

Expand Down Expand Up @@ -201,7 +209,13 @@ async def _populate_obj( # noqa: C901
return obj

async def delete(self, request: Request, pks: List[Any]) -> Optional[int]:
return self.document.objects(id__in=pks).delete()
objs = self.document.objects(id__in=pks)
for obj in objs:
await self.before_delete(request, obj)
deleted_count = objs.delete()
for obj in objs:
await self.after_delete(request, obj)
return deleted_count

def handle_exception(self, exc: Exception) -> None:
if isinstance(exc, ValidationError):
Expand Down
32 changes: 24 additions & 8 deletions starlette_admin/contrib/odmantic/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,30 +134,46 @@ async def create(self, request: Request, data: Dict) -> Any:
session: Union[AIOSession, SyncSession] = request.state.session
data = await self._arrange_data(request, data)
try:
obj = self.model(**data)
await self.before_create(request, data, obj)
if isinstance(session, AIOSession):
return await session.save(self.model(**data))
return await anyio.to_thread.run_sync(session.save, self.model(**data))
await session.save(obj)
else:
await anyio.to_thread.run_sync(session.save, obj)
await self.after_create(request, obj)
return obj
except Exception as e:
self.handle_exception(e)

async def edit(self, request: Request, pk: Any, data: Dict[str, Any]) -> Any:
session: Union[AIOSession, SyncSession] = request.state.session
data = await self._arrange_data(request, data, is_edit=True)
try:
instance = await self.find_by_pk(request, pk)
instance.update(data)
obj = await self.find_by_pk(request, pk)
obj.update(data)
await self.before_edit(request, data, obj)
if isinstance(session, AIOSession):
return await session.save(instance)
return await anyio.to_thread.run_sync(session.save, instance)
obj = await session.save(obj)
else:
obj = await anyio.to_thread.run_sync(session.save, obj)
await self.after_edit(request, obj)
return obj
except Exception as e:
self.handle_exception(e)

async def delete(self, request: Request, pks: List[Any]) -> Optional[int]:
objs = await self.find_by_pks(request, pks)
pks = list(map(ObjectId, pks))
for obj in objs:
await self.before_delete(request, obj)
session: Union[AIOSession, SyncSession] = request.state.session
if isinstance(session, AIOSession):
return await session.remove(self.model, self.model.id.in_(pks)) # type: ignore
return await anyio.to_thread.run_sync(session.remove, self.model, self.model.id.in_(pks)) # type: ignore
deleted_count = await session.remove(self.model, self.model.id.in_(pks)) # type: ignore
else:
deleted_count = await anyio.to_thread.run_sync(session.remove, self.model, self.model.id.in_(pks)) # type: ignore
for obj in objs:
await self.after_delete(request, obj)
return deleted_count

def handle_exception(self, exc: Exception) -> None:
if isinstance(exc, ValidationError):
Expand Down
16 changes: 13 additions & 3 deletions starlette_admin/contrib/sqla/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,12 +360,14 @@ async def create(self, request: Request, data: Dict[str, Any]) -> Any:
session: Union[Session, AsyncSession] = request.state.session
obj = await self._populate_obj(request, self.model(), data)
session.add(obj)
await self.before_create(request, data, obj)
if isinstance(session, AsyncSession):
await session.commit()
await session.refresh(obj)
else:
await anyio.to_thread.run_sync(session.commit)
await anyio.to_thread.run_sync(session.refresh, obj)
await self.after_create(request, obj)
return obj
except Exception as e:
return self.handle_exception(e)
Expand All @@ -376,13 +378,16 @@ async def edit(self, request: Request, pk: Any, data: Dict[str, Any]) -> Any:
await self.validate(request, data)
session: Union[Session, AsyncSession] = request.state.session
obj = await self.find_by_pk(request, pk)
session.add(await self._populate_obj(request, obj, data, True))
await self._populate_obj(request, obj, data, True)
session.add(obj)
await self.before_edit(request, data, obj)
if isinstance(session, AsyncSession):
await session.commit()
await session.refresh(obj)
else:
await anyio.to_thread.run_sync(session.commit)
await anyio.to_thread.run_sync(session.refresh, obj)
await self.after_edit(request, obj)
return obj
except Exception as e:
self.handle_exception(e)
Expand Down Expand Up @@ -439,12 +444,16 @@ async def delete(self, request: Request, pks: List[Any]) -> Optional[int]:
objs = await self.find_by_pks(request, pks)
if isinstance(session, AsyncSession):
for obj in objs:
await self.before_delete(request, obj)
await session.delete(obj)
await session.commit()
else:
for obj in objs:
await self.before_delete(request, obj)
await anyio.to_thread.run_sync(session.delete, obj)
await anyio.to_thread.run_sync(session.commit)
for obj in objs:
await self.after_delete(request, obj)
return len(objs)

async def build_full_text_search_query(
Expand All @@ -468,8 +477,9 @@ def build_order_clauses(
def handle_exception(self, exc: Exception) -> None:
try:
"""Automatically handle sqlalchemy_file error"""
sqlalchemy_file = __import__("sqlalchemy_file")
if isinstance(exc, sqlalchemy_file.exceptions.ValidationError):
from sqlalchemy_file.exceptions import ValidationError

if isinstance(exc, ValidationError):
raise FormValidationError({exc.key: exc.msg})
except ImportError: # pragma: no cover
pass
Expand Down
80 changes: 70 additions & 10 deletions starlette_admin/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,16 +548,6 @@ async def count(
"""
raise NotImplementedError()

@abstractmethod
async def delete(self, request: Request, pks: List[Any]) -> Optional[int]:
"""
Bulk delete items
Parameters:
request: The request being processed
pks: List of primary keys
"""
raise NotImplementedError()

@abstractmethod
async def find_by_pk(self, request: Request, pk: Any) -> Any:
"""
Expand All @@ -578,6 +568,18 @@ async def find_by_pks(self, request: Request, pks: List[Any]) -> Sequence[Any]:
"""
raise NotImplementedError()

async def before_create(
self, request: Request, data: Dict[str, Any], obj: Any
) -> None:
"""
This hook is called before a new item is created.

Args:
request: The request being processed.
data: Dict values contained converted form data.
obj: The object about to be created.
"""

@abstractmethod
async def create(self, request: Request, data: Dict) -> Any:
"""
Expand All @@ -590,6 +592,27 @@ async def create(self, request: Request, data: Dict) -> Any:
"""
raise NotImplementedError()

async def after_create(self, request: Request, obj: Any) -> None:
"""
This hook is called after a new item is successfully created.

Args:
request: The request being processed.
obj: The newly created object.
"""

async def before_edit(
self, request: Request, data: Dict[str, Any], obj: Any
) -> None:
"""
This hook is called before an item is edited.

Args:
request: The request being processed.
data: Dict values contained converted form data
obj: The object about to be edited.
"""

@abstractmethod
async def edit(self, request: Request, pk: Any, data: Dict[str, Any]) -> Any:
"""
Expand All @@ -603,6 +626,43 @@ async def edit(self, request: Request, pk: Any, data: Dict[str, Any]) -> Any:
"""
raise NotImplementedError()

async def after_edit(self, request: Request, obj: Any) -> None:
"""
This hook is called after an item is successfully edited.

Args:
request: The request being processed.
obj: The edited object.
"""

async def before_delete(self, request: Request, obj: Any) -> None:
"""
This hook is called before an item is deleted.

Args:
request: The request being processed.
obj: The object about to be deleted.
"""

@abstractmethod
async def delete(self, request: Request, pks: List[Any]) -> Optional[int]:
"""
Bulk delete items
Parameters:
request: The request being processed
pks: List of primary keys
"""
raise NotImplementedError()

async def after_delete(self, request: Request, obj: Any) -> None:
"""
This hook is called after an item is successfully deleted.

Args:
request: The request being processed.
obj: The deleted object.
"""

def can_view_details(self, request: Request) -> bool:
"""Permission for viewing full details of Item. Return True by default"""
return True
Expand Down
34 changes: 33 additions & 1 deletion tests/mongoengine/test_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import json
import tempfile
from enum import Enum
from typing import Any, Dict

import mongoengine as me
import pytest
from mongoengine import connect, disconnect
from requests import Request
from starlette.applications import Starlette
from starlette.testclient import TestClient
from starlette_admin.contrib.mongoengine import Admin, ModelView
Expand Down Expand Up @@ -42,6 +44,36 @@ class User(me.Document):
store = me.ReferenceField("Store")


class ProductView(ModelView):
async def before_create(
self, request: Request, data: Dict[str, Any], obj: Any
) -> None:
assert isinstance(obj, Product)
assert obj.id is None

async def after_create(self, request: Request, obj: Any) -> None:
assert isinstance(obj, Product)
assert obj.id is not None

async def before_edit(
self, request: Request, data: Dict[str, Any], obj: Any
) -> None:
assert isinstance(obj, Product)
assert obj.id is not None

async def after_edit(self, request: Request, obj: Any) -> None:
assert isinstance(obj, Product)
assert obj.id is not None

async def before_delete(self, request: Request, obj: Any) -> None:
assert isinstance(obj, Product)
assert obj.id is not None

async def after_delete(self, request: Request, obj: Any) -> None:
assert isinstance(obj, Product)
assert obj.id is not None


class TestMongoBasic:
def setup_method(self, method):
connect(host=MONGO_URL, uuidRepresentation="standard")
Expand All @@ -59,7 +91,7 @@ def teardown_method(self, method):
def admin(self):
admin = Admin()
admin.add_view(ModelView(Store))
admin.add_view(ModelView(Product))
admin.add_view(ProductView(Product))
admin.add_view(ModelView(User))
return admin

Expand Down
Loading