Skip to content

Commit

Permalink
Merge pull request #2867 from rockwotj/aws
Browse files Browse the repository at this point in the history
aws/ai: support bedrock embeddings
  • Loading branch information
Jeffail authored Sep 13, 2024
2 parents df6e57e + 1a82ef1 commit bb94bf7
Show file tree
Hide file tree
Showing 4 changed files with 378 additions and 36 deletions.
199 changes: 199 additions & 0 deletions docs/modules/components/pages/processors/aws_bedrock_embeddings.adoc
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
= aws_bedrock_embeddings
:type: processor
:status: experimental
:categories: ["AI"]



////
THIS FILE IS AUTOGENERATED!

To make changes, edit the corresponding source file under:

https://github.com/redpanda-data/connect/tree/main/internal/impl/<provider>.

And:

https://github.com/redpanda-data/connect/tree/main/cmd/tools/docs_gen/templates/plugin.adoc.tmpl
////
// © 2024 Redpanda Data Inc.
component_type_dropdown::[]
Computes vector embeddings on text, using the AWS Bedrock API.
Introduced in version 4.37.0.
[tabs]
======
Common::
+
--
```yml
# Common config fields, showing default values
label: ""
aws_bedrock_embeddings:
model: amazon.titan-embed-text-v1 # No default (required)
text: "" # No default (optional)
```
--
Advanced::
+
--
```yml
# All config fields, showing default values
label: ""
aws_bedrock_embeddings:
region: ""
endpoint: ""
credentials:
profile: ""
id: ""
secret: ""
token: ""
from_ec2_role: false
role: ""
role_external_id: ""
model: amazon.titan-embed-text-v1 # No default (required)
text: "" # No default (optional)
```
--
======
This processor sends text to your chosen large language model (LLM) and computes vector embeddings, using the AWS Bedrock API.
For more information, see the https://docs.aws.amazon.com/bedrock/latest/userguide[AWS Bedrock documentation^].
== Fields
=== `region`
The AWS region to target.
*Type*: `string`
*Default*: `""`
=== `endpoint`
Allows you to specify a custom endpoint for the AWS API.
*Type*: `string`
*Default*: `""`
=== `credentials`
Optional manual configuration of AWS credentials to use. More information can be found in xref:guides:cloud/aws.adoc[].
*Type*: `object`
=== `credentials.profile`
A profile from `~/.aws/credentials` to use.
*Type*: `string`
*Default*: `""`
=== `credentials.id`
The ID of credentials to use.
*Type*: `string`
*Default*: `""`
=== `credentials.secret`
The secret for the credentials being used.
[CAUTION]
====
This field contains sensitive information that usually shouldn't be added to a config directly, read our xref:configuration:secrets.adoc[secrets page for more info].
====
*Type*: `string`
*Default*: `""`
=== `credentials.token`
The token for the credentials being used, required when using short term credentials.
*Type*: `string`
*Default*: `""`
=== `credentials.from_ec2_role`
Use the credentials of a host EC2 machine configured to assume https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_use_switch-role-ec2.html[an IAM role associated with the instance^].
*Type*: `bool`
*Default*: `false`
Requires version 4.2.0 or newer
=== `credentials.role`
A role ARN to assume.
*Type*: `string`
*Default*: `""`
=== `credentials.role_external_id`
An external ID to provide when assuming a role.
*Type*: `string`
*Default*: `""`
=== `model`
The model ID to use. For a full list see the https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html[AWS Bedrock documentation^].
*Type*: `string`
```yml
# Examples
model: amazon.titan-embed-text-v1
model: amazon.titan-embed-text-v2:0
model: cohere.embed-english-v3
model: cohere.embed-multilingual-v3
```
=== `text`
The prompt you want to generate a response for. By default, the processor submits the entire payload as a string.
*Type*: `string`
Original file line number Diff line number Diff line change
Expand Up @@ -23,111 +23,111 @@ import (
)

