-
Notifications
You must be signed in to change notification settings - Fork 119
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #86 from Rajaniraiyn/openai-classifier
add OpenAI Classifier
- Loading branch information
Showing
4 changed files
with
381 additions
and
0 deletions.
There are no files selected for viewing
64 changes: 64 additions & 0 deletions
64
docs/src/content/docs/classifiers/built-in/openai-classifier.mdx
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
--- | ||
title: OpenAI Classifier | ||
description: How to configure the OpenAI classifier | ||
--- | ||
|
||
The OpenAI Classifier is a built-in classifier for the Multi-Agent Orchestrator that leverages OpenAI's language models for intent classification. It provides robust classification capabilities using OpenAI's state-of-the-art models like GPT-4o. | ||
|
||
The OpenAI Classifier extends the abstract `Classifier` class and uses the OpenAI API client to process requests and classify user intents. | ||
|
||
## Features | ||
|
||
- Utilizes OpenAI's advanced models (e.g., GPT-4) for intent classification | ||
- Configurable model selection and inference parameters | ||
- Supports custom system prompts and variables | ||
- Handles conversation history for context-aware classification | ||
|
||
### Basic Usage | ||
|
||
To use the OpenAIClassifier, you need to create an instance with your OpenAI API key and pass it to the Multi-Agent Orchestrator: | ||
|
||
import { Tabs, TabItem } from '@astrojs/starlight/components'; | ||
|
||
<Tabs syncKey="runtime"> | ||
<TabItem label="TypeScript" icon="seti:typescript" color="blue"> | ||
```typescript | ||
import { OpenAIClassifier } from "multi-agent-orchestrator"; | ||
import { MultiAgentOrchestrator } from "multi-agent-orchestrator"; | ||
|
||
const openaiClassifier = new OpenAIClassifier({ | ||
apiKey: 'your-openai-api-key' | ||
}); | ||
|
||
const orchestrator = new MultiAgentOrchestrator({ classifier: openaiClassifier }); | ||
``` | ||
</TabItem> | ||
</Tabs> | ||
|
||
### Custom Configuration | ||
|
||
You can customize the OpenAIClassifier by providing additional options: | ||
|
||
<Tabs syncKey="runtime"> | ||
<TabItem label="TypeScript" icon="seti:typescript" color="blue"> | ||
```typescript | ||
const customOpenAIClassifier = new OpenAIClassifier({ | ||
apiKey: 'your-openai-api-key', | ||
modelId: 'gpt-4o', | ||
inferenceConfig: { | ||
maxTokens: 500, | ||
temperature: 0.7, | ||
topP: 0.9, | ||
stopSequences: [''] | ||
} | ||
}); | ||
const orchestrator = new MultiAgentOrchestrator({ classifier: customOpenAIClassifier }); | ||
``` | ||
</TabItem> | ||
</Tabs> | ||
|
||
The OpenAIClassifier accepts the following configuration options: | ||
|
||
- `api_key` (required): Your OpenAI API key. | ||
- `model_id` (optional): The ID of the OpenAI model to use. Defaults to GPT-4 Turbo |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import OpenAI from "openai"; | ||
import { | ||
ConversationMessage, | ||
OPENAI_MODEL_ID_GPT_O_MINI | ||
} from "../types"; | ||
import { isClassifierToolInput } from "../utils/helpers"; | ||
import { Logger } from "../utils/logger"; | ||
import { Classifier, ClassifierResult } from "./classifier"; | ||
|
||
export interface OpenAIClassifierOptions { | ||
// Optional: The ID of the OpenAI model to use for classification | ||
// If not provided, a default model may be used | ||
modelId?: string; | ||
|
||
// Optional: Configuration for the inference process | ||
inferenceConfig?: { | ||
// Maximum number of tokens to generate in the response | ||
maxTokens?: number; | ||
|
||
// Controls randomness in output generation | ||
temperature?: number; | ||
|
||
// Controls diversity of output via nucleus sampling | ||
topP?: number; | ||
|
||
// Array of sequences that will stop the model from generating further tokens | ||
stopSequences?: string[]; | ||
}; | ||
|
||
// The API key for authenticating with OpenAI's services | ||
apiKey: string; | ||
} | ||
|
||
export class OpenAIClassifier extends Classifier { | ||
private client: OpenAI; | ||
protected inferenceConfig: { | ||
maxTokens?: number; | ||
temperature?: number; | ||
topP?: number; | ||
stopSequences?: string[]; | ||
}; | ||
|
||
private tools: OpenAI.ChatCompletionTool[] = [ | ||
{ | ||
type: "function", | ||
function: { | ||
name: 'analyzePrompt', | ||
description: 'Analyze the user input and provide structured output', | ||
parameters: { | ||
type: 'object', | ||
properties: { | ||
userinput: { | ||
type: 'string', | ||
description: 'The original user input', | ||
}, | ||
selected_agent: { | ||
type: 'string', | ||
description: 'The name of the selected agent', | ||
}, | ||
confidence: { | ||
type: 'number', | ||
description: 'Confidence level between 0 and 1', | ||
}, | ||
}, | ||
required: ['userinput', 'selected_agent', 'confidence'], | ||
}, | ||
}, | ||
}, | ||
]; | ||
|
||
constructor(options: OpenAIClassifierOptions) { | ||
super(); | ||
|
||
if (!options.apiKey) { | ||
throw new Error("OpenAI API key is required"); | ||
} | ||
this.client = new OpenAI({ apiKey: options.apiKey }); | ||
this.modelId = options.modelId || OPENAI_MODEL_ID_GPT_O_MINI; | ||
|
||
const defaultMaxTokens = 1000; | ||
this.inferenceConfig = { | ||
maxTokens: options.inferenceConfig?.maxTokens ?? defaultMaxTokens, | ||
temperature: options.inferenceConfig?.temperature, | ||
topP: options.inferenceConfig?.topP, | ||
stopSequences: options.inferenceConfig?.stopSequences, | ||
}; | ||
} | ||
|
||
async processRequest( | ||
inputText: string, | ||
chatHistory: ConversationMessage[] | ||
): Promise<ClassifierResult> { | ||
const messages: OpenAI.ChatCompletionMessageParam[] = [ | ||
{ | ||
role: 'system', | ||
content: this.systemPrompt | ||
}, | ||
{ | ||
role: 'user', | ||
content: inputText | ||
} | ||
]; | ||
|
||
try { | ||
const response = await this.client.chat.completions.create({ | ||
model: this.modelId, | ||
messages: messages, | ||
max_tokens: this.inferenceConfig.maxTokens, | ||
temperature: this.inferenceConfig.temperature, | ||
top_p: this.inferenceConfig.topP, | ||
tools: this.tools, | ||
tool_choice: { type: "function", function: { name: "analyzePrompt" } } | ||
}); | ||
|
||
const toolCall = response.choices[0]?.message?.tool_calls?.[0]; | ||
|
||
if (!toolCall || toolCall.function.name !== "analyzePrompt") { | ||
throw new Error("No valid tool call found in the response"); | ||
} | ||
|
||
const toolInput = JSON.parse(toolCall.function.arguments); | ||
|
||
if (!isClassifierToolInput(toolInput)) { | ||
throw new Error("Tool input does not match expected structure"); | ||
} | ||
|
||
const intentClassifierResult: ClassifierResult = { | ||
selectedAgent: this.getAgentById(toolInput.selected_agent), | ||
confidence: parseFloat(toolInput.confidence), | ||
}; | ||
return intentClassifierResult; | ||
|
||
} catch (error) { | ||
Logger.logger.error("Error processing request:", error); | ||
throw error; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
import { OpenAIClassifier, OpenAIClassifierOptions } from '../../src/classifiers/openAIClassifier'; | ||
import OpenAI from 'openai'; | ||
import { ConversationMessage, OPENAI_MODEL_ID_GPT_O_MINI, ParticipantRole } from "../../src/types"; | ||
import { MockAgent } from '../mock/mockAgent'; | ||
|
||
// Mock the OpenAI module | ||
jest.mock('openai'); | ||
|
||
describe('OpenAIClassifier', () => { | ||
let classifier: OpenAIClassifier; | ||
let mockCreateCompletion: jest.Mock; | ||
|
||
const defaultOptions: OpenAIClassifierOptions = { | ||
apiKey: 'test-api-key', | ||
}; | ||
|
||
beforeEach(() => { | ||
// Create a mock for the create method | ||
mockCreateCompletion = jest.fn(); | ||
|
||
// Mock the OpenAI constructor and chat.completions.create method | ||
(OpenAI as jest.MockedClass<typeof OpenAI>).mockImplementation(() => ({ | ||
chat: { | ||
completions: { | ||
create: mockCreateCompletion, | ||
}, | ||
}, | ||
} as unknown as OpenAI)); | ||
|
||
classifier = new OpenAIClassifier(defaultOptions); | ||
}); | ||
|
||
afterEach(() => { | ||
jest.clearAllMocks(); | ||
}); | ||
|
||
describe('constructor', () => { | ||
it('should create an instance with default options', () => { | ||
expect(classifier).toBeInstanceOf(OpenAIClassifier); | ||
expect(OpenAI).toHaveBeenCalledWith({ apiKey: 'test-api-key' }); | ||
}); | ||
|
||
it('should use custom model ID if provided', () => { | ||
const customOptions: OpenAIClassifierOptions = { | ||
...defaultOptions, | ||
modelId: 'custom-model-id', | ||
}; | ||
const customClassifier = new OpenAIClassifier(customOptions); | ||
expect(customClassifier['modelId']).toBe('custom-model-id'); | ||
}); | ||
|
||
it('should use default model ID if not provided', () => { | ||
expect(classifier['modelId']).toBe(OPENAI_MODEL_ID_GPT_O_MINI); | ||
}); | ||
|
||
it('should set inference config with custom values', () => { | ||
const customOptions: OpenAIClassifierOptions = { | ||
...defaultOptions, | ||
inferenceConfig: { | ||
maxTokens: 500, | ||
temperature: 0.7, | ||
topP: 0.9, | ||
stopSequences: ['STOP'], | ||
}, | ||
}; | ||
const customClassifier = new OpenAIClassifier(customOptions); | ||
expect(customClassifier['inferenceConfig']).toEqual(customOptions.inferenceConfig); | ||
}); | ||
|
||
it('should throw an error if API key is not provided', () => { | ||
expect(() => new OpenAIClassifier({ apiKey: '' })).toThrow('OpenAI API key is required'); | ||
}); | ||
}); | ||
|
||
describe('processRequest', () => { | ||
const inputText = 'Hello, how are you?'; | ||
const chatHistory: ConversationMessage[] = []; | ||
|
||
it('should process request successfully', async () => { | ||
const mockResponse = { | ||
choices: [{ | ||
message: { | ||
tool_calls: [{ | ||
function: { | ||
name: 'analyzePrompt', | ||
arguments: JSON.stringify({ | ||
userinput: inputText, | ||
selected_agent: 'test-agent', | ||
confidence: 0.95, | ||
}), | ||
}, | ||
}], | ||
}, | ||
}], | ||
}; | ||
|
||
const mockAgent = { | ||
'test-agent': new MockAgent({ | ||
name: "test-agent", | ||
description: 'A tech support agent', | ||
}) | ||
}; | ||
|
||
classifier.setAgents(mockAgent); | ||
|
||
mockCreateCompletion.mockResolvedValue(mockResponse); | ||
|
||
const result = await classifier.processRequest(inputText, chatHistory); | ||
|
||
expect(mockCreateCompletion).toHaveBeenCalledWith({ | ||
model: OPENAI_MODEL_ID_GPT_O_MINI, | ||
max_tokens: 1000, | ||
messages: [ | ||
{ | ||
role: 'system', | ||
content: classifier['systemPrompt'], // Use the actual system prompt | ||
}, | ||
{ | ||
role: 'user', | ||
content: inputText | ||
} | ||
], | ||
temperature: undefined, | ||
top_p: undefined, | ||
tools: classifier['tools'], // Use the actual tools array | ||
tool_choice: { type: "function", function: { name: "analyzePrompt" } } | ||
}); | ||
|
||
expect(result).toEqual({ | ||
selectedAgent: expect.any(MockAgent), | ||
confidence: 0.95, | ||
}); | ||
}); | ||
|
||
it('should throw an error if no tool calls are found in the response', async () => { | ||
const mockResponse = { | ||
choices: [{ | ||
message: {} | ||
}], | ||
}; | ||
|
||
mockCreateCompletion.mockResolvedValue(mockResponse); | ||
|
||
await expect(classifier.processRequest(inputText, chatHistory)) | ||
.rejects.toThrow('No valid tool call found in the response'); | ||
}); | ||
|
||
it('should throw an error if tool input does not match expected structure', async () => { | ||
const mockResponse = { | ||
choices: [{ | ||
message: { | ||
tool_calls: [{ | ||
function: { | ||
name: 'analyzePrompt', | ||
arguments: JSON.stringify({ | ||
invalidKey: 'invalidValue', | ||
}), | ||
}, | ||
}], | ||
}, | ||
}], | ||
}; | ||
|
||
mockCreateCompletion.mockResolvedValue(mockResponse); | ||
|
||
await expect(classifier.processRequest(inputText, chatHistory)) | ||
.rejects.toThrow('Tool input does not match expected structure'); | ||
}); | ||
|
||
it('should throw an error if API request fails', async () => { | ||
const errorMessage = 'API request failed'; | ||
mockCreateCompletion.mockRejectedValue(new Error(errorMessage)); | ||
|
||
await expect(classifier.processRequest(inputText, chatHistory)) | ||
.rejects.toThrow(errorMessage); | ||
}); | ||
}); | ||
}); |