Skip to content

Commit

Permalink
feat(langgraph): functional API
Browse files Browse the repository at this point in the history
  • Loading branch information
benjamincburns committed Jan 22, 2025
1 parent b290953 commit bb548c2
Show file tree
Hide file tree
Showing 11 changed files with 683 additions and 31 deletions.
10 changes: 9 additions & 1 deletion libs/langgraph/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ export const MISSING = Symbol.for("__missing__");
export const INPUT = "__input__";
export const ERROR = "__error__";
export const CONFIG_KEY_SEND = "__pregel_send";
export const CONFIG_KEY_CALL = "__pregel_call";
export const CONFIG_KEY_READ = "__pregel_read";
export const CONFIG_KEY_CHECKPOINTER = "__pregel_checkpointer";
export const CONFIG_KEY_RESUMING = "__pregel_resuming";
Expand All @@ -18,6 +19,8 @@ export const CONFIG_KEY_CHECKPOINT_MAP = "checkpoint_map";

export const INTERRUPT = "__interrupt__";
export const RESUME = "__resume__";
export const NO_WRITES = "__no_writes__";
export const RETURN = "__return__";
export const RUNTIME_PLACEHOLDER = "__pregel_runtime_placeholder__";
export const RECURSION_LIMIT_DEFAULT = 25;

