From 07c99d0db505243b21e3b30989f1721c7024c209 Mon Sep 17 00:00:00 2001 From: Ketan Khare Date: Fri, 5 Jan 2024 15:54:25 +0530 Subject: [PATCH] Added support for Azure OpenAI provider --- LangChain.sln | 6 + .../LangChain.Samples.Azure.csproj | 1 + examples/LangChain.Samples.Azure/Program.cs | 17 ++- src/Directory.Packages.props | 1 + .../LangChain.Providers.Azure/AzureModel.cs | 36 ----- .../AzureOpenAIConfiguration.cs | 49 +++++++ .../AzureOpenAIModel.cs | 131 ++++++++++++++++++ .../Gpt35Turbo16KModel.cs | 18 --- .../Gpt35TurboModel.cs | 18 --- .../LangChain.Providers.Azure/Gpt4Model.cs | 18 --- .../LangChain.Providers.Azure.csproj | 3 +- 11 files changed, 198 insertions(+), 100 deletions(-) delete mode 100644 src/libs/Providers/LangChain.Providers.Azure/AzureModel.cs create mode 100644 src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIConfiguration.cs create mode 100644 src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.cs delete mode 100644 src/libs/Providers/LangChain.Providers.Azure/Gpt35Turbo16KModel.cs delete mode 100644 src/libs/Providers/LangChain.Providers.Azure/Gpt35TurboModel.cs delete mode 100644 src/libs/Providers/LangChain.Providers.Azure/Gpt4Model.cs diff --git a/LangChain.sln b/LangChain.sln index 8785322e..7237775d 100644 --- a/LangChain.sln +++ b/LangChain.sln @@ -162,6 +162,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Google" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Google.UnitTests", "src\tests\LangChain.Providers.Google.UnitTests\LangChain.Providers.Google.UnitTests.csproj", "{DEAFA0CB-462D-4D74-B16F-68FD83FE3858}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Azure", "src\libs\Providers\LangChain.Providers.Azure\LangChain.Providers.Azure.csproj", "{18F5AAB1-1750-41BD-B623-6339CA5754D9}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -372,6 +374,10 @@ Global {DEAFA0CB-462D-4D74-B16F-68FD83FE3858}.Debug|Any CPU.Build.0 = Debug|Any CPU {DEAFA0CB-462D-4D74-B16F-68FD83FE3858}.Release|Any CPU.ActiveCfg = Release|Any CPU {DEAFA0CB-462D-4D74-B16F-68FD83FE3858}.Release|Any CPU.Build.0 = Release|Any CPU + {18F5AAB1-1750-41BD-B623-6339CA5754D9}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {18F5AAB1-1750-41BD-B623-6339CA5754D9}.Debug|Any CPU.Build.0 = Debug|Any CPU + {18F5AAB1-1750-41BD-B623-6339CA5754D9}.Release|Any CPU.ActiveCfg = Release|Any CPU + {18F5AAB1-1750-41BD-B623-6339CA5754D9}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/examples/LangChain.Samples.Azure/LangChain.Samples.Azure.csproj b/examples/LangChain.Samples.Azure/LangChain.Samples.Azure.csproj index e9040a22..398fa4f4 100644 --- a/examples/LangChain.Samples.Azure/LangChain.Samples.Azure.csproj +++ b/examples/LangChain.Samples.Azure/LangChain.Samples.Azure.csproj @@ -9,6 +9,7 @@ + \ No newline at end of file diff --git a/examples/LangChain.Samples.Azure/Program.cs b/examples/LangChain.Samples.Azure/Program.cs index 49b1b108..6dcbcef2 100644 --- a/examples/LangChain.Samples.Azure/Program.cs +++ b/examples/LangChain.Samples.Azure/Program.cs @@ -1,9 +1,8 @@ -// using LangChain.Providers; -// using LangChain.Providers.Azure; -// -// using var httpClient = new HttpClient(); -// var model = new Gpt35TurboModel("apiKey", "endpoint", new HttpClient()); -// var result = await model.GenerateAsync("What is a good name for a company that sells colourful socks?"); -// -// Console.WriteLine(result); -Console.WriteLine("Not implemented"); \ No newline at end of file +using LangChain.Providers; +using LangChain.Providers.Azure; + +var model = new AzureOpenAIModel("AZURE_OPEN_AI_KEY", "ENDPOINT", "DEPLOYMENT_NAME"); + +var result = await model.GenerateAsync("What is a good name for a company that sells colourful socks?"); + +Console.WriteLine(result); diff --git a/src/Directory.Packages.props b/src/Directory.Packages.props index de0b1a6e..012c9857 100644 --- a/src/Directory.Packages.props +++ b/src/Directory.Packages.props @@ -8,6 +8,7 @@ + diff --git a/src/libs/Providers/LangChain.Providers.Azure/AzureModel.cs b/src/libs/Providers/LangChain.Providers.Azure/AzureModel.cs deleted file mode 100644 index 77e4f269..00000000 --- a/src/libs/Providers/LangChain.Providers.Azure/AzureModel.cs +++ /dev/null @@ -1,36 +0,0 @@ -using LangChain.Providers.OpenAI; - -namespace LangChain.Providers.Azure; - -/// -/// -/// -public class AzureModel : OpenAiModel -{ - #region Constructors - - /// - /// Wrapper around Azure large language models. - /// - /// - /// - /// - public AzureModel(OpenAiConfiguration configuration, HttpClient httpClient) : base(configuration, httpClient) - { - } - - /// - /// Wrapper around Azure large language models. - /// - /// - /// - /// Specify the base server address without specifying a specific point, for example "https://myaccount.openai.azure.com/" - /// - /// - public AzureModel(string apiKey, string endpoint, HttpClient httpClient, string id) : base(apiKey, httpClient, id) - { - Api.BaseUrl = endpoint; - } - - #endregion -} \ No newline at end of file diff --git a/src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIConfiguration.cs b/src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIConfiguration.cs new file mode 100644 index 00000000..f14cccb2 --- /dev/null +++ b/src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIConfiguration.cs @@ -0,0 +1,49 @@ +namespace LangChain.Providers +{ + /// + /// Configuration options for Azure OpenAI + /// + public class AzureOpenAIConfiguration + { + /// + /// Context size + /// How much tokens model will remember. + /// Most models have 2048 + /// + public int ContextSize { get; set; } = 2048; + + /// + /// Temperature + /// controls the apparent creativity of generated completions. + /// Has a valid range of 0.0 to 2.0 + /// Defaults to 1.0 if not otherwise specified. + /// + public float Temperature { get; set; } = 0.7f; + + /// + /// Gets the maximum number of tokens to generate. Has minimum of 0. + /// + public int MaxTokens { get; set; } = 800; + + /// + /// Number of choices that should be generated per provided prompt. + /// Has a valid range of 1 to 128. + /// + public int ChoiceCount { get; set; } = 1; + + /// + /// Azure OpenAI API Key + /// + public string? ApiKey { get; set; } + + /// + /// Deployment name + /// + public string Id { get; set; } + + /// + /// Azure OpenAI Resource URI + /// + public string Endpoint { get; set; } + } +} diff --git a/src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.cs b/src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.cs new file mode 100644 index 00000000..005c5175 --- /dev/null +++ b/src/libs/Providers/LangChain.Providers.Azure/AzureOpenAIModel.cs @@ -0,0 +1,131 @@ +using Azure; +using Azure.AI.OpenAI; +using System.Diagnostics; + +namespace LangChain.Providers.Azure; + +/// +/// Wrapper around Azure OpenAI large language models +/// +public class AzureOpenAIModel : IChatModel +{ + /// + /// Azure OpenAI API Key + /// + public string ApiKey { get; init; } + + /// + public Usage TotalUsage { get; private set; } + + /// + /// Deployment name + /// + public string Id { get; init; } + + /// + public int ContextLength => Configurations.ContextSize; + + /// + /// Azure OpenAI Resource URI + /// + public string Endpoint { get; set; } + + private AzureOpenAIConfiguration Configurations { get; } + + #region Constructors + /// + /// Wrapper around Azure OpenAI + /// + /// API Key + /// Azure Open AI Resource URI + /// Deployment Model name + /// + public AzureOpenAIModel(string apiKey, string endpoint, string id) + { + Configurations = new AzureOpenAIConfiguration(); + Id = id ?? throw new ArgumentNullException(nameof(id)); + ApiKey = apiKey ?? throw new ArgumentNullException(nameof(apiKey)); + Endpoint = endpoint ?? throw new ArgumentNullException(nameof(endpoint)); + } + + /// + /// Wrapper around Azure OpenAI + /// + /// AzureOpenAIConfiguration + /// + /// + public AzureOpenAIModel(AzureOpenAIConfiguration configuration) + { + Configurations = configuration ?? throw new ArgumentNullException(nameof(configuration)); + ApiKey = configuration.ApiKey ?? throw new ArgumentException("ApiKey is not defined", nameof(configuration)); + Id = configuration.Id ?? throw new ArgumentException("Deployment model Id is not defined", nameof(configuration)); + Endpoint = configuration.Endpoint ?? throw new ArgumentException("Endpoint is not defined", nameof(configuration)); + } + #endregion + + #region Methods + /// + public async Task GenerateAsync(ChatRequest request, CancellationToken cancellationToken = default) + { + var messages = request.Messages.ToList(); + var watch = Stopwatch.StartNew(); + var response = await CreateChatCompleteAsync(messages, cancellationToken).ConfigureAwait(false); + + messages.Add(ToMessage(response.Value)); + + watch.Stop(); + + var usage = GetUsage(response.Value.Usage) with + { + Time = watch.Elapsed, + }; + TotalUsage += usage; + + return new ChatResponse( + Messages: messages, + Usage: usage); + } + + private async Task> CreateChatCompleteAsync(IReadOnlyCollection messages, CancellationToken cancellationToken = default) + { + var chatCompletionOptions = new ChatCompletionsOptions(Id, messages.Select(ToRequestMessage)) + { + MaxTokens = Configurations.MaxTokens, + ChoiceCount = Configurations.ChoiceCount, + Temperature = Configurations.Temperature, + }; + + var client = new OpenAIClient(new Uri(Endpoint), new AzureKeyCredential(ApiKey)); + return await client.GetChatCompletionsAsync(chatCompletionOptions, cancellationToken).ConfigureAwait(false); + } + + private static ChatRequestMessage ToRequestMessage(Message message) + { + return message.Role switch + { + MessageRole.System => new ChatRequestSystemMessage(message.Content), + MessageRole.Ai => new ChatRequestAssistantMessage(message.Content), + MessageRole.Human => new ChatRequestUserMessage(message.Content), + MessageRole.FunctionCall => throw new NotImplementedException(), + MessageRole.FunctionResult => throw new NotImplementedException(), + _ => throw new NotImplementedException() + }; + } + + private static Message ToMessage(ChatCompletions message) + { + return new Message( + Content: message.Choices[0].Message.Content, + Role: MessageRole.Ai); + } + + private static Usage GetUsage(CompletionsUsage usage) + { + return Usage.Empty with + { + InputTokens = usage.PromptTokens, + OutputTokens = usage.CompletionTokens + }; + } + #endregion +} \ No newline at end of file diff --git a/src/libs/Providers/LangChain.Providers.Azure/Gpt35Turbo16KModel.cs b/src/libs/Providers/LangChain.Providers.Azure/Gpt35Turbo16KModel.cs deleted file mode 100644 index 5ad55686..00000000 --- a/src/libs/Providers/LangChain.Providers.Azure/Gpt35Turbo16KModel.cs +++ /dev/null @@ -1,18 +0,0 @@ -namespace LangChain.Providers.Azure; - -/// -public class Gpt35Turbo16KModel : AzureModel -{ - #region Constructors - - /// - /// - /// Specify the base server address without specifying a specific point, for example "https://myaccount.openai.azure.com/" - /// - /// - public Gpt35Turbo16KModel(string apiKey, string endpoint, HttpClient httpClient) : base(apiKey, endpoint, httpClient, id: ModelIds.Gpt35Turbo_16k) - { - } - - #endregion -} \ No newline at end of file diff --git a/src/libs/Providers/LangChain.Providers.Azure/Gpt35TurboModel.cs b/src/libs/Providers/LangChain.Providers.Azure/Gpt35TurboModel.cs deleted file mode 100644 index 8af7f9e0..00000000 --- a/src/libs/Providers/LangChain.Providers.Azure/Gpt35TurboModel.cs +++ /dev/null @@ -1,18 +0,0 @@ -namespace LangChain.Providers.Azure; - -/// -public class Gpt35TurboModel : AzureModel -{ - #region Constructors - - /// - /// - /// Specify the base server address without specifying a specific point, for example "https://myaccount.openai.azure.com/" - /// - /// - public Gpt35TurboModel(string apiKey, string endpoint, HttpClient httpClient) : base(apiKey, endpoint, httpClient, id: ModelIds.Gpt35Turbo) - { - } - - #endregion -} \ No newline at end of file diff --git a/src/libs/Providers/LangChain.Providers.Azure/Gpt4Model.cs b/src/libs/Providers/LangChain.Providers.Azure/Gpt4Model.cs deleted file mode 100644 index 6f976acb..00000000 --- a/src/libs/Providers/LangChain.Providers.Azure/Gpt4Model.cs +++ /dev/null @@ -1,18 +0,0 @@ -namespace LangChain.Providers.Azure; - -/// -public class Gpt4Model : AzureModel -{ - #region Constructors - - /// - /// - /// Specify the base server address without specifying a specific point, for example "https://myaccount.openai.azure.com/" - /// - /// - public Gpt4Model(string apiKey, string endpoint, HttpClient httpClient) : base(apiKey, endpoint, httpClient, id: ModelIds.Gpt4) - { - } - - #endregion -} \ No newline at end of file diff --git a/src/libs/Providers/LangChain.Providers.Azure/LangChain.Providers.Azure.csproj b/src/libs/Providers/LangChain.Providers.Azure/LangChain.Providers.Azure.csproj index 45c65e05..36b40cf9 100644 --- a/src/libs/Providers/LangChain.Providers.Azure/LangChain.Providers.Azure.csproj +++ b/src/libs/Providers/LangChain.Providers.Azure/LangChain.Providers.Azure.csproj @@ -1,7 +1,7 @@ - net4.6.2;netstandard2.0;net6.0;net7.0 + net4.6.2;netstandard2.0;net6.0;net7.0;net8.0 @@ -15,6 +15,7 @@ +