Skip to content

Commit

Permalink
Merge pull request #86 from Rajaniraiyn/openai-classifier
Browse files Browse the repository at this point in the history
add OpenAI Classifier
  • Loading branch information
brnaba-aws authored Nov 22, 2024
2 parents 0f6a325 + ce29f84 commit 951b05c
Show file tree
Hide file tree
Showing 4 changed files with 381 additions and 0 deletions.
64 changes: 64 additions & 0 deletions docs/src/content/docs/classifiers/built-in/openai-classifier.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
---
title: OpenAI Classifier
description: How to configure the OpenAI classifier
---

The OpenAI Classifier is a built-in classifier for the Multi-Agent Orchestrator that leverages OpenAI's language models for intent classification. It provides robust classification capabilities using OpenAI's state-of-the-art models like GPT-4o.

The OpenAI Classifier extends the abstract `Classifier` class and uses the OpenAI API client to process requests and classify user intents.

## Features

- Utilizes OpenAI's advanced models (e.g., GPT-4) for intent classification
- Configurable model selection and inference parameters
- Supports custom system prompts and variables
- Handles conversation history for context-aware classification

### Basic Usage

To use the OpenAIClassifier, you need to create an instance with your OpenAI API key and pass it to the Multi-Agent Orchestrator:

import { Tabs, TabItem } from '@astrojs/starlight/components';

<Tabs syncKey="runtime">
<TabItem label="TypeScript" icon="seti:typescript" color="blue">
```typescript
import { OpenAIClassifier } from "multi-agent-orchestrator";
import { MultiAgentOrchestrator } from "multi-agent-orchestrator";

const openaiClassifier = new OpenAIClassifier({
apiKey: 'your-openai-api-key'
});

const orchestrator = new MultiAgentOrchestrator({ classifier: openaiClassifier });
```
</TabItem>
</Tabs>

### Custom Configuration

You can customize the OpenAIClassifier by providing additional options:

<Tabs syncKey="runtime">
<TabItem label="TypeScript" icon="seti:typescript" color="blue">
```typescript
const customOpenAIClassifier = new OpenAIClassifier({
apiKey: 'your-openai-api-key',
modelId: 'gpt-4o',
inferenceConfig: {
maxTokens: 500,
temperature: 0.7,
topP: 0.9,
stopSequences: ['']
}
});
const orchestrator = new MultiAgentOrchestrator({ classifier: customOpenAIClassifier });
```
</TabItem>
</Tabs>

The OpenAIClassifier accepts the following configuration options:

- `api_key` (required): Your OpenAI API key.
- `model_id` (optional): The ID of the OpenAI model to use. Defaults to GPT-4 Turbo
138 changes: 138 additions & 0 deletions typescript/src/classifiers/openAIClassifier.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import OpenAI from "openai";
import {
ConversationMessage,
OPENAI_MODEL_ID_GPT_O_MINI
} from "../types";
import { isClassifierToolInput } from "../utils/helpers";
import { Logger } from "../utils/logger";
import { Classifier, ClassifierResult } from "./classifier";

export interface OpenAIClassifierOptions {
// Optional: The ID of the OpenAI model to use for classification
// If not provided, a default model may be used
modelId?: string;

// Optional: Configuration for the inference process
inferenceConfig?: {
// Maximum number of tokens to generate in the response
maxTokens?: number;

// Controls randomness in output generation
temperature?: number;

// Controls diversity of output via nucleus sampling
topP?: number;

// Array of sequences that will stop the model from generating further tokens
stopSequences?: string[];
};

// The API key for authenticating with OpenAI's services
apiKey: string;
}

