Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: 支持子依赖定义 Pydantic 类型校验 #2310

Merged
merged 2 commits into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions nonebot/dependencies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class Param(abc.ABC, FieldInfo):
继承自 `pydantic.fields.FieldInfo`,用于描述参数信息(不包括参数名)。
"""

def __init__(self, *args, validate: bool = False, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.validate = validate

@classmethod
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type["Param"], ...]
Expand Down Expand Up @@ -206,10 +210,12 @@ async def check(self, **params: Any) -> None:
raise

async def _solve_field(self, field: ModelField, params: Dict[str, Any]) -> Any:
value = await cast(Param, field.field_info)._solve(**params)
param = cast(Param, field.field_info)
value = await param._solve(**params)
if value is Undefined:
value = field.get_default()
return check_field_type(field, value)
v = check_field_type(field, value)
return v if param.validate else value

async def solve(self, **params: Any) -> Dict[str, Any]:
# solve parameterless
Expand Down
10 changes: 4 additions & 6 deletions nonebot/dependencies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@
"""

import inspect
from typing import Any, Dict, TypeVar, Callable, ForwardRef
from typing import Any, Dict, Callable, ForwardRef

from loguru import logger
from pydantic.fields import ModelField
from pydantic.typing import evaluate_forwardref

from nonebot.exception import TypeMisMatch

V = TypeVar("V")


def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
"""获取可调用对象签名"""
Expand Down Expand Up @@ -49,10 +47,10 @@ def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) ->
return annotation


def check_field_type(field: ModelField, value: V) -> V:
def check_field_type(field: ModelField, value: Any) -> Any:
"""检查字段类型是否匹配"""

_, errs_ = field.validate(value, {}, loc=())
v, errs_ = field.validate(value, {}, loc=())
if errs_:
raise TypeMisMatch(field, value)
return value
return v
106 changes: 91 additions & 15 deletions nonebot/internal/params.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
import asyncio
import inspect
from typing_extensions import Annotated
from typing_extensions import Self, Annotated, override
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from typing import TYPE_CHECKING, Any, Type, Tuple, Literal, Callable, Optional, cast
from typing import (
TYPE_CHECKING,
Any,
Type,
Tuple,
Union,
Literal,
Callable,
Optional,
cast,
)

from pydantic.typing import get_args, get_origin
from pydantic.fields import Required, Undefined, ModelField
from pydantic.fields import Required, FieldInfo, Undefined, ModelField

from nonebot.dependencies.utils import check_field_type
from nonebot.dependencies import Param, Dependent, CustomConfig
Expand All @@ -24,33 +34,55 @@
from nonebot.matcher import Matcher
from nonebot.adapters import Bot, Event

EXTRA_FIELD_INFO = (
"gt",
"lt",
"ge",
"le",
"multiple_of",
"allow_inf_nan",
"max_digits",
"decimal_places",
"min_items",
"max_items",
"unique_items",
"min_length",
"max_length",
"regex",
)


class DependsInner:
def __init__(
self,
dependency: Optional[T_Handler] = None,
*,
use_cache: bool = True,
validate: Union[bool, FieldInfo] = False,
) -> None:
self.dependency = dependency
self.use_cache = use_cache
self.validate = validate

def __repr__(self) -> str:
dep = get_name(self.dependency)
cache = "" if self.use_cache else ", use_cache=False"
return f"DependsInner({dep}{cache})"
validate = f", validate={self.validate}" if self.validate else ""
return f"DependsInner({dep}{cache}{validate})"


def Depends(
dependency: Optional[T_Handler] = None,
*,
use_cache: bool = True,
validate: Union[bool, FieldInfo] = False,
) -> Any:
"""子依赖装饰器

参数:
dependency: 依赖函数。默认为参数的类型注释。
use_cache: 是否使用缓存。默认为 `True`。
validate: 是否使用 Pydantic 类型校验。默认为 `False`。

用法:
```python
Expand All @@ -70,7 +102,7 @@ async def handler(
...
```
"""
return DependsInner(dependency, use_cache=use_cache)
return DependsInner(dependency, use_cache=use_cache, validate=validate)


class DependParam(Param):
Expand All @@ -85,37 +117,63 @@ def __repr__(self) -> str:
return f"Depends({self.extra['dependent']})"

@classmethod
def _from_field(
cls, sub_dependent: Dependent, use_cache: bool, validate: Union[bool, FieldInfo]
) -> Self:
kwargs = {}
if isinstance(validate, FieldInfo):
kwargs.update((k, getattr(validate, k)) for k in EXTRA_FIELD_INFO)

return cls(
Required,
validate=bool(validate),
**kwargs,
dependent=sub_dependent,
use_cache=use_cache,
)

@classmethod
@override
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["DependParam"]:
) -> Optional[Self]:
type_annotation, depends_inner = param.annotation, None
# extract type annotation and dependency from Annotated
if get_origin(param.annotation) is Annotated:
type_annotation, *extra_args = get_args(param.annotation)
depends_inner = next(
(x for x in extra_args if isinstance(x, DependsInner)), None
)