Expand Down Expand Up @@ -58,7 +61,12 @@ export interface SendInterface {

export function _isSendInterface(x: unknown): x is SendInterface {
const operation = x as SendInterface;
return typeof operation.node === "string" && operation.args !== undefined;
return (
operation !== null &&
operation !== undefined &&
typeof operation.node === "string" &&
operation.args !== undefined
);
}

/**
Expand Down
128 changes: 128 additions & 0 deletions libs/langgraph/src/func.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import { BaseCheckpointSaver } from "@langchain/langgraph-checkpoint";
import { Pregel } from "./pregel/index.js";
import { PregelNode } from "./pregel/read.js";
import { END, START } from "./graph/graph.js";
import { BaseChannel } from "./channels/base.js";
import { ChannelWrite, PASSTHROUGH } from "./pregel/write.js";
import { TAG_HIDDEN } from "./constants.js";
import { ManagedValueSpec } from "./managed/base.js";
import { EphemeralValue } from "./channels/ephemeral_value.js";
import { call, getRunnableForFunc } from "./pregel/call.js";
import { RetryPolicy } from "./pregel/utils/index.js";
import { Promisified } from "./pregel/types.js";

/**
* Options for the @see task function
*
* @experimental
*/
export type TaskOptions = {
/**
* The retry policy for the task
*/
retry?: RetryPolicy;
};

/**
* Wraps a function in a task that can be retried
*
* !!! warning "Experimental"
* This is an experimental API that is subject to change.
* Do not use for production code.
*
* @experimental
*
* @param name - The name of the task, analagous to the node name in @see StateGraph
* @param func - The function that executes this task
* @param options.retry - The retry policy for the task
* @returns A proxy function that accepts the same arguments as the original and always returns the result as a @see Promise.
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export function task<FuncT extends (...args: any[]) => any>(
name: string,
func: FuncT,
options?: TaskOptions
): (...args: Parameters<FuncT>) => Promisified<FuncT> {
return (...args: Parameters<FuncT>) => {
return call({ func, name, retry: options?.retry }, ...args);
};
}

/**
* Options for the @see entrypoint function
*
* @experimental
*/
export type EntrypointOptions = {
/**
* The name of the entrypoint, analagous to the node name in @see StateGraph
*
* @experimental
*/
name: string;
/**
* The checkpointer for the entrypoint
*
* @experimental
*/
checkpointer?: BaseCheckpointSaver;
};

/**
* Creates an entrypoint that returns a Pregel instance
*
* !!! warning "Experimental"
* This is an experimental API that is subject to change.
* Do not use for production code.
*
* @experimental
*
* @param options.name - The name of the entrypoint, analagous to the node name in @see StateGraph
* @param options.checkpointer - The checkpointer for the entrypoint
* @param func - The function that executes this entrypoint
* @returns A Pregel instance that can be run
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export function entrypoint<FuncT extends (...args: any[]) => any>(
{ name, checkpointer }: EntrypointOptions,
func: FuncT
): Pregel<
Record<string, PregelNode>,
Record<string, BaseChannel | ManagedValueSpec>,
Record<string, unknown>,
Parameters<FuncT>,
ReturnType<FuncT>
> {
const p = new Pregel<
Record<string, PregelNode>,
Record<string, BaseChannel | ManagedValueSpec>,
Record<string, unknown>,
Parameters<FuncT>,
ReturnType<FuncT>
>({
checkpointer,
nodes: {
[name]: new PregelNode({
bound: getRunnableForFunc(name, func),
triggers: [START],
channels: [START],
writers: [
new ChannelWrite(
[{ channel: END, value: PASSTHROUGH }],
[TAG_HIDDEN]
),
],
}),
},
channels: {
[START]: new EphemeralValue(),
[END]: new EphemeralValue(),
},
inputChannels: START,
outputChannels: END,
streamChannels: [],
streamMode: "values",
});
p.name = name;
return p;
}
144 changes: 134 additions & 10 deletions libs/langgraph/src/pregel/algo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,23 @@ import {
NULL_TASK_ID,
CONFIG_KEY_SCRATCHPAD,
CONFIG_KEY_WRITES,
RETURN,
ERROR,
} from "../constants.js";
import { PregelExecutableTask, PregelTaskDescription } from "./types.js";
import {
Call,
isCall,
PregelExecutableTask,
PregelTaskDescription,
SimpleTaskPath,
TaskPath,
VariadicTaskPath,
} from "./types.js";
import { EmptyChannelError, InvalidUpdateError } from "../errors.js";
import { getNullChannelVersion } from "./utils/index.js";
import { ManagedValueMapping } from "../managed/base.js";
import { LangGraphRunnableConfig } from "./runnable_types.js";
import { getRunnableForFunc } from "./call.js";

/**
* Construct a type with a set of properties K of type T
Expand All @@ -64,7 +75,7 @@ export type WritesProtocol<C = string> = {
name: string;
writes: PendingWrite<C>[];
triggers: string[];
path?: [string, ...(string | number)[]];
path?: TaskPath;
};

export const increment = (current?: number) => {
Expand Down Expand Up @@ -171,7 +182,7 @@ export function _localWrite(
writes: [string, any][]
) {
for (const [chan, value] of writes) {
if (chan === TASKS) {
if (chan === TASKS && value != null) {
if (!_isSend(value)) {
throw new InvalidUpdateError(
`Invalid packet type, expected SendProtocol, got ${JSON.stringify(
Expand All @@ -191,7 +202,13 @@ export function _localWrite(
commit(writes);
}

const IGNORE = new Set<string | number | symbol>([PUSH, RESUME, INTERRUPT]);
const IGNORE = new Set<string | number | symbol>([
PUSH,
RESUME,
INTERRUPT,
RETURN,
ERROR,
]);

export function _applyWrites<Cc extends Record<string, BaseChannel>>(
checkpoint: Checkpoint,
Expand Down Expand Up @@ -467,7 +484,7 @@ export function _prepareSingleTask<
Nn extends StrRecord<string, PregelNode>,
Cc extends StrRecord<string, BaseChannel>
>(
taskPath: [string, string | number],
taskPath: SimpleTaskPath,
checkpoint: ReadonlyCheckpoint,
pendingWrites: [string, string, unknown][] | undefined,
processes: Nn,
Expand All @@ -482,7 +499,7 @@ export function _prepareSingleTask<
Nn extends StrRecord<string, PregelNode>,
Cc extends StrRecord<string, BaseChannel>
>(
taskPath: [string, ...(string | number)[]],
taskPath: TaskPath,
checkpoint: ReadonlyCheckpoint,
pendingWrites: [string, string, unknown][] | undefined,
processes: Nn,
Expand All @@ -497,7 +514,7 @@ export function _prepareSingleTask<
Nn extends StrRecord<string, PregelNode>,
Cc extends StrRecord<string, BaseChannel>
>(
taskPath: [string, ...(string | number)[]],
taskPath: TaskPath,
checkpoint: ReadonlyCheckpoint,
pendingWrites: [string, string, unknown][] | undefined,
processes: Nn,
Expand All @@ -516,7 +533,7 @@ export function _prepareSingleTask<
Nn extends StrRecord<string, PregelNode>,
Cc extends StrRecord<string, BaseChannel>
>(
taskPath: [string, ...(string | number)[]],
taskPath: TaskPath,
checkpoint: ReadonlyCheckpoint,
pendingWrites: [string, string, unknown][] | undefined,
processes: Nn,
Expand All @@ -533,7 +550,109 @@ export function _prepareSingleTask<
const configurable = config.configurable ?? {};
const parentNamespace = configurable.checkpoint_ns ?? "";

if (taskPath[0] === PUSH) {
if (taskPath[0] === PUSH && isCall(taskPath[taskPath.length - 1])) {
const call = taskPath[taskPath.length - 1] as Call;
const proc = getRunnableForFunc(call.name, call.func);
const triggers = [PUSH];
const checkpointNamespace =
parentNamespace === ""
? call.name
: `${parentNamespace}${CHECKPOINT_NAMESPACE_SEPARATOR}${call.name}`;
const id = uuid5(
JSON.stringify([
checkpointNamespace,
step.toString(),
call.name,
PUSH,
taskPath[1],
taskPath[2],
]),
checkpoint.id
);
const taskCheckpointNamespace = `${checkpointNamespace}${CHECKPOINT_NAMESPACE_END}${id}`;
const metadata = {
langgraph_step: step,
langgraph_node: call.name,
langgraph_triggers: triggers,
langgraph_path: taskPath.slice(0, 3),
langgraph_checkpoint_ns: taskCheckpointNamespace,
};
if (forExecution) {
const writes: [keyof Cc, unknown][] = [];
const task = {
name: call.name,
input: call.input,
proc,
writes,
config: patchConfig(
mergeConfigs(config, {
metadata,
store: extra.store ?? config.store,
}),
{
runName: call.name,
callbacks: manager?.getChild(`graph:step:${step}`),
configurable: {
[CONFIG_KEY_TASK_ID]: id,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
[CONFIG_KEY_SEND]: (writes_: [string, any][]) =>
_localWrite(
step,
(items: [keyof Cc, unknown][]) => writes.push(...items),
processes,
managed,
writes_
),
[CONFIG_KEY_READ]: (
select_: Array<keyof Cc> | keyof Cc,
fresh_: boolean = false
) =>
_localRead(
step,
checkpoint,
channels,
managed,
{
name: call.name,
writes: writes as Array<[string, unknown]>,
triggers,
path: taskPath.slice(0, 3) as VariadicTaskPath,
},
select_,
fresh_
),
[CONFIG_KEY_CHECKPOINTER]:
checkpointer ?? configurable[CONFIG_KEY_CHECKPOINTER],
[CONFIG_KEY_CHECKPOINT_MAP]: {
...configurable[CONFIG_KEY_CHECKPOINT_MAP],
[parentNamespace]: checkpoint.id,
},
[CONFIG_KEY_WRITES]: [
...(pendingWrites || []),
...(configurable[CONFIG_KEY_WRITES] || []),
].filter((w) => w[0] === NULL_TASK_ID || w[0] === id),
[CONFIG_KEY_SCRATCHPAD]: {},
checkpoint_id: undefined,
checkpoint_ns: taskCheckpointNamespace,
},
}
),
triggers,
retry_policy: call.retry,
id,
path: taskPath.slice(0, 3) as VariadicTaskPath,
writers: [],
};
return task;
} else {
return {
id,
name: call.name,
interrupts: [],
path: taskPath.slice(0, 3) as VariadicTaskPath,
};
}
} else if (taskPath[0] === PUSH) {
const index =
typeof taskPath[1] === "number"
? taskPath[1]
Expand Down Expand Up @@ -654,7 +773,12 @@ export function _prepareSingleTask<
};
}
} else {
return { id: taskId, name: packet.node, interrupts: [], path: taskPath };
return {
id: taskId,
name: packet.node,
interrupts: [],
path: taskPath,
};
}
} else if (taskPath[0] === PULL) {
const name = taskPath[1].toString();
Expand Down
Loading

0 comments on commit bb548c2

Please sign in to comment.