Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for Azure OpenAI provider #93

Merged
merged 1 commit into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions LangChain.sln
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Google"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Google.UnitTests", "src\tests\LangChain.Providers.Google.UnitTests\LangChain.Providers.Google.UnitTests.csproj", "{DEAFA0CB-462D-4D74-B16F-68FD83FE3858}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Azure", "src\libs\Providers\LangChain.Providers.Azure\LangChain.Providers.Azure.csproj", "{18F5AAB1-1750-41BD-B623-6339CA5754D9}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -372,6 +374,10 @@ Global
{DEAFA0CB-462D-4D74-B16F-68FD83FE3858}.Debug|Any CPU.Build.0 = Debug|Any CPU
{DEAFA0CB-462D-4D74-B16F-68FD83FE3858}.Release|Any CPU.ActiveCfg = Release|Any CPU
{DEAFA0CB-462D-4D74-B16F-68FD83FE3858}.Release|Any CPU.Build.0 = Release|Any CPU
{18F5AAB1-1750-41BD-B623-6339CA5754D9}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{18F5AAB1-1750-41BD-B623-6339CA5754D9}.Debug|Any CPU.Build.0 = Debug|Any CPU
{18F5AAB1-1750-41BD-B623-6339CA5754D9}.Release|Any CPU.ActiveCfg = Release|Any CPU
{18F5AAB1-1750-41BD-B623-6339CA5754D9}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

<ItemGroup>
<ProjectReference Include="..\..\src\libs\LangChain\LangChain.csproj" />
<ProjectReference Include="..\..\src\libs\Providers\LangChain.Providers.Azure\LangChain.Providers.Azure.csproj" />
</ItemGroup>

</Project>
17 changes: 8 additions & 9 deletions examples/LangChain.Samples.Azure/Program.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
// using LangChain.Providers;
// using LangChain.Providers.Azure;
//
// using var httpClient = new HttpClient();
// var model = new Gpt35TurboModel("apiKey", "endpoint", new HttpClient());
// var result = await model.GenerateAsync("What is a good name for a company that sells colourful socks?");
//
// Console.WriteLine(result);
Console.WriteLine("Not implemented");
using LangChain.Providers;
using LangChain.Providers.Azure;

var model = new AzureOpenAIModel("AZURE_OPEN_AI_KEY", "ENDPOINT", "DEPLOYMENT_NAME");

var result = await model.GenerateAsync("What is a good name for a company that sells colourful socks?");

Console.WriteLine(result);
1 change: 1 addition & 0 deletions src/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
<PackageVersion Include="Anyscale" Version="1.0.2" />
<PackageVersion Include="Aspose.PDF" Version="23.11.0" />
<PackageVersion Include="AWSSDK.Kendra" Version="3.7.300.5" />
<PackageVersion Include="Azure.AI.OpenAI" Version="1.0.0-beta.12" />
<PackageVersion Include="Docker.DotNet" Version="3.125.15" />
<PackageVersion Include="DotNet.ReproducibleBuilds" Version="1.1.1" />
<PackageVersion Include="FluentAssertions" Version="6.12.0" />
Expand Down
36 changes: 0 additions & 36 deletions src/libs/Providers/LangChain.Providers.Azure/AzureModel.cs

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
namespace LangChain.Providers
{
/// <summary>
/// Configuration options for Azure OpenAI
/// </summary>
public class AzureOpenAIConfiguration
{
/// <summary>
/// Context size
/// How much tokens model will remember.
/// Most models have 2048
/// </summary>
public int ContextSize { get; set; } = 2048;

/// <summary>
/// Temperature
/// controls the apparent creativity of generated completions.
/// Has a valid range of 0.0 to 2.0
/// Defaults to 1.0 if not otherwise specified.
/// </summary>
public float Temperature { get; set; } = 0.7f;

/// <summary>
/// Gets the maximum number of tokens to generate. Has minimum of 0.
/// </summary>
public int MaxTokens { get; set; } = 800;

/// <summary>
/// Number of choices that should be generated per provided prompt.
/// Has a valid range of 1 to 128.
/// </summary>
public int ChoiceCount { get; set; } = 1;

/// <summary>
/// Azure OpenAI API Key
/// </summary>
public string? ApiKey { get; set; }

/// <summary>
/// Deployment name
/// </summary>
public string Id { get; set; }

Check warning on line 42 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIConfiguration.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Non-nullable property 'Id' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

Check warning on line 42 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIConfiguration.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Non-nullable property 'Id' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

Check warning on line 42 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIConfiguration.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Non-nullable property 'Id' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

Check warning on line 42 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIConfiguration.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Non-nullable property 'Id' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

Check warning on line 42 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIConfiguration.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Non-nullable property 'Id' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

/// <summary>
/// Azure OpenAI Resource URI
/// </summary>
public string Endpoint { get; set; }

Check warning on line 47 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIConfiguration.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Non-nullable property 'Endpoint' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

Check warning on line 47 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIConfiguration.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Non-nullable property 'Endpoint' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

Check warning on line 47 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIConfiguration.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Non-nullable property 'Endpoint' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

Check warning on line 47 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIConfiguration.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Non-nullable property 'Endpoint' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.

Check warning on line 47 in src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIConfiguration.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Non-nullable property 'Endpoint' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.
}
}
131 changes: 131 additions & 0 deletions src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
using Azure;
using Azure.AI.OpenAI;
using System.Diagnostics;

