Skip to content

Commit

Permalink
Reformat files
Browse files Browse the repository at this point in the history
  • Loading branch information
Artucuno committed Jan 28, 2024
1 parent 2b9c745 commit bb1708c
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 25 deletions.
3 changes: 2 additions & 1 deletion integration_test/test_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

# This test isn't really nessessary


def extract_python_snippets(content):
# Regular expression pattern for finding Python code blocks
pattern = r'```python(.*?)```'
pattern = r"```python(.*?)```"
snippets = re.findall(pattern, content, re.DOTALL)
return snippets

Expand Down
53 changes: 35 additions & 18 deletions pydantic_mongo/abstract_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def to_document(model: T) -> dict:
:param model: Model to convert
:return: dict
"""
data = model.model_dump(mode='json')
data = model.model_dump(mode="json")
data.pop("id")
if model.id:
data["_id"] = model.id
Expand Down Expand Up @@ -121,7 +121,9 @@ def __map_sort(self, sort: Sort) -> str | list[tuple] | list[tuple[str | Any, An
raise Exception("Sort should be str, tuple or list of tuples")
return result

def to_model_custom(self, output_type: Type[OutputT], data: Union[dict, Mapping[str, Any]]) -> OutputT:
def to_model_custom(
self, output_type: Type[OutputT], data: Union[dict, Mapping[str, Any]]
) -> OutputT:
"""
Convert document to model with custom output type
"""
Expand Down Expand Up @@ -158,8 +160,9 @@ def save(self, model: T, **kwargs) -> Union[InsertOneResult, UpdateResult]:
model.id = result.inserted_id
return result

def save_many(self, models: Iterable[T], **kwargs) -> \
Union[Tuple[BulkWriteResult, None], Tuple[None, InsertManyResult]]:
def save_many(
self, models: Iterable[T], **kwargs
) -> Union[Tuple[BulkWriteResult, None], Tuple[None, InsertManyResult]]:
"""
Save multiple entities to database
:param models: Iterable of models to save
Expand All @@ -177,8 +180,7 @@ def save_many(self, models: Iterable[T], **kwargs) -> \
result = None
if len(models_to_insert) > 0:
result = self.get_collection().insert_many(
(self.to_document(model) for model in models_to_insert),
**kwargs
(self.to_document(model) for model in models_to_insert), **kwargs
)

for idx, inserted_id in enumerate(result.inserted_ids):
Expand Down Expand Up @@ -421,7 +423,7 @@ def to_document(model: T) -> dict:
:param model: Model to convert
:return: dict
"""
data = model.model_dump(mode='json')
data = model.model_dump(mode="json")
data.pop("id")
if model.id:
data["_id"] = model.id
Expand Down Expand Up @@ -454,7 +456,9 @@ def __map_sort(self, sort: Sort) -> str | list[tuple] | list[tuple[str | Any, An
raise Exception("Sort should be str, tuple or list of tuples")
return result

def to_model_custom(self, output_type: Type[OutputT], data: Union[dict, Mapping[str, Any]]) -> OutputT:
def to_model_custom(
self, output_type: Type[OutputT], data: Union[dict, Mapping[str, Any]]
) -> OutputT:
"""
Convert document to model with custom output type
"""
Expand All @@ -469,7 +473,9 @@ def get_collection(self) -> motor_asyncio.AsyncIOMotorCollection:
"""
return self.__database[self.__collection_name]

