From b290953bfeb1d5b44671951c2c3eed560c17d981 Mon Sep 17 00:00:00 2001 From: Ben Burns <803016+benjamincburns@users.noreply.github.com> Date: Wed, 22 Jan 2025 07:24:03 +1300 Subject: [PATCH] patch(langgraph): Refactor pregel loop to use new `PregelRunner` class, ported from python LangGraph. (#791) --- libs/langgraph/src/errors.ts | 10 +- libs/langgraph/src/pregel/index.ts | 196 ++++++++++--------------- libs/langgraph/src/pregel/loop.ts | 53 ++++--- libs/langgraph/src/pregel/retry.ts | 112 +++++---------- libs/langgraph/src/pregel/runner.ts | 216 ++++++++++++++++++++++++++++ 5 files changed, 365 insertions(+), 222 deletions(-) create mode 100644 libs/langgraph/src/pregel/runner.ts diff --git a/libs/langgraph/src/errors.ts b/libs/langgraph/src/errors.ts index ca471bc4..9d4bbf1b 100644 --- a/libs/langgraph/src/errors.ts +++ b/libs/langgraph/src/errors.ts @@ -95,26 +95,24 @@ export class ParentCommand extends GraphBubbleUp { } } -export function isParentCommand(e?: Error): e is ParentCommand { +export function isParentCommand(e?: unknown): e is ParentCommand { return ( e !== undefined && (e as ParentCommand).name === ParentCommand.unminifiable_name ); } -export function isGraphBubbleUp(e?: Error): e is GraphBubbleUp { +export function isGraphBubbleUp(e?: unknown): e is GraphBubbleUp { return e !== undefined && (e as GraphBubbleUp).is_bubble_up === true; } -export function isGraphInterrupt( - e?: GraphInterrupt | Error -): e is GraphInterrupt { +export function isGraphInterrupt(e?: unknown): e is GraphInterrupt { return ( e !== undefined && [ GraphInterrupt.unminifiable_name, NodeInterrupt.unminifiable_name, - ].includes(e.name) + ].includes((e as Error).name) ); } diff --git a/libs/langgraph/src/pregel/index.ts b/libs/langgraph/src/pregel/index.ts index b349ec7d..2c39ba53 100644 --- a/libs/langgraph/src/pregel/index.ts +++ b/libs/langgraph/src/pregel/index.ts @@ -11,7 +11,6 @@ import { RunnableLike, } from "@langchain/core/runnables"; import { IterableReadableStream } from "@langchain/core/utils/stream"; -import type { CallbackManagerForChainRun } from "@langchain/core/callbacks/manager"; import { All, BaseCheckpointSaver, @@ -54,7 +53,6 @@ import { Command, NULL_TASK_ID, INPUT, - RESUME, PUSH, } from "../constants.js"; import { @@ -71,8 +69,6 @@ import { GraphRecursionError, GraphValueError, InvalidUpdateError, - isGraphBubbleUp, - isGraphInterrupt, } from "../errors.js"; import { _prepareNextTasks, @@ -89,7 +85,6 @@ import { } from "./utils/index.js"; import { findSubgraphPregel } from "./utils/subgraph.js"; import { PregelLoop, IterableReadableWritableStream } from "./loop.js"; -import { executeTasksWithRetry } from "./retry.js"; import { ChannelKeyPlaceholder, isConfiguredManagedValue, @@ -102,6 +97,7 @@ import { gatherIterator, patchConfigurable } from "../utils.js"; import { ensureLangGraphConfig } from "./utils/config.js"; import { LangGraphRunnableConfig } from "./runnable_types.js"; import { StreamMessagesHandler } from "./messages.js"; +import { PregelRunner } from "./runner.js"; type WriteValue = Runnable | RunnableFunc | unknown; @@ -1156,115 +1152,6 @@ export class Pregel< }; } - async _runLoop(params: { - loop: PregelLoop; - interruptAfter: string[] | "*"; - interruptBefore: string[] | "*"; - runManager?: CallbackManagerForChainRun; - debug: boolean; - config: LangGraphRunnableConfig; - }) { - const { loop, interruptAfter, interruptBefore, runManager, debug, config } = - params; - let tickError; - try { - while ( - await loop.tick({ - inputKeys: this.inputChannels as string | string[], - interruptAfter, - interruptBefore, - manager: runManager, - }) - ) { - if (debug) { - printStepCheckpoint( - loop.checkpointMetadata.step, - loop.channels, - this.streamChannelsList as string[] - ); - } - if (debug) { - printStepTasks(loop.step, Object.values(loop.tasks)); - } - // execute tasks, and wait for one to fail or all to finish. - // each task is independent from all other concurrent tasks - // yield updates/debug output as each task finishes - const taskStream = executeTasksWithRetry( - Object.values(loop.tasks).filter((task) => task.writes.length === 0), - { - stepTimeout: this.stepTimeout, - signal: config.signal, - retryPolicy: this.retryPolicy, - } - ); - let graphInterrupt; - for await (const { task, error } of taskStream) { - if (error !== undefined) { - if (isGraphBubbleUp(error)) { - if (loop.isNested) { - throw error; - } - if (isGraphInterrupt(error)) { - graphInterrupt = error; - if (error.interrupts.length) { - const interrupts: PendingWrite[] = - error.interrupts.map((interrupt) => [INTERRUPT, interrupt]); - const resumes = task.writes.filter((w) => w[0] === RESUME); - if (resumes.length) { - interrupts.push(...resumes); - } - loop.putWrites(task.id, interrupts); - } - } - } else { - loop.putWrites(task.id, [ - [ERROR, { message: error.message, name: error.name }], - ]); - throw error; - } - } else { - loop.putWrites(task.id, task.writes); - } - } - - if (debug) { - printStepWrites( - loop.step, - Object.values(loop.tasks) - .map((task) => task.writes) - .flat(), - this.streamChannelsList as string[] - ); - } - if (graphInterrupt !== undefined) { - throw graphInterrupt; - } - } - if (loop.status === "out_of_steps") { - throw new GraphRecursionError( - [ - `Recursion limit of ${config.recursionLimit} reached`, - "without hitting a stop condition. You can increase the", - `limit by setting the "recursionLimit" config key.`, - ].join(" "), - { - lc_error_code: "GRAPH_RECURSION_LIMIT", - } - ); - } - } catch (e) { - tickError = e as Error; - const suppress = await loop.finishAndHandleError(tickError); - if (!suppress) { - throw e; - } - } finally { - if (tickError === undefined) { - await loop.finishAndHandleError(); - } - } - } - override async *_streamIterator( input: PregelInputType | Command, options?: Partial> @@ -1365,21 +1252,23 @@ export class Pregel< streamKeys: this.streamChannelsAsIs as string | string[], store, stream, + interruptAfter, + interruptBefore, + manager: runManager, + debug: this.debug, + }); + + const runner = new PregelRunner({ + loop, }); + if (options?.subgraphs) { loop.config.configurable = { ...loop.config.configurable, [CONFIG_KEY_STREAM]: loop.stream, }; } - await this._runLoop({ - loop, - interruptAfter, - interruptBefore, - runManager, - debug, - config, - }); + await this._runLoop({ loop, runner, debug, config }); } catch (e) { loopError = e; } finally { @@ -1472,4 +1361,67 @@ export class Pregel< } return chunks as OutputType; } + + private async _runLoop(params: { + loop: PregelLoop; + runner: PregelRunner; + config: RunnableConfig; + debug: boolean; + }) { + const { loop, runner, debug, config } = params; + let tickError; + try { + while ( + await loop.tick({ + inputKeys: this.inputChannels as string | string[], + }) + ) { + if (debug) { + printStepCheckpoint( + loop.checkpointMetadata.step, + loop.channels, + this.streamChannelsList as string[] + ); + } + if (debug) { + printStepTasks(loop.step, Object.values(loop.tasks)); + } + await runner.tick({ + timeout: this.stepTimeout, + retryPolicy: this.retryPolicy, + onStepWrite: (step, writes) => { + if (debug) { + printStepWrites( + step, + writes, + this.streamChannelsList as string[] + ); + } + }, + }); + } + if (loop.status === "out_of_steps") { + throw new GraphRecursionError( + [ + `Recursion limit of ${config.recursionLimit} reached`, + "without hitting a stop condition. You can increase the", + `limit by setting the "recursionLimit" config key.`, + ].join(" "), + { + lc_error_code: "GRAPH_RECURSION_LIMIT", + } + ); + } + } catch (e) { + tickError = e as Error; + const suppress = await loop.finishAndHandleError(tickError); + if (!suppress) { + throw e; + } + } finally { + if (tickError === undefined) { + await loop.finishAndHandleError(); + } + } + } } diff --git a/libs/langgraph/src/pregel/loop.ts b/libs/langgraph/src/pregel/loop.ts index 7dce6215..ce062763 100644 --- a/libs/langgraph/src/pregel/loop.ts +++ b/libs/langgraph/src/pregel/loop.ts @@ -93,6 +93,10 @@ export type PregelLoopInitializeParams = { stream: IterableReadableWritableStream; store?: BaseStore; checkSubgraphs?: boolean; + interruptAfter: string[] | All; + interruptBefore: string[] | All; + manager?: CallbackManagerForChainRun; + debug: boolean; }; type PregelLoopParams = { @@ -115,9 +119,13 @@ type PregelLoopParams = { checkpointNamespace: string[]; skipDoneTasks: boolean; isNested: boolean; + manager?: CallbackManagerForChainRun; stream: IterableReadableWritableStream; store?: AsyncBatchedStore; prevCheckpointConfig: RunnableConfig | undefined; + interruptAfter: string[] | All; + interruptBefore: string[] | All; + debug: boolean; }; export class IterableReadableWritableStream extends IterableReadableStream { @@ -254,6 +262,16 @@ export class PregelLoop { store?: AsyncBatchedStore; + manager?: CallbackManagerForChainRun; + + interruptAfter: string[] | All; + + interruptBefore: string[] | All; + + toInterrupt: PregelExecutableTask[] = []; + + debug: boolean = false; + constructor(params: PregelLoopParams) { this.input = params.input; this.checkpointer = params.checkpointer; @@ -277,6 +295,7 @@ export class PregelLoop { this.config = params.config; this.checkpointConfig = params.checkpointConfig; this.isNested = params.isNested; + this.manager = params.manager; this.outputKeys = params.outputKeys; this.streamKeys = params.streamKeys; this.nodes = params.nodes; @@ -285,6 +304,9 @@ export class PregelLoop { this.stream = params.stream; this.checkpointNamespace = params.checkpointNamespace; this.prevCheckpointConfig = params.prevCheckpointConfig; + this.interruptAfter = params.interruptAfter; + this.interruptBefore = params.interruptBefore; + this.debug = params.debug; } static async initialize(params: PregelLoopInitializeParams) { @@ -401,6 +423,7 @@ export class PregelLoop { channels, managed: params.managed, isNested, + manager: params.manager, skipDoneTasks, step, stop, @@ -411,6 +434,9 @@ export class PregelLoop { nodes: params.nodes, stream, store, + interruptAfter: params.interruptAfter, + interruptBefore: params.interruptBefore, + debug: params.debug, }); } @@ -532,21 +558,11 @@ export class PregelLoop { * Returns true if more iterations are needed. * @param params */ - async tick(params: { - inputKeys?: string | string[]; - interruptAfter: string[] | All; - interruptBefore: string[] | All; - manager?: CallbackManagerForChainRun; - }): Promise { + async tick(params: { inputKeys?: string | string[] }): Promise { if (this.store && !this.store.isRunning) { this.store?.start(); } - const { - inputKeys = [], - interruptAfter = [], - interruptBefore = [], - manager, - } = params; + const { inputKeys = [] } = params; if (this.status !== "pending") { throw new Error( `Cannot tick when status is no longer "pending". Current status: "${this.status}"` @@ -592,7 +608,7 @@ export class PregelLoop { if ( shouldInterrupt( this.checkpoint, - interruptAfter, + this.interruptAfter, Object.values(this.tasks) ) ) { @@ -619,7 +635,7 @@ export class PregelLoop { step: this.step, checkpointer: this.checkpointer, isResuming: this.input === INPUT_RESUMING, - manager, + manager: this.manager, store: this.store, } ); @@ -669,19 +685,14 @@ export class PregelLoop { } // if all tasks have finished, re-tick if (Object.values(this.tasks).every((task) => task.writes.length > 0)) { - return this.tick({ - inputKeys, - interruptAfter, - interruptBefore, - manager, - }); + return this.tick({ inputKeys }); } // Before execution, check if we should interrupt if ( shouldInterrupt( this.checkpoint, - interruptBefore, + this.interruptBefore, Object.values(this.tasks) ) ) { diff --git a/libs/langgraph/src/pregel/retry.ts b/libs/langgraph/src/pregel/retry.ts index 12fbc08f..1d16bba4 100644 --- a/libs/langgraph/src/pregel/retry.ts +++ b/libs/langgraph/src/pregel/retry.ts @@ -1,11 +1,15 @@ -import { CHECKPOINT_NAMESPACE_SEPARATOR, Command } from "../constants.js"; +import { + CHECKPOINT_NAMESPACE_SEPARATOR, + Command, + CONFIG_KEY_RESUMING, +} from "../constants.js"; import { getSubgraphsSeenSet, isGraphBubbleUp, isParentCommand, } from "../errors.js"; import { PregelExecutableTask } from "./types.js"; -import type { RetryPolicy } from "./utils/index.js"; +import { patchConfigurable, type RetryPolicy } from "./utils/index.js"; export const DEFAULT_INITIAL_INTERVAL = 500; export const DEFAULT_BACKOFF_FACTOR = 2; @@ -56,59 +60,19 @@ export type SettledPregelTask = { error: Error; }; -export async function* executeTasksWithRetry( - // eslint-disable-next-line @typescript-eslint/no-explicit-any - tasks: PregelExecutableTask[], - options?: { - stepTimeout?: number; - signal?: AbortSignal; - retryPolicy?: RetryPolicy; - } -): AsyncGenerator { - const { stepTimeout, retryPolicy } = options ?? {}; - let signal = options?.signal; - // Start tasks - const executingTasksMap = Object.fromEntries( - tasks.map((pregelTask) => { - return [pregelTask.id, _runWithRetry(pregelTask, retryPolicy)]; - }) - ); - if (stepTimeout && signal) { - if ("any" in AbortSignal) { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - signal = (AbortSignal as any).any([ - signal, - AbortSignal.timeout(stepTimeout), - ]); - } - } else if (stepTimeout) { - signal = AbortSignal.timeout(stepTimeout); - } - - // Abort if signal is aborted - signal?.throwIfAborted(); - - let listener: () => void; - const signalPromise = new Promise((_resolve, reject) => { - listener = () => reject(new Error("Abort")); - signal?.addEventListener("abort", listener); - }).finally(() => signal?.removeEventListener("abort", listener)); - - while (Object.keys(executingTasksMap).length > 0) { - const settledTask = await Promise.race([ - ...Object.values(executingTasksMap), - signalPromise, - ]); - yield settledTask; - delete executingTasksMap[settledTask.task.id]; - } -} - -async function _runWithRetry( +export async function _runWithRetry< + N extends PropertyKey, + C extends PropertyKey +>( // eslint-disable-next-line @typescript-eslint/no-explicit-any - pregelTask: PregelExecutableTask, - retryPolicy?: RetryPolicy -) { + pregelTask: PregelExecutableTask, + retryPolicy?: RetryPolicy, + configurable?: Record +): Promise<{ + task: PregelExecutableTask; + result: unknown; + error: Error | undefined; +}> { const resolvedRetryPolicy = pregelTask.retry_policy ?? retryPolicy; let interval = resolvedRetryPolicy !== undefined @@ -117,30 +81,29 @@ async function _runWithRetry( let attempts = 0; let error; let result; + + let { config } = pregelTask; + if (configurable) { + config = patchConfigurable(config, configurable); + } // eslint-disable-next-line no-constant-condition while (true) { - // Modify writes in place to clear any previous retries - while (pregelTask.writes.length > 0) { - pregelTask.writes.pop(); - } + // Clear any writes from previous attempts + pregelTask.writes.splice(0, pregelTask.writes.length); error = undefined; try { - result = await pregelTask.proc.invoke( - pregelTask.input, - pregelTask.config - ); + result = await pregelTask.proc.invoke(pregelTask.input, config); break; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - } catch (e: any) { + } catch (e: unknown) { error = e; - error.pregelTaskId = pregelTask.id; + (error as { pregelTaskId: string }).pregelTaskId = pregelTask.id; if (isParentCommand(error)) { - const ns: string = pregelTask.config?.configurable?.checkpoint_ns; + const ns: string = config?.configurable?.checkpoint_ns; const cmd = error.command; if (cmd.graph === ns) { // this command is for the current graph, handle it for (const writer of pregelTask.writers) { - await writer.invoke(cmd, pregelTask.config); + await writer.invoke(cmd, config); } break; } else if (cmd.graph === Command.PARENT) { @@ -184,18 +147,21 @@ async function _runWithRetry( await new Promise((resolve) => setTimeout(resolve, intervalWithJitter)); // log the retry const errorName = - error.name ?? + (error as Error).name ?? // eslint-disable-next-line @typescript-eslint/no-explicit-any - (error.constructor as any).unminifiable_name ?? - error.constructor.name; + ((error as Error).constructor as any).unminifiable_name ?? + (error as Error).constructor.name; console.log( - `Retrying task "${pregelTask.name}" after ${interval.toFixed( + `Retrying task "${String(pregelTask.name)}" after ${interval.toFixed( 2 )}ms (attempt ${attempts}) after ${errorName}: ${error}` ); + + // signal subgraphs to resume (if available) + config = patchConfigurable(config, { [CONFIG_KEY_RESUMING]: true }); } finally { // Clear checkpoint_ns seen (for subgraph detection) - const checkpointNs = pregelTask.config?.configurable?.checkpoint_ns; + const checkpointNs = config?.configurable?.checkpoint_ns; if (checkpointNs) { getSubgraphsSeenSet().delete(checkpointNs); } @@ -204,6 +170,6 @@ async function _runWithRetry( return { task: pregelTask, result, - error, + error: error as Error | undefined, }; } diff --git a/libs/langgraph/src/pregel/runner.ts b/libs/langgraph/src/pregel/runner.ts new file mode 100644 index 00000000..a4c256c3 --- /dev/null +++ b/libs/langgraph/src/pregel/runner.ts @@ -0,0 +1,216 @@ +import { PendingWrite } from "@langchain/langgraph-checkpoint"; +import { PregelExecutableTask } from "./types.js"; +import { RetryPolicy } from "./utils/index.js"; +import { CONFIG_KEY_SEND, ERROR, INTERRUPT, RESUME } from "../constants.js"; +import { + GraphInterrupt, + isGraphBubbleUp, + isGraphInterrupt, +} from "../errors.js"; +import { _runWithRetry, SettledPregelTask } from "./retry.js"; +import { PregelLoop } from "./loop.js"; + +/** + * Options for the @see PregelRunner#tick method. + */ +export type TickOptions = { + /** + * The deadline before which all tasks must be completed. + */ + timeout?: number; + + /** + * An optional @see AbortSignal to cancel processing of tasks. + */ + signal?: AbortSignal; + + /** + * The @see RetryPolicy to use for the tick. + */ + retryPolicy?: RetryPolicy; + + /** + * An optional callback to be called after all task writes are completed. + */ + onStepWrite?: (step: number, writes: PendingWrite[]) => void; +}; + +/** + * Responsible for handling task execution on each tick of the @see PregelLoop. + */ +export class PregelRunner { + loop: PregelLoop; + + /** + * Construct a new PregelRunner, which executes tasks from the provided PregelLoop. + * @param loop - The PregelLoop that produces tasks for this runner to execute. + */ + constructor({ loop }: { loop: PregelLoop }) { + this.loop = loop; + } + + /** + * Execute tasks from the current step of the PregelLoop. + * + * Note: this method does NOT call @see PregelLoop#tick. That must be handled externally. + * @param options - Options for the execution. + */ + async tick(options: TickOptions = {}): Promise { + const tasks = Object.values(this.loop.tasks); + + const { timeout, signal, retryPolicy, onStepWrite } = options; + + // Start task execution + const pendingTasks = tasks.filter((t) => t.writes.length === 0); + const taskStream = this._executeTasksWithRetry(pendingTasks, { + stepTimeout: timeout, + signal, + retryPolicy, + }); + + let graphInterrupt: GraphInterrupt | undefined; + + for await (const { task, error } of taskStream) { + graphInterrupt = this._commit(task, error) ?? graphInterrupt; + } + + onStepWrite?.( + this.loop.step, + Object.values(this.loop.tasks) + .map((task) => task.writes) + .flat() + ); + + if (graphInterrupt) { + throw graphInterrupt; + } + } + + /** + * Concurrently executes tasks with the requested retry policy, yielding a @see SettledPregelTask for each task as it completes. + * @param tasks - The tasks to execute. + * @param options - Options for the execution. + */ + private async *_executeTasksWithRetry( + tasks: PregelExecutableTask[], + options?: { + stepTimeout?: number; + signal?: AbortSignal; + retryPolicy?: RetryPolicy; + } + ): AsyncGenerator { + const { stepTimeout, retryPolicy } = options ?? {}; + let signal = options?.signal; + + const executingTasksMap: Record< + string, + Promise<{ + task: PregelExecutableTask; + result?: unknown; + error?: Error; + }> + > = {}; + + const writer = ( + task: PregelExecutableTask, + writes: Array<[string, unknown]> + ): Array | undefined> => { + // placeholder function - will have logic added when functional API is implemented + return task.config?.configurable?.[CONFIG_KEY_SEND]?.(writes) ?? []; + }; + + // Start tasks + Object.assign( + executingTasksMap, + Object.fromEntries( + tasks.map((pregelTask) => { + return [ + pregelTask.id, + _runWithRetry(pregelTask, retryPolicy, { + [CONFIG_KEY_SEND]: writer?.bind(this, pregelTask), + }).catch((error) => { + return { task: pregelTask, error }; + }), + ]; + }) + ) + ); + + if (stepTimeout && signal) { + if ("any" in AbortSignal) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + signal = (AbortSignal as any).any([ + signal, + AbortSignal.timeout(stepTimeout), + ]); + } + } else if (stepTimeout) { + signal = AbortSignal.timeout(stepTimeout); + } + + // Abort if signal is aborted + signal?.throwIfAborted(); + + let listener: () => void; + const signalPromise = new Promise((_resolve, reject) => { + listener = () => reject(new Error("Abort")); + signal?.addEventListener("abort", listener); + }).finally(() => signal?.removeEventListener("abort", listener)); + + while (Object.keys(executingTasksMap).length > 0) { + const settledTask = await Promise.race([ + ...Object.values(executingTasksMap), + signalPromise, + ]); + yield settledTask as SettledPregelTask; + delete executingTasksMap[settledTask.task.id]; + } + } + + /** + * Determines what writes to apply based on whether the task completed successfully, and what type of error occurred. + * + * Throws an error if the error is a @see GraphBubbleUp error and @see PregelLoop#isNested is true. + * + * Note that in the case of a @see GraphBubbleUp error that is not a @see GraphInterrupt, like a @see Command, this method does not apply any writes. + * + * @param task - The task to commit. + * @param error - The error that occurred, if any. + * @returns The @see GraphInterrupt that occurred, if the user's code threw one. + */ + private _commit( + task: PregelExecutableTask, + error?: Error + ): GraphInterrupt | undefined { + let graphInterrupt; + if (error !== undefined) { + if (isGraphBubbleUp(error)) { + if (this.loop.isNested) { + throw error; + } + if (isGraphInterrupt(error)) { + graphInterrupt = error; + if (error.interrupts.length) { + const interrupts: PendingWrite[] = error.interrupts.map( + (interrupt) => [INTERRUPT, interrupt] + ); + const resumes = task.writes.filter((w) => w[0] === RESUME); + if (resumes.length) { + interrupts.push(...resumes); + } + this.loop.putWrites(task.id, interrupts); + } + } + } else { + this.loop.putWrites(task.id, [ + [ERROR, { message: error.message, name: error.name }], + ]); + throw error; + } + } else { + // Save task writes to checkpointer + this.loop.putWrites(task.id, task.writes); + } + return graphInterrupt; + } +}