Skip to content

Commit

Permalink
feat: Added Claude3 and Mistral models for AWS Bedrock (#159)
Browse files Browse the repository at this point in the history
* fix: changed Titan's TextToImage to support images

* ImageToText working.  needs to get refactored and cleaned

* feat: Added ImageToText abstractions and HuggingFace implementation.  also added example to HF sample

* fix: remove postgres tests from bedrock tests

* feat: Added ImageToTextGenerationChain

* fix: SageMaker customizable inputs and responses

* fix: removed ItemGroup

* fix: Fixed warnings.

* fix: Fixed lastes commit.

* fix: Fixed warnings.

* feat: Added Claude3 and Mistral models
  • Loading branch information
curlyfro authored Mar 4, 2024
1 parent f894138 commit 48b1e3d
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 18 deletions.
2 changes: 1 addition & 1 deletion src/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
<PackageVersion Include="Anthropic" Version="0.3.1" />
<PackageVersion Include="Anyscale" Version="1.0.2" />
<PackageVersion Include="Aspose.PDF" Version="24.2.0" />
<PackageVersion Include="AWSSDK.BedrockRuntime" Version="3.7.301.41" />
<PackageVersion Include="AWSSDK.BedrockRuntime" Version="3.7.301.43" />
<PackageVersion Include="AWSSDK.Kendra" Version="3.7.300.52" />
<PackageVersion Include="AWSSDK.SageMakerRuntime" Version="3.7.301.37" />
<PackageVersion Include="Azure.AI.OpenAI" Version="1.0.0-beta.13" />
Expand Down
27 changes: 19 additions & 8 deletions src/Providers/Amazon.Bedrock/src/Chat/AnthropicClaudeChatModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ public override async Task<ChatResponse> GenerateAsync(
{
var response = await provider.Api.InvokeModelAsync(Id, bodyJson, cancellationToken).ConfigureAwait(false);

var generatedText = response?["completion"]?.GetValue<string>() ?? "";
var generatedText = response?["content"]?[0]?["text"]?.GetValue<string>() ?? "";

messages.Add(generatedText.AsAiMessage());
OnCompletedResponseGenerated(generatedText);
}

var usage = Usage.Empty with
{
Time = watch.Elapsed,
Expand All @@ -94,12 +94,23 @@ private static JsonObject CreateBodyJson(string prompt, BedrockChatSettings used
{
var bodyJson = new JsonObject
{
["prompt"] = prompt,
["max_tokens_to_sample"] = usedSettings.MaxTokens!.Value,
["temperature"] = usedSettings.Temperature!.Value,
["top_p"] = usedSettings.TopP!.Value,
["top_k"] = usedSettings.TopK!.Value,
["stop_sequences"] = new JsonArray("\n\nHuman:")
["anthropic_version"] = "bedrock-2023-05-31",
["max_tokens"] = usedSettings.MaxTokens!.Value,
["messages"] = new JsonArray
{
new JsonObject
{
["role"] = "user",
["content"] = new JsonArray
{
new JsonObject
{
["type"] = "text",
["text"] = prompt,
}
}
}
}
};
return bodyJson;
}
Expand Down
106 changes: 106 additions & 0 deletions src/Providers/Amazon.Bedrock/src/Chat/MistralInstructChatModel.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
using System.Diagnostics;
using System.Text;
using System.Text.Json;
using System.Text.Json.Nodes;
using Amazon.BedrockRuntime.Model;
using LangChain.Providers.Amazon.Bedrock.Internal;

// ReSharper disable once CheckNamespace
namespace LangChain.Providers.Amazon.Bedrock;

public abstract class MistralInstructChatModel(
BedrockProvider provider,
string id)
: ChatModel(id)
{
public override async Task<ChatResponse> GenerateAsync(
ChatRequest request,
ChatSettings? settings = null,
CancellationToken cancellationToken = default)
{
request = request ?? throw new ArgumentNullException(nameof(request));

var watch = Stopwatch.StartNew();
var prompt = request.Messages.ToSimplePrompt();
var messages = request.Messages.ToList();

var stringBuilder = new StringBuilder();

var usedSettings = BedrockChatSettings.Calculate(
requestSettings: settings,
modelSettings: Settings,
providerSettings: provider.ChatSettings);

var bodyJson = CreateBodyJson(prompt, usedSettings);

if (usedSettings.UseStreaming == true)
{
var streamRequest = BedrockModelStreamRequest.Create(Id, bodyJson);
var response = await provider.Api.InvokeModelWithResponseStreamAsync(streamRequest, cancellationToken).ConfigureAwait(false);

foreach (var payloadPart in response.Body)
{
var streamEvent = (PayloadPart)payloadPart;
var chunk = await JsonSerializer.DeserializeAsync<JsonObject>(streamEvent.Bytes, cancellationToken: cancellationToken)
.ConfigureAwait(false);
var delta = chunk?["outputText"]!.GetValue<string>();

OnPartialResponseGenerated(delta!);
stringBuilder.Append(delta);

var finished = chunk?["completionReason"]?.GetValue<string>();
if (finished?.ToUpperInvariant() == "FINISH")
{
OnCompletedResponseGenerated(stringBuilder.ToString());
}
}

OnPartialResponseGenerated(Environment.NewLine);
stringBuilder.Append(Environment.NewLine);

var newMessage = new Message(
Content: stringBuilder.ToString(),
Role: MessageRole.Ai);
messages.Add(newMessage);

OnCompletedResponseGenerated(newMessage.Content);
}
else
{
var response = await provider.Api.InvokeModelAsync(Id, bodyJson, cancellationToken)
.ConfigureAwait(false);

var generatedText = response?["outputs"]?[0]?["text"]?.GetValue<string>() ?? string.Empty;

messages.Add(generatedText.AsAiMessage());
OnCompletedResponseGenerated(generatedText);
}

var usage = Usage.Empty with
{
Time = watch.Elapsed,
};
AddUsage(usage);
provider.AddUsage(usage);

return new ChatResponse
{
Messages = messages,
UsedSettings = usedSettings,
Usage = usage,
};
}

private static JsonObject CreateBodyJson(string prompt, BedrockChatSettings usedSettings)
{
var bodyJson = new JsonObject
{
["prompt"] = prompt,
["max_tokens"] = usedSettings.MaxTokens!.Value,
["temperature"] = usedSettings.Temperature!.Value,
["top_p"] = usedSettings.TopP!.Value,
// ["top_k"] = usedSettings.TopK!.Value
};
return bodyJson;
}
}
6 changes: 5 additions & 1 deletion src/Providers/Amazon.Bedrock/src/Predefined/Anthropic.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,8 @@ public class ClaudeV2Model(BedrockProvider provider)

/// <inheritdoc />
public class ClaudeV21Model(BedrockProvider provider)
: AnthropicClaudeChatModel(provider, id: "anthropic.claude-v2:1");
: AnthropicClaudeChatModel(provider, id: "anthropic.claude-v2:1");

/// <inheritdoc />
public class Claude3SonnetModel(BedrockProvider provider)
: AnthropicClaudeChatModel(provider, id: "anthropic.claude-3-sonnet-20240229-v1:0");
10 changes: 10 additions & 0 deletions src/Providers/Amazon.Bedrock/src/Predefined/Mistral.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// ReSharper disable once CheckNamespace
namespace LangChain.Providers.Amazon.Bedrock.Predefined.Mistral;

/// <inheritdoc />
public class Mistral7BInstruct(BedrockProvider provider)
: MistralInstructChatModel(provider, id: "mistral.mistral-7b-instruct-v0:2");

/// <inheritdoc />
public class Mistral8x7BInstruct(BedrockProvider provider)
: MistralInstructChatModel(provider, id: "mistral.mixtral-8x7b-instruct-v0:1");
4 changes: 0 additions & 4 deletions src/Providers/Amazon.Bedrock/src/Predefined/Stability.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
// ReSharper disable once CheckNamespace
namespace LangChain.Providers.Amazon.Bedrock.Predefined.Stability;

/// <inheritdoc />
public class StableDiffusionExtraLargeV0Model(BedrockProvider provider)
: StableDiffusionTextToImageModel(provider, id: "stability.stable-diffusion-xl-v0");

/// <inheritdoc />
public class StableDiffusionExtraLargeV1Model(BedrockProvider provider)
: StableDiffusionTextToImageModel(provider, id: "stability.stable-diffusion-xl-v1");
10 changes: 6 additions & 4 deletions src/Providers/Amazon.Bedrock/test/BedrockTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Diagnostics;
using Amazon;
using LangChain.Chains.LLM;
using LangChain.Chains.Sequentials;
using LangChain.Databases;
Expand All @@ -11,6 +12,7 @@
using LangChain.Providers.Amazon.Bedrock.Predefined.Anthropic;
using LangChain.Providers.Amazon.Bedrock.Predefined.Cohere;
using LangChain.Providers.Amazon.Bedrock.Predefined.Meta;
using LangChain.Providers.Amazon.Bedrock.Predefined.Mistral;
using LangChain.Providers.Amazon.Bedrock.Predefined.Stability;
using LangChain.Schema;
using LangChain.Sources;
Expand All @@ -25,11 +27,11 @@ public class BedrockTests
[Test]
public async Task Chains()
{
var provider = new BedrockProvider();
var provider = new BedrockProvider(RegionEndpoint.USWest2);
//var llm = new Jurassic2MidModel(provider);
var llm = new ClaudeV21Model(provider);
//var modelId = "amazon.titan-text-express-v1";
// var modelId = "cohere.command-light-text-v14";
//var llm = new ClaudeV21Model(provider);
//var llm = new Mistral7BInstruct(provider);
var llm = new Claude3SonnetModel(provider);

var template = "What is a good name for a company that makes {product}?";
var prompt = new PromptTemplate(new PromptTemplateInput(template, new List<string>(1) { "product" }));
Expand Down

0 comments on commit 48b1e3d

Please sign in to comment.