Skip to content

Commit

Permalink
fix(rename decorator): properly map relation when renaming pk field (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jbarreau authored Nov 28, 2024
1 parent c73544b commit ac9f63e
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def computed_fields(tree: ConditionTree) -> ConditionTree:

return _filter

async def list(self, caller: User, _filter: PaginatedFilter, projection: Projection) -> List[RecordsDataAlias]:
async def list(self, caller: User, filter_: PaginatedFilter, projection: Projection) -> List[RecordsDataAlias]:
child_projection = projection.replace(lambda field_name: self._path_to_child_collection(field_name))
records: List[RecordsDataAlias] = await super().list(caller, _filter, child_projection) # type: ignore
records: List[RecordsDataAlias] = await super().list(caller, filter_, child_projection) # type: ignore
if child_projection == projection:
return records

Expand All @@ -108,15 +108,15 @@ async def create(self, caller: User, data: List[RecordsDataAlias]) -> List[Recor
)
return [self._record_from_child_collection(record) for record in records]

async def update(self, caller: User, _filter: Optional[Filter], patch: RecordsDataAlias) -> None:
async def update(self, caller: User, filter_: Optional[Filter], patch: RecordsDataAlias) -> None:
refined_patch = self._record_to_child_collection(patch)
return await super().update(caller, _filter, refined_patch) # type: ignore
return await super().update(caller, filter_, refined_patch) # type: ignore

async def aggregate(
self, caller: User, _filter: Optional[Filter], aggregation: Aggregation, limit: Optional[int] = None
self, caller: User, filter_: Optional[Filter], aggregation: Aggregation, limit: Optional[int] = None
) -> List[AggregateResult]:
rows: List[AggregateResult] = await super().aggregate( # type: ignore
caller, _filter, aggregation.replace_fields(lambda name: self._path_to_child_collection(name)), limit
caller, filter_, aggregation.replace_fields(lambda name: self._path_to_child_collection(name)), limit
)
return [AggregateResult(value=row["value"], group=self._build_group_aggregate(row)) for row in rows]

Expand All @@ -130,24 +130,50 @@ def _refine_schema(self, sub_schema: CollectionSchema) -> CollectionSchema:
new_fields_schema = {}
datasource = cast(Datasource[BoundCollection], self.datasource)

for old_field_name, field_schema in sub_schema["fields"].items():
for old_field_name, old_schema in sub_schema["fields"].items():
field_schema = {**old_schema}
if is_many_to_one(field_schema):
relation = cast(
RenameFieldCollectionDecorator,
self.datasource.get_collection(field_schema["foreign_collection"]),
)
field_schema["foreign_key"] = self._from_child_collection.get(
field_schema["foreign_key"], field_schema["foreign_key"]
)
field_schema["foreign_key_target"] = relation._from_child_collection.get(
field_schema["foreign_key_target"], field_schema["foreign_key_target"]
)
elif is_one_to_many(field_schema) or is_one_to_one(field_schema):
relation = datasource.get_collection(field_schema["foreign_collection"])
relation = cast(
RenameFieldCollectionDecorator,
self.datasource.get_collection(field_schema["foreign_collection"]),
)
field_schema["origin_key"] = relation._from_child_collection.get(
field_schema["origin_key"], field_schema["origin_key"]
)
field_schema["origin_key_target"] = self._from_child_collection.get(
field_schema["origin_key_target"], field_schema["origin_key_target"]
)
elif is_many_to_many(field_schema):
through = datasource.get_collection(field_schema["through_collection"])
relation = cast(
RenameFieldCollectionDecorator,
self.datasource.get_collection(field_schema["foreign_collection"]),
)
through = cast(
RenameFieldCollectionDecorator, datasource.get_collection(field_schema["through_collection"])
)
field_schema["foreign_key"] = through._from_child_collection.get(
field_schema["foreign_key"], field_schema["foreign_key"]
)
field_schema["origin_key"] = through._from_child_collection.get(
field_schema["origin_key"], field_schema["origin_key"]
)
field_schema["origin_key_target"] = self._from_child_collection.get(
field_schema["origin_key_target"], field_schema["origin_key_target"]
)
field_schema["foreign_key_target"] = relation._from_child_collection.get(
field_schema["foreign_key_target"], field_schema["foreign_key_target"]
)
# we don't handle schema modification for polymorphic many to one and reverse relations because
# we forbid to rename foreign key and type fields on polymorphic many to one

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,3 +411,20 @@ def test_relation_should_be_updated_in_all_collection_when_renaming_fk(self):
assert person_book_fields["book"]["foreign_key"] == "novel_id"
assert person_book_fields["person"]["foreign_key"] == "author_id"
assert person_fields["book"]["origin_key"] == "author_id"

def test_when_renaming_pk_should_update_all_collections(self):
self.decorated_collection_book.rename_field("id", "new_book_id")
self.decorated_collection_person.rename_field("id", "new_person_id")

book_fields = self.decorated_collection_book.schema["fields"]
person_fields = self.decorated_collection_person.schema["fields"]
person_book_fields = self.decorated_collection_person_book.schema["fields"]

self.assertEqual(book_fields["persons"]["foreign_key_target"], "new_person_id")
self.assertEqual(book_fields["persons"]["origin_key_target"], "new_book_id")

self.assertEqual(book_fields["author"]["foreign_key_target"], "new_person_id")
self.assertEqual(person_fields["book"]["origin_key_target"], "new_person_id")

self.assertEqual(person_book_fields["book"]["foreign_key_target"], "new_book_id")
self.assertEqual(person_book_fields["person"]["foreign_key_target"], "new_person_id")

0 comments on commit ac9f63e

Please sign in to comment.