Skip to content

Commit

Permalink
chore(internal): cleanup event stream helpers (#950)
Browse files Browse the repository at this point in the history
* [wip]: refactor

* a solution

* Bind this

* fix formatting

---------

Co-authored-by: Young-Jin Park <[email protected]>
  • Loading branch information
RobertCraigie and yjp20 authored Jul 30, 2024
1 parent c610c81 commit 8f49956
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 395 deletions.
255 changes: 16 additions & 239 deletions src/lib/AbstractChatCompletionRunner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
type ChatCompletionCreateParams,
type ChatCompletionTool,
} from 'openai/resources/chat/completions';
import { APIUserAbortError, OpenAIError } from 'openai/error';
import { OpenAIError } from 'openai/error';
import {
type RunnableFunction,
isRunnableFunctionWithParse,
Expand All @@ -20,75 +20,36 @@ import {
ChatCompletionStreamingToolRunnerParams,
} from './ChatCompletionStreamingRunner';
import { isAssistantMessage, isFunctionMessage, isToolMessage } from './chatCompletionUtils';
import { BaseEvents, EventStream } from './EventStream';

const DEFAULT_MAX_CHAT_COMPLETIONS = 10;
export interface RunnerOptions extends Core.RequestOptions {
/** How many requests to make before canceling. Default 10. */
maxChatCompletions?: number;
}

export abstract class AbstractChatCompletionRunner<
Events extends CustomEvents<any> = AbstractChatCompletionRunnerEvents,
> {
controller: AbortController = new AbortController();

#connectedPromise: Promise<void>;
#resolveConnectedPromise: () => void = () => {};
#rejectConnectedPromise: (error: OpenAIError) => void = () => {};

#endPromise: Promise<void>;
#resolveEndPromise: () => void = () => {};
#rejectEndPromise: (error: OpenAIError) => void = () => {};

#listeners: { [Event in keyof Events]?: ListenersForEvent<Events, Event> } = {};

export class AbstractChatCompletionRunner<
EventTypes extends AbstractChatCompletionRunnerEvents,
> extends EventStream<EventTypes> {
protected _chatCompletions: ChatCompletion[] = [];
messages: ChatCompletionMessageParam[] = [];

#ended = false;
#errored = false;
#aborted = false;
#catchingPromiseCreated = false;

constructor() {
this.#connectedPromise = new Promise<void>((resolve, reject) => {
this.#resolveConnectedPromise = resolve;
this.#rejectConnectedPromise = reject;
});

this.#endPromise = new Promise<void>((resolve, reject) => {
this.#resolveEndPromise = resolve;
this.#rejectEndPromise = reject;
});

// Don't let these promises cause unhandled rejection errors.
// we will manually cause an unhandled rejection error later
// if the user hasn't registered any error listener or called
// any promise-returning method.
this.#connectedPromise.catch(() => {});
this.#endPromise.catch(() => {});
}