namespace LangChain.Providers.Azure;

/// <summary>
/// Wrapper around Azure OpenAI large language models
/// </summary>
public class AzureOpenAIModel : IChatModel
{
/// <summary>
/// Azure OpenAI API Key
/// </summary>
public string ApiKey { get; init; }

/// <inheritdoc/>
public Usage TotalUsage { get; private set; }

/// <summary>
/// Deployment name
/// </summary>
public string Id { get; init; }

/// <inheritdoc/>
public int ContextLength => Configurations.ContextSize;

/// <summary>
/// Azure OpenAI Resource URI
/// </summary>
public string Endpoint { get; set; }

private AzureOpenAIConfiguration Configurations { get; }

#region Constructors
/// <summary>
/// Wrapper around Azure OpenAI
/// </summary>
/// <param name="apiKey">API Key</param>
/// <param name="endpoint">Azure Open AI Resource URI</param>
/// <param name="id">Deployment Model name</param>
/// <exception cref="ArgumentNullException"></exception>
public AzureOpenAIModel(string apiKey, string endpoint, string id)
{
Configurations = new AzureOpenAIConfiguration();
Id = id ?? throw new ArgumentNullException(nameof(id));
ApiKey = apiKey ?? throw new ArgumentNullException(nameof(apiKey));
Endpoint = endpoint ?? throw new ArgumentNullException(nameof(endpoint));
}

/// <summary>
/// Wrapper around Azure OpenAI
/// </summary>
/// <param name="configuration">AzureOpenAIConfiguration</param>
/// <exception cref="ArgumentNullException"></exception>
/// <exception cref="ArgumentException"></exception>
public AzureOpenAIModel(AzureOpenAIConfiguration configuration)
{
Configurations = configuration ?? throw new ArgumentNullException(nameof(configuration));
ApiKey = configuration.ApiKey ?? throw new ArgumentException("ApiKey is not defined", nameof(configuration));
Id = configuration.Id ?? throw new ArgumentException("Deployment model Id is not defined", nameof(configuration));
Endpoint = configuration.Endpoint ?? throw new ArgumentException("Endpoint is not defined", nameof(configuration));
}
#endregion

#region Methods
/// <inheritdoc/>
public async Task<ChatResponse> GenerateAsync(ChatRequest request, CancellationToken cancellationToken = default)
{
var messages = request.Messages.ToList();
var watch = Stopwatch.StartNew();
var response = await CreateChatCompleteAsync(messages, cancellationToken).ConfigureAwait(false);

messages.Add(ToMessage(response.Value));

watch.Stop();

var usage = GetUsage(response.Value.Usage) with
{
Time = watch.Elapsed,
};
TotalUsage += usage;

return new ChatResponse(
Messages: messages,
Usage: usage);
}

private async Task<Response<ChatCompletions>> CreateChatCompleteAsync(IReadOnlyCollection<Message> messages, CancellationToken cancellationToken = default)
{
var chatCompletionOptions = new ChatCompletionsOptions(Id, messages.Select(ToRequestMessage))
{
MaxTokens = Configurations.MaxTokens,
ChoiceCount = Configurations.ChoiceCount,
Temperature = Configurations.Temperature,
};

var client = new OpenAIClient(new Uri(Endpoint), new AzureKeyCredential(ApiKey));
return await client.GetChatCompletionsAsync(chatCompletionOptions, cancellationToken).ConfigureAwait(false);
}

private static ChatRequestMessage ToRequestMessage(Message message)
{
return message.Role switch
{
MessageRole.System => new ChatRequestSystemMessage(message.Content),
MessageRole.Ai => new ChatRequestAssistantMessage(message.Content),
MessageRole.Human => new ChatRequestUserMessage(message.Content),
MessageRole.FunctionCall => throw new NotImplementedException(),
MessageRole.FunctionResult => throw new NotImplementedException(),
_ => throw new NotImplementedException()
};
}

private static Message ToMessage(ChatCompletions message)
{
return new Message(
Content: message.Choices[0].Message.Content,
Role: MessageRole.Ai);
}

private static Usage GetUsage(CompletionsUsage usage)
{
return Usage.Empty with
{
InputTokens = usage.PromptTokens,
OutputTokens = usage.CompletionTokens
};
}
#endregion
}
18 changes: 0 additions & 18 deletions src/libs/Providers/LangChain.Providers.Azure/Gpt35Turbo16KModel.cs

This file was deleted.

18 changes: 0 additions & 18 deletions src/libs/Providers/LangChain.Providers.Azure/Gpt35TurboModel.cs

This file was deleted.

18 changes: 0 additions & 18 deletions src/libs/Providers/LangChain.Providers.Azure/Gpt4Model.cs

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFrameworks>net4.6.2;netstandard2.0;net6.0;net7.0</TargetFrameworks>
<TargetFrameworks>net4.6.2;netstandard2.0;net6.0;net7.0;net8.0</TargetFrameworks>
</PropertyGroup>

<ItemGroup Label="Usings">
Expand All @@ -15,6 +15,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Azure.AI.OpenAI" />
<PackageReference Include="tryAGI.OpenAI" />
</ItemGroup>

Expand Down