diff --git a/.changeset/young-poets-change.md b/.changeset/young-poets-change.md new file mode 100644 index 0000000000000..41ff2c264b836 --- /dev/null +++ b/.changeset/young-poets-change.md @@ -0,0 +1,8 @@ +--- +"@gradio/app": patch +"@gradio/client": patch +"@gradio/preview": patch +"gradio": patch +--- + +fix:Handle gradio apps using `state` in the JS Client diff --git a/client/js/src/client.ts b/client/js/src/client.ts index 3ff5e7f6a736f..752e80d03d5a3 100644 --- a/client/js/src/client.ts +++ b/client/js/src/client.ts @@ -6,6 +6,7 @@ import type { DuplicateOptions, EndpointInfo, JsApiData, + PredictReturn, SpaceStatus, Status, SubmitReturn, @@ -114,7 +115,7 @@ export class Client { endpoint: string | number, data: unknown[] | Record, event_data?: unknown - ) => Promise; + ) => Promise; open_stream: () => Promise; private resolve_config: (endpoint: string) => Promise; private resolve_cookies: () => Promise; diff --git a/client/js/src/helpers/data.ts b/client/js/src/helpers/data.ts index f1840ef766c20..88467e4043154 100644 --- a/client/js/src/helpers/data.ts +++ b/client/js/src/helpers/data.ts @@ -5,7 +5,9 @@ import type { Config, EndpointInfo, JsApiData, - DataType + DataType, + Dependency, + ComponentMeta } from "../types"; export function update_object( @@ -118,3 +120,62 @@ export function post_message( window.parent.postMessage(message, origin, [channel.port2]); }); } + +/** + * Handles the payload by filtering out state inputs and returning an array of resolved payload values. + * We send null values for state inputs to the server, but we don't want to include them in the resolved payload. + * + * @param resolved_payload - The resolved payload values received from the client or the server + * @param dependency - The dependency object. + * @param components - The array of component metadata. + * @param with_null_state - Optional. Specifies whether to include null values for state inputs. Default is false. + * @returns An array of resolved payload values, filtered based on the dependency and component metadata. + */ +export function handle_payload( + resolved_payload: unknown[], + dependency: Dependency, + components: ComponentMeta[], + type: "input" | "output", + with_null_state = false +): unknown[] { + if (type === "input" && !with_null_state) { + throw new Error("Invalid code path. Cannot skip state inputs for input."); + } + // data comes from the server with null state values so we skip + if (type === "output" && with_null_state) { + return resolved_payload; + } + + let updated_payload: unknown[] = []; + let payload_index = 0; + for (let i = 0; i < dependency.inputs.length; i++) { + const input_id = dependency.inputs[i]; + const component = components.find((c) => c.id === input_id); + + if (component?.type === "state") { + // input + with_null_state needs us to fill state with null values + if (with_null_state) { + if (resolved_payload.length === dependency.inputs.length) { + const value = resolved_payload[payload_index]; + updated_payload.push(value); + payload_index++; + } else { + updated_payload.push(null); + } + } else { + // this is output & !with_null_state, we skip state inputs + // the server payload always comes with null state values so we move along the payload index + payload_index++; + continue; + } + // input & !with_null_state isn't a case we care about, server needs null + continue; + } else { + const value = resolved_payload[payload_index]; + updated_payload.push(value); + payload_index++; + } + } + + return updated_payload; +} diff --git a/client/js/src/test/api_info.test.ts b/client/js/src/test/api_info.test.ts index bf6c413c5ef80..ad0a7538eaeef 100644 --- a/client/js/src/test/api_info.test.ts +++ b/client/js/src/test/api_info.test.ts @@ -16,7 +16,6 @@ import { initialise_server } from "./server"; import { transformed_api_info } from "./test_data"; const server = initialise_server(); -const IS_NODE = process.env.TEST_MODE === "node"; beforeAll(() => server.listen()); afterEach(() => server.resetHandlers()); diff --git a/client/js/src/test/data.test.ts b/client/js/src/test/data.test.ts index 82f8fb36e6db2..b2587fb38ecac 100644 --- a/client/js/src/test/data.test.ts +++ b/client/js/src/test/data.test.ts @@ -3,7 +3,8 @@ import { update_object, walk_and_store_blobs, skip_queue, - post_message + post_message, + handle_payload } from "../helpers/data"; import { NodeBlob } from "../client"; import { config_response, endpoint_info } from "./test_data"; @@ -276,3 +277,135 @@ describe("post_message", () => { ]); }); }); + +describe("handle_payload", () => { + it("should return an input payload with null in place of `state` when with_null_state is true", () => { + const resolved_payload = [2]; + const dependency = { + inputs: [1, 2] + }; + const components = [ + { id: 1, type: "number" }, + { id: 2, type: "state" } + ]; + const with_null_state = true; + const result = handle_payload( + resolved_payload, + // @ts-ignore + dependency, + components, + "input", + with_null_state + ); + expect(result).toEqual([2, null]); + }); + it("should return an input payload with null in place of two `state` components when with_null_state is true", () => { + const resolved_payload = ["hello", "goodbye"]; + const dependency = { + inputs: [1, 2, 3, 4] + }; + const components = [ + { id: 1, type: "textbox" }, + { id: 2, type: "state" }, + { id: 3, type: "textbox" }, + { id: 4, type: "state" } + ]; + const with_null_state = true; + const result = handle_payload( + resolved_payload, + // @ts-ignore + dependency, + components, + "input", + with_null_state + ); + expect(result).toEqual(["hello", null, "goodbye", null]); + }); + + it("should return an output payload without the state component value when with_null_state is false", () => { + const resolved_payload = ["hello", null]; + const dependency = { + inputs: [2, 3] + }; + const components = [ + { id: 2, type: "textbox" }, + { id: 3, type: "state" } + ]; + const with_null_state = false; + const result = handle_payload( + resolved_payload, + // @ts-ignore + dependency, + components, + "output", + with_null_state + ); + expect(result).toEqual(["hello"]); + }); + + it("should return an ouput payload without the two state component values when with_null_state is false", () => { + const resolved_payload = ["hello", null, "world", null]; + const dependency = { + inputs: [2, 3, 4, 5] + }; + const components = [ + { id: 2, type: "textbox" }, + { id: 3, type: "state" }, + { id: 4, type: "textbox" }, + { id: 5, type: "state" } + ]; + const with_null_state = false; + const result = handle_payload( + resolved_payload, + // @ts-ignore + dependency, + components, + "output", + with_null_state + ); + expect(result).toEqual(["hello", "world"]); + }); + + it("should return an ouput payload with the two state component values when with_null_state is true", () => { + const resolved_payload = ["hello", null, "world", null]; + const dependency = { + inputs: [2, 3, 4, 5] + }; + const components = [ + { id: 2, type: "textbox" }, + { id: 3, type: "state" }, + { id: 4, type: "textbox" }, + { id: 5, type: "state" } + ]; + const with_null_state = true; + const result = handle_payload( + resolved_payload, + // @ts-ignore + dependency, + components, + "output", + with_null_state + ); + expect(result).toEqual(["hello", null, "world", null]); + }); + + it("should return the same payload where no state components are defined", () => { + const resolved_payload = ["hello", "world"]; + const dependency = { + inputs: [2, 3] + }; + const components = [ + { id: 2, type: "textbox" }, + { id: 3, type: "textbox" } + ]; + const with_null_state = true; + const result = handle_payload( + resolved_payload, + // @ts-ignore + dependency, + components, + with_null_state + ); + expect(result).toEqual(["hello", "world"]); + }); +}); diff --git a/client/js/src/types.ts b/client/js/src/types.ts index 244bad47e1d68..9b8605423afb0 100644 --- a/client/js/src/types.ts +++ b/client/js/src/types.ts @@ -1,6 +1,8 @@ // API Data Types import { hardware_types } from "./helpers/spaces"; +import type { SvelteComponent } from "svelte"; +import type { ComponentType } from "svelte"; export interface ApiData { label: string; @@ -62,7 +64,7 @@ export type PredictFunction = ( endpoint: string | number, data: unknown[] | Record, event_data?: unknown -) => Promise; +) => Promise; // Event and Submission Types @@ -90,6 +92,14 @@ export type SubmitReturn = { destroy: () => void; }; +export type PredictReturn = { + type: EventType; + time: Date; + data: unknown; + endpoint: string; + fn_index: number; +}; + // Space Status Types export type SpaceStatus = SpaceStatusNormal | SpaceStatusError; @@ -128,7 +138,7 @@ export interface Config { analytics_enabled: boolean; connect_heartbeat: boolean; auth_message: string; - components: any[]; + components: ComponentMeta[]; css: string | null; js: string | null; head: string | null; @@ -153,6 +163,45 @@ export interface Config { max_file_size?: number; } +// todo: DRY up types +export interface ComponentMeta { + type: string; + id: number; + has_modes: boolean; + props: SharedProps; + instance: SvelteComponent; + component: ComponentType; + documentation?: Documentation; + children?: ComponentMeta[]; + parent?: ComponentMeta; + value?: any; + component_class_id: string; + key: string | number | null; + rendered_in?: number; +} + +interface SharedProps { + elem_id?: string; + elem_classes?: string[]; + components?: string[]; + server_fns?: string[]; + interactive: boolean; + [key: string]: unknown; + root_url?: string; +} + +export interface Documentation { + type?: TypeDescription; + description?: TypeDescription; + example_data?: string; +} + +interface TypeDescription { + input_payload?: string; + response_object?: string; + payload?: string; +} + export interface Dependency { id: number; targets: [number, string][]; @@ -218,6 +267,7 @@ export interface ClientOptions { hf_token?: `hf_${string}`; status_callback?: SpaceStatusCallback | null; auth?: [string, string] | null; + with_null_state?: boolean; } export interface FileData { diff --git a/client/js/src/utils/predict.ts b/client/js/src/utils/predict.ts index 74d89560759bc..a4bf47aa916d5 100644 --- a/client/js/src/utils/predict.ts +++ b/client/js/src/utils/predict.ts @@ -1,11 +1,11 @@ import { Client } from "../client"; -import type { Dependency, SubmitReturn } from "../types"; +import type { Dependency, PredictReturn } from "../types"; export async function predict( this: Client, endpoint: string | number, data: unknown[] | Record -): Promise { +): Promise { let data_returned = false; let status_complete = false; let dependency: Dependency; @@ -38,7 +38,7 @@ export async function predict( // if complete message comes before data, resolve here if (status_complete) { app.destroy(); - resolve(d as SubmitReturn); + resolve(d as PredictReturn); } data_returned = true; result = d; @@ -50,7 +50,7 @@ export async function predict( // if complete message comes after data, resolve here if (data_returned) { app.destroy(); - resolve(result as SubmitReturn); + resolve(result as PredictReturn); } } }); diff --git a/client/js/src/utils/submit.ts b/client/js/src/utils/submit.ts index 80ca94b7831fa..212ea55f66b18 100644 --- a/client/js/src/utils/submit.ts +++ b/client/js/src/utils/submit.ts @@ -14,7 +14,7 @@ import type { Dependency } from "../types"; -import { skip_queue, post_message } from "../helpers/data"; +import { skip_queue, post_message, handle_payload } from "../helpers/data"; import { resolve_root } from "../helpers/init_helpers"; import { handle_message, @@ -47,7 +47,8 @@ export function submit( pending_diff_streams, event_callbacks, unclosed_events, - post_data + post_data, + options } = this; if (!api_info) throw new Error("No API found"); @@ -193,8 +194,15 @@ export function submit( this.handle_blob(config.root, resolved_data, endpoint_info).then( async (_payload) => { + let input_data = handle_payload( + _payload, + dependency, + config.components, + "input", + true + ); payload = { - data: _payload || [], + data: input_data || [], event_data, fn_index, trigger_id @@ -225,7 +233,13 @@ export function submit( type: "data", endpoint: _endpoint, fn_index, - data: data, + data: handle_payload( + data, + dependency, + config.components, + "output", + options.with_null_state + ), time: new Date(), event_data, trigger_id @@ -359,7 +373,13 @@ export function submit( fire_event({ type: "data", time: new Date(), - data: data.data, + data: handle_payload( + data.data, + dependency, + config.components, + "output", + options.with_null_state + ), endpoint: _endpoint, fn_index, event_data, @@ -482,7 +502,13 @@ export function submit( fire_event({ type: "data", time: new Date(), - data: data.data, + data: handle_payload( + data.data, + dependency, + config.components, + "output", + options.with_null_state + ), endpoint: _endpoint, fn_index, event_data, @@ -633,7 +659,13 @@ export function submit( fire_event({ type: "data", time: new Date(), - data: data.data, + data: handle_payload( + data.data, + dependency, + config.components, + "output", + options.with_null_state + ), endpoint: _endpoint, fn_index }); diff --git a/js/app/src/Index.svelte b/js/app/src/Index.svelte index 87324dc777af8..eb806a18010f4 100644 --- a/js/app/src/Index.svelte +++ b/js/app/src/Index.svelte @@ -274,7 +274,8 @@ : host || space || src || location.origin; app = await Client.connect(api_url, { - status_callback: handle_status + status_callback: handle_status, + with_null_state: true }); if (!app.config) { diff --git a/js/preview/src/dev.ts b/js/preview/src/dev.ts index 543d07a0a48e8..2428fbf69c4bd 100644 --- a/js/preview/src/dev.ts +++ b/js/preview/src/dev.ts @@ -101,7 +101,7 @@ function find_frontend_folders(start_path: string): string[] { function to_posix(_path: string): string { const isExtendedLengthPath = /^\\\\\?\\/.test(_path); - const hasNonAscii = /[^\u0000-\u0080]+/.test(_path); // eslint-disable-line no-control-regex + const hasNonAscii = /[^\u0000-\u0080]+/.test(_path); if (isExtendedLengthPath || hasNonAscii) { return _path;