protected _run(executor: () => Promise<any>) {
// Unfortunately if we call `executor()` immediately we get runtime errors about
// references to `this` before the `super()` constructor call returns.
setTimeout(() => {
executor().then(() => {
this._emitFinal();
this._emit('end');
}, this.#handleError);
}, 0);
}

protected _addChatCompletion(chatCompletion: ChatCompletion): ChatCompletion {
protected _addChatCompletion(
this: AbstractChatCompletionRunner<AbstractChatCompletionRunnerEvents>,
chatCompletion: ChatCompletion,
): ChatCompletion {
this._chatCompletions.push(chatCompletion);
this._emit('chatCompletion', chatCompletion);
const message = chatCompletion.choices[0]?.message;
if (message) this._addMessage(message as ChatCompletionMessageParam);
return chatCompletion;
}

protected _addMessage(message: ChatCompletionMessageParam, emit = true) {
protected _addMessage(
this: AbstractChatCompletionRunner<AbstractChatCompletionRunnerEvents>,
message: ChatCompletionMessageParam,
emit = true,
) {
if (!('content' in message)) message.content = null;

this.messages.push(message);
Expand All @@ -110,99 +71,6 @@ export abstract class AbstractChatCompletionRunner<
}
}

protected _connected() {
if (this.ended) return;
this.#resolveConnectedPromise();
this._emit('connect');
}

get ended(): boolean {
return this.#ended;
}

get errored(): boolean {
return this.#errored;
}

get aborted(): boolean {
return this.#aborted;
}

abort() {
this.controller.abort();
}

/**
* Adds the listener function to the end of the listeners array for the event.
* No checks are made to see if the listener has already been added. Multiple calls passing
* the same combination of event and listener will result in the listener being added, and
* called, multiple times.
* @returns this ChatCompletionStream, so that calls can be chained
*/
on<Event extends keyof Events>(event: Event, listener: ListenerForEvent<Events, Event>): this {
const listeners: ListenersForEvent<Events, Event> =
this.#listeners[event] || (this.#listeners[event] = []);
listeners.push({ listener });
return this;
}

/**
* Removes the specified listener from the listener array for the event.
* off() will remove, at most, one instance of a listener from the listener array. If any single
* listener has been added multiple times to the listener array for the specified event, then
* off() must be called multiple times to remove each instance.
* @returns this ChatCompletionStream, so that calls can be chained
*/
off<Event extends keyof Events>(event: Event, listener: ListenerForEvent<Events, Event>): this {
const listeners = this.#listeners[event];
if (!listeners) return this;
const index = listeners.findIndex((l) => l.listener === listener);
if (index >= 0) listeners.splice(index, 1);
return this;
}

/**
* Adds a one-time listener function for the event. The next time the event is triggered,
* this listener is removed and then invoked.
* @returns this ChatCompletionStream, so that calls can be chained
*/
once<Event extends keyof Events>(event: Event, listener: ListenerForEvent<Events, Event>): this {
const listeners: ListenersForEvent<Events, Event> =
this.#listeners[event] || (this.#listeners[event] = []);
listeners.push({ listener, once: true });
return this;
}

/**
* This is similar to `.once()`, but returns a Promise that resolves the next time
* the event is triggered, instead of calling a listener callback.
* @returns a Promise that resolves the next time given event is triggered,
* or rejects if an error is emitted. (If you request the 'error' event,
* returns a promise that resolves with the error).
*
* Example:
*
* const message = await stream.emitted('message') // rejects if the stream errors
*/
emitted<Event extends keyof Events>(
event: Event,
): Promise<
EventParameters<Events, Event> extends [infer Param] ? Param
: EventParameters<Events, Event> extends [] ? void
: EventParameters<Events, Event>
> {
return new Promise((resolve, reject) => {
this.#catchingPromiseCreated = true;
if (event !== 'error') this.once('error', reject);
this.once(event, resolve as any);
});
}

async done(): Promise<void> {
this.#catchingPromiseCreated = true;
await this.#endPromise;
}

/**
* @returns a promise that resolves with the final ChatCompletion, or rejects
* if an error occurred or the stream ended prematurely without producing a ChatCompletion.
Expand Down Expand Up @@ -327,75 +195,7 @@ export abstract class AbstractChatCompletionRunner<
return [...this._chatCompletions];
}

#handleError = (error: unknown) => {
this.#errored = true;
if (error instanceof Error && error.name === 'AbortError') {
error = new APIUserAbortError();
}
if (error instanceof APIUserAbortError) {
this.#aborted = true;
return this._emit('abort', error);
}
if (error instanceof OpenAIError) {
return this._emit('error', error);
}
if (error instanceof Error) {
const openAIError: OpenAIError = new OpenAIError(error.message);
// @ts-ignore
openAIError.cause = error;
return this._emit('error', openAIError);
}
return this._emit('error', new OpenAIError(String(error)));
};

protected _emit<Event extends keyof Events>(event: Event, ...args: EventParameters<Events, Event>) {
// make sure we don't emit any events after end
if (this.#ended) {
return;
}

if (event === 'end') {
this.#ended = true;
this.#resolveEndPromise();
}

const listeners: ListenersForEvent<Events, Event> | undefined = this.#listeners[event];
if (listeners) {
this.#listeners[event] = listeners.filter((l) => !l.once) as any;
listeners.forEach(({ listener }: any) => listener(...args));
}

if (event === 'abort') {
const error = args[0] as APIUserAbortError;
if (!this.#catchingPromiseCreated && !listeners?.length) {
Promise.reject(error);
}
this.#rejectConnectedPromise(error);
this.#rejectEndPromise(error);
this._emit('end');
return;
}

if (event === 'error') {
// NOTE: _emit('error', error) should only be called from #handleError().

const error = args[0] as OpenAIError;
if (!this.#catchingPromiseCreated && !listeners?.length) {
// Trigger an unhandled rejection if the user hasn't registered any error handlers.
// If you are seeing stack traces here, make sure to handle errors via either:
// - runner.on('error', () => ...)
// - await runner.done()
// - await runner.finalChatCompletion()
// - etc.
Promise.reject(error);
}
this.#rejectConnectedPromise(error);
this.#rejectEndPromise(error);
this._emit('end');
}
}

protected _emitFinal() {
protected override _emitFinal(this: AbstractChatCompletionRunner<AbstractChatCompletionRunnerEvents>) {
const completion = this._chatCompletions[this._chatCompletions.length - 1];
if (completion) this._emit('finalChatCompletion', completion);
const finalMessage = this.#getFinalMessage();
Expand Down Expand Up @@ -650,27 +450,7 @@ export abstract class AbstractChatCompletionRunner<
}
}

type CustomEvents<Event extends string> = {
[k in Event]: k extends keyof AbstractChatCompletionRunnerEvents ? AbstractChatCompletionRunnerEvents[k]
: (...args: any[]) => void;
};

type ListenerForEvent<Events extends CustomEvents<any>, Event extends keyof Events> = Event extends (
keyof AbstractChatCompletionRunnerEvents
) ?
AbstractChatCompletionRunnerEvents[Event]
: Events[Event];

type ListenersForEvent<Events extends CustomEvents<any>, Event extends keyof Events> = Array<{
listener: ListenerForEvent<Events, Event>;
once?: boolean;
}>;
type EventParameters<Events extends CustomEvents<any>, Event extends keyof Events> = Parameters<
ListenerForEvent<Events, Event>
>;

export interface AbstractChatCompletionRunnerEvents {
connect: () => void;
export interface AbstractChatCompletionRunnerEvents extends BaseEvents {
functionCall: (functionCall: ChatCompletionMessage.FunctionCall) => void;
message: (message: ChatCompletionMessageParam) => void;
chatCompletion: (completion: ChatCompletion) => void;
Expand All @@ -680,8 +460,5 @@ export interface AbstractChatCompletionRunnerEvents {
finalFunctionCall: (functionCall: ChatCompletionMessage.FunctionCall) => void;
functionCallResult: (content: string) => void;
finalFunctionCallResult: (content: string) => void;
error: (error: OpenAIError) => void;
abort: (error: APIUserAbortError) => void;
end: () => void;
totalUsage: (usage: CompletionUsage) => void;
}
Loading

0 comments on commit 8f49956

Please sign in to comment.