Skip to content

Commit

Permalink
Change client submit API to be an AsyncIterable and support more plat…
Browse files Browse the repository at this point in the history
…forms (#8451)

* fix param name

* format

* save

* changes

* changes

* fix param name

* format

* switch to async iterable interface

* switch to async iterable interface

* changes

* add changeset

* fix

* fix param name

* format

* fixes

* fix checks

* fix checks

* add changeset

* fix checks

* add changeset

* add changeset

* fix checks

* fix param name

* format

* fix types

* cleanup comments

* add changeset

* fix param name

* format

* changes

* Refactor Cancelling Logic To Use /cancel (#8370)

* Cancel refactor

* add changeset

* add changeset

* types

* Add code

* Fix types

---------

Co-authored-by: gradio-pr-bot <[email protected]>

* fix param name

* format

* changes

* fix

* fix param name

* format

* fix test

* fix notebooks

* fix type

---------

Co-authored-by: gradio-pr-bot <[email protected]>
Co-authored-by: Freddy Boulton <[email protected]>
  • Loading branch information
3 people authored Jun 6, 2024
1 parent 6447dfa commit 9d2d605
Show file tree
Hide file tree
Showing 27 changed files with 589 additions and 331 deletions.
9 changes: 9 additions & 0 deletions .changeset/quick-melons-remain.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
"@gradio/app": patch
"@gradio/client": patch
"@gradio/file": patch
"@gradio/spaces-test": patch
"gradio": patch
---

fix:Change client submit API to be an AsyncIterable and support more platforms
39 changes: 39 additions & 0 deletions client/js/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
<!doctype html>

<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Client</title>
<script type="module">
import { Client } from "./dist/index.js";
console.log(Client);

const client = await Client.connect("pngwn/chatinterface_streaming_echo");
async function run(message, n) {
// console.log(client);
const req = client.submit("/chat", {
message
});
console.log("start");
for await (const c of req) {
if (c.type === "data") {
console.log(`${n}: ${c.data[0]}`);
}
}

console.log("end");

return "hi";
}

run("My name is frank", 1);
run("Hello there", 2);

console.log("boo");
</script>
</head>
<body>
<div id="app"></div>
</body>
</html>
5 changes: 4 additions & 1 deletion client/js/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
"@types/eventsource": "^1.1.15",
"bufferutil": "^4.0.7",
"eventsource": "^2.0.2",
"fetch-event-stream": "^0.1.5",
"msw": "^2.2.1",
"semiver": "^1.1.0",
"textlinestream": "^1.1.1",
"typescript": "^5.0.0",
"ws": "^8.13.0"
},
Expand All @@ -31,7 +33,8 @@
"build": "pnpm bundle && pnpm generate_types",
"test": "pnpm test:client && pnpm test:client:node",
"test:client": "vitest run -c vite.config.js",
"test:client:node": "TEST_MODE=node vitest run -c vite.config.js"
"test:client:node": "TEST_MODE=node vitest run -c vite.config.js",
"preview:browser": "vite dev --mode=preview"
},
"engines": {
"node": ">=18.0.0"
Expand Down
61 changes: 36 additions & 25 deletions client/js/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ import type {
PredictReturn,
SpaceStatus,
Status,
SubmitReturn,
UploadResponse,
client_return
client_return,
SubmitIterable,
GradioEvent
} from "./types";
import { view_api } from "./utils/view_api";
import { upload_files } from "./utils/upload_files";
Expand All @@ -30,7 +31,7 @@ import {
parse_and_set_cookies
} from "./helpers/init_helpers";
import { check_space_status } from "./helpers/spaces";
import { open_stream } from "./utils/stream";
import { open_stream, readable_stream } from "./utils/stream";
import { API_INFO_ERROR_MSG, CONFIG_ERROR_MSG } from "./constants";

export class Client {
Expand All @@ -53,6 +54,8 @@ export class Client {
event_callbacks: Record<string, (data?: unknown) => Promise<void>> = {};
unclosed_events: Set<string> = new Set();
heartbeat_event: EventSource | null = null;
abort_controller: AbortController | null = null;
stream_instance: EventSource | null = null;

fetch(input: RequestInfo | URL, init?: RequestInit): Promise<Response> {
const headers = new Headers(init?.headers || {});
Expand All @@ -63,18 +66,14 @@ export class Client {
return fetch(input, { ...init, headers });
}

async stream(url: URL): Promise<EventSource> {
if (typeof window === "undefined" || typeof EventSource === "undefined") {
try {
const EventSourceModule = await import("eventsource");
return new EventSourceModule.default(url.toString()) as EventSource;
} catch (error) {
console.error("Failed to load EventSource module:", error);
throw error;
}
} else {
return new EventSource(url.toString());
}
stream(url: URL): EventSource {
this.abort_controller = new AbortController();

this.stream_instance = readable_stream(url.toString(), {
signal: this.abort_controller.signal
});

return this.stream_instance;
}

view_api: () => Promise<ApiInfo<JsApiData>>;
Expand Down Expand Up @@ -104,7 +103,7 @@ export class Client {
data: unknown[] | Record<string, unknown>,
event_data?: unknown,
trigger_id?: number | null
) => SubmitReturn;
) => SubmitIterable<GradioEvent>;
predict: (
endpoint: string | number,
data: unknown[] | Record<string, unknown>,
Expand All @@ -113,8 +112,15 @@ export class Client {
open_stream: () => Promise<void>;
private resolve_config: (endpoint: string) => Promise<Config | undefined>;
private resolve_cookies: () => Promise<void>;
constructor(app_reference: string, options: ClientOptions = {}) {
constructor(
app_reference: string,
options: ClientOptions = { events: ["data"] }
) {
this.app_reference = app_reference;
if (!options.events) {
options.events = ["data"];
}

this.options = options;

this.view_api = view_api.bind(this);
Expand Down Expand Up @@ -184,16 +190,17 @@ export class Client {
}

// Just connect to the endpoint without parsing the response. Ref: https://github.com/gradio-app/gradio/pull/7974#discussion_r1557717540
if (!this.heartbeat_event)
this.heartbeat_event = await this.stream(heartbeat_url);
} else {
this.heartbeat_event?.close();
if (!this.heartbeat_event) {
this.heartbeat_event = this.stream(heartbeat_url);
}
}
}

static async connect(
app_reference: string,
options: ClientOptions = {}
options: ClientOptions = {
events: ["data"]
}
): Promise<Client> {
const client = new this(app_reference, options); // this refers to the class itself, not the instance
await client.init();
Expand All @@ -206,7 +213,9 @@ export class Client {

static async duplicate(
app_reference: string,
options: DuplicateOptions = {}
options: DuplicateOptions = {
events: ["data"]
}
): Promise<Client> {
return duplicate(app_reference, options);
}
Expand Down Expand Up @@ -253,7 +262,7 @@ export class Client {
): Promise<Config | client_return> {
this.config = _config;

if (typeof window !== "undefined") {
if (typeof window !== "undefined" && typeof document !== "undefined") {
if (window.location.protocol === "https:") {
this.config.root = this.config.root.replace("http://", "https://");
}
Expand Down Expand Up @@ -405,7 +414,9 @@ export class Client {
*/
export async function client(
app_reference: string,
options: ClientOptions = {}
options: ClientOptions = {
events: ["data"]
}
): Promise<Client> {
return await Client.connect(app_reference, options);
}
Expand Down
3 changes: 2 additions & 1 deletion client/js/src/helpers/spaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,10 @@ export async function discussions_enabled(space_id: string): Promise<boolean> {
method: "HEAD"
}
);

const error = r.headers.get("x-error-message");

if (error && RE_DISABLED_DISCUSSION.test(error)) return false;
if (!r.ok || (error && RE_DISABLED_DISCUSSION.test(error))) return false;
return true;
} catch (e) {
return false;
Expand Down
6 changes: 5 additions & 1 deletion client/js/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@ export { handle_file } from "./helpers/data";

export type {
SpaceStatus,
StatusMessage,
Status,
client_return,
UploadResponse
UploadResponse,
RenderMessage,
LogMessage,
Payload
} from "./types";

// todo: remove in @gradio/client v1.0
Expand Down
10 changes: 9 additions & 1 deletion client/js/src/test/handlers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import {

const root_url = "https://huggingface.co";

const direct_space_url = "https://hmb-hello-world.hf.space";
export const direct_space_url = "https://hmb-hello-world.hf.space";
const private_space_url = "https://hmb-secret-world.hf.space";
const private_auth_space_url = "https://hmb-private-auth-space.hf.space";

Expand Down Expand Up @@ -431,6 +431,14 @@ export const handlers: RequestHandler[] = [
});
}),
// queue requests
http.get(`${direct_space_url}/queue/data`, () => {
return new HttpResponse(JSON.stringify({ event_id: "123" }), {
status: 200,
headers: {
"Content-Type": "application/json"
}
});
}),
http.post(`${direct_space_url}/queue/join`, () => {
return new HttpResponse(JSON.stringify({ event_id: "123" }), {
status: 200,
Expand Down
22 changes: 12 additions & 10 deletions client/js/src/test/stream.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import { vi, type Mock } from "vitest";
import { Client } from "../client";
import { readable_stream } from "../utils/stream";
import { initialise_server } from "./server";
import { direct_space_url } from "./handlers.ts";

import {
describe,
Expand All @@ -11,27 +13,23 @@ import {
afterAll,
beforeEach
} from "vitest";
import "./mock_eventsource.ts";
import NodeEventSource from "eventsource";

const server = initialise_server();
const IS_NODE = process.env.TEST_MODE === "node";

beforeAll(() => server.listen());
afterEach(() => server.resetHandlers());
afterAll(() => server.close());

describe("open_stream", () => {
let mock_eventsource: any;
let app: Client;

beforeEach(async () => {
app = await Client.connect("hmb/hello_world");
app.stream = vi.fn().mockImplementation(() => {
mock_eventsource = IS_NODE
? new NodeEventSource("")
: new EventSource("");
return mock_eventsource;
app.stream_instance = readable_stream(
new URL(`${direct_space_url}/queue/data`)
);
return app.stream_instance;
});
});

Expand All @@ -58,8 +56,12 @@ describe("open_stream", () => {

expect(app.stream).toHaveBeenCalledWith(eventsource_mock_call);

const onMessageCallback = mock_eventsource.onmessage;
const onErrorCallback = mock_eventsource.onerror;
if (!app.stream_instance?.onmessage || !app.stream_instance?.onerror) {
throw new Error("stream instance is not defined");
}

const onMessageCallback = app.stream_instance.onmessage.bind(app);
const onErrorCallback = app.stream_instance.onerror.bind(app);

const message = { msg: "hello jerry" };

Expand Down
2 changes: 1 addition & 1 deletion client/js/src/test/upload_files.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ describe("upload_files", () => {
expect(response.files[0]).toBe("lion.jpg");
});

it("should handle a server error when connected to a running app and uploading files", async () => {
it.skip("should handle a server error when connected to a running app and uploading files", async () => {
const client = await Client.connect("hmb/server_test");

const root_url = "https://hmb-server-test.hf.space";
Expand Down
Loading

0 comments on commit 9d2d605

Please sign in to comment.