Skip to content

Commit

Permalink
Remove deprecated parameters from Python Client (#8444)
Browse files Browse the repository at this point in the history
* deprecate

* add changeset

* file -> handle_file

* more updates

* format

* add changeset

* fix connect

* fix tests

* fix more tests

* remove outdated test

* serialize

* address review comments

* fix dir

---------

Co-authored-by: gradio-pr-bot <[email protected]>
  • Loading branch information
abidlabs and gradio-pr-bot authored Jun 5, 2024
1 parent 3dbce6b commit 2cd02ff
Show file tree
Hide file tree
Showing 22 changed files with 87 additions and 114 deletions.
7 changes: 7 additions & 0 deletions .changeset/neat-trains-repair.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"@gradio/app": minor
"gradio": minor
"gradio_client": minor
---

feat:Remove deprecated parameters from Python Client
3 changes: 2 additions & 1 deletion client/python/gradio_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from gradio_client.client import Client
from gradio_client.utils import __version__, file
from gradio_client.utils import __version__, file, handle_file

__all__ = [
"Client",
"file",
"handle_file",
"__version__",
]
70 changes: 21 additions & 49 deletions client/python/gradio_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,39 +77,26 @@ def __init__(
src: str,
hf_token: str | None = None,
max_workers: int = 40,
serialize: bool | None = None, # TODO: remove in 1.0
output_dir: str
| Path = DEFAULT_TEMP_DIR, # Maybe this can be combined with `download_files` in 1.0
verbose: bool = True,
auth: tuple[str, str] | None = None,
*,
headers: dict[str, str] | None = None,
upload_files: bool = True, # TODO: remove and hardcode to False in 1.0
download_files: bool = True, # TODO: consider setting to False in 1.0
_skip_components: bool = True, # internal parameter to skip values certain components (e.g. State) that do not need to be displayed to users.
download_files: str | Path | Literal[False] = DEFAULT_TEMP_DIR,
ssl_verify: bool = True,
_skip_components: bool = True, # internal parameter to skip values certain components (e.g. State) that do not need to be displayed to users.
):
"""
Parameters:
src: Either the name of the Hugging Face Space to load, (e.g. "abidlabs/whisper-large-v2") or the full URL (including "http" or "https") of the hosted Gradio app to load (e.g. "http://mydomain.com/app" or "https://bec81a83-5b5c-471e.gradio.live/").
hf_token: The Hugging Face token to use to access private Spaces. Automatically fetched if you are logged in via the Hugging Face Hub CLI. Obtain from: https://huggingface.co/settings/token
max_workers: The maximum number of thread workers that can be used to make requests to the remote Gradio app simultaneously.
serialize: Deprecated. Please use the equivalent `upload_files` parameter instead.
output_dir: The directory to save files that are downloaded from the remote API. If None, reads from the GRADIO_TEMP_DIR environment variable. Defaults to a temporary directory on your machine.
verbose: Whether the client should print statements to the console.
headers: Additional headers to send to the remote Gradio app on every request. By default only the HF authorization and user-agent headers are sent. These headers will override the default headers if they have the same keys.
upload_files: Whether the client should treat input string filepath as files and upload them to the remote server. If False, the client will treat input string filepaths as strings always and not modify them, and files should be passed in explicitly using `gradio_client.file("path/to/file/or/url")` instead. This parameter will be deleted and False will become the default in a future version.
download_files: Whether the client should download output files from the remote API and return them as string filepaths on the local machine. If False, the client will return a FileData dataclass object with the filepath on the remote machine instead.
headers: Additional headers to send to the remote Gradio app on every request. By default only the HF authorization and user-agent headers are sent. This parameter will override the default headers if they have the same keys.
download_files: Directory where the client should download output files on the local machine from the remote API. By default, uses the value of the GRADIO_TEMP_DIR environment variable which, if not set by the user, is a temporary directory on your machine. If False, the client does not download files and returns a FileData dataclass object with the filepath on the remote machine instead.
ssl_verify: If False, skips certificate validation which allows the client to connect to Gradio apps that are using self-signed certificates.
"""
self.verbose = verbose
self.hf_token = hf_token
if serialize is not None:
warnings.warn(
"The `serialize` parameter is deprecated and will be removed. Please use the equivalent `upload_files` parameter instead."
)
upload_files = serialize
self.upload_files = upload_files
self.download_files = download_files
self._skip_components = _skip_components
self.headers = build_hf_headers(
Expand All @@ -122,9 +109,14 @@ def __init__(
self.ssl_verify = ssl_verify
self.space_id = None
self.cookies: dict[str, str] = {}
self.output_dir = (
str(output_dir) if isinstance(output_dir, Path) else output_dir
)
if isinstance(self.download_files, (str, Path)):
if not os.path.exists(self.download_files):
os.makedirs(self.download_files, exist_ok=True)
if not os.path.isdir(self.download_files):
raise ValueError(f"Path: {self.download_files} is not a directory.")
self.output_dir = str(self.download_files)
else:
self.output_dir = DEFAULT_TEMP_DIR

if src.startswith("http://") or src.startswith("https://"):
_src = src if src.endswith("/") else src + "/"
Expand Down Expand Up @@ -554,10 +546,7 @@ def fn(future):
return job

def _get_api_info(self):
if self.upload_files:
api_info_url = urllib.parse.urljoin(self.src, utils.API_INFO_URL)
else:
api_info_url = urllib.parse.urljoin(self.src, utils.RAW_API_INFO_URL)
api_info_url = urllib.parse.urljoin(self.src, utils.RAW_API_INFO_URL)
if self.app_version > version.Version("3.36.1"):
r = httpx.get(
api_info_url,
Expand All @@ -574,7 +563,7 @@ def _get_api_info(self):
utils.SPACE_FETCHER_URL,
json={
"config": json.dumps(self.config),
"serialize": self.upload_files,
"serialize": False,
},
)
if fetch.is_success:
Expand Down Expand Up @@ -737,7 +726,7 @@ def _render_endpoints_info(
default_value = info.get("parameter_default")
default_value = utils.traverse(
default_value,
lambda x: f"file(\"{x['url']}\")",
lambda x: f"handle_file(\"{x['url']}\")",
utils.is_file_obj_with_meta,
)
default_info = (
Expand Down Expand Up @@ -1273,20 +1262,11 @@ def insert_empty_state(self, *data) -> tuple:
def process_input_files(self, *data) -> tuple:
data_ = []
for i, d in enumerate(data):
if self.client.upload_files and self.input_component_types[i].value_is_file:
d = utils.traverse(
d,
partial(self._upload_file, data_index=i),
lambda f: utils.is_filepath(f)
or utils.is_file_obj_with_meta(f)
or utils.is_http_url_like(f),
)
elif not self.client.upload_files:
d = utils.traverse(
d,
partial(self._upload_file, data_index=i),
utils.is_file_obj_with_meta,
)
d = utils.traverse(
d,
partial(self._upload_file, data_index=i),
utils.is_file_obj_with_meta,
)
data_.append(d)
return tuple(data_)

Expand Down Expand Up @@ -1329,15 +1309,7 @@ def reduce_singleton_output(self, *data) -> Any:
return data

def _upload_file(self, f: str | dict, data_index: int) -> dict[str, str]:
if isinstance(f, str):
warnings.warn(
f'The Client is treating: "{f}" as a file path. In future versions, this behavior will not happen automatically. '
f'\n\nInstead, please provide file path or URLs like this: gradio_client.file("{f}"). '
"\n\nNote: to stop treating strings as filepaths unless file() is used, set upload_files=False in Client()."
)
file_path = f
else:
file_path = f["path"]
file_path = f["path"]
orig_name = Path(file_path)
if not utils.is_http_url_like(file_path):
component_id = self.dependency["inputs"][data_index]
Expand Down
3 changes: 1 addition & 2 deletions client/python/gradio_client/compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ def _inner(*data):
if not self.is_valid:
raise utils.InvalidAPIEndpointError()
data = self.insert_state(*data)
if self.client.upload_files:
data = self.serialize(*data)
data = self.serialize(*data)
predictions = _predict(*data)
predictions = self.process_predictions(*predictions)
# Append final output only if not already present
Expand Down
9 changes: 8 additions & 1 deletion client/python/gradio_client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,7 @@ def is_file_obj_with_url(d) -> bool:
}


def file(filepath_or_url: str | Path):
def handle_file(filepath_or_url: str | Path):
s = str(filepath_or_url)
data = {"path": s, "meta": {"_type": "gradio.FileData"}}
if is_http_url_like(s):
Expand All @@ -1093,6 +1093,13 @@ def file(filepath_or_url: str | Path):
)


def file(filepath_or_url: str | Path):
warnings.warn(
"file() is deprecated and will be removed in a future version. Use handle_file() instead."
)
return handle_file(filepath_or_url)


def construct_args(
parameters_info: list[ParameterInfo] | None, args: tuple, kwargs: dict
) -> list:
Expand Down
36 changes: 12 additions & 24 deletions client/python/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from huggingface_hub import HfFolder
from huggingface_hub.utils import RepositoryNotFoundError

from gradio_client import Client, file
from gradio_client import Client, handle_file
from gradio_client.client import DEFAULT_TEMP_DIR
from gradio_client.exceptions import AppError, AuthenticationError
from gradio_client.utils import (
Expand All @@ -37,13 +37,12 @@
@contextmanager
def connect(
demo: gr.Blocks,
serialize: bool = True,
output_dir: str = DEFAULT_TEMP_DIR,
download_files: str = DEFAULT_TEMP_DIR,
**kwargs,
):
_, local_url, _ = demo.launch(prevent_thread_lock=True, **kwargs)
try:
yield Client(local_url, serialize=serialize, output_dir=output_dir)
yield Client(local_url, download_files=download_files)
finally:
# A more verbose version of .close()
# because we should set a timeout
Expand Down Expand Up @@ -92,11 +91,11 @@ def test_raise_error_max_file_size(self, max_file_size_demo):
with connect(max_file_size_demo, max_file_size="15kb") as client:
with pytest.raises(ValueError, match="exceeds the maximum file size"):
client.predict(
file(Path(__file__).parent / "files" / "cheetah1.jpg"),
handle_file(Path(__file__).parent / "files" / "cheetah1.jpg"),
api_name="/upload_1b",
)
client.predict(
file(Path(__file__).parent / "files" / "alphabet.txt"),
handle_file(Path(__file__).parent / "files" / "alphabet.txt"),
api_name="/upload_1b",
)

Expand Down Expand Up @@ -254,17 +253,11 @@ def test_raises_exception(self, calculator_demo):
job = client.submit("foo", "add", 9, fn_index=0)
job.result()

def test_raises_exception_no_queue(self, sentiment_classification_demo):
with pytest.raises(Exception):
with connect(sentiment_classification_demo) as client:
job = client.submit([5], api_name="/sleep")
job.result()

def test_job_output_video(self, video_component):
with connect(video_component) as client:
job = client.submit(
{
"video": file(
"video": handle_file(
"https://huggingface.co/spaces/gradio/video_component/resolve/main/files/a.mp4"
)
},
Expand All @@ -277,10 +270,10 @@ def test_job_output_video(self, video_component):
)

temp_dir = tempfile.mkdtemp()
with connect(video_component, output_dir=temp_dir) as client:
with connect(video_component, download_files=temp_dir) as client:
job = client.submit(
{
"video": file(
"video": handle_file(
"https://huggingface.co/spaces/gradio/video_component/resolve/main/files/a.mp4"
)
},
Expand Down Expand Up @@ -430,13 +423,15 @@ def test_cancel_subsequent_jobs_state_reset(self, yield_demo):
def test_stream_audio(self, stream_audio):
with connect(stream_audio) as client:
job1 = client.submit(
file("https://gradio-builds.s3.amazonaws.com/demo-files/bark_demo.mp4"),
handle_file(
"https://gradio-builds.s3.amazonaws.com/demo-files/bark_demo.mp4"
),
api_name="/predict",
)
assert Path(job1.result()).exists()

job2 = client.submit(
file(
handle_file(
"https://gradio-builds.s3.amazonaws.com/demo-files/audio_sample.wav"
),
api_name="/predict",
Expand Down Expand Up @@ -552,13 +547,6 @@ def test_upload_file_upload_route_does_not_exist(self):
client.submit(1, "foo", f.name, fn_index=0).result()
serialize.assert_called_once_with(1, "foo", f.name)

def test_state_without_serialize(self, stateful_chatbot):
with connect(stateful_chatbot, serialize=False) as client:
initial_history = [["", None]]
message = "Hello"
ret = client.predict(message, initial_history, api_name="/submit")
assert ret == ("", [["", None], ["Hello", "I love you"]])

def test_does_not_upload_dir(self, stateful_chatbot):
with connect(stateful_chatbot) as client:
initial_history = [["", None]]
Expand Down
4 changes: 2 additions & 2 deletions gradio/_simple_templates/simpleimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pathlib import Path
from typing import Any

from gradio_client import file
from gradio_client import handle_file
from gradio_client.documentation import document

from gradio.components.base import Component
Expand Down Expand Up @@ -102,7 +102,7 @@ def postprocess(self, value: str | Path | None) -> FileData | None:
return FileData(path=str(value), orig_name=Path(value).name)

def example_payload(self) -> Any:
return file(
return handle_file(
"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
)

Expand Down
4 changes: 2 additions & 2 deletions gradio/components/annotated_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import gradio_client.utils as client_utils
import numpy as np
import PIL.Image
from gradio_client import file
from gradio_client import handle_file
from gradio_client.documentation import document

from gradio import processing_utils, utils
Expand Down Expand Up @@ -217,7 +217,7 @@ def hex_to_rgb(value):

def example_payload(self) -> Any:
return {
"image": file(
"image": handle_file(
"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
),
"annotations": [],
Expand Down
4 changes: 2 additions & 2 deletions gradio/components/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import httpx
import numpy as np
from gradio_client import file
from gradio_client import handle_file
from gradio_client import utils as client_utils
from gradio_client.documentation import document

Expand Down Expand Up @@ -186,7 +186,7 @@ def __init__(
)

def example_payload(self) -> Any:
return file(
return handle_file(
"https://github.com/gradio-app/gradio/raw/main/test/test_files/audio_sample.wav"
)

Expand Down
4 changes: 2 additions & 2 deletions gradio/components/download_button.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pathlib import Path
from typing import Callable, Literal

from gradio_client import file
from gradio_client import handle_file
from gradio_client.documentation import document

from gradio.components.base import Component
Expand Down Expand Up @@ -104,7 +104,7 @@ def postprocess(self, value: str | Path | None) -> FileData | None:
return FileData(path=str(value))

def example_payload(self) -> dict:
return file(
return handle_file(
"https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
)

Expand Down
6 changes: 3 additions & 3 deletions gradio/components/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Any, Callable, Literal

import gradio_client.utils as client_utils
from gradio_client import file
from gradio_client import handle_file
from gradio_client.documentation import document

from gradio import processing_utils
Expand Down Expand Up @@ -203,12 +203,12 @@ def process_example(self, input_data: str | list | None) -> str:

def example_payload(self) -> Any:
if self.file_count == "single":
return file(
return handle_file(
"https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
)
else:
return [
file(
handle_file(
"https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
)
]
Expand Down
Loading

0 comments on commit 2cd02ff

Please sign in to comment.