Skip to content

Commit

Permalink
Handle special arguments when extracting parameter names for view API…
Browse files Browse the repository at this point in the history
… page (#8400)

* fix special args

* add changeset

* format

* ignore typecheck

---------

Co-authored-by: gradio-pr-bot <[email protected]>
  • Loading branch information
abidlabs and gradio-pr-bot authored Jun 4, 2024
1 parent d393a4a commit 33c8081
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 39 deletions.
5 changes: 5 additions & 0 deletions .changeset/fair-items-sort.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": patch
---

fix:Handle special arguments when extracting parameter names for view API page
70 changes: 31 additions & 39 deletions gradio/chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

import functools
import inspect
from typing import AsyncGenerator, Callable, Literal, Union, cast

Expand Down Expand Up @@ -448,7 +449,36 @@ def _setup_stop_events(
)

def _setup_api(self) -> None:
api_fn = self._api_stream_fn if self.is_generator else self._api_submit_fn
if self.is_generator:

@functools.wraps(self.fn)
async def api_fn(message, history, *args, **kwargs): # type: ignore
if self.is_async:
generator = self.fn(message, history, *args, **kwargs)
else:
generator = await anyio.to_thread.run_sync(
self.fn, message, history, *args, **kwargs, limiter=self.limiter
)
generator = SyncToAsyncIterator(generator, self.limiter)
try:
first_response = await async_iteration(generator)
yield first_response, history + [[message, first_response]]
except StopIteration:
yield None, history + [[message, None]]
async for response in generator:
yield response, history + [[message, response]]
else:

@functools.wraps(self.fn)
async def api_fn(message, history, *args, **kwargs):
if self.is_async:
response = await self.fn(message, history, *args, **kwargs)
else:
response = await anyio.to_thread.run_sync(
self.fn, message, history, *args, **kwargs, limiter=self.limiter
)
history.append([message, response])
return response, history

self.fake_api_btn.click(
api_fn,
Expand Down Expand Up @@ -575,44 +605,6 @@ async def _stream_fn(
update = history + [[message, response]]
yield update, update

async def _api_submit_fn(
self, message: str, history: list[list[str | None]], request: Request, *args
) -> tuple[str, list[list[str | None]]]:
inputs, _, _ = special_args(
self.fn, inputs=[message, history, *args], request=request
)

if self.is_async:
response = await self.fn(*inputs)
else:
response = await anyio.to_thread.run_sync(
self.fn, *inputs, limiter=self.limiter
)
history.append([message, response])
return response, history

async def _api_stream_fn(
self, message: str, history: list[list[str | None]], request: Request, *args
) -> AsyncGenerator:
inputs, _, _ = special_args(
self.fn, inputs=[message, history, *args], request=request
)

if self.is_async:
generator = self.fn(*inputs)
else:
generator = await anyio.to_thread.run_sync(
self.fn, *inputs, limiter=self.limiter
)
generator = SyncToAsyncIterator(generator, self.limiter)
try:
first_response = await async_iteration(generator)
yield first_response, history + [[message, first_response]]
except StopIteration:
yield None, history + [[message, None]]
async for response in generator:
yield response, history + [[message, response]]

async def _examples_fn(self, message: str, *args) -> list[list[str | None]]:
inputs, _, _ = special_args(self.fn, inputs=[message, [], *args], request=None)

Expand Down
7 changes: 7 additions & 0 deletions gradio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,14 +1302,21 @@ def get_upload_folder() -> str:


def get_function_params(func: Callable) -> list[tuple[str, bool, Any]]:
"""
Gets the parameters of a function as a list of tuples of the form (name, has_default, default_value).
Excludes *args and **kwargs, as well as args that are Gradio-specific, such as gr.Request, gr.EventData, gr.OAuthProfile, and gr.OAuthToken.
"""
params_info = []
signature = inspect.signature(func)
type_hints = get_type_hints(func)
for name, parameter in signature.parameters.items():
if parameter.kind in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
):
break
if is_special_typed_parameter(name, type_hints):
continue
if parameter.default is inspect.Parameter.empty:
params_info.append((name, False, None))
else:
Expand Down
6 changes: 6 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,12 @@ def func(a, **kwargs):

assert get_function_params(func) == [("a", False, None)]

def test_function_with_special_args(self):
def func(a, r: Request, b=10):
pass

assert get_function_params(func) == [("a", False, None), ("b", True, 10)]

def test_class_method_skip_first_param(self):
class MyClass:
def method(self, arg1, arg2=42):
Expand Down

0 comments on commit 33c8081

Please sign in to comment.