Skip to content

Commit

Permalink
fix(json-schema): correct handling of nested recursive schemas (#992)
Browse files Browse the repository at this point in the history
* Fix zod to json schema with nested and recursive objects

* minor style updates

* add an iteration limit

---------

Co-authored-by: Zijia Zhang <[email protected]>
  • Loading branch information
RobertCraigie and ZijiaZhang authored Aug 13, 2024
1 parent d486d27 commit ac309ab
Show file tree
Hide file tree
Showing 4 changed files with 242 additions and 27 deletions.
31 changes: 18 additions & 13 deletions src/_vendor/zod-to-json-schema/Options.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,9 @@ export type Options<Target extends Targets = 'jsonSchema7'> = {
openaiStrictMode?: boolean;
};

export const defaultOptions: Options = {
const defaultOptions: Omit<Options, 'definitions' | 'basePath'> = {
name: undefined,
$refStrategy: 'root',
basePath: ['#'],
effectStrategy: 'input',
pipeStrategy: 'all',
dateStrategy: 'format:date-time',
Expand All @@ -51,7 +50,6 @@ export const defaultOptions: Options = {
definitionPath: 'definitions',
target: 'jsonSchema7',
strictUnions: false,
definitions: {},
errorMessages: false,
markdownDescription: false,
patternStrategy: 'escape',
Expand All @@ -63,13 +61,20 @@ export const defaultOptions: Options = {

export const getDefaultOptions = <Target extends Targets>(
options: Partial<Options<Target>> | string | undefined,
) =>
(typeof options === 'string' ?
{
...defaultOptions,
name: options,
}
: {
...defaultOptions,
...options,
}) as Options<Target>;
) => {
// We need to add `definitions` here as we may mutate it
return (
typeof options === 'string' ?
{
...defaultOptions,
basePath: ['#'],
definitions: {},
name: options,
}
: {
...defaultOptions,
basePath: ['#'],
definitions: {},
...options,
}) as Options<Target>;
};
28 changes: 21 additions & 7 deletions src/_vendor/zod-to-json-schema/zodToJsonSchema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,28 @@ const zodToJsonSchema = <Target extends Targets = 'jsonSchema7'>(
}

const definitions: Record<string, any> = {};
const processedDefinitions = new Set();

for (const [name, zodSchema] of Object.entries(refs.definitions)) {
definitions[name] =
parseDef(
zodDef(zodSchema),
{ ...refs, currentPath: [...refs.basePath, refs.definitionPath, name] },
true,
) ?? {};
// the call to `parseDef()` here might itself add more entries to `.definitions`
// so we need to continually evaluate definitions until we've resolved all of them
//
// we have a generous iteration limit here to avoid blowing up the stack if there
// are any bugs that would otherwise result in us iterating indefinitely
for (let i = 0; i < 500; i++) {
const newDefinitions = Object.entries(refs.definitions).filter(
([key]) => !processedDefinitions.has(key),
);
if (newDefinitions.length === 0) break;

for (const [key, schema] of newDefinitions) {
definitions[key] =
parseDef(
zodDef(schema),
{ ...refs, currentPath: [...refs.basePath, refs.definitionPath, key] },
true,
) ?? {};
processedDefinitions.add(key);
}
}

return definitions;
Expand Down
28 changes: 28 additions & 0 deletions tests/lib/__snapshots__/parser.test.ts.snap
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,34 @@ exports[`.parse() zod nested schema extraction 2`] = `
"
`;

exports[`.parse() zod recursive schema extraction 2`] = `
"{
"id": "chatcmpl-9vdbw9dekyUSEsSKVQDhTxA2RCxcK",
"object": "chat.completion",
"created": 1723523988,
"model": "gpt-4o-2024-08-06",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "{\\"linked_list\\":{\\"value\\":1,\\"next\\":{\\"value\\":2,\\"next\\":{\\"value\\":3,\\"next\\":{\\"value\\":4,\\"next\\":{\\"value\\":5,\\"next\\":null}}}}}}",
"refusal": null
},
"logprobs": null,
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 40,
"completion_tokens": 38,
"total_tokens": 78
},
"system_fingerprint": "fp_2a322c9ffc"
}
"
`;

exports[`.parse() zod top-level recursive schemas 1`] = `
"{
"id": "chatcmpl-9uLhw79ArBF4KsQQOlsoE68m6vh6v",
Expand Down
182 changes: 175 additions & 7 deletions tests/lib/parser.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -525,13 +525,6 @@ describe('.parse()', () => {
"$schema": "http://json-schema.org/draft-07/schema#",
"additionalProperties": false,
"definitions": {
"contactPerson_properties_person1_properties_name": {
"type": "string",
},
"contactPerson_properties_person1_properties_phone_number": {
"nullable": true,
"type": "string",
},
"query": {
"additionalProperties": false,
"properties": {
Expand Down Expand Up @@ -616,6 +609,21 @@ describe('.parse()', () => {
},
],
},
"query_properties_fields_items_anyOf_0_properties_metadata_anyOf_0": {
"additionalProperties": false,
"properties": {
"foo": {
"$ref": "#/definitions/query_properties_fields_items_anyOf_0_properties_metadata_anyOf_0_properties_foo",
},
},
"required": [
"foo",
],
"type": "object",
},
"query_properties_fields_items_anyOf_0_properties_metadata_anyOf_0_properties_foo": {
"type": "string",
},
},
"properties": {
"fields": {
Expand Down Expand Up @@ -783,5 +791,165 @@ describe('.parse()', () => {
}
`);
});

test('recursive schema extraction', async () => {
const baseLinkedListNodeSchema = z.object({
value: z.number(),
});
type LinkedListNode = z.infer<typeof baseLinkedListNodeSchema> & {
next: LinkedListNode | null;
};
const linkedListNodeSchema: z.ZodType<LinkedListNode> = baseLinkedListNodeSchema.extend({
next: z.lazy(() => z.union([linkedListNodeSchema, z.null()])),
});

// Define the main schema
const mainSchema = z.object({
linked_list: linkedListNodeSchema,
});

expect(zodResponseFormat(mainSchema, 'query').json_schema.schema).toMatchInlineSnapshot(`
{
"$schema": "http://json-schema.org/draft-07/schema#",
"additionalProperties": false,
"definitions": {
"query": {
"additionalProperties": false,
"properties": {
"linked_list": {
"additionalProperties": false,
"properties": {
"next": {
"anyOf": [
{
"$ref": "#/definitions/query_properties_linked_list",
},
{
"type": "null",
},
],
},
"value": {
"type": "number",
},
},
"required": [
"value",
"next",
],
"type": "object",
},
},
"required": [
"linked_list",
],
"type": "object",
},
"query_properties_linked_list": {
"additionalProperties": false,
"properties": {
"next": {
"$ref": "#/definitions/query_properties_linked_list_properties_next",
},
"value": {
"$ref": "#/definitions/query_properties_linked_list_properties_value",
},
},
"required": [
"value",
"next",
],
"type": "object",
},
"query_properties_linked_list_properties_next": {
"anyOf": [
{
"$ref": "#/definitions/query_properties_linked_list",
},
{
"type": "null",
},
],
},
"query_properties_linked_list_properties_value": {
"type": "number",
},
},
"properties": {
"linked_list": {
"additionalProperties": false,
"properties": {
"next": {
"anyOf": [
{
"$ref": "#/definitions/query_properties_linked_list",
},
{
"type": "null",
},
],
},
"value": {
"type": "number",
},
},
"required": [
"value",
"next",
],
"type": "object",
},
},
"required": [
"linked_list",
],
"type": "object",
}
`);

const completion = await makeSnapshotRequest(
(openai) =>
openai.beta.chat.completions.parse({
model: 'gpt-4o-2024-08-06',
messages: [
{
role: 'system',
content:
"You are a helpful assistant. Generate a data model according to the user's instructions.",
},
{ role: 'user', content: 'create a linklist from 1 to 5' },
],
response_format: zodResponseFormat(mainSchema, 'query'),
}),
2,
);

expect(completion.choices[0]?.message).toMatchInlineSnapshot(`
{
"content": "{"linked_list":{"value":1,"next":{"value":2,"next":{"value":3,"next":{"value":4,"next":{"value":5,"next":null}}}}}}",
"parsed": {
"linked_list": {
"next": {
"next": {
"next": {
"next": {
"next": null,
"value": 5,
},
"value": 4,
},
"value": 3,
},
"value": 2,
},
"value": 1,
},
},
"refusal": null,
"role": "assistant",
"tool_calls": [],
}
`);
});
});
});

0 comments on commit ac309ab

Please sign in to comment.