Skip to content

Commit

Permalink
fix(tmc#353)(agents): add ChainCallOptions to Plan
Browse files Browse the repository at this point in the history
  • Loading branch information
disk0Dancer committed May 23, 2024
1 parent 4c509c0 commit dae44c3
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 6 deletions.
3 changes: 2 additions & 1 deletion agents/agents.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package agents
import (
"context"

"github.com/tmc/langchaingo/chains"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/tools"
)
Expand All @@ -11,7 +12,7 @@ import (
type Agent interface {
// Plan Given an input and previous steps decide what to do next. Returns
// either actions or a finish.
Plan(ctx context.Context, intermediateSteps []schema.AgentStep, inputs map[string]string) ([]schema.AgentAction, *schema.AgentFinish, error) //nolint:lll
Plan(ctx context.Context, intermediateSteps []schema.AgentStep, inputs map[string]string, options ...chains.ChainCallOption) ([]schema.AgentAction, *schema.AgentFinish, error) //nolint:lll
GetInputKeys() []string
GetOutputKeys() []string
GetTools() []tools.Tool
Expand Down
7 changes: 5 additions & 2 deletions agents/conversational.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ func (a *ConversationalAgent) Plan(
ctx context.Context,
intermediateSteps []schema.AgentStep,
inputs map[string]string,
options ...chains.ChainCallOption,
) ([]schema.AgentAction, *schema.AgentFinish, error) {
fullInputs := make(map[string]any, len(inputs))
for key, value := range inputs {
Expand All @@ -80,12 +81,14 @@ func (a *ConversationalAgent) Plan(
}
}

options = append(options, chains.WithStopWords([]string{"\nObservation:", "\n\tObservation:"}))
options = append(options, chains.WithStreamingFunc(stream))

output, err := chains.Predict(
ctx,
a.Chain,
fullInputs,
chains.WithStopWords([]string{"\nObservation:", "\n\tObservation:"}),
chains.WithStreamingFunc(stream),
options...,
)
if err != nil {
return nil, nil, err
Expand Down
3 changes: 2 additions & 1 deletion agents/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,9 @@ func (e *Executor) doIteration( // nolint
steps []schema.AgentStep,
nameToTool map[string]tools.Tool,
inputs map[string]string,
options ...chains.ChainCallOption,
) ([]schema.AgentStep, map[string]any, error) {
actions, finish, err := e.Agent.Plan(ctx, steps, inputs)
actions, finish, err := e.Agent.Plan(ctx, steps, inputs, options...)
if errors.Is(err, ErrUnableToParseOutput) && e.ErrorHandler != nil {
formattedObservation := err.Error()
if e.ErrorHandler.Formatter != nil {
Expand Down
1 change: 1 addition & 0 deletions agents/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func (a *testAgent) Plan(
_ context.Context,
intermediateSteps []schema.AgentStep,
inputs map[string]string,
_ ...chains.ChainCallOption,
) ([]schema.AgentAction, *schema.AgentFinish, error) {
a.recordedIntermediateSteps = intermediateSteps
a.recordedInputs = inputs
Expand Down
7 changes: 5 additions & 2 deletions agents/mrkl.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ func (a *OneShotZeroAgent) Plan(
ctx context.Context,
intermediateSteps []schema.AgentStep,
inputs map[string]string,
options ...chains.ChainCallOption,
) ([]schema.AgentAction, *schema.AgentFinish, error) {
fullInputs := make(map[string]any, len(inputs))
for key, value := range inputs {
Expand All @@ -82,12 +83,14 @@ func (a *OneShotZeroAgent) Plan(
}
}

options = append(options, chains.WithStopWords([]string{"\nObservation:", "\n\tObservation:"}))
options = append(options, chains.WithStreamingFunc(stream))

output, err := chains.Predict(
ctx,
a.Chain,
fullInputs,
chains.WithStopWords([]string{"\nObservation:", "\n\tObservation:"}),
chains.WithStreamingFunc(stream),
options...,
)
if err != nil {
return nil, nil, err
Expand Down
3 changes: 3 additions & 0 deletions agents/openai_functions_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"

"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/chains"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/prompts"
"github.com/tmc/langchaingo/schema"
Expand Down Expand Up @@ -67,10 +68,12 @@ func (o *OpenAIFunctionsAgent) functions() []llms.FunctionDefinition {
}

// Plan decides what action to take or returns the final result of the input.
// TODO(fix): add {options ...chains.ChainCallOption} to llm request.
func (o *OpenAIFunctionsAgent) Plan(
ctx context.Context,
intermediateSteps []schema.AgentStep,
inputs map[string]string,
_ ...chains.ChainCallOption,
) ([]schema.AgentAction, *schema.AgentFinish, error) {
fullInputs := make(map[string]any, len(inputs))
for key, value := range inputs {
Expand Down

0 comments on commit dae44c3

Please sign in to comment.