From 9887348c42e1f35abf2f1abc9d928ed73e0dbf4d Mon Sep 17 00:00:00 2001 From: h3l Date: Fri, 29 Dec 2023 07:17:37 +0800 Subject: [PATCH] support callback handler to llm chain (#461) * support callback handler to llm chain * fix golint --------- Co-authored-by: lvzhong --- agents/conversational.go | 6 +++++- agents/mrkl.go | 6 +++++- chains/llm.go | 18 +++++++++++------- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/agents/conversational.go b/agents/conversational.go index 9daa779d2..1a493b7d1 100644 --- a/agents/conversational.go +++ b/agents/conversational.go @@ -45,7 +45,11 @@ func NewConversationalAgent(llm llms.LanguageModel, tools []tools.Tool, opts ... } return &ConversationalAgent{ - Chain: chains.NewLLMChain(llm, options.getConversationalPrompt(tools)), + Chain: chains.NewLLMChain( + llm, + options.getConversationalPrompt(tools), + chains.WithCallback(options.callbacksHandler), + ), Tools: tools, OutputKey: options.outputKey, CallbacksHandler: options.callbacksHandler, diff --git a/agents/mrkl.go b/agents/mrkl.go index 574e8d209..8bdcb4bc8 100644 --- a/agents/mrkl.go +++ b/agents/mrkl.go @@ -48,7 +48,11 @@ func NewOneShotAgent(llm llms.LanguageModel, tools []tools.Tool, opts ...Creatio } return &OneShotZeroAgent{ - Chain: chains.NewLLMChain(llm, options.getMrklPrompt(tools)), + Chain: chains.NewLLMChain( + llm, + options.getMrklPrompt(tools), + chains.WithCallback(options.callbacksHandler), + ), Tools: tools, OutputKey: options.outputKey, CallbacksHandler: options.callbacksHandler, diff --git a/chains/llm.go b/chains/llm.go index b70c90514..74efae9c4 100644 --- a/chains/llm.go +++ b/chains/llm.go @@ -29,14 +29,18 @@ var ( ) // NewLLMChain creates a new LLMChain with an LLM and a prompt. -func NewLLMChain(llm llms.LanguageModel, prompt prompts.FormatPrompter) *LLMChain { +func NewLLMChain(llm llms.LanguageModel, prompt prompts.FormatPrompter, opts ...ChainCallOption) *LLMChain { + opt := &chainCallOption{} + for _, o := range opts { + o(opt) + } chain := &LLMChain{ - Prompt: prompt, - LLM: llm, - OutputParser: outputparser.NewSimple(), - Memory: memory.NewSimple(), - - OutputKey: _llmChainDefaultOutputKey, + Prompt: prompt, + LLM: llm, + OutputParser: outputparser.NewSimple(), + Memory: memory.NewSimple(), + OutputKey: _llmChainDefaultOutputKey, + CallbacksHandler: opt.CallbackHandler, } return chain