From ac9f63efe9a7728c1da012059fcab0d436143940 Mon Sep 17 00:00:00 2001 From: Julien Barreau Date: Thu, 28 Nov 2024 15:25:38 +0100 Subject: [PATCH] fix(rename decorator): properly map relation when renaming pk field (#292) --- .../decorators/rename_field/collections.py | 44 +++++++++++++++---- .../test_rename_field_collection_decorator.py | 17 +++++++ 2 files changed, 52 insertions(+), 9 deletions(-) diff --git a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/rename_field/collections.py b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/rename_field/collections.py index 71b4ca339..bbf9d8114 100644 --- a/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/rename_field/collections.py +++ b/src/datasource_toolkit/forestadmin/datasource_toolkit/decorators/rename_field/collections.py @@ -94,9 +94,9 @@ def computed_fields(tree: ConditionTree) -> ConditionTree: return _filter - async def list(self, caller: User, _filter: PaginatedFilter, projection: Projection) -> List[RecordsDataAlias]: + async def list(self, caller: User, filter_: PaginatedFilter, projection: Projection) -> List[RecordsDataAlias]: child_projection = projection.replace(lambda field_name: self._path_to_child_collection(field_name)) - records: List[RecordsDataAlias] = await super().list(caller, _filter, child_projection) # type: ignore + records: List[RecordsDataAlias] = await super().list(caller, filter_, child_projection) # type: ignore if child_projection == projection: return records @@ -108,15 +108,15 @@ async def create(self, caller: User, data: List[RecordsDataAlias]) -> List[Recor ) return [self._record_from_child_collection(record) for record in records] - async def update(self, caller: User, _filter: Optional[Filter], patch: RecordsDataAlias) -> None: + async def update(self, caller: User, filter_: Optional[Filter], patch: RecordsDataAlias) -> None: refined_patch = self._record_to_child_collection(patch) - return await super().update(caller, _filter, refined_patch) # type: ignore + return await super().update(caller, filter_, refined_patch) # type: ignore async def aggregate( - self, caller: User, _filter: Optional[Filter], aggregation: Aggregation, limit: Optional[int] = None + self, caller: User, filter_: Optional[Filter], aggregation: Aggregation, limit: Optional[int] = None ) -> List[AggregateResult]: rows: List[AggregateResult] = await super().aggregate( # type: ignore - caller, _filter, aggregation.replace_fields(lambda name: self._path_to_child_collection(name)), limit + caller, filter_, aggregation.replace_fields(lambda name: self._path_to_child_collection(name)), limit ) return [AggregateResult(value=row["value"], group=self._build_group_aggregate(row)) for row in rows] @@ -130,24 +130,50 @@ def _refine_schema(self, sub_schema: CollectionSchema) -> CollectionSchema: new_fields_schema = {} datasource = cast(Datasource[BoundCollection], self.datasource) - for old_field_name, field_schema in sub_schema["fields"].items(): + for old_field_name, old_schema in sub_schema["fields"].items(): + field_schema = {**old_schema} if is_many_to_one(field_schema): + relation = cast( + RenameFieldCollectionDecorator, + self.datasource.get_collection(field_schema["foreign_collection"]), + ) field_schema["foreign_key"] = self._from_child_collection.get( field_schema["foreign_key"], field_schema["foreign_key"] ) + field_schema["foreign_key_target"] = relation._from_child_collection.get( + field_schema["foreign_key_target"], field_schema["foreign_key_target"] + ) elif is_one_to_many(field_schema) or is_one_to_one(field_schema): - relation = datasource.get_collection(field_schema["foreign_collection"]) + relation = cast( + RenameFieldCollectionDecorator, + self.datasource.get_collection(field_schema["foreign_collection"]), + ) field_schema["origin_key"] = relation._from_child_collection.get( field_schema["origin_key"], field_schema["origin_key"] ) + field_schema["origin_key_target"] = self._from_child_collection.get( + field_schema["origin_key_target"], field_schema["origin_key_target"] + ) elif is_many_to_many(field_schema): - through = datasource.get_collection(field_schema["through_collection"]) + relation = cast( + RenameFieldCollectionDecorator, + self.datasource.get_collection(field_schema["foreign_collection"]), + ) + through = cast( + RenameFieldCollectionDecorator, datasource.get_collection(field_schema["through_collection"]) + ) field_schema["foreign_key"] = through._from_child_collection.get( field_schema["foreign_key"], field_schema["foreign_key"] ) field_schema["origin_key"] = through._from_child_collection.get( field_schema["origin_key"], field_schema["origin_key"] ) + field_schema["origin_key_target"] = self._from_child_collection.get( + field_schema["origin_key_target"], field_schema["origin_key_target"] + ) + field_schema["foreign_key_target"] = relation._from_child_collection.get( + field_schema["foreign_key_target"], field_schema["foreign_key_target"] + ) # we don't handle schema modification for polymorphic many to one and reverse relations because # we forbid to rename foreign key and type fields on polymorphic many to one diff --git a/src/datasource_toolkit/tests/decorators/rename_field/test_rename_field_collection_decorator.py b/src/datasource_toolkit/tests/decorators/rename_field/test_rename_field_collection_decorator.py index 74946ebff..f254fe8e7 100644 --- a/src/datasource_toolkit/tests/decorators/rename_field/test_rename_field_collection_decorator.py +++ b/src/datasource_toolkit/tests/decorators/rename_field/test_rename_field_collection_decorator.py @@ -411,3 +411,20 @@ def test_relation_should_be_updated_in_all_collection_when_renaming_fk(self): assert person_book_fields["book"]["foreign_key"] == "novel_id" assert person_book_fields["person"]["foreign_key"] == "author_id" assert person_fields["book"]["origin_key"] == "author_id" + + def test_when_renaming_pk_should_update_all_collections(self): + self.decorated_collection_book.rename_field("id", "new_book_id") + self.decorated_collection_person.rename_field("id", "new_person_id") + + book_fields = self.decorated_collection_book.schema["fields"] + person_fields = self.decorated_collection_person.schema["fields"] + person_book_fields = self.decorated_collection_person_book.schema["fields"] + + self.assertEqual(book_fields["persons"]["foreign_key_target"], "new_person_id") + self.assertEqual(book_fields["persons"]["origin_key_target"], "new_book_id") + + self.assertEqual(book_fields["author"]["foreign_key_target"], "new_person_id") + self.assertEqual(person_fields["book"]["origin_key_target"], "new_person_id") + + self.assertEqual(person_book_fields["book"]["foreign_key_target"], "new_book_id") + self.assertEqual(person_book_fields["person"]["foreign_key_target"], "new_person_id")