export class OpenAIClassifier extends Classifier {
private client: OpenAI;
protected inferenceConfig: {
maxTokens?: number;
temperature?: number;
topP?: number;
stopSequences?: string[];
};

private tools: OpenAI.ChatCompletionTool[] = [
{
type: "function",
function: {
name: 'analyzePrompt',
description: 'Analyze the user input and provide structured output',
parameters: {
type: 'object',
properties: {
userinput: {
type: 'string',
description: 'The original user input',
},
selected_agent: {
type: 'string',
description: 'The name of the selected agent',
},
confidence: {
type: 'number',
description: 'Confidence level between 0 and 1',
},
},
required: ['userinput', 'selected_agent', 'confidence'],
},
},
},
];

constructor(options: OpenAIClassifierOptions) {
super();

if (!options.apiKey) {
throw new Error("OpenAI API key is required");
}
this.client = new OpenAI({ apiKey: options.apiKey });
this.modelId = options.modelId || OPENAI_MODEL_ID_GPT_O_MINI;

const defaultMaxTokens = 1000;
this.inferenceConfig = {
maxTokens: options.inferenceConfig?.maxTokens ?? defaultMaxTokens,
temperature: options.inferenceConfig?.temperature,
topP: options.inferenceConfig?.topP,
stopSequences: options.inferenceConfig?.stopSequences,
};
}

async processRequest(
inputText: string,
chatHistory: ConversationMessage[]
): Promise<ClassifierResult> {
const messages: OpenAI.ChatCompletionMessageParam[] = [
{
role: 'system',
content: this.systemPrompt
},
{
role: 'user',
content: inputText
}
];

try {
const response = await this.client.chat.completions.create({
model: this.modelId,
messages: messages,
max_tokens: this.inferenceConfig.maxTokens,
temperature: this.inferenceConfig.temperature,
top_p: this.inferenceConfig.topP,
tools: this.tools,
tool_choice: { type: "function", function: { name: "analyzePrompt" } }
});

const toolCall = response.choices[0]?.message?.tool_calls?.[0];

if (!toolCall || toolCall.function.name !== "analyzePrompt") {
throw new Error("No valid tool call found in the response");
}

const toolInput = JSON.parse(toolCall.function.arguments);

if (!isClassifierToolInput(toolInput)) {
throw new Error("Tool input does not match expected structure");
}

const intentClassifierResult: ClassifierResult = {
selectedAgent: this.getAgentById(toolInput.selected_agent),
confidence: parseFloat(toolInput.confidence),
};
return intentClassifierResult;

} catch (error) {
Logger.logger.error("Error processing request:", error);
throw error;
}
}
}
1 change: 1 addition & 0 deletions typescript/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export { AgentResponse } from './agents/agent';

export { BedrockClassifier, BedrockClassifierOptions } from './classifiers/bedrockClassifier';
export { AnthropicClassifier, AnthropicClassifierOptions } from './classifiers/anthropicClassifier';
export { OpenAIClassifier, OpenAIClassifierOptions } from "./classifiers/openAIClassifier"

export { Retriever } from './retrievers/retriever';
export { AmazonKnowledgeBasesRetriever, AmazonKnowledgeBasesRetrieverOptions } from './retrievers/AmazonKBRetriever';
Expand Down
178 changes: 178 additions & 0 deletions typescript/tests/classifiers/OpenAIClassifier.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import { OpenAIClassifier, OpenAIClassifierOptions } from '../../src/classifiers/openAIClassifier';
import OpenAI from 'openai';
import { ConversationMessage, OPENAI_MODEL_ID_GPT_O_MINI, ParticipantRole } from "../../src/types";
import { MockAgent } from '../mock/mockAgent';

// Mock the OpenAI module
jest.mock('openai');

