From 48b1e3d571a92e4e9a8ba591932b6281917aa11f Mon Sep 17 00:00:00 2001 From: Ty Augustine Date: Mon, 4 Mar 2024 14:08:00 -0500 Subject: [PATCH] feat: Added Claude3 and Mistral models for AWS Bedrock (#159) * fix: changed Titan's TextToImage to support images * ImageToText working. needs to get refactored and cleaned * feat: Added ImageToText abstractions and HuggingFace implementation. also added example to HF sample * fix: remove postgres tests from bedrock tests * feat: Added ImageToTextGenerationChain * fix: SageMaker customizable inputs and responses * fix: removed ItemGroup * fix: Fixed warnings. * fix: Fixed lastes commit. * fix: Fixed warnings. * feat: Added Claude3 and Mistral models --- src/Directory.Packages.props | 2 +- .../src/Chat/AnthropicClaudeChatModel.cs | 27 +++-- .../src/Chat/MistralInstructChatModel.cs | 106 ++++++++++++++++++ .../src/Predefined/Anthropic.cs | 6 +- .../Amazon.Bedrock/src/Predefined/Mistral.cs | 10 ++ .../src/Predefined/Stability.cs | 4 - .../Amazon.Bedrock/test/BedrockTests.cs | 10 +- 7 files changed, 147 insertions(+), 18 deletions(-) create mode 100644 src/Providers/Amazon.Bedrock/src/Chat/MistralInstructChatModel.cs create mode 100644 src/Providers/Amazon.Bedrock/src/Predefined/Mistral.cs diff --git a/src/Directory.Packages.props b/src/Directory.Packages.props index e6e674a9..f5509918 100644 --- a/src/Directory.Packages.props +++ b/src/Directory.Packages.props @@ -7,7 +7,7 @@ - + diff --git a/src/Providers/Amazon.Bedrock/src/Chat/AnthropicClaudeChatModel.cs b/src/Providers/Amazon.Bedrock/src/Chat/AnthropicClaudeChatModel.cs index 6c8d1098..559d55cc 100644 --- a/src/Providers/Amazon.Bedrock/src/Chat/AnthropicClaudeChatModel.cs +++ b/src/Providers/Amazon.Bedrock/src/Chat/AnthropicClaudeChatModel.cs @@ -69,12 +69,12 @@ public override async Task GenerateAsync( { var response = await provider.Api.InvokeModelAsync(Id, bodyJson, cancellationToken).ConfigureAwait(false); - var generatedText = response?["completion"]?.GetValue() ?? ""; + var generatedText = response?["content"]?[0]?["text"]?.GetValue() ?? ""; messages.Add(generatedText.AsAiMessage()); OnCompletedResponseGenerated(generatedText); } - + var usage = Usage.Empty with { Time = watch.Elapsed, @@ -94,12 +94,23 @@ private static JsonObject CreateBodyJson(string prompt, BedrockChatSettings used { var bodyJson = new JsonObject { - ["prompt"] = prompt, - ["max_tokens_to_sample"] = usedSettings.MaxTokens!.Value, - ["temperature"] = usedSettings.Temperature!.Value, - ["top_p"] = usedSettings.TopP!.Value, - ["top_k"] = usedSettings.TopK!.Value, - ["stop_sequences"] = new JsonArray("\n\nHuman:") + ["anthropic_version"] = "bedrock-2023-05-31", + ["max_tokens"] = usedSettings.MaxTokens!.Value, + ["messages"] = new JsonArray + { + new JsonObject + { + ["role"] = "user", + ["content"] = new JsonArray + { + new JsonObject + { + ["type"] = "text", + ["text"] = prompt, + } + } + } + } }; return bodyJson; } diff --git a/src/Providers/Amazon.Bedrock/src/Chat/MistralInstructChatModel.cs b/src/Providers/Amazon.Bedrock/src/Chat/MistralInstructChatModel.cs new file mode 100644 index 00000000..f096ba7e --- /dev/null +++ b/src/Providers/Amazon.Bedrock/src/Chat/MistralInstructChatModel.cs @@ -0,0 +1,106 @@ +using System.Diagnostics; +using System.Text; +using System.Text.Json; +using System.Text.Json.Nodes; +using Amazon.BedrockRuntime.Model; +using LangChain.Providers.Amazon.Bedrock.Internal; + +// ReSharper disable once CheckNamespace +namespace LangChain.Providers.Amazon.Bedrock; + +public abstract class MistralInstructChatModel( + BedrockProvider provider, + string id) + : ChatModel(id) +{ + public override async Task GenerateAsync( + ChatRequest request, + ChatSettings? settings = null, + CancellationToken cancellationToken = default) + { + request = request ?? throw new ArgumentNullException(nameof(request)); + + var watch = Stopwatch.StartNew(); + var prompt = request.Messages.ToSimplePrompt(); + var messages = request.Messages.ToList(); + + var stringBuilder = new StringBuilder(); + + var usedSettings = BedrockChatSettings.Calculate( + requestSettings: settings, + modelSettings: Settings, + providerSettings: provider.ChatSettings); + + var bodyJson = CreateBodyJson(prompt, usedSettings); + + if (usedSettings.UseStreaming == true) + { + var streamRequest = BedrockModelStreamRequest.Create(Id, bodyJson); + var response = await provider.Api.InvokeModelWithResponseStreamAsync(streamRequest, cancellationToken).ConfigureAwait(false); + + foreach (var payloadPart in response.Body) + { + var streamEvent = (PayloadPart)payloadPart; + var chunk = await JsonSerializer.DeserializeAsync(streamEvent.Bytes, cancellationToken: cancellationToken) + .ConfigureAwait(false); + var delta = chunk?["outputText"]!.GetValue(); + + OnPartialResponseGenerated(delta!); + stringBuilder.Append(delta); + + var finished = chunk?["completionReason"]?.GetValue(); + if (finished?.ToUpperInvariant() == "FINISH") + { + OnCompletedResponseGenerated(stringBuilder.ToString()); + } + } + + OnPartialResponseGenerated(Environment.NewLine); + stringBuilder.Append(Environment.NewLine); + + var newMessage = new Message( + Content: stringBuilder.ToString(), + Role: MessageRole.Ai); + messages.Add(newMessage); + + OnCompletedResponseGenerated(newMessage.Content); + } + else + { + var response = await provider.Api.InvokeModelAsync(Id, bodyJson, cancellationToken) + .ConfigureAwait(false); + + var generatedText = response?["outputs"]?[0]?["text"]?.GetValue() ?? string.Empty; + + messages.Add(generatedText.AsAiMessage()); + OnCompletedResponseGenerated(generatedText); + } + + var usage = Usage.Empty with + { + Time = watch.Elapsed, + }; + AddUsage(usage); + provider.AddUsage(usage); + + return new ChatResponse + { + Messages = messages, + UsedSettings = usedSettings, + Usage = usage, + }; + } + + private static JsonObject CreateBodyJson(string prompt, BedrockChatSettings usedSettings) + { + var bodyJson = new JsonObject + { + ["prompt"] = prompt, + ["max_tokens"] = usedSettings.MaxTokens!.Value, + ["temperature"] = usedSettings.Temperature!.Value, + ["top_p"] = usedSettings.TopP!.Value, + // ["top_k"] = usedSettings.TopK!.Value + }; + return bodyJson; + } +} \ No newline at end of file diff --git a/src/Providers/Amazon.Bedrock/src/Predefined/Anthropic.cs b/src/Providers/Amazon.Bedrock/src/Predefined/Anthropic.cs index ca3ff828..d646dc5b 100644 --- a/src/Providers/Amazon.Bedrock/src/Predefined/Anthropic.cs +++ b/src/Providers/Amazon.Bedrock/src/Predefined/Anthropic.cs @@ -15,4 +15,8 @@ public class ClaudeV2Model(BedrockProvider provider) /// public class ClaudeV21Model(BedrockProvider provider) - : AnthropicClaudeChatModel(provider, id: "anthropic.claude-v2:1"); \ No newline at end of file + : AnthropicClaudeChatModel(provider, id: "anthropic.claude-v2:1"); + +/// +public class Claude3SonnetModel(BedrockProvider provider) + : AnthropicClaudeChatModel(provider, id: "anthropic.claude-3-sonnet-20240229-v1:0"); \ No newline at end of file diff --git a/src/Providers/Amazon.Bedrock/src/Predefined/Mistral.cs b/src/Providers/Amazon.Bedrock/src/Predefined/Mistral.cs new file mode 100644 index 00000000..5afe718a --- /dev/null +++ b/src/Providers/Amazon.Bedrock/src/Predefined/Mistral.cs @@ -0,0 +1,10 @@ +// ReSharper disable once CheckNamespace +namespace LangChain.Providers.Amazon.Bedrock.Predefined.Mistral; + +/// +public class Mistral7BInstruct(BedrockProvider provider) + : MistralInstructChatModel(provider, id: "mistral.mistral-7b-instruct-v0:2"); + +/// +public class Mistral8x7BInstruct(BedrockProvider provider) + : MistralInstructChatModel(provider, id: "mistral.mixtral-8x7b-instruct-v0:1"); \ No newline at end of file diff --git a/src/Providers/Amazon.Bedrock/src/Predefined/Stability.cs b/src/Providers/Amazon.Bedrock/src/Predefined/Stability.cs index 3ec07d9e..65ee6083 100644 --- a/src/Providers/Amazon.Bedrock/src/Predefined/Stability.cs +++ b/src/Providers/Amazon.Bedrock/src/Predefined/Stability.cs @@ -1,10 +1,6 @@ // ReSharper disable once CheckNamespace namespace LangChain.Providers.Amazon.Bedrock.Predefined.Stability; -/// -public class StableDiffusionExtraLargeV0Model(BedrockProvider provider) - : StableDiffusionTextToImageModel(provider, id: "stability.stable-diffusion-xl-v0"); - /// public class StableDiffusionExtraLargeV1Model(BedrockProvider provider) : StableDiffusionTextToImageModel(provider, id: "stability.stable-diffusion-xl-v1"); \ No newline at end of file diff --git a/src/Providers/Amazon.Bedrock/test/BedrockTests.cs b/src/Providers/Amazon.Bedrock/test/BedrockTests.cs index a8a9634f..c4b98a2f 100644 --- a/src/Providers/Amazon.Bedrock/test/BedrockTests.cs +++ b/src/Providers/Amazon.Bedrock/test/BedrockTests.cs @@ -1,4 +1,5 @@ using System.Diagnostics; +using Amazon; using LangChain.Chains.LLM; using LangChain.Chains.Sequentials; using LangChain.Databases; @@ -11,6 +12,7 @@ using LangChain.Providers.Amazon.Bedrock.Predefined.Anthropic; using LangChain.Providers.Amazon.Bedrock.Predefined.Cohere; using LangChain.Providers.Amazon.Bedrock.Predefined.Meta; +using LangChain.Providers.Amazon.Bedrock.Predefined.Mistral; using LangChain.Providers.Amazon.Bedrock.Predefined.Stability; using LangChain.Schema; using LangChain.Sources; @@ -25,11 +27,11 @@ public class BedrockTests [Test] public async Task Chains() { - var provider = new BedrockProvider(); + var provider = new BedrockProvider(RegionEndpoint.USWest2); //var llm = new Jurassic2MidModel(provider); - var llm = new ClaudeV21Model(provider); - //var modelId = "amazon.titan-text-express-v1"; - // var modelId = "cohere.command-light-text-v14"; + //var llm = new ClaudeV21Model(provider); + //var llm = new Mistral7BInstruct(provider); + var llm = new Claude3SonnetModel(provider); var template = "What is a good name for a company that makes {product}?"; var prompt = new PromptTemplate(new PromptTemplateInput(template, new List(1) { "product" }));