diff --git a/src/libs/LangChain.Core/Base/BaseChain.cs b/src/libs/LangChain.Core/Base/BaseChain.cs
index 0ce8fa27..2949ab99 100644
--- a/src/libs/LangChain.Core/Base/BaseChain.cs
+++ b/src/libs/LangChain.Core/Base/BaseChain.cs
@@ -63,8 +63,13 @@ public abstract class BaseChain(IChainInputs fields) : IChain
/// Run the chain using a simple input/output.
///
/// The dict input to use to execute the chain.
+ ///
+ /// Callbacks to use for this chain run. These will be called in
+ /// addition to callbacks passed to the chain during construction, but only
+ /// these runtime callbacks will propagate to calls to other objects.
+ ///
/// A text value containing the result of the chain.
- public virtual async Task Run(Dictionary input)
+ public virtual async Task Run(Dictionary input, ICallbacks? callbacks = null)
{
var keysLengthDifferent = InputKeys.Length != input.Count;
@@ -73,7 +78,7 @@ public virtual async Task Run(Dictionary input)
throw new ArgumentException($"Chain {ChainType()} expects {InputKeys.Length} but, received {input.Count}");
}
- var returnValues = await CallAsync(new ChainValues(input));
+ var returnValues = await CallAsync(new ChainValues(input), callbacks);
var returnValue = returnValues.Value.FirstOrDefault(kv => kv.Key == OutputKeys[0]).Value;
@@ -88,7 +93,8 @@ public virtual async Task Run(Dictionary input)
///
///
///
- public async Task CallAsync(IChainValues values,
+ public async Task CallAsync(
+ IChainValues values,
ICallbacks? callbacks = null,
IReadOnlyList? tags = null,
IReadOnlyDictionary? metadata = null)
diff --git a/src/libs/LangChain.Core/Callback/ICallbacks.cs b/src/libs/LangChain.Core/Callback/ICallbacks.cs
index 46e0c174..2faaa4b4 100644
--- a/src/libs/LangChain.Core/Callback/ICallbacks.cs
+++ b/src/libs/LangChain.Core/Callback/ICallbacks.cs
@@ -4,6 +4,11 @@ namespace LangChain.Callback;
public interface ICallbacks;
+public static class ManagerCallbacksExtensions
+{
+ public static ManagerCallbacks ToCallbacks(this ParentRunManager source) => new ManagerCallbacks(source.GetChild());
+}
+
public record ManagerCallbacks(CallbackManager Value) : ICallbacks;
public record HandlersCallbacks(List Value) : ICallbacks;
\ No newline at end of file
diff --git a/src/libs/LangChain.Core/Chains/Base/IChain.cs b/src/libs/LangChain.Core/Chains/Base/IChain.cs
index 9e306546..6dbc058e 100644
--- a/src/libs/LangChain.Core/Chains/Base/IChain.cs
+++ b/src/libs/LangChain.Core/Chains/Base/IChain.cs
@@ -7,7 +7,10 @@ public interface IChain
{
string[] InputKeys { get; }
string[] OutputKeys { get; }
-
+
+ Task Run(string input);
+ Task Run(Dictionary input, ICallbacks? callbacks = null);
+
Task CallAsync(IChainValues values,
ICallbacks? callbacks = null,
IReadOnlyList? tags = null,
diff --git a/src/libs/LangChain.Core/Chains/CombineDocuments/StuffDocumentsChain.cs b/src/libs/LangChain.Core/Chains/CombineDocuments/StuffDocumentsChain.cs
index ec15b762..5264fd9d 100644
--- a/src/libs/LangChain.Core/Chains/CombineDocuments/StuffDocumentsChain.cs
+++ b/src/libs/LangChain.Core/Chains/CombineDocuments/StuffDocumentsChain.cs
@@ -18,18 +18,18 @@ namespace LangChain.Chains.CombineDocuments;
///
public class StuffDocumentsChain : BaseCombineDocumentsChain
{
- private readonly ILlmChain _llmChain;
+ public readonly ILlmChain LlmChain;
private readonly BasePromptTemplate _documentPrompt;
private readonly string _documentVariableName;
- private readonly string _documentSeparator = "\n\n";
+ private readonly string _documentSeparator;
public StuffDocumentsChain(StuffDocumentsChainInput input) : base(input)
{
- _llmChain = input.LlmChain;
+ LlmChain = input.LlmChain;
_documentPrompt = input.DocumentPrompt;
_documentSeparator = input.DocumentSeparator;
- var llmChainVariables = _llmChain.Prompt.InputVariables;
+ var llmChainVariables = LlmChain.Prompt.InputVariables;
if (input.DocumentVariableName == null)
{
@@ -50,7 +50,7 @@ public StuffDocumentsChain(StuffDocumentsChainInput input) : base(input)
}
public override string[] InputKeys =>
- base.InputKeys.Concat(_llmChain.InputKeys.Where(k => k != _documentVariableName)).ToArray();
+ base.InputKeys.Concat(LlmChain.InputKeys.Where(k => k != _documentVariableName)).ToArray();
public override string ChainType() => "stuff_documents_chain";
@@ -59,17 +59,17 @@ public StuffDocumentsChain(StuffDocumentsChainInput input) : base(input)
IReadOnlyDictionary otherKeys)
{
var inputs = await GetInputs(docs, otherKeys);
- var predict = await _llmChain.Predict(new ChainValues(inputs.Value));
+ var predict = await LlmChain.Predict(new ChainValues(inputs.Value));
return (predict.ToString() ?? string.Empty, new Dictionary());
}
public override async Task PromptLength(IReadOnlyList docs, IReadOnlyDictionary otherKeys)
{
- if (_llmChain.Llm is ISupportsCountTokens supportsCountTokens)
+ if (LlmChain.Llm is ISupportsCountTokens supportsCountTokens)
{
var inputs = await GetInputs(docs, otherKeys);
- var prompt = await _llmChain.Prompt.FormatPromptValue(inputs);
+ var prompt = await LlmChain.Prompt.FormatPromptValue(inputs);
return supportsCountTokens.CountTokens(prompt.ToString());
}
@@ -84,7 +84,7 @@ private async Task GetInputs(IReadOnlyList docs, IReadOnl
var inputs = new Dictionary();
foreach (var kv in otherKeys)
{
- if (_llmChain.Prompt.InputVariables.Contains(kv.Key))
+ if (LlmChain.Prompt.InputVariables.Contains(kv.Key))
{
inputs[kv.Key] = kv.Value;
}
diff --git a/src/libs/LangChain.Core/Chains/ConversationalRetrieval/BaseConversationalRetrievalChain.cs b/src/libs/LangChain.Core/Chains/ConversationalRetrieval/BaseConversationalRetrievalChain.cs
new file mode 100644
index 00000000..022c3e1e
--- /dev/null
+++ b/src/libs/LangChain.Core/Chains/ConversationalRetrieval/BaseConversationalRetrievalChain.cs
@@ -0,0 +1,113 @@
+using LangChain.Abstractions.Schema;
+using LangChain.Base;
+using LangChain.Callback;
+using LangChain.Common;
+using LangChain.Docstore;
+using LangChain.Providers;
+using LangChain.Schema;
+
+namespace LangChain.Chains.ConversationalRetrieval;
+
+///
+/// Chain for chatting with an index.
+///
+public abstract class BaseConversationalRetrievalChain(BaseConversationalRetrievalChainInput fields) : BaseChain(fields)
+{
+ /// Chain input fields
+ private readonly BaseConversationalRetrievalChainInput _fields = fields;
+
+ public override string[] InputKeys => new[] { "question", "chat_history" };
+
+ public override string[] OutputKeys
+ {
+ get
+ {
+ var outputKeys = new List { _fields.OutputKey };
+ if (_fields.ReturnSourceDocuments)
+ {
+ outputKeys.Add("source_documents");
+ }
+
+ if (_fields.ReturnGeneratedQuestion)
+ {
+ outputKeys.Add("generated_question");
+ }
+
+ return outputKeys.ToArray();
+ }
+ }
+
+ protected override async Task CallAsync(IChainValues values, CallbackManagerForChainRun? runManager)
+ {
+ runManager ??= BaseRunManager.GetNoopManager();
+
+ var question = values.Value["question"].ToString();
+
+ var getChatHistory = _fields.GetChatHistory;
+ var chatHistoryStr = getChatHistory(values.Value["chat_history"] as List);
+
+ string? newQuestion;
+ if (chatHistoryStr != null)
+ {
+ var callbacks = runManager.GetChild();
+ newQuestion = await _fields.QuestionGenerator.Run(
+ new Dictionary
+ {
+ ["question"] = question,
+ ["chat_history"] = chatHistoryStr
+ },
+ callbacks: new ManagerCallbacks(callbacks));
+ }
+ else
+ {
+ newQuestion = question;
+ }
+
+ var docs = await GetDocsAsync(newQuestion, values.Value);
+ var newInputs = new Dictionary
+ {
+ ["chat_history"] = chatHistoryStr,
+ ["input_documents"] = docs
+ };
+
+ if (_fields.RephraseQuestion)
+ {
+ newInputs["question"] = newQuestion;
+ }
+
+ newInputs.TryAddKeyValues(values.Value);
+
+ var answer = await _fields.CombineDocsChain.Run(
+ input: newInputs,
+ callbacks: new ManagerCallbacks(runManager.GetChild()));
+
+ var output = new Dictionary
+ {
+ [_fields.OutputKey] = answer
+ };
+
+ if (_fields.ReturnSourceDocuments)
+ {
+ output["source_documents"] = docs;
+ }
+
+ if (_fields.ReturnGeneratedQuestion)
+ {
+ output["generated_question"] = newQuestion;
+ }
+
+ return new ChainValues(output);
+ }
+
+ ///
+ /// Get docs.
+ ///
+ ///
+ ///
+ ///
+ ///
+ protected abstract Task> GetDocsAsync(
+ string question,
+ Dictionary inputs,
+ CallbackManagerForChainRun? runManager = null);
+}
\ No newline at end of file
diff --git a/src/libs/LangChain.Core/Chains/ConversationalRetrieval/BaseConversationalRetrievalChainInput.cs b/src/libs/LangChain.Core/Chains/ConversationalRetrieval/BaseConversationalRetrievalChainInput.cs
new file mode 100644
index 00000000..955f18b6
--- /dev/null
+++ b/src/libs/LangChain.Core/Chains/ConversationalRetrieval/BaseConversationalRetrievalChainInput.cs
@@ -0,0 +1,63 @@
+using LangChain.Base;
+using LangChain.Chains.CombineDocuments;
+using LangChain.Chains.LLM;
+using LangChain.Providers;
+
+namespace LangChain.Chains.ConversationalRetrieval;
+
+public class BaseConversationalRetrievalChainInput(
+ BaseCombineDocumentsChain combineDocsChain,
+ ILlmChain questionGenerator)
+ : ChainInputs
+{
+ ///
+ /// The chain used to combine any retrieved documents.
+ ///
+ public BaseCombineDocumentsChain CombineDocsChain { get; } = combineDocsChain;
+
+ ///
+ /// The chain used to generate a new question for the sake of retrieval.
+ ///
+ /// This chain will take in the current question (with variable `question`)
+ /// and any chat history (with variable `chat_history`) and will produce
+ /// a new standalone question to be used later on.
+ ///
+ public ILlmChain QuestionGenerator { get; } = questionGenerator;
+
+ ///
+ /// The output key to return the final answer of this chain in.
+ ///
+ public string OutputKey { get; set; } = "answer";
+
+ ///
+ /// Whether or not to pass the new generated question to the combine_docs_chain.
+ ///
+ /// If True, will pass the new generated question along.
+ /// If False, will only use the new generated question for retrieval and pass the
+ /// original question along to the .
+ ///
+ public bool RephraseQuestion { get; set; } = true;
+
+ ///
+ /// Return the retrieved source documents as part of the final result.
+ ///
+ public bool ReturnSourceDocuments { get; set; }
+
+ ///
+ /// Return the generated question as part of the final result.
+ ///
+ public bool ReturnGeneratedQuestion { get; set; }
+
+ ///
+ /// An optional function to get a string of the chat history.
+ /// If None is provided, will use a default.
+ ///
+ public Func, string?> GetChatHistory { get; set; } =
+ ChatTurnTypeHelper.GetChatHistory;
+
+ ///
+ /// If specified, the chain will return a fixed response if no docs
+ /// are found for the question.
+ ///
+ public string? ResponseIfNoDocsFound { get; set; }
+}
\ No newline at end of file
diff --git a/src/libs/LangChain.Core/Chains/ConversationalRetrieval/ChatTurnTypeHelper.cs b/src/libs/LangChain.Core/Chains/ConversationalRetrieval/ChatTurnTypeHelper.cs
new file mode 100644
index 00000000..157e4d82
--- /dev/null
+++ b/src/libs/LangChain.Core/Chains/ConversationalRetrieval/ChatTurnTypeHelper.cs
@@ -0,0 +1,26 @@
+using System.Text;
+using LangChain.Providers;
+
+namespace LangChain.Chains.ConversationalRetrieval;
+
+public static class ChatTurnTypeHelper
+{
+ public static string GetChatHistory(IReadOnlyList chatHistory)
+ {
+ var buffer = new StringBuilder();
+
+ foreach (var message in chatHistory)
+ {
+ var rolePrefix = message.Role switch
+ {
+ MessageRole.Human => "Human: ",
+ MessageRole.Ai => "Assistant: ",
+ _ => $"{message.Role}: "
+ };
+
+ buffer.AppendLine($"{rolePrefix}{message.Content}");
+ }
+
+ return buffer.ToString();
+ }
+}
\ No newline at end of file
diff --git a/src/libs/LangChain.Core/Chains/ConversationalRetrieval/ConversationalRetrievalChain.cs b/src/libs/LangChain.Core/Chains/ConversationalRetrieval/ConversationalRetrievalChain.cs
new file mode 100644
index 00000000..c75e0203
--- /dev/null
+++ b/src/libs/LangChain.Core/Chains/ConversationalRetrieval/ConversationalRetrievalChain.cs
@@ -0,0 +1,68 @@
+using LangChain.Callback;
+using LangChain.Chains.CombineDocuments;
+using LangChain.Docstore;
+using LangChain.Providers;
+
+namespace LangChain.Chains.ConversationalRetrieval;
+
+///
+/// Chain for having a conversation based on retrieved documents.
+///
+/// This chain takes in chat history (a list of messages) and new questions,
+/// and then returns an answer to that question.
+/// The algorithm for this chain consists of three parts:
+///
+/// 1. Use the chat history and the new question to create a "standalone question".
+/// This is done so that this question can be passed into the retrieval step to fetch
+/// relevant documents. If only the new question was passed in, then relevant context
+/// may be lacking. If the whole conversation was passed into retrieval, there may
+/// be unnecessary information there that would distract from retrieval.
+///
+/// 2. This new question is passed to the retriever and relevant documents are
+/// returned.
+///
+/// 3. The retrieved documents are passed to an LLM along with either the new question
+/// (default behavior) or the original question and chat history to generate a final
+/// response.
+///
+public class ConversationalRetrievalChain(ConversationalRetrievalChainInput fields)
+ : BaseConversationalRetrievalChain(fields)
+{
+ private readonly ConversationalRetrievalChainInput _fields = fields;
+
+ public override string ChainType() => "conversational_retrieval";
+
+ protected override async Task> GetDocsAsync(
+ string question,
+ Dictionary inputs,
+ CallbackManagerForChainRun? runManager = null)
+ {
+ var docs = await _fields.Retriever.GetRelevantDocumentsAsync(
+ question,
+ callbacks: runManager?.ToCallbacks());
+
+ return ReduceTokensBelowLimit(docs);
+ }
+
+ public List ReduceTokensBelowLimit(IEnumerable docs)
+ {
+ var docsList = docs.ToList();
+ var numDocs = docsList.Count;
+
+ if (_fields.MaxTokensLimit != null &&
+ _fields.CombineDocsChain is StuffDocumentsChain stuffDocumentsChain &&
+ stuffDocumentsChain.LlmChain.Llm is ISupportsCountTokens counter)
+ {
+ var tokens = docsList.Select(doc => counter.CountTokens(doc.PageContent)).ToArray();
+ var tokenCount = tokens.Sum();
+
+ while (tokenCount > _fields.MaxTokensLimit)
+ {
+ numDocs -= 1;
+ tokenCount -= tokens[numDocs];
+ }
+ }
+
+ return docsList.Take(numDocs).ToList();
+ }
+}
\ No newline at end of file
diff --git a/src/libs/LangChain.Core/Chains/ConversationalRetrieval/ConversationalRetrievalChainInput.cs b/src/libs/LangChain.Core/Chains/ConversationalRetrieval/ConversationalRetrievalChainInput.cs
new file mode 100644
index 00000000..97cf99ef
--- /dev/null
+++ b/src/libs/LangChain.Core/Chains/ConversationalRetrieval/ConversationalRetrievalChainInput.cs
@@ -0,0 +1,23 @@
+using LangChain.Chains.CombineDocuments;
+using LangChain.Chains.LLM;
+using LangChain.Retrievers;
+
+namespace LangChain.Chains.ConversationalRetrieval;
+
+public class ConversationalRetrievalChainInput(
+ BaseRetriever retriever,
+ BaseCombineDocumentsChain combineDocsChain,
+ ILlmChain questionGenerator)
+ : BaseConversationalRetrievalChainInput(combineDocsChain, questionGenerator)
+{
+ ///
+ /// Retriever to use to fetch documents.
+ ///
+ public BaseRetriever Retriever { get; } = retriever;
+
+ ///
+ /// If set, enforces that the documents returned are less than this limit.
+ /// This is only enforced if is of type .
+ ///
+ public int? MaxTokensLimit { get; set; }
+}
\ No newline at end of file
diff --git a/src/libs/LangChain.Core/Chains/StackableChains/BaseStackableChain.cs b/src/libs/LangChain.Core/Chains/StackableChains/BaseStackableChain.cs
index a69275ad..1d23cff1 100644
--- a/src/libs/LangChain.Core/Chains/StackableChains/BaseStackableChain.cs
+++ b/src/libs/LangChain.Core/Chains/StackableChains/BaseStackableChain.cs
@@ -82,7 +82,6 @@ public static StackChain BitwiseOr(BaseStackableChain left, BaseStackableChain r
public async Task Run()
{
-
var res = await CallAsync(new ChainValues());
return res;
}
@@ -92,4 +91,13 @@ public async Task Run(string resultKey)
var res = await CallAsync(new ChainValues());
return res.Value[resultKey].ToString();
}
+
+ public async Task Run(
+ Dictionary input,
+ ICallbacks? callbacks = null)
+ {
+ var res = await CallAsync(new ChainValues(input));
+
+ return res.Value[OutputKeys[0]].ToString();
+ }
}
\ No newline at end of file
diff --git a/src/libs/LangChain.Core/Prompts/PromptTemplate.cs b/src/libs/LangChain.Core/Prompts/PromptTemplate.cs
index 31f25bee..e092dff6 100644
--- a/src/libs/LangChain.Core/Prompts/PromptTemplate.cs
+++ b/src/libs/LangChain.Core/Prompts/PromptTemplate.cs
@@ -186,43 +186,31 @@ public static List ParseFString(string template)
{
// Core logic replicated from internals of pythons built in Formatter class.
// https://github.com/python/cpython/blob/135ec7cefbaffd516b77362ad2b2ad1025af462e/Objects/stringlib/unicode_format.h#L700-L706
- List chars = template.ToList();
- List nodes = new List();
-
- Func nextBracket = (bracket, start) =>
- {
- for (int i = start; i < chars.Count; i++)
- {
- if (bracket.Contains(chars[i]))
- {
- return i;
- }
- }
- return -1;
- };
+ var chars = template.AsSpan();
+ var nodes = new List();
int i = 0;
- while (i < chars.Count)
+ while (i < chars.Length)
{
- if (chars[i] == '{' && i + 1 < chars.Count && chars[i + 1] == '{')
+ if (chars[i] == '{' && i + 1 < chars.Length && chars[i + 1] == '{')
{
nodes.Add(new LiteralNode("{"));
i += 2;
}
- else if (chars[i] == '}' && i + 1 < chars.Count && chars[i + 1] == '}')
+ else if (chars[i] == '}' && i + 1 < chars.Length && chars[i + 1] == '}')
{
nodes.Add(new LiteralNode("}"));
i += 2;
}
else if (chars[i] == '{')
{
- int j = nextBracket("}", i);
+ var j = GetNextBracketPosition(ref chars, "}", i);
if (j < 0)
{
throw new Exception("Unclosed '{' in template.");
}
- nodes.Add(new VariableNode(new string(chars.GetRange(i + 1, j - (i + 1)).ToArray())));
+ nodes.Add(new VariableNode(chars.Slice(i + 1, j - (i + 1)).ToString()));
i = j + 1;
}
else if (chars[i] == '}')
@@ -231,15 +219,31 @@ public static List ParseFString(string template)
}
else
{
- int next = nextBracket("{}", i);
- string text = next < 0 ? new string(chars.GetRange(i, chars.Count - i).ToArray()) : new string(chars.GetRange(i, next - i).ToArray());
+ var next = GetNextBracketPosition(ref chars, "{}", i);
+ var text = next < 0
+ ? chars.Slice(i, chars.Length - i).ToString()
+ : chars.Slice(i, next - i).ToString();
+
nodes.Add(new LiteralNode(text));
- i = next < 0 ? chars.Count : next;
+ i = next < 0 ? chars.Length : next;
}
}
+
return nodes;
- }
+ int GetNextBracketPosition(ref ReadOnlySpan source, string bracket, int start)
+ {
+ for (var idx = start; idx < source.Length; idx++)
+ {
+ if (bracket.Contains(source[idx]))
+ {
+ return idx;
+ }
+ }
+
+ return -1;
+ }
+ }
public static string RenderTemplate(string template, TemplateFormatOptions templateFormat, Dictionary inputValues)
{
diff --git a/src/libs/LangChain.Core/Providers/Models/Message.cs b/src/libs/LangChain.Core/Providers/Models/Message.cs
index 1e320cd7..cd8b9227 100644
--- a/src/libs/LangChain.Core/Providers/Models/Message.cs
+++ b/src/libs/LangChain.Core/Providers/Models/Message.cs
@@ -11,6 +11,9 @@ public readonly record struct Message(
MessageRole Role,
string? FunctionName = null)
{
+ public static Message Human(string content) => new(content, MessageRole.Human);
+ public static Message Ai(string content) => new(content, MessageRole.Ai);
+
///
///
///
diff --git a/src/libs/LangChain.Core/Retrievers/BaseRetriever.cs b/src/libs/LangChain.Core/Retrievers/BaseRetriever.cs
index cf99a9e1..e54fecc8 100644
--- a/src/libs/LangChain.Core/Retrievers/BaseRetriever.cs
+++ b/src/libs/LangChain.Core/Retrievers/BaseRetriever.cs
@@ -63,14 +63,18 @@ public virtual async Task> GetRelevantDocumentsAsync(
try
{
var docs = await GetRelevantDocumentsCoreAsync(query, runManager);
- await runManager.HandleRetrieverEndAsync(query, docs.ToList());
+ var docsList = docs.ToList();
+ await runManager.HandleRetrieverEndAsync(query, docsList);
- return docs;
+ return docsList;
}
catch (Exception exception)
{
if (runManager != null)
+ {
await runManager.HandleRetrieverErrorAsync(exception, query);
+ }
+
throw;
}
}
diff --git a/src/libs/LangChain.Core/Schema/ChainValues.cs b/src/libs/LangChain.Core/Schema/ChainValues.cs
index 2267a827..7c831cb9 100644
--- a/src/libs/LangChain.Core/Schema/ChainValues.cs
+++ b/src/libs/LangChain.Core/Schema/ChainValues.cs
@@ -36,4 +36,10 @@ public ChainValues()
}
public Dictionary Value { get; set; }
+
+ public object this[string key]
+ {
+ get => Value[key];
+ set => Value[key] = value;
+ }
}
diff --git a/src/tests/LangChain.Core.UnitTests/Chains/ConversationalRetrieval/ConversationalRetrievalChainTests.cs b/src/tests/LangChain.Core.UnitTests/Chains/ConversationalRetrieval/ConversationalRetrievalChainTests.cs
new file mode 100644
index 00000000..cf0a6541
--- /dev/null
+++ b/src/tests/LangChain.Core.UnitTests/Chains/ConversationalRetrieval/ConversationalRetrievalChainTests.cs
@@ -0,0 +1,144 @@
+using LangChain.Callback;
+using LangChain.Chains.CombineDocuments;
+using LangChain.Chains.ConversationalRetrieval;
+using LangChain.Chains.LLM;
+using LangChain.Docstore;
+using LangChain.Prompts;
+using LangChain.Providers;
+using LangChain.Retrievers;
+using LangChain.Schema;
+using Moq;
+
+namespace LangChain.Core.UnitTests.Chains.ConversationalRetrieval;
+
+[TestFixture]
+public class ConversationalRetrievalChainTests
+{
+ [Test]
+ public async Task Call_Ok()
+ {
+ var combineDocumentsChainInput = new Mock().Object;
+
+ var combineDocsChainMock = new Mock(combineDocumentsChainInput);
+ combineDocsChainMock.Setup(x => x
+ .Run(It.IsAny>(), It.IsAny()))
+ .Returns, ICallbacks>((input, _) => Task.FromResult("Alice"));
+
+ var retrieverMock = new Mock();
+ retrieverMock
+ .Setup(x => x
+ .GetRelevantDocumentsAsync(
+ It.IsAny(),
+ It.IsAny(),
+ It.IsAny(),
+ It.IsAny(),
+ It.IsAny>(),
+ It.IsAny>()))
+ .Returns, Dictionary>((query, _, _, _, _, _) =>
+ {
+ var docs = new List
+ {
+ new Document("first"),
+ new Document("second"),
+ new Document("third")
+ }.AsEnumerable();
+
+ return Task.FromResult(docs);
+ });
+
+ // # This controls how the standalone question is generated.
+ // # Should take `chat_history` and `question` as input variables.
+ var template =
+ "Combine the chat history and follow up question into a standalone question. Chat History: {chat_history}. Follow up question: {question}";
+
+ var prompt = PromptTemplate.FromTemplate(template);
+ var questionGeneratorLlmMock = new Mock();
+ questionGeneratorLlmMock
+ .Setup(v => v.GenerateAsync(It.IsAny(), It.IsAny()))
+ .Returns((request, _) =>
+ {
+ var chatResponse = new ChatResponse(new[] { Message.Ai("Bob's asking what is hist name") }, Usage.Empty);
+ return Task.FromResult(chatResponse);
+ });
+
+ var llmInput = new LlmChainInput(questionGeneratorLlmMock.Object, prompt);
+ var questionGeneratorChain = new LlmChain(llmInput);
+
+ var chainInput = new ConversationalRetrievalChainInput(retrieverMock.Object, combineDocsChainMock.Object, questionGeneratorChain)
+ {
+ ReturnSourceDocuments = true,
+ ReturnGeneratedQuestion = true
+ };
+
+ var chain = new ConversationalRetrievalChain(chainInput);
+
+ var input = new ChainValues
+ {
+ ["question"] = "What is my name?",
+ ["chat_history"] = new List
+ {
+ Message.Human("My name is Alice"),
+ Message.Ai("Hello Alice")
+ }
+ };
+
+ var result = await chain.CallAsync(input);
+
+ result.Should().NotBeNull();
+ result.Value.Should().ContainKey("answer");
+ result.Value["answer"].Should().BeEquivalentTo("Alice");
+
+ result.Value.Should().ContainKey("source_documents");
+ var resultSourceDocuments = result.Value["source_documents"] as List;
+ resultSourceDocuments.Should().NotBeNull();
+ resultSourceDocuments.Should().HaveCount(3);
+ resultSourceDocuments[0].PageContent.Should().BeEquivalentTo("first");
+
+ result.Value.Should().ContainKey("generated_question");
+ result.Value["generated_question"].Should().BeEquivalentTo("Bob's asking what is hist name");
+
+ questionGeneratorLlmMock
+ .Verify(v => v.GenerateAsync(
+ It.Is(request => request.Messages.Count == 1),
+ It.IsAny()));
+ }
+
+ [Test]
+ public async Task ReduceTokensBelowLimit_Ok()
+ {
+ var supportsCountTokensMock = new Mock();
+ supportsCountTokensMock
+ .Setup(v => v.CountTokens(It.IsAny()))
+ .Returns(input => input.Length);
+
+ var chatModelMock = supportsCountTokensMock.As();
+
+ var llmWithCounterMock = new Mock();
+ llmWithCounterMock
+ .SetupGet(v => v.Llm)
+ .Returns(chatModelMock.Object);
+
+ var prompt = PromptTemplate.FromTemplate("{documents}");
+ llmWithCounterMock
+ .SetupGet(v => v.Prompt)
+ .Returns(prompt);
+
+ var combineDocsChainInput = new StuffDocumentsChainInput(llmWithCounterMock.Object);
+ var combineDocsChain = new StuffDocumentsChain(combineDocsChainInput);
+
+ var retriever = new Mock().Object;
+ var questionGeneratorChain = new Mock().Object;
+
+ var chainInput = new ConversationalRetrievalChainInput(retriever, combineDocsChain, questionGeneratorChain)
+ {
+ MaxTokensLimit = 500
+ };
+
+ var chain = new ConversationalRetrievalChain(chainInput);
+
+ var inputDocs = Enumerable.Range(1, 10).Select(_ => new Document(new string('*', 100)));
+ var result = chain.ReduceTokensBelowLimit(inputDocs);
+
+ result.Should().HaveCount(5);
+ }
+}
\ No newline at end of file
diff --git a/src/tests/LangChain.Core.UnitTests/Chains/RetrievalQa/RetrievalQaChainTests.cs b/src/tests/LangChain.Core.UnitTests/Chains/RetrievalQa/RetrievalQaChainTests.cs
index c9a03f63..5de21dc5 100644
--- a/src/tests/LangChain.Core.UnitTests/Chains/RetrievalQa/RetrievalQaChainTests.cs
+++ b/src/tests/LangChain.Core.UnitTests/Chains/RetrievalQa/RetrievalQaChainTests.cs
@@ -40,7 +40,8 @@ public async Task Retrieval_Ok()
x["input_documents"].As>()
.Select(doc => doc.PageContent)
.Intersect(new string[] { "first", "second", "third" })
- .Count() == 3)),
+ .Count() == 3),
+ It.IsAny()),
Times.Once());
}
@@ -76,8 +77,8 @@ private Mock CreateCombineDocumentsChainMock()
var mock = new Mock(new Mock().Object);
mock.Setup(x => x
- .Run(It.IsAny>()))
- .Returns>(input => Task.FromResult("answer"));
+ .Run(It.IsAny>(), It.IsAny()))
+ .Returns, ICallbacks>((input, _) => Task.FromResult("answer"));
return mock;
}