const (
bedpFieldModel = "model"
bedpFieldUserPrompt = "prompt"
bedpFieldSystemPrompt = "system_prompt"
bedpFieldMaxTokens = "max_tokens"
bedpFieldTemp = "stop"
bedpFieldStop = "temperature"
bedpFieldTopP = "top_p"
bedcpFieldModel = "model"
bedcpFieldUserPrompt = "prompt"
bedcpFieldSystemPrompt = "system_prompt"
bedcpFieldMaxTokens = "max_tokens"
bedcpFieldTemp = "stop"
bedcpFieldStop = "temperature"
bedcpFieldTopP = "top_p"
)

func init() {
err := service.RegisterProcessor("aws_bedrock_chat", newBedrockConfigSpec(), newBedrockProcessor)
err := service.RegisterProcessor("aws_bedrock_chat", newBedrockChatConfigSpec(), newBedrockChatProcessor)
if err != nil {
panic(err)
}
}

func newBedrockConfigSpec() *service.ConfigSpec {
func newBedrockChatConfigSpec() *service.ConfigSpec {
return service.NewConfigSpec().
Summary("Generates responses to messages in a chat conversation, using the AWS Bedrock API.").
Description(`This processor sends prompts to your chosen large language model (LLM) and generates text from the responses, using the AWS Bedrock API.
For more information, see the https://docs.aws.amazon.com/bedrock/latest/userguide[AWS Bedrock documentation^].`).
Categories("AI").
Version("4.34.0").
Fields(config.SessionFields()...).
Field(service.NewStringField(bedpFieldModel).
Field(service.NewStringField(bedcpFieldModel).
Examples("amazon.titan-text-express-v1", "anthropic.claude-3-5-sonnet-20240620-v1:0", "cohere.command-text-v14", "meta.llama3-1-70b-instruct-v1:0", "mistral.mistral-large-2402-v1:0").
Description("The model ID to use. For a full list see the https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html[AWS Bedrock documentation^].")).
Field(service.NewStringField(bedpFieldUserPrompt).
Field(service.NewStringField(bedcpFieldUserPrompt).
Description("The prompt you want to generate a response for. By default, the processor submits the entire payload as a string.").
Optional()).
Field(service.NewStringField(bedpFieldSystemPrompt).
Field(service.NewStringField(bedcpFieldSystemPrompt).
Optional().
Description("The system prompt to submit to the AWS Bedrock LLM.")).
Field(service.NewIntField(bedpFieldMaxTokens).
Field(service.NewIntField(bedcpFieldMaxTokens).
Optional().
Description("The maximum number of tokens to allow in the generated response.").
LintRule(`root = this < 1 { ["field must be greater than or equal to 1"] }`)).
Field(service.NewFloatField(bedpFieldTemp).
Field(service.NewFloatField(bedcpFieldTemp).
Optional().
Description("The likelihood of the model selecting higher-probability options while generating a response. A lower value makes the model omre likely to choose higher-probability options, while a higher value makes the model more likely to choose lower-probability options.").
LintRule(`root = if this < 0 || this > 1 { ["field must be between 0.0-1.0"] }`)).
Field(service.NewStringListField(bedpFieldStop).
Field(service.NewStringListField(bedcpFieldStop).
Optional().
Advanced().
Description("A list of stop sequences. A stop sequence is a sequence of characters that causes the model to stop generating the response.")).
Field(service.NewFloatField(bedpFieldTopP).
Field(service.NewFloatField(bedcpFieldTopP).
Optional().
Advanced().
Description("The percentage of most-likely candidates that the model considers for the next token. For example, if you choose a value of 0.8, the model selects from the top 80% of the probability distribution of tokens that could be next in the sequence. ").
LintRule(`root = if this < 0 || this > 1 { ["field must be between 0.0-1.0"] }`))
}

func newBedrockProcessor(conf *service.ParsedConfig, mgr *service.Resources) (service.Processor, error) {
func newBedrockChatProcessor(conf *service.ParsedConfig, mgr *service.Resources) (service.Processor, error) {
aconf, err := aws.GetSession(context.Background(), conf)
if err != nil {
return nil, err
}
client := bedrockruntime.NewFromConfig(aconf)
model, err := conf.FieldString(bedpFieldModel)
model, err := conf.FieldString(bedcpFieldModel)
if err != nil {
return nil, err
}
p := &bedrockProcessor{
p := &bedrockChatProcessor{
client: client,
model: model,
}
if conf.Contains(bedpFieldUserPrompt) {
pf, err := conf.FieldInterpolatedString(bedpFieldUserPrompt)
if conf.Contains(bedcpFieldUserPrompt) {
pf, err := conf.FieldInterpolatedString(bedcpFieldUserPrompt)
if err != nil {
return nil, err
}
p.userPrompt = pf
}
if conf.Contains(bedpFieldSystemPrompt) {
pf, err := conf.FieldInterpolatedString(bedpFieldSystemPrompt)
if conf.Contains(bedcpFieldSystemPrompt) {
pf, err := conf.FieldInterpolatedString(bedcpFieldSystemPrompt)
if err != nil {
return nil, err
}
p.systemPrompt = pf
}
if conf.Contains(bedpFieldMaxTokens) {
v, err := conf.FieldInt(bedpFieldMaxTokens)
if conf.Contains(bedcpFieldMaxTokens) {
v, err := conf.FieldInt(bedcpFieldMaxTokens)
if err != nil {
return nil, err
}
mt := int32(v)
p.maxTokens = &mt
}
if conf.Contains(bedpFieldTemp) {
v, err := conf.FieldFloat(bedpFieldTemp)
if conf.Contains(bedcpFieldTemp) {
v, err := conf.FieldFloat(bedcpFieldTemp)
if err != nil {
return nil, err
}
t := float32(v)
p.temp = &t
}
if conf.Contains(bedpFieldStop) {
stop, err := conf.FieldStringList(bedpFieldStop)
if conf.Contains(bedcpFieldStop) {
stop, err := conf.FieldStringList(bedcpFieldStop)
if err != nil {
return nil, err
}
p.stop = stop
}
if conf.Contains(bedpFieldTopP) {
v, err := conf.FieldFloat(bedpFieldTopP)
if conf.Contains(bedcpFieldTopP) {
v, err := conf.FieldFloat(bedcpFieldTopP)
if err != nil {
return nil, err
}
Expand All @@ -137,7 +137,7 @@ func newBedrockProcessor(conf *service.ParsedConfig, mgr *service.Resources) (se
return p, nil
}

type bedrockProcessor struct {
type bedrockChatProcessor struct {
client *bedrockruntime.Client
model string

Expand All @@ -149,7 +149,7 @@ type bedrockProcessor struct {
topP *float32
}

func (b *bedrockProcessor) Process(ctx context.Context, msg *service.Message) (service.MessageBatch, error) {
func (b *bedrockChatProcessor) Process(ctx context.Context, msg *service.Message) (service.MessageBatch, error) {
prompt, err := b.computePrompt(msg)
if err != nil {
return nil, err
Expand All @@ -176,7 +176,7 @@ func (b *bedrockProcessor) Process(ctx context.Context, msg *service.Message) (s
if b.systemPrompt != nil {
prompt, err := b.systemPrompt.TryString(msg)
if err != nil {
return nil, fmt.Errorf("unable to interpolate `%s`: %w", bedpFieldSystemPrompt, err)
return nil, fmt.Errorf("unable to interpolate `%s`: %w", bedcpFieldSystemPrompt, err)
}
input.System = []bedrocktypes.SystemContentBlock{
&bedrocktypes.SystemContentBlockMemberText{Value: prompt},
Expand Down Expand Up @@ -204,7 +204,7 @@ func (b *bedrockProcessor) Process(ctx context.Context, msg *service.Message) (s
return service.MessageBatch{out}, nil
}

func (b *bedrockProcessor) computePrompt(msg *service.Message) (string, error) {
func (b *bedrockChatProcessor) computePrompt(msg *service.Message) (string, error) {
if b.userPrompt != nil {
return b.userPrompt.TryString(msg)
}
Expand All @@ -218,6 +218,6 @@ func (b *bedrockProcessor) computePrompt(msg *service.Message) (string, error) {
return string(buf), nil
}

func (b *bedrockProcessor) Close(ctx context.Context) error {
func (b *bedrockChatProcessor) Close(ctx context.Context) error {
return nil
}
Loading

0 comments on commit bb94bf7

Please sign in to comment.