async def find_one_by_id(self, _id: Union[ObjectId, str], *args, **kwargs) -> Optional[T]:
async def find_one_by_id(
self, _id: Union[ObjectId, str], *args, **kwargs
) -> Optional[T]:
"""
Find entity by id
Expand All @@ -481,7 +487,9 @@ async def find_one_by(self, query: dict, *args, **kwargs) -> Optional[T]:
"""
Find entity by mongo query
"""
result = await self.get_collection().find_one(self.__map_id(query), *args, **kwargs)
result = await self.get_collection().find_one(
self.__map_id(query), *args, **kwargs
)
return self.to_model(result) if result else None

async def find_by_with_output_type(
Expand Down Expand Up @@ -512,7 +520,10 @@ async def find_by_with_output_type(
cursor.skip(skip)
if sort:
cursor.sort(mapped_sort)
return map(lambda doc: self.to_model_custom(output_type, doc), await cursor.to_list(limit))
return map(
lambda doc: self.to_model_custom(output_type, doc),
await cursor.to_list(limit),
)

async def find_by(
self,
Expand All @@ -534,7 +545,9 @@ async def find_by(
projection=projection,
)

async def paginate_simple(self, query: dict, limit: int = 25, page: int = 1, **kwargs):
async def paginate_simple(
self, query: dict, limit: int = 25, page: int = 1, **kwargs
):
"""
Paginate entities by mongo query using simple pagination
:param query:
Expand Down Expand Up @@ -570,8 +583,9 @@ async def save(self, model: T, **kwargs) -> Union[InsertOneResult, UpdateResult]
model.id = result.inserted_id
return result

async def save_many(self, models: Iterable[T], **kwargs) -> \
Union[Tuple[BulkWriteResult, None], Tuple[None, InsertManyResult]]:
async def save_many(
self, models: Iterable[T], **kwargs
) -> Union[Tuple[BulkWriteResult, None], Tuple[None, InsertManyResult]]:
"""
Save multiple entities to database
:param models: Iterable of models to save
Expand All @@ -589,8 +603,7 @@ async def save_many(self, models: Iterable[T], **kwargs) -> \
result = None
if len(models_to_insert) > 0:
result = await self.get_collection().insert_many(
(self.to_document(model) for model in models_to_insert),
**kwargs
(self.to_document(model) for model in models_to_insert), **kwargs
)

for idx, inserted_id in enumerate(result.inserted_ids):
Expand All @@ -611,14 +624,18 @@ async def save_many(self, models: Iterable[T], **kwargs) -> \
models_to_update[idx].id = inserted_id
return bw, result

async def delete_many(self, models: Iterable[T], **kwargs) -> Union[DeleteResult, None]:
async def delete_many(
self, models: Iterable[T], **kwargs
) -> Union[DeleteResult, None]:
"""
Delete multiple entities from database
"""
mongo_ids = [model.id for model in models]
if len(mongo_ids) == 0:
return None
return await self.get_collection().delete_many({"_id": {"$in": mongo_ids}}, **kwargs)
return await self.get_collection().delete_many(
{"_id": {"$in": mongo_ids}}, **kwargs
)

async def delete(self, model: T, **kwargs) -> DeleteResult:
"""
Expand Down
9 changes: 6 additions & 3 deletions pydantic_mongo/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __get_validators__(cls):

@classmethod
def __get_pydantic_json_schema__(
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
return handler(core_schema.str_schema())

Expand All @@ -31,8 +31,11 @@ def __get_pydantic_core_schema__(
return core_schema.json_or_python_schema(
json_schema=object_id_schema,
python_schema=core_schema.union_schema(
[core_schema.no_info_plain_validator_function(cls.validate), core_schema.is_instance_schema(ObjectId),
object_id_schema]
[
core_schema.no_info_plain_validator_function(cls.validate),
core_schema.is_instance_schema(ObjectId),
object_id_schema,
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda x: str(x)
Expand Down
4 changes: 1 addition & 3 deletions test/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ def test_save(self, database):
def test_save_upsert(self, database):
spam_repository = SpamRepository(database=database)
spam = Spam(
id=ObjectId("65012da68ea5a4798502f710"),
foo=Foo(count=1, size=1.0),
bars=[]
id=ObjectId("65012da68ea5a4798502f710"), foo=Foo(count=1, size=1.0), bars=[]
)
spam_repository.save(spam)

Expand Down

0 comments on commit bb1708c

Please sign in to comment.