Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: 为子依赖添加 PEP593 Annotated 支持 #1832

Merged
merged 7 commits into from
Mar 22, 2023
40 changes: 27 additions & 13 deletions nonebot/internal/params.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
import inspect
from typing_extensions import Annotated
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from typing import TYPE_CHECKING, Any, Type, Tuple, Literal, Callable, Optional, cast

from pydantic.typing import get_args, get_origin
from pydantic.fields import Required, Undefined, ModelField

from nonebot.dependencies.utils import check_field_type
Expand Down Expand Up @@ -78,21 +80,33 @@ def __repr__(self) -> str:
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["DependParam"]:
if isinstance(param.default, DependsInner):
dependency: T_Handler
if param.default.dependency is None:
assert param.annotation is not param.empty, "Dependency cannot be empty"
dependency = param.annotation
else:
dependency = param.default.dependency
sub_dependent = Dependent[Any].parse(
call=dependency,
allow_types=allow_types,
)
return cls(
Required, use_cache=param.default.use_cache, dependent=sub_dependent
type_annotation, depends_inner = param.annotation, None
if get_origin(param.annotation) is Annotated:
type_annotation, *extra_args = get_args(param.annotation)
depends_inner = next(
(x for x in extra_args if isinstance(x, DependsInner)), None
)

depends_inner = (
param.default if isinstance(param.default, DependsInner) else depends_inner
)
if depends_inner is None:
return

dependency: T_Handler
if depends_inner.dependency is None:
assert (
type_annotation is not inspect.Signature.empty
), "Dependency cannot be empty"
dependency = type_annotation
else:
dependency = depends_inner.dependency
sub_dependent = Dependent[Any].parse(
call=dependency,
allow_types=allow_types,
)
return cls(Required, use_cache=depends_inner.use_cache, dependent=sub_dependent)

@classmethod
def _check_parameterless(
cls, value: Any, allow_types: Tuple[Type[Param], ...]
Expand Down
15 changes: 15 additions & 0 deletions tests/plugins/param/param_depend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from typing_extensions import Annotated

from nonebot import on_message
from nonebot.params import Depends
Expand Down Expand Up @@ -47,3 +48,17 @@ async def depends_cache(y: int = Depends(dependency, use_cache=True)):

async def class_depend(c: ClassDependency = Depends()):
return c


async def annotated_depend(x: Annotated[int, Depends(dependency)]):
return x


async def annotated_class_depend(c: Annotated[ClassDependency, Depends()]):
return c


async def annotated_prior_depend(
x: Annotated[int, Depends(lambda: 2)] = Depends(dependency)
):
return x
17 changes: 17 additions & 0 deletions tests/test_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ async def test_depend(app: App):
depends,
class_depend,
test_depends,
annotated_depend,
annotated_class_depend,
annotated_prior_depend,
)

async with app.test_dependent(depends, allow_types=[DependParam]) as ctx:
Expand All @@ -63,6 +66,20 @@ async def test_depend(app: App):
async with app.test_dependent(class_depend, allow_types=[DependParam]) as ctx:
ctx.should_return(ClassDependency(x=1, y=2))

async with app.test_dependent(annotated_depend, allow_types=[DependParam]) as ctx:
ctx.should_return(1)

async with app.test_dependent(
annotated_prior_depend, allow_types=[DependParam]
) as ctx:
ctx.should_return(1)
assert runned == [1, 1]

async with app.test_dependent(
annotated_class_depend, allow_types=[DependParam]
) as ctx:
ctx.should_return(ClassDependency(x=1, y=2))


@pytest.mark.asyncio
async def test_bot(app: App):
Expand Down