Skip to content

Commit

Permalink
feat(azure): batch api (#839)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-bot authored May 10, 2024
1 parent 6e556d9 commit e279f8c
Show file tree
Hide file tree
Showing 2 changed files with 279 additions and 6 deletions.
11 changes: 5 additions & 6 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ export interface AzureClientOptions extends ClientOptions {
/** API Client for interfacing with the Azure OpenAI API. */
export class AzureOpenAI extends OpenAI {
private _azureADTokenProvider: (() => Promise<string>) | undefined;
private _deployment: string | undefined;
apiVersion: string = '';
/**
* API Client for interfacing with the Azure OpenAI API.
Expand Down Expand Up @@ -412,11 +413,7 @@ export class AzureOpenAI extends OpenAI {
);
}

if (deployment) {
baseURL = `${endpoint}/openai/deployments/${deployment}`;
} else {
baseURL = `${endpoint}/openai`;
}
baseURL = `${endpoint}/openai`;
} else {
if (endpoint) {
throw new Errors.OpenAIError('baseURL and endpoint are mutually exclusive');
Expand All @@ -432,6 +429,7 @@ export class AzureOpenAI extends OpenAI {

this._azureADTokenProvider = azureADTokenProvider;
this.apiVersion = apiVersion;
this._deployment = deployment;
}

override buildRequest(options: Core.FinalRequestOptions<unknown>): {
Expand All @@ -443,7 +441,7 @@ export class AzureOpenAI extends OpenAI {
if (!Core.isObj(options.body)) {
throw new Error('Expected request body to be an object');
}
const model = options.body['model'];
const model = this._deployment || options.body['model'];
delete options.body['model'];
if (model !== undefined && !this.baseURL.includes('/deployments')) {
options.path = `/deployments/${model}${options.path}`;
Expand Down Expand Up @@ -494,6 +492,7 @@ const _deployments_endpoints = new Set([
'/audio/translations',
'/audio/speech',
'/images/generations',
'/batches',
]);

const API_KEY_SENTINEL = '<Missing Key>';
Expand Down
274 changes: 274 additions & 0 deletions tests/lib/azure.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import { Headers } from 'openai/core';
import defaultFetch, { Response, type RequestInit, type RequestInfo } from 'node-fetch';

const apiVersion = '2024-02-15-preview';
const deployment = 'deployment';
const model = 'unused model';

describe('instantiate azure client', () => {
const env = process.env;
Expand Down Expand Up @@ -275,6 +277,278 @@ describe('instantiate azure client', () => {
describe('azure request building', () => {
const client = new AzureOpenAI({ baseURL: 'https://example.com', apiKey: 'My API Key', apiVersion });

describe('model to deployment mapping', function () {
const testFetch = async (url: RequestInfo): Promise<Response> => {
return new Response(JSON.stringify({ url }), { headers: { 'content-type': 'application/json' } });
};
describe('with client-level deployment', function () {
const client = new AzureOpenAI({
endpoint: 'https://example.com',
apiKey: 'My API Key',
apiVersion,
deployment,
fetch: testFetch,
});

test('handles Batch', async () => {
expect(
await client.batches.create({
completion_window: '24h',
endpoint: '/v1/chat/completions',
input_file_id: 'file-id',
}),
).toStrictEqual({
url: `https://example.com/openai/deployments/${deployment}/batches?api-version=${apiVersion}`,
});
});

test('handles completions', async () => {
expect(
await client.completions.create({
model,
prompt: 'prompt',
}),
).toStrictEqual({
url: `https://example.com/openai/deployments/${deployment}/completions?api-version=${apiVersion}`,
});
});

test('handles chat completions', async () => {
expect(
await client.chat.completions.create({
model,
messages: [{ role: 'system', content: 'Hello' }],
}),
).toStrictEqual({
url: `https://example.com/openai/deployments/${deployment}/chat/completions?api-version=${apiVersion}`,
});
});

test('handles embeddings', async () => {
expect(
await client.embeddings.create({
model,
input: 'input',
}),
).toStrictEqual({
url: `https://example.com/openai/deployments/${deployment}/embeddings?api-version=${apiVersion}`,
});
});

test('handles audio translations', async () => {
expect(
await client.audio.translations.create({
model,
file: { url: 'https://example.com', blob: () => 0 as any },
}),
).toStrictEqual({
url: `https://example.com/openai/deployments/${deployment}/audio/translations?api-version=${apiVersion}`,
});
});

test('handles audio transcriptions', async () => {
expect(
await client.audio.transcriptions.create({
model,
file: { url: 'https://example.com', blob: () => 0 as any },
}),
).toStrictEqual({
url: `https://example.com/openai/deployments/${deployment}/audio/transcriptions?api-version=${apiVersion}`,
});
});

test('handles text to speech', async () => {
expect(
await (
await client.audio.speech.create({
model,
input: '',
voice: 'alloy',
})
).json(),
).toStrictEqual({
url: `https://example.com/openai/deployments/${deployment}/audio/speech?api-version=${apiVersion}`,
});
});

test('handles image generation', async () => {
expect(
await client.images.generate({
model,
prompt: 'prompt',
}),
).toStrictEqual({
url: `https://example.com/openai/deployments/${deployment}/images/generations?api-version=${apiVersion}`,
});
});

test('handles assistants', async () => {
expect(
await client.beta.assistants.create({
model,
}),
).toStrictEqual({
url: `https://example.com/openai/assistants?api-version=${apiVersion}`,
});
});

test('handles files', async () => {
expect(
await client.files.create({
file: { url: 'https://example.com', blob: () => 0 as any },
purpose: 'assistants',
}),
).toStrictEqual({
url: `https://example.com/openai/files?api-version=${apiVersion}`,
});
});

test('handles fine tuning', async () => {
expect(
await client.fineTuning.jobs.create({
model,
training_file: '',
}),
).toStrictEqual({
url: `https://example.com/openai/fine_tuning/jobs?api-version=${apiVersion}`,
});
});
});

describe('with no client-level deployment', function () {
const client = new AzureOpenAI({
endpoint: 'https://example.com',
apiKey: 'My API Key',
apiVersion,
fetch: testFetch,
});

test('Batch is not handled', async () => {
expect(
await client.batches.create({
completion_window: '24h',
endpoint: '/v1/chat/completions',
input_file_id: 'file-id',
}),
).toStrictEqual({
url: `https://example.com/openai/batches?api-version=${apiVersion}`,
});
});

test('handles completions', async () => {
expect(
await client.completions.create({
model: deployment,
prompt: 'prompt',
}),
).toStrictEqual({
url: `https://example.com/openai/deployments/${deployment}/completions?api-version=${apiVersion}`,
});
});

test('handles chat completions', async () => {
expect(
await client.chat.completions.create({
model: deployment,
messages: [{ role: 'system', content: 'Hello' }],
}),
).toStrictEqual({
url: `https://example.com/openai/deployments/${deployment}/chat/completions?api-version=${apiVersion}`,
});
});

test('handles embeddings', async () => {
expect(
await client.embeddings.create({
model: deployment,
input: 'input',
}),
).toStrictEqual({
url: `https://example.com/openai/deployments/${deployment}/embeddings?api-version=${apiVersion}`,
});
});

test('Audio translations is not handled', async () => {
expect(
await client.audio.translations.create({
model: deployment,
file: { url: 'https://example.com', blob: () => 0 as any },
}),
).toStrictEqual({
url: `https://example.com/openai/audio/translations?api-version=${apiVersion}`,
});
});

test('Audio transcriptions is not handled', async () => {
expect(
await client.audio.transcriptions.create({
model: deployment,
file: { url: 'https://example.com', blob: () => 0 as any },
}),
).toStrictEqual({
url: `https://example.com/openai/audio/transcriptions?api-version=${apiVersion}`,
});
});

test('handles text to speech', async () => {
expect(
await (
await client.audio.speech.create({
model: deployment,
input: '',
voice: 'alloy',
})
).json(),
).toStrictEqual({
url: `https://example.com/openai/deployments/${deployment}/audio/speech?api-version=${apiVersion}`,
});
});

test('handles image generation', async () => {
expect(
await client.images.generate({
model: deployment,
prompt: 'prompt',
}),
).toStrictEqual({
url: `https://example.com/openai/deployments/${deployment}/images/generations?api-version=${apiVersion}`,
});
});

test('handles assistants', async () => {
expect(
await client.beta.assistants.create({
model,
}),
).toStrictEqual({
url: `https://example.com/openai/assistants?api-version=${apiVersion}`,
});
});

test('handles files', async () => {
expect(
await client.files.create({
file: { url: 'https://example.com', blob: () => 0 as any },
purpose: 'assistants',
}),
).toStrictEqual({
url: `https://example.com/openai/files?api-version=${apiVersion}`,
});
});

test('handles fine tuning', async () => {
expect(
await client.fineTuning.jobs.create({
model,
training_file: '',
}),
).toStrictEqual({
url: `https://example.com/openai/fine_tuning/jobs?api-version=${apiVersion}`,
});
});
});
});

describe('Content-Length', () => {
test('handles multi-byte characters', () => {
const { req } = client.buildRequest({ path: '/foo', method: 'post', body: { value: '—' } });
Expand Down

0 comments on commit e279f8c

Please sign in to comment.