describe('OpenAIClassifier', () => {
let classifier: OpenAIClassifier;
let mockCreateCompletion: jest.Mock;

const defaultOptions: OpenAIClassifierOptions = {
apiKey: 'test-api-key',
};

beforeEach(() => {
// Create a mock for the create method
mockCreateCompletion = jest.fn();

// Mock the OpenAI constructor and chat.completions.create method
(OpenAI as jest.MockedClass<typeof OpenAI>).mockImplementation(() => ({
chat: {
completions: {
create: mockCreateCompletion,
},
},
} as unknown as OpenAI));

classifier = new OpenAIClassifier(defaultOptions);
});

afterEach(() => {
jest.clearAllMocks();
});

describe('constructor', () => {
it('should create an instance with default options', () => {
expect(classifier).toBeInstanceOf(OpenAIClassifier);
expect(OpenAI).toHaveBeenCalledWith({ apiKey: 'test-api-key' });
});

it('should use custom model ID if provided', () => {
const customOptions: OpenAIClassifierOptions = {
...defaultOptions,
modelId: 'custom-model-id',
};
const customClassifier = new OpenAIClassifier(customOptions);
expect(customClassifier['modelId']).toBe('custom-model-id');
});

it('should use default model ID if not provided', () => {
expect(classifier['modelId']).toBe(OPENAI_MODEL_ID_GPT_O_MINI);
});

it('should set inference config with custom values', () => {
const customOptions: OpenAIClassifierOptions = {
...defaultOptions,
inferenceConfig: {
maxTokens: 500,
temperature: 0.7,
topP: 0.9,
stopSequences: ['STOP'],
},
};
const customClassifier = new OpenAIClassifier(customOptions);
expect(customClassifier['inferenceConfig']).toEqual(customOptions.inferenceConfig);
});

it('should throw an error if API key is not provided', () => {
expect(() => new OpenAIClassifier({ apiKey: '' })).toThrow('OpenAI API key is required');
});
});

describe('processRequest', () => {
const inputText = 'Hello, how are you?';
const chatHistory: ConversationMessage[] = [];

it('should process request successfully', async () => {
const mockResponse = {
choices: [{
message: {
tool_calls: [{
function: {
name: 'analyzePrompt',
arguments: JSON.stringify({
userinput: inputText,
selected_agent: 'test-agent',
confidence: 0.95,
}),
},
}],
},
}],
};

const mockAgent = {
'test-agent': new MockAgent({
name: "test-agent",
description: 'A tech support agent',
})
};

classifier.setAgents(mockAgent);

mockCreateCompletion.mockResolvedValue(mockResponse);

const result = await classifier.processRequest(inputText, chatHistory);

expect(mockCreateCompletion).toHaveBeenCalledWith({
model: OPENAI_MODEL_ID_GPT_O_MINI,
max_tokens: 1000,
messages: [
{
role: 'system',
content: classifier['systemPrompt'], // Use the actual system prompt
},
{
role: 'user',
content: inputText
}
],
temperature: undefined,
top_p: undefined,
tools: classifier['tools'], // Use the actual tools array
tool_choice: { type: "function", function: { name: "analyzePrompt" } }
});

expect(result).toEqual({
selectedAgent: expect.any(MockAgent),
confidence: 0.95,
});
});

it('should throw an error if no tool calls are found in the response', async () => {
const mockResponse = {
choices: [{
message: {}
}],
};

mockCreateCompletion.mockResolvedValue(mockResponse);

await expect(classifier.processRequest(inputText, chatHistory))
.rejects.toThrow('No valid tool call found in the response');
});

it('should throw an error if tool input does not match expected structure', async () => {
const mockResponse = {
choices: [{
message: {
tool_calls: [{
function: {
name: 'analyzePrompt',
arguments: JSON.stringify({
invalidKey: 'invalidValue',
}),
},
}],
},
}],
};

mockCreateCompletion.mockResolvedValue(mockResponse);

await expect(classifier.processRequest(inputText, chatHistory))
.rejects.toThrow('Tool input does not match expected structure');
});

it('should throw an error if API request fails', async () => {
const errorMessage = 'API request failed';
mockCreateCompletion.mockRejectedValue(new Error(errorMessage));

await expect(classifier.processRequest(inputText, chatHistory))
.rejects.toThrow(errorMessage);
});
});
});

0 comments on commit 951b05c

Please sign in to comment.