-
-
Notifications
You must be signed in to change notification settings - Fork 92
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Added Claude3 and Mistral models for AWS Bedrock (#159)
* fix: changed Titan's TextToImage to support images * ImageToText working. needs to get refactored and cleaned * feat: Added ImageToText abstractions and HuggingFace implementation. also added example to HF sample * fix: remove postgres tests from bedrock tests * feat: Added ImageToTextGenerationChain * fix: SageMaker customizable inputs and responses * fix: removed ItemGroup * fix: Fixed warnings. * fix: Fixed lastes commit. * fix: Fixed warnings. * feat: Added Claude3 and Mistral models
- Loading branch information
Showing
7 changed files
with
147 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
106 changes: 106 additions & 0 deletions
106
src/Providers/Amazon.Bedrock/src/Chat/MistralInstructChatModel.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
using System.Diagnostics; | ||
using System.Text; | ||
using System.Text.Json; | ||
using System.Text.Json.Nodes; | ||
using Amazon.BedrockRuntime.Model; | ||
using LangChain.Providers.Amazon.Bedrock.Internal; | ||
|
||
// ReSharper disable once CheckNamespace | ||
namespace LangChain.Providers.Amazon.Bedrock; | ||
|
||
public abstract class MistralInstructChatModel( | ||
BedrockProvider provider, | ||
string id) | ||
: ChatModel(id) | ||
{ | ||
public override async Task<ChatResponse> GenerateAsync( | ||
ChatRequest request, | ||
ChatSettings? settings = null, | ||
CancellationToken cancellationToken = default) | ||
{ | ||
request = request ?? throw new ArgumentNullException(nameof(request)); | ||
|
||
var watch = Stopwatch.StartNew(); | ||
var prompt = request.Messages.ToSimplePrompt(); | ||
var messages = request.Messages.ToList(); | ||
|
||
var stringBuilder = new StringBuilder(); | ||
|
||
var usedSettings = BedrockChatSettings.Calculate( | ||
requestSettings: settings, | ||
modelSettings: Settings, | ||
providerSettings: provider.ChatSettings); | ||
|
||
var bodyJson = CreateBodyJson(prompt, usedSettings); | ||
|
||
if (usedSettings.UseStreaming == true) | ||
{ | ||
var streamRequest = BedrockModelStreamRequest.Create(Id, bodyJson); | ||
var response = await provider.Api.InvokeModelWithResponseStreamAsync(streamRequest, cancellationToken).ConfigureAwait(false); | ||
|
||
foreach (var payloadPart in response.Body) | ||
{ | ||
var streamEvent = (PayloadPart)payloadPart; | ||
var chunk = await JsonSerializer.DeserializeAsync<JsonObject>(streamEvent.Bytes, cancellationToken: cancellationToken) | ||
.ConfigureAwait(false); | ||
var delta = chunk?["outputText"]!.GetValue<string>(); | ||
|
||
OnPartialResponseGenerated(delta!); | ||
stringBuilder.Append(delta); | ||
|
||
var finished = chunk?["completionReason"]?.GetValue<string>(); | ||
if (finished?.ToUpperInvariant() == "FINISH") | ||
{ | ||
OnCompletedResponseGenerated(stringBuilder.ToString()); | ||
} | ||
} | ||
|
||
OnPartialResponseGenerated(Environment.NewLine); | ||
stringBuilder.Append(Environment.NewLine); | ||
|
||
var newMessage = new Message( | ||
Content: stringBuilder.ToString(), | ||
Role: MessageRole.Ai); | ||
messages.Add(newMessage); | ||
|
||
OnCompletedResponseGenerated(newMessage.Content); | ||
} | ||
else | ||
{ | ||
var response = await provider.Api.InvokeModelAsync(Id, bodyJson, cancellationToken) | ||
.ConfigureAwait(false); | ||
|
||
var generatedText = response?["outputs"]?[0]?["text"]?.GetValue<string>() ?? string.Empty; | ||
|
||
messages.Add(generatedText.AsAiMessage()); | ||
OnCompletedResponseGenerated(generatedText); | ||
} | ||
|
||
var usage = Usage.Empty with | ||
{ | ||
Time = watch.Elapsed, | ||
}; | ||
AddUsage(usage); | ||
provider.AddUsage(usage); | ||
|
||
return new ChatResponse | ||
{ | ||
Messages = messages, | ||
UsedSettings = usedSettings, | ||
Usage = usage, | ||
}; | ||
} | ||
|
||
private static JsonObject CreateBodyJson(string prompt, BedrockChatSettings usedSettings) | ||
{ | ||
var bodyJson = new JsonObject | ||
{ | ||
["prompt"] = prompt, | ||
["max_tokens"] = usedSettings.MaxTokens!.Value, | ||
["temperature"] = usedSettings.Temperature!.Value, | ||
["top_p"] = usedSettings.TopP!.Value, | ||
// ["top_k"] = usedSettings.TopK!.Value | ||
}; | ||
return bodyJson; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
// ReSharper disable once CheckNamespace | ||
namespace LangChain.Providers.Amazon.Bedrock.Predefined.Mistral; | ||
|
||
/// <inheritdoc /> | ||
public class Mistral7BInstruct(BedrockProvider provider) | ||
: MistralInstructChatModel(provider, id: "mistral.mistral-7b-instruct-v0:2"); | ||
|
||
/// <inheritdoc /> | ||
public class Mistral8x7BInstruct(BedrockProvider provider) | ||
: MistralInstructChatModel(provider, id: "mistral.mixtral-8x7b-instruct-v0:1"); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,6 @@ | ||
// ReSharper disable once CheckNamespace | ||
namespace LangChain.Providers.Amazon.Bedrock.Predefined.Stability; | ||
|
||
/// <inheritdoc /> | ||
public class StableDiffusionExtraLargeV0Model(BedrockProvider provider) | ||
: StableDiffusionTextToImageModel(provider, id: "stability.stable-diffusion-xl-v0"); | ||
|
||
/// <inheritdoc /> | ||
public class StableDiffusionExtraLargeV1Model(BedrockProvider provider) | ||
: StableDiffusionTextToImageModel(provider, id: "stability.stable-diffusion-xl-v1"); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters