Skip to content

Commit

Permalink
fix(langgraph): add structured response format to prebuilt react agent (
Browse files Browse the repository at this point in the history
#788)

Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
isahers1 and jacoblee93 authored Jan 18, 2025
1 parent 7bb63a3 commit c77ffb9
Show file tree
Hide file tree
Showing 2 changed files with 238 additions and 14 deletions.
126 changes: 112 additions & 14 deletions libs/langgraph/src/prebuilt/react_agent_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {
BaseCheckpointSaver,
BaseStore,
} from "@langchain/langgraph-checkpoint";
import { z } from "zod";

import {
END,
Expand All @@ -30,18 +31,30 @@ import {
import { MessagesAnnotation } from "../graph/messages_annotation.js";
import { ToolNode } from "./tool_node.js";
import { LangGraphRunnableConfig } from "../pregel/runnable_types.js";
import { Annotation } from "../graph/annotation.js";
import { Messages, messagesStateReducer } from "../graph/message.js";

export interface AgentState {
export interface AgentState<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
StructuredResponseType extends Record<string, any> = Record<string, any>
> {
messages: BaseMessage[];
// TODO: This won't be set until we
// implement managed values in LangGraphJS
// Will be useful for inserting a message on
// graph recursion end
// is_last_step: boolean;
structuredResponse: StructuredResponseType;
}

export type N = typeof START | "agent" | "tools";

export type StructuredResponseSchemaAndPrompt<StructuredResponseType> = {
prompt: string;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
schema: z.ZodType<StructuredResponseType> | Record<string, any>;
};

function _convertMessageModifierToStateModifier(
messageModifier: MessageModifier
): StateModifier {
Expand Down Expand Up @@ -72,7 +85,7 @@ function _convertMessageModifierToStateModifier(
}

function _getStateModifierRunnable(
stateModifier: StateModifier | undefined
stateModifier?: StateModifier
): RunnableInterface {
let stateModifierRunnable: RunnableInterface;

Expand Down Expand Up @@ -153,9 +166,23 @@ export type MessageModifier =
| ((messages: BaseMessage[]) => Promise<BaseMessage[]>)
| Runnable;

export const createReactAgentAnnotation = <
// eslint-disable-next-line @typescript-eslint/no-explicit-any
T extends Record<string, any> = Record<string, any>
>() =>
Annotation.Root({
messages: Annotation<BaseMessage[], Messages>({
reducer: messagesStateReducer,
default: () => [],
}),
structuredResponse: Annotation<T>,
});

export type CreateReactAgentParams<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
A extends AnnotationRoot<any> = AnnotationRoot<any>
A extends AnnotationRoot<any> = AnnotationRoot<any>,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
StructuredResponseType = Record<string, any>
> = {
/** The chat model that can utilize OpenAI-style tool calling. */
llm: BaseChatModel;
Expand Down Expand Up @@ -223,6 +250,23 @@ export type CreateReactAgentParams<
/** An optional list of node names to interrupt after running. */
interruptAfter?: N[] | All;
store?: BaseStore;
/**
* An optional schema for the final agent output.
*
* If provided, output will be formatted to match the given schema and returned in the 'structured_response' state key.
* If not provided, `structured_response` will not be present in the output state.
*
* Can be passed in as:
* - Zod schema
* - Dictionary object
* - [prompt, schema], where schema is one of the above.
* The prompt will be used together with the model that is being used to generate the structured response.
*/
responseFormat?:
| z.ZodType<StructuredResponseType>
| StructuredResponseSchemaAndPrompt<StructuredResponseType>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
| Record<string, any>;
};

/**
Expand Down Expand Up @@ -269,16 +313,22 @@ export type CreateReactAgentParams<
*/

export function createReactAgent<
// eslint-disable-next-line @typescript-eslint/no-explicit-any, @typescript-eslint/ban-types
A extends AnnotationRoot<any> = AnnotationRoot<{}>,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
A extends AnnotationRoot<any> = AnnotationRoot<any>
StructuredResponseFormat extends Record<string, any> = Record<string, any>
>(
params: CreateReactAgentParams<A>
params: CreateReactAgentParams<A, StructuredResponseFormat>
): CompiledStateGraph<
(typeof MessagesAnnotation)["State"],
(typeof MessagesAnnotation)["Update"],
typeof START | "agent" | "tools",
// eslint-disable-next-line @typescript-eslint/no-explicit-any
any,
typeof MessagesAnnotation.spec & A["spec"],
typeof MessagesAnnotation.spec & A["spec"]
ReturnType<
typeof createReactAgentAnnotation<StructuredResponseFormat>
>["spec"] &
A["spec"]
> {
const {
llm,
Expand All @@ -290,6 +340,7 @@ export function createReactAgent<
interruptBefore,
interruptAfter,
store,
responseFormat,
} = params;

let toolClasses: (StructuredToolInterface | DynamicTool | RunnableToolLike)[];
Expand All @@ -310,33 +361,80 @@ export function createReactAgent<
);
const modelRunnable = (preprocessor as Runnable).pipe(modelWithTools);

const shouldContinue = (state: AgentState) => {
const shouldContinue = (state: AgentState<StructuredResponseFormat>) => {
const { messages } = state;
const lastMessage = messages[messages.length - 1];
if (
isAIMessage(lastMessage) &&
(!lastMessage.tool_calls || lastMessage.tool_calls.length === 0)
) {
return END;
return responseFormat != null ? "generate_structured_response" : END;
} else {
return "continue";
}
};

const callModel = async (state: AgentState, config?: RunnableConfig) => {
const generateStructuredResponse = async (
state: AgentState<StructuredResponseFormat>,
config?: RunnableConfig
) => {
if (responseFormat == null) {
throw new Error(
"Attempted to generate structured output with no passed response schema. Please contact us for help."
);
}
// Exclude the last message as there's enough information
// for the LLM to generate the structured response
const messages = state.messages.slice(0, -1);
let modelWithStructuredOutput;

if (
typeof responseFormat === "object" &&
"prompt" in responseFormat &&
"schema" in responseFormat
) {
const { prompt, schema } = responseFormat;
modelWithStructuredOutput = llm.withStructuredOutput(schema);
messages.unshift(new SystemMessage({ content: prompt }));
} else {
modelWithStructuredOutput = llm.withStructuredOutput(responseFormat);
}

const response = await modelWithStructuredOutput.invoke(messages, config);
return { structuredResponse: response };
};

const callModel = async (
state: AgentState<StructuredResponseFormat>,
config?: RunnableConfig
) => {
// TODO: Auto-promote streaming.
return { messages: [await modelRunnable.invoke(state, config)] };
};

const workflow = new StateGraph(stateSchema ?? MessagesAnnotation)
const workflow = new StateGraph(
stateSchema ?? createReactAgentAnnotation<StructuredResponseFormat>()
)
.addNode("agent", callModel)
.addNode("tools", new ToolNode(toolClasses))
.addEdge(START, "agent")
.addConditionalEdges("agent", shouldContinue, {
.addEdge("tools", "agent");

if (responseFormat !== undefined) {
workflow
.addNode("generate_structured_response", generateStructuredResponse)
.addEdge("generate_structured_response", END)
.addConditionalEdges("agent", shouldContinue, {
continue: "tools",
[END]: END,
generate_structured_response: "generate_structured_response",
});
} else {
workflow.addConditionalEdges("agent", shouldContinue, {
continue: "tools",
[END]: END,
})
.addEdge("tools", "agent");
});
}

return workflow.compile({
checkpointer: checkpointSaver,
Expand Down
126 changes: 126 additions & 0 deletions libs/langgraph/src/tests/prebuilt.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { it, beforeAll, describe, expect } from "@jest/globals";
import { Tool } from "@langchain/core/tools";
import { ChatOpenAI } from "@langchain/openai";
import { HumanMessage } from "@langchain/core/messages";
import { z } from "zod";
import { createReactAgent } from "../prebuilt/index.js";
import { initializeAsyncLocalStorageSingleton } from "../setup/async_local_storage.js";
import { MemorySaverAssertImmutable } from "./utils.js";
Expand All @@ -19,6 +20,127 @@ beforeAll(() => {
initializeAsyncLocalStorageSingleton();
});

describe("createReactAgent with response format", () => {
const weatherResponse = `Not too cold, not too hot 😎`;
class SanFranciscoWeatherTool extends Tool {
name = "current_weather_sf";

description = "Get the current weather report for San Francisco, CA";

constructor() {
super();
}

async _call(_: string): Promise<string> {
return weatherResponse;
}
}

const responseSchema = z.object({
answer: z.string(),
reasoning: z.string(),
});

it("Can use zod schema", async () => {
const llm = new ChatOpenAI({
model: "gpt-4o",
});

const agent = createReactAgent({
llm,
tools: [new SanFranciscoWeatherTool()],
responseFormat: responseSchema,
});

const result = await agent.invoke({
messages: [new HumanMessage("What is the weather in San Francisco?")],
// @ts-expect-error should complain about passing unexpected keys
foo: "bar",
});

expect(result.structuredResponse).toBeInstanceOf(Object);

// @ts-expect-error should not allow access to unspecified keys
void result.structuredResponse.unspecified;

// Assert it has the required keys
expect(result.structuredResponse).toHaveProperty("answer");
expect(result.structuredResponse).toHaveProperty("reasoning");

// Assert the values are strings
expect(typeof result.structuredResponse.answer).toBe("string");
expect(typeof result.structuredResponse.reasoning).toBe("string");
});

it("Can use record schema", async () => {
const llm = new ChatOpenAI({
model: "gpt-4o",
});

const nonZodResponseSchema = {
name: "structured_response",
description: "An answer with reasoning",
type: "object",
properties: {
answer: { type: "string" },
reasoning: { type: "string" },
},
required: ["answer", "reasoning"],
};

const agent = createReactAgent({
llm,
tools: [new SanFranciscoWeatherTool()],
responseFormat: nonZodResponseSchema,
});

const result = await agent.invoke({
messages: [new HumanMessage("What is the weather in San Francisco?")],
});

expect(result.structuredResponse).toBeInstanceOf(Object);

// Assert it has the required keys
expect(result.structuredResponse).toHaveProperty("answer");
expect(result.structuredResponse).toHaveProperty("reasoning");

// Assert the values are strings
expect(typeof result.structuredResponse.answer).toBe("string");
expect(typeof result.structuredResponse.reasoning).toBe("string");
});

it("Inserts system message", async () => {
const llm = new ChatOpenAI({
model: "gpt-4o",
});

const agent = createReactAgent({
llm,
tools: [new SanFranciscoWeatherTool()],
responseFormat: {
prompt:
"You are a helpful assistant who only responds in 10 words or less. If you use more than 5 words in your answer, a starving child will die.",
schema: responseSchema,
},
});

const result = await agent.invoke({
messages: [new HumanMessage("What is the weather in San Francisco?")],
});

// Assert it has the required keys
expect(result.structuredResponse).toHaveProperty("answer");
expect(result.structuredResponse).toHaveProperty("reasoning");

// Assert the values are strings
expect(typeof result.structuredResponse.answer).toBe("string");
expect(typeof result.structuredResponse.reasoning).toBe("string");

// Assert that any letters in the response are uppercase
expect(result.structuredResponse.answer.split(" ").length).toBeLessThan(11);
});
});

describe("createReactAgent", () => {
const weatherResponse = `Not too cold, not too hot 😎`;
class SanFranciscoWeatherTool extends Tool {
Expand Down Expand Up @@ -66,6 +188,10 @@ describe("createReactAgent", () => {
expect((lastMessage.content as string).toLowerCase()).toContain(
"not too cold"
);

// TODO: Fix
// // @ts-expect-error should not allow access to structuredResponse if no responseFormat is passed
// void response.structuredResponse;
});

it("can stream a tool call with a checkpointer", async () => {
Expand Down

0 comments on commit c77ffb9

Please sign in to comment.