From f30ef1dc4f5400355adc44ce817732e1b303731b Mon Sep 17 00:00:00 2001 From: Khoroshev Evgeniy Date: Fri, 1 Dec 2023 14:32:51 +0300 Subject: [PATCH] feat: ConversationalRetrievalChain (#85) Co-authored-by: Evgenii Khoroshev Closes #13 --- src/libs/LangChain.Core/Base/BaseChain.cs | 12 +- .../LangChain.Core/Callback/ICallbacks.cs | 5 + src/libs/LangChain.Core/Chains/Base/IChain.cs | 5 +- .../CombineDocuments/StuffDocumentsChain.cs | 18 +-- .../BaseConversationalRetrievalChain.cs | 113 ++++++++++++++ .../BaseConversationalRetrievalChainInput.cs | 63 ++++++++ .../ChatTurnTypeHelper.cs | 26 ++++ .../ConversationalRetrievalChain.cs | 68 +++++++++ .../ConversationalRetrievalChainInput.cs | 23 +++ .../StackableChains/BaseStackableChain.cs | 10 +- .../LangChain.Core/Prompts/PromptTemplate.cs | 50 +++--- .../Providers/Models/Message.cs | 3 + .../Retrievers/BaseRetriever.cs | 8 +- src/libs/LangChain.Core/Schema/ChainValues.cs | 6 + .../ConversationalRetrievalChainTests.cs | 144 ++++++++++++++++++ .../RetrievalQa/RetrievalQaChainTests.cs | 7 +- 16 files changed, 519 insertions(+), 42 deletions(-) create mode 100644 src/libs/LangChain.Core/Chains/ConversationalRetrieval/BaseConversationalRetrievalChain.cs create mode 100644 src/libs/LangChain.Core/Chains/ConversationalRetrieval/BaseConversationalRetrievalChainInput.cs create mode 100644 src/libs/LangChain.Core/Chains/ConversationalRetrieval/ChatTurnTypeHelper.cs create mode 100644 src/libs/LangChain.Core/Chains/ConversationalRetrieval/ConversationalRetrievalChain.cs create mode 100644 src/libs/LangChain.Core/Chains/ConversationalRetrieval/ConversationalRetrievalChainInput.cs create mode 100644 src/tests/LangChain.Core.UnitTests/Chains/ConversationalRetrieval/ConversationalRetrievalChainTests.cs diff --git a/src/libs/LangChain.Core/Base/BaseChain.cs b/src/libs/LangChain.Core/Base/BaseChain.cs index 0ce8fa27..2949ab99 100644 --- a/src/libs/LangChain.Core/Base/BaseChain.cs +++ b/src/libs/LangChain.Core/Base/BaseChain.cs @@ -63,8 +63,13 @@ public abstract class BaseChain(IChainInputs fields) : IChain /// Run the chain using a simple input/output. /// /// The dict input to use to execute the chain. + /// + /// Callbacks to use for this chain run. These will be called in + /// addition to callbacks passed to the chain during construction, but only + /// these runtime callbacks will propagate to calls to other objects. + /// /// A text value containing the result of the chain. - public virtual async Task Run(Dictionary input) + public virtual async Task Run(Dictionary input, ICallbacks? callbacks = null) { var keysLengthDifferent = InputKeys.Length != input.Count; @@ -73,7 +78,7 @@ public virtual async Task Run(Dictionary input) throw new ArgumentException($"Chain {ChainType()} expects {InputKeys.Length} but, received {input.Count}"); } - var returnValues = await CallAsync(new ChainValues(input)); + var returnValues = await CallAsync(new ChainValues(input), callbacks); var returnValue = returnValues.Value.FirstOrDefault(kv => kv.Key == OutputKeys[0]).Value; @@ -88,7 +93,8 @@ public virtual async Task Run(Dictionary input) /// /// /// - public async Task CallAsync(IChainValues values, + public async Task CallAsync( + IChainValues values, ICallbacks? callbacks = null, IReadOnlyList? tags = null, IReadOnlyDictionary? metadata = null) diff --git a/src/libs/LangChain.Core/Callback/ICallbacks.cs b/src/libs/LangChain.Core/Callback/ICallbacks.cs index 46e0c174..2faaa4b4 100644 --- a/src/libs/LangChain.Core/Callback/ICallbacks.cs +++ b/src/libs/LangChain.Core/Callback/ICallbacks.cs @@ -4,6 +4,11 @@ namespace LangChain.Callback; public interface ICallbacks; +public static class ManagerCallbacksExtensions +{ + public static ManagerCallbacks ToCallbacks(this ParentRunManager source) => new ManagerCallbacks(source.GetChild()); +} + public record ManagerCallbacks(CallbackManager Value) : ICallbacks; public record HandlersCallbacks(List Value) : ICallbacks; \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/Base/IChain.cs b/src/libs/LangChain.Core/Chains/Base/IChain.cs index 9e306546..6dbc058e 100644 --- a/src/libs/LangChain.Core/Chains/Base/IChain.cs +++ b/src/libs/LangChain.Core/Chains/Base/IChain.cs @@ -7,7 +7,10 @@ public interface IChain { string[] InputKeys { get; } string[] OutputKeys { get; } - + + Task Run(string input); + Task Run(Dictionary input, ICallbacks? callbacks = null); + Task CallAsync(IChainValues values, ICallbacks? callbacks = null, IReadOnlyList? tags = null, diff --git a/src/libs/LangChain.Core/Chains/CombineDocuments/StuffDocumentsChain.cs b/src/libs/LangChain.Core/Chains/CombineDocuments/StuffDocumentsChain.cs index ec15b762..5264fd9d 100644 --- a/src/libs/LangChain.Core/Chains/CombineDocuments/StuffDocumentsChain.cs +++ b/src/libs/LangChain.Core/Chains/CombineDocuments/StuffDocumentsChain.cs @@ -18,18 +18,18 @@ namespace LangChain.Chains.CombineDocuments; /// public class StuffDocumentsChain : BaseCombineDocumentsChain { - private readonly ILlmChain _llmChain; + public readonly ILlmChain LlmChain; private readonly BasePromptTemplate _documentPrompt; private readonly string _documentVariableName; - private readonly string _documentSeparator = "\n\n"; + private readonly string _documentSeparator; public StuffDocumentsChain(StuffDocumentsChainInput input) : base(input) { - _llmChain = input.LlmChain; + LlmChain = input.LlmChain; _documentPrompt = input.DocumentPrompt; _documentSeparator = input.DocumentSeparator; - var llmChainVariables = _llmChain.Prompt.InputVariables; + var llmChainVariables = LlmChain.Prompt.InputVariables; if (input.DocumentVariableName == null) { @@ -50,7 +50,7 @@ public StuffDocumentsChain(StuffDocumentsChainInput input) : base(input) } public override string[] InputKeys => - base.InputKeys.Concat(_llmChain.InputKeys.Where(k => k != _documentVariableName)).ToArray(); + base.InputKeys.Concat(LlmChain.InputKeys.Where(k => k != _documentVariableName)).ToArray(); public override string ChainType() => "stuff_documents_chain"; @@ -59,17 +59,17 @@ public StuffDocumentsChain(StuffDocumentsChainInput input) : base(input) IReadOnlyDictionary otherKeys) { var inputs = await GetInputs(docs, otherKeys); - var predict = await _llmChain.Predict(new ChainValues(inputs.Value)); + var predict = await LlmChain.Predict(new ChainValues(inputs.Value)); return (predict.ToString() ?? string.Empty, new Dictionary()); } public override async Task PromptLength(IReadOnlyList docs, IReadOnlyDictionary otherKeys) { - if (_llmChain.Llm is ISupportsCountTokens supportsCountTokens) + if (LlmChain.Llm is ISupportsCountTokens supportsCountTokens) { var inputs = await GetInputs(docs, otherKeys); - var prompt = await _llmChain.Prompt.FormatPromptValue(inputs); + var prompt = await LlmChain.Prompt.FormatPromptValue(inputs); return supportsCountTokens.CountTokens(prompt.ToString()); } @@ -84,7 +84,7 @@ private async Task GetInputs(IReadOnlyList docs, IReadOnl var inputs = new Dictionary(); foreach (var kv in otherKeys) { - if (_llmChain.Prompt.InputVariables.Contains(kv.Key)) + if (LlmChain.Prompt.InputVariables.Contains(kv.Key)) { inputs[kv.Key] = kv.Value; } diff --git a/src/libs/LangChain.Core/Chains/ConversationalRetrieval/BaseConversationalRetrievalChain.cs b/src/libs/LangChain.Core/Chains/ConversationalRetrieval/BaseConversationalRetrievalChain.cs new file mode 100644 index 00000000..022c3e1e --- /dev/null +++ b/src/libs/LangChain.Core/Chains/ConversationalRetrieval/BaseConversationalRetrievalChain.cs @@ -0,0 +1,113 @@ +using LangChain.Abstractions.Schema; +using LangChain.Base; +using LangChain.Callback; +using LangChain.Common; +using LangChain.Docstore; +using LangChain.Providers; +using LangChain.Schema; + +namespace LangChain.Chains.ConversationalRetrieval; + +/// +/// Chain for chatting with an index. +/// +public abstract class BaseConversationalRetrievalChain(BaseConversationalRetrievalChainInput fields) : BaseChain(fields) +{ + /// Chain input fields + private readonly BaseConversationalRetrievalChainInput _fields = fields; + + public override string[] InputKeys => new[] { "question", "chat_history" }; + + public override string[] OutputKeys + { + get + { + var outputKeys = new List { _fields.OutputKey }; + if (_fields.ReturnSourceDocuments) + { + outputKeys.Add("source_documents"); + } + + if (_fields.ReturnGeneratedQuestion) + { + outputKeys.Add("generated_question"); + } + + return outputKeys.ToArray(); + } + } + + protected override async Task CallAsync(IChainValues values, CallbackManagerForChainRun? runManager) + { + runManager ??= BaseRunManager.GetNoopManager(); + + var question = values.Value["question"].ToString(); + + var getChatHistory = _fields.GetChatHistory; + var chatHistoryStr = getChatHistory(values.Value["chat_history"] as List); + + string? newQuestion; + if (chatHistoryStr != null) + { + var callbacks = runManager.GetChild(); + newQuestion = await _fields.QuestionGenerator.Run( + new Dictionary + { + ["question"] = question, + ["chat_history"] = chatHistoryStr + }, + callbacks: new ManagerCallbacks(callbacks)); + } + else + { + newQuestion = question; + } + + var docs = await GetDocsAsync(newQuestion, values.Value); + var newInputs = new Dictionary + { + ["chat_history"] = chatHistoryStr, + ["input_documents"] = docs + }; + + if (_fields.RephraseQuestion) + { + newInputs["question"] = newQuestion; + } + + newInputs.TryAddKeyValues(values.Value); + + var answer = await _fields.CombineDocsChain.Run( + input: newInputs, + callbacks: new ManagerCallbacks(runManager.GetChild())); + + var output = new Dictionary + { + [_fields.OutputKey] = answer + }; + + if (_fields.ReturnSourceDocuments) + { + output["source_documents"] = docs; + } + + if (_fields.ReturnGeneratedQuestion) + { + output["generated_question"] = newQuestion; + } + + return new ChainValues(output); + } + + /// + /// Get docs. + /// + /// + /// + /// + /// + protected abstract Task> GetDocsAsync( + string question, + Dictionary inputs, + CallbackManagerForChainRun? runManager = null); +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/ConversationalRetrieval/BaseConversationalRetrievalChainInput.cs b/src/libs/LangChain.Core/Chains/ConversationalRetrieval/BaseConversationalRetrievalChainInput.cs new file mode 100644 index 00000000..955f18b6 --- /dev/null +++ b/src/libs/LangChain.Core/Chains/ConversationalRetrieval/BaseConversationalRetrievalChainInput.cs @@ -0,0 +1,63 @@ +using LangChain.Base; +using LangChain.Chains.CombineDocuments; +using LangChain.Chains.LLM; +using LangChain.Providers; + +namespace LangChain.Chains.ConversationalRetrieval; + +public class BaseConversationalRetrievalChainInput( + BaseCombineDocumentsChain combineDocsChain, + ILlmChain questionGenerator) + : ChainInputs +{ + /// + /// The chain used to combine any retrieved documents. + /// + public BaseCombineDocumentsChain CombineDocsChain { get; } = combineDocsChain; + + /// + /// The chain used to generate a new question for the sake of retrieval. + /// + /// This chain will take in the current question (with variable `question`) + /// and any chat history (with variable `chat_history`) and will produce + /// a new standalone question to be used later on. + /// + public ILlmChain QuestionGenerator { get; } = questionGenerator; + + /// + /// The output key to return the final answer of this chain in. + /// + public string OutputKey { get; set; } = "answer"; + + /// + /// Whether or not to pass the new generated question to the combine_docs_chain. + /// + /// If True, will pass the new generated question along. + /// If False, will only use the new generated question for retrieval and pass the + /// original question along to the . + /// + public bool RephraseQuestion { get; set; } = true; + + /// + /// Return the retrieved source documents as part of the final result. + /// + public bool ReturnSourceDocuments { get; set; } + + /// + /// Return the generated question as part of the final result. + /// + public bool ReturnGeneratedQuestion { get; set; } + + /// + /// An optional function to get a string of the chat history. + /// If None is provided, will use a default. + /// + public Func, string?> GetChatHistory { get; set; } = + ChatTurnTypeHelper.GetChatHistory; + + /// + /// If specified, the chain will return a fixed response if no docs + /// are found for the question. + /// + public string? ResponseIfNoDocsFound { get; set; } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/ConversationalRetrieval/ChatTurnTypeHelper.cs b/src/libs/LangChain.Core/Chains/ConversationalRetrieval/ChatTurnTypeHelper.cs new file mode 100644 index 00000000..157e4d82 --- /dev/null +++ b/src/libs/LangChain.Core/Chains/ConversationalRetrieval/ChatTurnTypeHelper.cs @@ -0,0 +1,26 @@ +using System.Text; +using LangChain.Providers; + +namespace LangChain.Chains.ConversationalRetrieval; + +public static class ChatTurnTypeHelper +{ + public static string GetChatHistory(IReadOnlyList chatHistory) + { + var buffer = new StringBuilder(); + + foreach (var message in chatHistory) + { + var rolePrefix = message.Role switch + { + MessageRole.Human => "Human: ", + MessageRole.Ai => "Assistant: ", + _ => $"{message.Role}: " + }; + + buffer.AppendLine($"{rolePrefix}{message.Content}"); + } + + return buffer.ToString(); + } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/ConversationalRetrieval/ConversationalRetrievalChain.cs b/src/libs/LangChain.Core/Chains/ConversationalRetrieval/ConversationalRetrievalChain.cs new file mode 100644 index 00000000..c75e0203 --- /dev/null +++ b/src/libs/LangChain.Core/Chains/ConversationalRetrieval/ConversationalRetrievalChain.cs @@ -0,0 +1,68 @@ +using LangChain.Callback; +using LangChain.Chains.CombineDocuments; +using LangChain.Docstore; +using LangChain.Providers; + +namespace LangChain.Chains.ConversationalRetrieval; + +/// +/// Chain for having a conversation based on retrieved documents. +/// +/// This chain takes in chat history (a list of messages) and new questions, +/// and then returns an answer to that question. +/// The algorithm for this chain consists of three parts: +/// +/// 1. Use the chat history and the new question to create a "standalone question". +/// This is done so that this question can be passed into the retrieval step to fetch +/// relevant documents. If only the new question was passed in, then relevant context +/// may be lacking. If the whole conversation was passed into retrieval, there may +/// be unnecessary information there that would distract from retrieval. +/// +/// 2. This new question is passed to the retriever and relevant documents are +/// returned. +/// +/// 3. The retrieved documents are passed to an LLM along with either the new question +/// (default behavior) or the original question and chat history to generate a final +/// response. +/// +public class ConversationalRetrievalChain(ConversationalRetrievalChainInput fields) + : BaseConversationalRetrievalChain(fields) +{ + private readonly ConversationalRetrievalChainInput _fields = fields; + + public override string ChainType() => "conversational_retrieval"; + + protected override async Task> GetDocsAsync( + string question, + Dictionary inputs, + CallbackManagerForChainRun? runManager = null) + { + var docs = await _fields.Retriever.GetRelevantDocumentsAsync( + question, + callbacks: runManager?.ToCallbacks()); + + return ReduceTokensBelowLimit(docs); + } + + public List ReduceTokensBelowLimit(IEnumerable docs) + { + var docsList = docs.ToList(); + var numDocs = docsList.Count; + + if (_fields.MaxTokensLimit != null && + _fields.CombineDocsChain is StuffDocumentsChain stuffDocumentsChain && + stuffDocumentsChain.LlmChain.Llm is ISupportsCountTokens counter) + { + var tokens = docsList.Select(doc => counter.CountTokens(doc.PageContent)).ToArray(); + var tokenCount = tokens.Sum(); + + while (tokenCount > _fields.MaxTokensLimit) + { + numDocs -= 1; + tokenCount -= tokens[numDocs]; + } + } + + return docsList.Take(numDocs).ToList(); + } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/ConversationalRetrieval/ConversationalRetrievalChainInput.cs b/src/libs/LangChain.Core/Chains/ConversationalRetrieval/ConversationalRetrievalChainInput.cs new file mode 100644 index 00000000..97cf99ef --- /dev/null +++ b/src/libs/LangChain.Core/Chains/ConversationalRetrieval/ConversationalRetrievalChainInput.cs @@ -0,0 +1,23 @@ +using LangChain.Chains.CombineDocuments; +using LangChain.Chains.LLM; +using LangChain.Retrievers; + +namespace LangChain.Chains.ConversationalRetrieval; + +public class ConversationalRetrievalChainInput( + BaseRetriever retriever, + BaseCombineDocumentsChain combineDocsChain, + ILlmChain questionGenerator) + : BaseConversationalRetrievalChainInput(combineDocsChain, questionGenerator) +{ + /// + /// Retriever to use to fetch documents. + /// + public BaseRetriever Retriever { get; } = retriever; + + /// + /// If set, enforces that the documents returned are less than this limit. + /// This is only enforced if is of type . + /// + public int? MaxTokensLimit { get; set; } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/StackableChains/BaseStackableChain.cs b/src/libs/LangChain.Core/Chains/StackableChains/BaseStackableChain.cs index a69275ad..1d23cff1 100644 --- a/src/libs/LangChain.Core/Chains/StackableChains/BaseStackableChain.cs +++ b/src/libs/LangChain.Core/Chains/StackableChains/BaseStackableChain.cs @@ -82,7 +82,6 @@ public static StackChain BitwiseOr(BaseStackableChain left, BaseStackableChain r public async Task Run() { - var res = await CallAsync(new ChainValues()); return res; } @@ -92,4 +91,13 @@ public async Task Run(string resultKey) var res = await CallAsync(new ChainValues()); return res.Value[resultKey].ToString(); } + + public async Task Run( + Dictionary input, + ICallbacks? callbacks = null) + { + var res = await CallAsync(new ChainValues(input)); + + return res.Value[OutputKeys[0]].ToString(); + } } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Prompts/PromptTemplate.cs b/src/libs/LangChain.Core/Prompts/PromptTemplate.cs index 31f25bee..e092dff6 100644 --- a/src/libs/LangChain.Core/Prompts/PromptTemplate.cs +++ b/src/libs/LangChain.Core/Prompts/PromptTemplate.cs @@ -186,43 +186,31 @@ public static List ParseFString(string template) { // Core logic replicated from internals of pythons built in Formatter class. // https://github.com/python/cpython/blob/135ec7cefbaffd516b77362ad2b2ad1025af462e/Objects/stringlib/unicode_format.h#L700-L706 - List chars = template.ToList(); - List nodes = new List(); - - Func nextBracket = (bracket, start) => - { - for (int i = start; i < chars.Count; i++) - { - if (bracket.Contains(chars[i])) - { - return i; - } - } - return -1; - }; + var chars = template.AsSpan(); + var nodes = new List(); int i = 0; - while (i < chars.Count) + while (i < chars.Length) { - if (chars[i] == '{' && i + 1 < chars.Count && chars[i + 1] == '{') + if (chars[i] == '{' && i + 1 < chars.Length && chars[i + 1] == '{') { nodes.Add(new LiteralNode("{")); i += 2; } - else if (chars[i] == '}' && i + 1 < chars.Count && chars[i + 1] == '}') + else if (chars[i] == '}' && i + 1 < chars.Length && chars[i + 1] == '}') { nodes.Add(new LiteralNode("}")); i += 2; } else if (chars[i] == '{') { - int j = nextBracket("}", i); + var j = GetNextBracketPosition(ref chars, "}", i); if (j < 0) { throw new Exception("Unclosed '{' in template."); } - nodes.Add(new VariableNode(new string(chars.GetRange(i + 1, j - (i + 1)).ToArray()))); + nodes.Add(new VariableNode(chars.Slice(i + 1, j - (i + 1)).ToString())); i = j + 1; } else if (chars[i] == '}') @@ -231,15 +219,31 @@ public static List ParseFString(string template) } else { - int next = nextBracket("{}", i); - string text = next < 0 ? new string(chars.GetRange(i, chars.Count - i).ToArray()) : new string(chars.GetRange(i, next - i).ToArray()); + var next = GetNextBracketPosition(ref chars, "{}", i); + var text = next < 0 + ? chars.Slice(i, chars.Length - i).ToString() + : chars.Slice(i, next - i).ToString(); + nodes.Add(new LiteralNode(text)); - i = next < 0 ? chars.Count : next; + i = next < 0 ? chars.Length : next; } } + return nodes; - } + int GetNextBracketPosition(ref ReadOnlySpan source, string bracket, int start) + { + for (var idx = start; idx < source.Length; idx++) + { + if (bracket.Contains(source[idx])) + { + return idx; + } + } + + return -1; + } + } public static string RenderTemplate(string template, TemplateFormatOptions templateFormat, Dictionary inputValues) { diff --git a/src/libs/LangChain.Core/Providers/Models/Message.cs b/src/libs/LangChain.Core/Providers/Models/Message.cs index 1e320cd7..cd8b9227 100644 --- a/src/libs/LangChain.Core/Providers/Models/Message.cs +++ b/src/libs/LangChain.Core/Providers/Models/Message.cs @@ -11,6 +11,9 @@ public readonly record struct Message( MessageRole Role, string? FunctionName = null) { + public static Message Human(string content) => new(content, MessageRole.Human); + public static Message Ai(string content) => new(content, MessageRole.Ai); + /// /// /// diff --git a/src/libs/LangChain.Core/Retrievers/BaseRetriever.cs b/src/libs/LangChain.Core/Retrievers/BaseRetriever.cs index cf99a9e1..e54fecc8 100644 --- a/src/libs/LangChain.Core/Retrievers/BaseRetriever.cs +++ b/src/libs/LangChain.Core/Retrievers/BaseRetriever.cs @@ -63,14 +63,18 @@ public virtual async Task> GetRelevantDocumentsAsync( try { var docs = await GetRelevantDocumentsCoreAsync(query, runManager); - await runManager.HandleRetrieverEndAsync(query, docs.ToList()); + var docsList = docs.ToList(); + await runManager.HandleRetrieverEndAsync(query, docsList); - return docs; + return docsList; } catch (Exception exception) { if (runManager != null) + { await runManager.HandleRetrieverErrorAsync(exception, query); + } + throw; } } diff --git a/src/libs/LangChain.Core/Schema/ChainValues.cs b/src/libs/LangChain.Core/Schema/ChainValues.cs index 2267a827..7c831cb9 100644 --- a/src/libs/LangChain.Core/Schema/ChainValues.cs +++ b/src/libs/LangChain.Core/Schema/ChainValues.cs @@ -36,4 +36,10 @@ public ChainValues() } public Dictionary Value { get; set; } + + public object this[string key] + { + get => Value[key]; + set => Value[key] = value; + } } diff --git a/src/tests/LangChain.Core.UnitTests/Chains/ConversationalRetrieval/ConversationalRetrievalChainTests.cs b/src/tests/LangChain.Core.UnitTests/Chains/ConversationalRetrieval/ConversationalRetrievalChainTests.cs new file mode 100644 index 00000000..cf0a6541 --- /dev/null +++ b/src/tests/LangChain.Core.UnitTests/Chains/ConversationalRetrieval/ConversationalRetrievalChainTests.cs @@ -0,0 +1,144 @@ +using LangChain.Callback; +using LangChain.Chains.CombineDocuments; +using LangChain.Chains.ConversationalRetrieval; +using LangChain.Chains.LLM; +using LangChain.Docstore; +using LangChain.Prompts; +using LangChain.Providers; +using LangChain.Retrievers; +using LangChain.Schema; +using Moq; + +namespace LangChain.Core.UnitTests.Chains.ConversationalRetrieval; + +[TestFixture] +public class ConversationalRetrievalChainTests +{ + [Test] + public async Task Call_Ok() + { + var combineDocumentsChainInput = new Mock().Object; + + var combineDocsChainMock = new Mock(combineDocumentsChainInput); + combineDocsChainMock.Setup(x => x + .Run(It.IsAny>(), It.IsAny())) + .Returns, ICallbacks>((input, _) => Task.FromResult("Alice")); + + var retrieverMock = new Mock(); + retrieverMock + .Setup(x => x + .GetRelevantDocumentsAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny>(), + It.IsAny>())) + .Returns, Dictionary>((query, _, _, _, _, _) => + { + var docs = new List + { + new Document("first"), + new Document("second"), + new Document("third") + }.AsEnumerable(); + + return Task.FromResult(docs); + }); + + // # This controls how the standalone question is generated. + // # Should take `chat_history` and `question` as input variables. + var template = + "Combine the chat history and follow up question into a standalone question. Chat History: {chat_history}. Follow up question: {question}"; + + var prompt = PromptTemplate.FromTemplate(template); + var questionGeneratorLlmMock = new Mock(); + questionGeneratorLlmMock + .Setup(v => v.GenerateAsync(It.IsAny(), It.IsAny())) + .Returns((request, _) => + { + var chatResponse = new ChatResponse(new[] { Message.Ai("Bob's asking what is hist name") }, Usage.Empty); + return Task.FromResult(chatResponse); + }); + + var llmInput = new LlmChainInput(questionGeneratorLlmMock.Object, prompt); + var questionGeneratorChain = new LlmChain(llmInput); + + var chainInput = new ConversationalRetrievalChainInput(retrieverMock.Object, combineDocsChainMock.Object, questionGeneratorChain) + { + ReturnSourceDocuments = true, + ReturnGeneratedQuestion = true + }; + + var chain = new ConversationalRetrievalChain(chainInput); + + var input = new ChainValues + { + ["question"] = "What is my name?", + ["chat_history"] = new List + { + Message.Human("My name is Alice"), + Message.Ai("Hello Alice") + } + }; + + var result = await chain.CallAsync(input); + + result.Should().NotBeNull(); + result.Value.Should().ContainKey("answer"); + result.Value["answer"].Should().BeEquivalentTo("Alice"); + + result.Value.Should().ContainKey("source_documents"); + var resultSourceDocuments = result.Value["source_documents"] as List; + resultSourceDocuments.Should().NotBeNull(); + resultSourceDocuments.Should().HaveCount(3); + resultSourceDocuments[0].PageContent.Should().BeEquivalentTo("first"); + + result.Value.Should().ContainKey("generated_question"); + result.Value["generated_question"].Should().BeEquivalentTo("Bob's asking what is hist name"); + + questionGeneratorLlmMock + .Verify(v => v.GenerateAsync( + It.Is(request => request.Messages.Count == 1), + It.IsAny())); + } + + [Test] + public async Task ReduceTokensBelowLimit_Ok() + { + var supportsCountTokensMock = new Mock(); + supportsCountTokensMock + .Setup(v => v.CountTokens(It.IsAny())) + .Returns(input => input.Length); + + var chatModelMock = supportsCountTokensMock.As(); + + var llmWithCounterMock = new Mock(); + llmWithCounterMock + .SetupGet(v => v.Llm) + .Returns(chatModelMock.Object); + + var prompt = PromptTemplate.FromTemplate("{documents}"); + llmWithCounterMock + .SetupGet(v => v.Prompt) + .Returns(prompt); + + var combineDocsChainInput = new StuffDocumentsChainInput(llmWithCounterMock.Object); + var combineDocsChain = new StuffDocumentsChain(combineDocsChainInput); + + var retriever = new Mock().Object; + var questionGeneratorChain = new Mock().Object; + + var chainInput = new ConversationalRetrievalChainInput(retriever, combineDocsChain, questionGeneratorChain) + { + MaxTokensLimit = 500 + }; + + var chain = new ConversationalRetrievalChain(chainInput); + + var inputDocs = Enumerable.Range(1, 10).Select(_ => new Document(new string('*', 100))); + var result = chain.ReduceTokensBelowLimit(inputDocs); + + result.Should().HaveCount(5); + } +} \ No newline at end of file diff --git a/src/tests/LangChain.Core.UnitTests/Chains/RetrievalQa/RetrievalQaChainTests.cs b/src/tests/LangChain.Core.UnitTests/Chains/RetrievalQa/RetrievalQaChainTests.cs index c9a03f63..5de21dc5 100644 --- a/src/tests/LangChain.Core.UnitTests/Chains/RetrievalQa/RetrievalQaChainTests.cs +++ b/src/tests/LangChain.Core.UnitTests/Chains/RetrievalQa/RetrievalQaChainTests.cs @@ -40,7 +40,8 @@ public async Task Retrieval_Ok() x["input_documents"].As>() .Select(doc => doc.PageContent) .Intersect(new string[] { "first", "second", "third" }) - .Count() == 3)), + .Count() == 3), + It.IsAny()), Times.Once()); } @@ -76,8 +77,8 @@ private Mock CreateCombineDocumentsChainMock() var mock = new Mock(new Mock().Object); mock.Setup(x => x - .Run(It.IsAny>())) - .Returns>(input => Task.FromResult("answer")); + .Run(It.IsAny>(), It.IsAny())) + .Returns, ICallbacks>((input, _) => Task.FromResult("answer")); return mock; }