Skip to content

Commit

Permalink
patch(langgraph): Refactor pregel loop to use new PregelRunner clas…
Browse files Browse the repository at this point in the history
…s, ported from python LangGraph. (#791)
  • Loading branch information
benjamincburns authored Jan 21, 2025
1 parent e13199d commit b290953
Show file tree
Hide file tree
Showing 5 changed files with 365 additions and 222 deletions.
10 changes: 4 additions & 6 deletions libs/langgraph/src/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,26 +95,24 @@ export class ParentCommand extends GraphBubbleUp {
}
}

export function isParentCommand(e?: Error): e is ParentCommand {
export function isParentCommand(e?: unknown): e is ParentCommand {
return (
e !== undefined &&
(e as ParentCommand).name === ParentCommand.unminifiable_name
);
}

export function isGraphBubbleUp(e?: Error): e is GraphBubbleUp {
export function isGraphBubbleUp(e?: unknown): e is GraphBubbleUp {
return e !== undefined && (e as GraphBubbleUp).is_bubble_up === true;
}

export function isGraphInterrupt(
e?: GraphInterrupt | Error
): e is GraphInterrupt {
export function isGraphInterrupt(e?: unknown): e is GraphInterrupt {
return (
e !== undefined &&
[
GraphInterrupt.unminifiable_name,
NodeInterrupt.unminifiable_name,
].includes(e.name)
].includes((e as Error).name)
);
}

Expand Down
196 changes: 74 additions & 122 deletions libs/langgraph/src/pregel/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import {
RunnableLike,
} from "@langchain/core/runnables";
import { IterableReadableStream } from "@langchain/core/utils/stream";
import type { CallbackManagerForChainRun } from "@langchain/core/callbacks/manager";
import {
All,
BaseCheckpointSaver,
Expand Down Expand Up @@ -54,7 +53,6 @@ import {
Command,
NULL_TASK_ID,
INPUT,
RESUME,
PUSH,
} from "../constants.js";
import {
Expand All @@ -71,8 +69,6 @@ import {
GraphRecursionError,
GraphValueError,
InvalidUpdateError,
isGraphBubbleUp,
isGraphInterrupt,
} from "../errors.js";
import {
_prepareNextTasks,
Expand All @@ -89,7 +85,6 @@ import {
} from "./utils/index.js";
import { findSubgraphPregel } from "./utils/subgraph.js";
import { PregelLoop, IterableReadableWritableStream } from "./loop.js";
import { executeTasksWithRetry } from "./retry.js";
import {
ChannelKeyPlaceholder,
isConfiguredManagedValue,
Expand All @@ -102,6 +97,7 @@ import { gatherIterator, patchConfigurable } from "../utils.js";
import { ensureLangGraphConfig } from "./utils/config.js";
import { LangGraphRunnableConfig } from "./runnable_types.js";
import { StreamMessagesHandler } from "./messages.js";
import { PregelRunner } from "./runner.js";

type WriteValue = Runnable | RunnableFunc<unknown, unknown> | unknown;

Expand Down Expand Up @@ -1156,115 +1152,6 @@ export class Pregel<
};
}

async _runLoop(params: {
loop: PregelLoop;
interruptAfter: string[] | "*";
interruptBefore: string[] | "*";
runManager?: CallbackManagerForChainRun;
debug: boolean;
config: LangGraphRunnableConfig;
}) {
const { loop, interruptAfter, interruptBefore, runManager, debug, config } =
params;
let tickError;
try {
while (
await loop.tick({
inputKeys: this.inputChannels as string | string[],
interruptAfter,
interruptBefore,
manager: runManager,
})
) {
if (debug) {
printStepCheckpoint(
loop.checkpointMetadata.step,
loop.channels,
this.streamChannelsList as string[]
);
}
if (debug) {
printStepTasks(loop.step, Object.values(loop.tasks));
}
// execute tasks, and wait for one to fail or all to finish.
// each task is independent from all other concurrent tasks
// yield updates/debug output as each task finishes
const taskStream = executeTasksWithRetry(
Object.values(loop.tasks).filter((task) => task.writes.length === 0),
{
stepTimeout: this.stepTimeout,
signal: config.signal,
retryPolicy: this.retryPolicy,
}
);
let graphInterrupt;
for await (const { task, error } of taskStream) {
if (error !== undefined) {
if (isGraphBubbleUp(error)) {
if (loop.isNested) {
throw error;
}
if (isGraphInterrupt(error)) {
graphInterrupt = error;
if (error.interrupts.length) {
const interrupts: PendingWrite<string>[] =
error.interrupts.map((interrupt) => [INTERRUPT, interrupt]);
const resumes = task.writes.filter((w) => w[0] === RESUME);
if (resumes.length) {
interrupts.push(...resumes);
}
loop.putWrites(task.id, interrupts);
}
}
} else {
loop.putWrites(task.id, [
[ERROR, { message: error.message, name: error.name }],
]);
throw error;
}
} else {
loop.putWrites(task.id, task.writes);
}
}

if (debug) {
printStepWrites(
loop.step,
Object.values(loop.tasks)
.map((task) => task.writes)
.flat(),
this.streamChannelsList as string[]
);
}
if (graphInterrupt !== undefined) {
throw graphInterrupt;
}
}
if (loop.status === "out_of_steps") {
throw new GraphRecursionError(
[
`Recursion limit of ${config.recursionLimit} reached`,
"without hitting a stop condition. You can increase the",
`limit by setting the "recursionLimit" config key.`,
].join(" "),
{
lc_error_code: "GRAPH_RECURSION_LIMIT",
}
);
}
} catch (e) {
tickError = e as Error;
const suppress = await loop.finishAndHandleError(tickError);
if (!suppress) {
throw e;
}
} finally {
if (tickError === undefined) {
await loop.finishAndHandleError();
}
}
}

override async *_streamIterator(
input: PregelInputType | Command,
options?: Partial<PregelOptions<Nn, Cc>>
Expand Down Expand Up @@ -1365,21 +1252,23 @@ export class Pregel<
streamKeys: this.streamChannelsAsIs as string | string[],
store,
stream,
interruptAfter,
interruptBefore,
manager: runManager,
debug: this.debug,
});

const runner = new PregelRunner({
loop,
});

if (options?.subgraphs) {
loop.config.configurable = {
...loop.config.configurable,
[CONFIG_KEY_STREAM]: loop.stream,
};
}
await this._runLoop({
loop,
interruptAfter,
interruptBefore,
runManager,
debug,
config,
});
await this._runLoop({ loop, runner, debug, config });
} catch (e) {
loopError = e;
} finally {
Expand Down Expand Up @@ -1472,4 +1361,67 @@ export class Pregel<
}
return chunks as OutputType;
}

private async _runLoop(params: {
loop: PregelLoop;
runner: PregelRunner;
config: RunnableConfig;
debug: boolean;
}) {
const { loop, runner, debug, config } = params;
let tickError;
try {
while (
await loop.tick({
inputKeys: this.inputChannels as string | string[],
})
) {
if (debug) {
printStepCheckpoint(
loop.checkpointMetadata.step,
loop.channels,
this.streamChannelsList as string[]
);
}
if (debug) {
printStepTasks(loop.step, Object.values(loop.tasks));
}
await runner.tick({
timeout: this.stepTimeout,
retryPolicy: this.retryPolicy,
onStepWrite: (step, writes) => {
if (debug) {
printStepWrites(
step,
writes,
this.streamChannelsList as string[]
);
}
},
});
}
if (loop.status === "out_of_steps") {
throw new GraphRecursionError(
[
`Recursion limit of ${config.recursionLimit} reached`,
"without hitting a stop condition. You can increase the",
`limit by setting the "recursionLimit" config key.`,
].join(" "),
{
lc_error_code: "GRAPH_RECURSION_LIMIT",
}
);
}
} catch (e) {
tickError = e as Error;
const suppress = await loop.finishAndHandleError(tickError);
if (!suppress) {
throw e;
}
} finally {
if (tickError === undefined) {
await loop.finishAndHandleError();
}
}
}
}
Loading

0 comments on commit b290953

Please sign in to comment.