# param default value takes higher priority
depends_inner = (
param.default if isinstance(param.default, DependsInner) else depends_inner
)
# not a dependent
if depends_inner is None:
return

dependency: T_Handler
# sub dependency is not specified, use type annotation
if depends_inner.dependency is None:
assert (
type_annotation is not inspect.Signature.empty
), "Dependency cannot be empty"
dependency = type_annotation
else:
dependency = depends_inner.dependency
# parse sub dependency
sub_dependent = Dependent[Any].parse(
call=dependency,
allow_types=allow_types,
)
return cls(Required, use_cache=depends_inner.use_cache, dependent=sub_dependent)

return cls._from_field(
sub_dependent, depends_inner.use_cache, depends_inner.validate
)

@classmethod
@override
def _check_parameterless(
cls, value: Any, allow_types: Tuple[Type[Param], ...]
) -> Optional["Param"]:
Expand All @@ -124,8 +182,9 @@ def _check_parameterless(
dependent = Dependent[Any].parse(
call=value.dependency, allow_types=allow_types
)
return cls(Required, use_cache=value.use_cache, dependent=dependent)
return cls._from_field(dependent, value.use_cache, value.validate)

@override
async def _solve(
self,
stack: Optional[AsyncExitStack] = None,
Expand Down Expand Up @@ -169,6 +228,7 @@ async def _solve(
dependency_cache[call] = task
return await task

@override
async def _check(self, **kwargs: Any) -> None:
# run sub dependent pre-checkers
sub_dependent: Dependent = self.extra["dependent"]
Expand All @@ -195,9 +255,10 @@ def __repr__(self) -> str:
)

@classmethod
@override
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["BotParam"]:
) -> Optional[Self]:
from nonebot.adapters import Bot

# param type is Bot(s) or subclass(es) of Bot or None
Expand All @@ -217,9 +278,11 @@ def _check_param(
elif param.annotation == param.empty and param.name == "bot":
return cls(Required)

@override
async def _solve(self, bot: "Bot", **kwargs: Any) -> Any:
return bot

@override
async def _check(self, bot: "Bot", **kwargs: Any) -> None:
if checker := self.extra.get("checker"):
check_field_type(checker, bot)
Expand All @@ -245,9 +308,10 @@ def __repr__(self) -> str:
)

@classmethod
@override
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["EventParam"]:
) -> Optional[Self]:
from nonebot.adapters import Event

# param type is Event(s) or subclass(es) of Event or None
Expand All @@ -267,9 +331,11 @@ def _check_param(
elif param.annotation == param.empty and param.name == "event":
return cls(Required)

@override
async def _solve(self, event: "Event", **kwargs: Any) -> Any:
return event

@override
async def _check(self, event: "Event", **kwargs: Any) -> Any:
if checker := self.extra.get("checker", None):
check_field_type(checker, event)
Expand All @@ -287,16 +353,18 @@ def __repr__(self) -> str:
return "StateParam()"

@classmethod
@override
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["StateParam"]:
) -> Optional[Self]:
# param type is T_State
if param.annotation is T_State:
return cls(Required)
# legacy: param is named "state" and has no type annotation
elif param.annotation == param.empty and param.name == "state":
return cls(Required)

@override
async def _solve(self, state: T_State, **kwargs: Any) -> Any:
return state

Expand All @@ -313,9 +381,10 @@ def __repr__(self) -> str:
return "MatcherParam()"

@classmethod
@override
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["MatcherParam"]:
) -> Optional[Self]:
from nonebot.matcher import Matcher

# param type is Matcher(s) or subclass(es) of Matcher or None
Expand All @@ -335,9 +404,11 @@ def _check_param(
elif param.annotation == param.empty and param.name == "matcher":
return cls(Required)

@override
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
return matcher

@override
async def _check(self, matcher: "Matcher", **kwargs: Any) -> Any:
if checker := self.extra.get("checker", None):
check_field_type(checker, matcher)
Expand Down Expand Up @@ -382,9 +453,10 @@ def __repr__(self) -> str:
return f"ArgParam(key={self.extra['key']!r}, type={self.extra['type']!r})"

@classmethod
@override
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["ArgParam"]:
) -> Optional[Self]:
if isinstance(param.default, ArgInner):
return cls(
Required, key=param.default.key or param.name, type=param.default.type
Expand Down Expand Up @@ -419,16 +491,18 @@ def __repr__(self) -> str:
return "ExceptionParam()"

@classmethod
@override
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["ExceptionParam"]:
) -> Optional[Self]:
# param type is Exception(s) or subclass(es) of Exception or None
if generic_check_issubclass(param.annotation, Exception):
return cls(Required)
# legacy: param is named "exception" and has no type annotation
elif param.annotation == param.empty and param.name == "exception":
return cls(Required)

@override
async def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any:
return exception

Expand All @@ -445,12 +519,14 @@ def __repr__(self) -> str:
return f"DefaultParam(default={self.default!r})"

@classmethod
@override
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["DefaultParam"]:
) -> Optional[Self]:
if param.default != param.empty:
return cls(param.default)

@override
async def _solve(self, **kwargs: Any) -> Any:
return Undefined

Expand Down
Loading
Loading