Skip to content

Commit

Permalink
feat: ConversationalRetrievalChain (#85)
Browse files Browse the repository at this point in the history
Co-authored-by: Evgenii Khoroshev <[email protected]>

Closes #13
  • Loading branch information
khoroshevj authored Dec 1, 2023
1 parent 55b5db9 commit f30ef1d
Show file tree
Hide file tree
Showing 16 changed files with 519 additions and 42 deletions.
12 changes: 9 additions & 3 deletions src/libs/LangChain.Core/Base/BaseChain.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,13 @@ public abstract class BaseChain(IChainInputs fields) : IChain
/// Run the chain using a simple input/output.
/// </summary>
/// <param name="input">The dict input to use to execute the chain.</param>
/// <param name="callbacks">
/// Callbacks to use for this chain run. These will be called in
/// addition to callbacks passed to the chain during construction, but only
/// these runtime callbacks will propagate to calls to other objects.
/// </param>
/// <returns>A text value containing the result of the chain.</returns>
public virtual async Task<string> Run(Dictionary<string, object> input)
public virtual async Task<string> Run(Dictionary<string, object> input, ICallbacks? callbacks = null)
{
var keysLengthDifferent = InputKeys.Length != input.Count;

Expand All @@ -73,7 +78,7 @@ public virtual async Task<string> Run(Dictionary<string, object> input)
throw new ArgumentException($"Chain {ChainType()} expects {InputKeys.Length} but, received {input.Count}");
}

var returnValues = await CallAsync(new ChainValues(input));
var returnValues = await CallAsync(new ChainValues(input), callbacks);

var returnValue = returnValues.Value.FirstOrDefault(kv => kv.Key == OutputKeys[0]).Value;

Expand All @@ -88,7 +93,8 @@ public virtual async Task<string> Run(Dictionary<string, object> input)
/// <param name="tags"></param>
/// <param name="metadata"></param>
/// <returns></returns>
public async Task<IChainValues> CallAsync(IChainValues values,
public async Task<IChainValues> CallAsync(
IChainValues values,
ICallbacks? callbacks = null,
IReadOnlyList<string>? tags = null,
IReadOnlyDictionary<string, object>? metadata = null)
Expand Down
5 changes: 5 additions & 0 deletions src/libs/LangChain.Core/Callback/ICallbacks.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ namespace LangChain.Callback;

public interface ICallbacks;

public static class ManagerCallbacksExtensions
{
public static ManagerCallbacks ToCallbacks(this ParentRunManager source) => new ManagerCallbacks(source.GetChild());
}

public record ManagerCallbacks(CallbackManager Value) : ICallbacks;

public record HandlersCallbacks(List<BaseCallbackHandler> Value) : ICallbacks;
5 changes: 4 additions & 1 deletion src/libs/LangChain.Core/Chains/Base/IChain.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ public interface IChain
{
string[] InputKeys { get; }
string[] OutputKeys { get; }


Task<string?> Run(string input);
Task<string> Run(Dictionary<string, object> input, ICallbacks? callbacks = null);

Task<IChainValues> CallAsync(IChainValues values,
ICallbacks? callbacks = null,
IReadOnlyList<string>? tags = null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@ namespace LangChain.Chains.CombineDocuments;
/// </summary>
public class StuffDocumentsChain : BaseCombineDocumentsChain
{
private readonly ILlmChain _llmChain;
public readonly ILlmChain LlmChain;
private readonly BasePromptTemplate _documentPrompt;
private readonly string _documentVariableName;
private readonly string _documentSeparator = "\n\n";
private readonly string _documentSeparator;

public StuffDocumentsChain(StuffDocumentsChainInput input) : base(input)
{
_llmChain = input.LlmChain;
LlmChain = input.LlmChain;
_documentPrompt = input.DocumentPrompt;
_documentSeparator = input.DocumentSeparator;

var llmChainVariables = _llmChain.Prompt.InputVariables;
var llmChainVariables = LlmChain.Prompt.InputVariables;

if (input.DocumentVariableName == null)
{
Expand All @@ -50,7 +50,7 @@ public StuffDocumentsChain(StuffDocumentsChainInput input) : base(input)
}

public override string[] InputKeys =>
base.InputKeys.Concat(_llmChain.InputKeys.Where(k => k != _documentVariableName)).ToArray();
base.InputKeys.Concat(LlmChain.InputKeys.Where(k => k != _documentVariableName)).ToArray();

public override string ChainType() => "stuff_documents_chain";

Expand All @@ -59,17 +59,17 @@ public StuffDocumentsChain(StuffDocumentsChainInput input) : base(input)
IReadOnlyDictionary<string, object> otherKeys)
{
var inputs = await GetInputs(docs, otherKeys);
var predict = await _llmChain.Predict(new ChainValues(inputs.Value));
var predict = await LlmChain.Predict(new ChainValues(inputs.Value));

return (predict.ToString() ?? string.Empty, new Dictionary<string, object>());
}

public override async Task<int?> PromptLength(IReadOnlyList<Document> docs, IReadOnlyDictionary<string, object> otherKeys)
{
if (_llmChain.Llm is ISupportsCountTokens supportsCountTokens)
if (LlmChain.Llm is ISupportsCountTokens supportsCountTokens)
{
var inputs = await GetInputs(docs, otherKeys);
var prompt = await _llmChain.Prompt.FormatPromptValue(inputs);
var prompt = await LlmChain.Prompt.FormatPromptValue(inputs);

return supportsCountTokens.CountTokens(prompt.ToString());
}
Expand All @@ -84,7 +84,7 @@ private async Task<InputValues> GetInputs(IReadOnlyList<Document> docs, IReadOnl
var inputs = new Dictionary<string, object>();
foreach (var kv in otherKeys)
{
if (_llmChain.Prompt.InputVariables.Contains(kv.Key))
if (LlmChain.Prompt.InputVariables.Contains(kv.Key))
{
inputs[kv.Key] = kv.Value;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
using LangChain.Abstractions.Schema;
using LangChain.Base;
using LangChain.Callback;
using LangChain.Common;
using LangChain.Docstore;
using LangChain.Providers;
using LangChain.Schema;

namespace LangChain.Chains.ConversationalRetrieval;

/// <summary>
/// Chain for chatting with an index.
/// </summary>
public abstract class BaseConversationalRetrievalChain(BaseConversationalRetrievalChainInput fields) : BaseChain(fields)
{
/// <summary> Chain input fields </summary>
private readonly BaseConversationalRetrievalChainInput _fields = fields;

public override string[] InputKeys => new[] { "question", "chat_history" };

public override string[] OutputKeys
{
get
{
var outputKeys = new List<string> { _fields.OutputKey };
if (_fields.ReturnSourceDocuments)
{
outputKeys.Add("source_documents");
}

if (_fields.ReturnGeneratedQuestion)
{
outputKeys.Add("generated_question");
}

return outputKeys.ToArray();
}
}

protected override async Task<IChainValues> CallAsync(IChainValues values, CallbackManagerForChainRun? runManager)
{
runManager ??= BaseRunManager.GetNoopManager<CallbackManagerForChainRun>();

var question = values.Value["question"].ToString();

var getChatHistory = _fields.GetChatHistory;
var chatHistoryStr = getChatHistory(values.Value["chat_history"] as List<Message>);

string? newQuestion;
if (chatHistoryStr != null)
{
var callbacks = runManager.GetChild();
newQuestion = await _fields.QuestionGenerator.Run(
new Dictionary<string, object>
{
["question"] = question,
["chat_history"] = chatHistoryStr
},
callbacks: new ManagerCallbacks(callbacks));
}
else
{
newQuestion = question;
}

var docs = await GetDocsAsync(newQuestion, values.Value);
var newInputs = new Dictionary<string, object>
{
["chat_history"] = chatHistoryStr,
["input_documents"] = docs
};

if (_fields.RephraseQuestion)
{
newInputs["question"] = newQuestion;
}

newInputs.TryAddKeyValues(values.Value);

var answer = await _fields.CombineDocsChain.Run(
input: newInputs,
callbacks: new ManagerCallbacks(runManager.GetChild()));

var output = new Dictionary<string, object>
{
[_fields.OutputKey] = answer
};

if (_fields.ReturnSourceDocuments)
{
output["source_documents"] = docs;
}

if (_fields.ReturnGeneratedQuestion)
{
output["generated_question"] = newQuestion;
}

return new ChainValues(output);
}

/// <summary>
/// Get docs.
/// </summary>
/// <param name="question"></param>
/// <param name="inputs"></param>
/// <param name="runManager"></param>
/// <returns></returns>
protected abstract Task<List<Document>> GetDocsAsync(
string question,
Dictionary<string, object> inputs,
CallbackManagerForChainRun? runManager = null);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
using LangChain.Base;
using LangChain.Chains.CombineDocuments;
using LangChain.Chains.LLM;
using LangChain.Providers;

namespace LangChain.Chains.ConversationalRetrieval;

public class BaseConversationalRetrievalChainInput(
BaseCombineDocumentsChain combineDocsChain,
ILlmChain questionGenerator)
: ChainInputs
{
/// <summary>
/// The chain used to combine any retrieved documents.
/// </summary>
public BaseCombineDocumentsChain CombineDocsChain { get; } = combineDocsChain;

/// <summary>
/// The chain used to generate a new question for the sake of retrieval.
///
/// This chain will take in the current question (with variable `question`)
/// and any chat history (with variable `chat_history`) and will produce
/// a new standalone question to be used later on.
/// </summary>
public ILlmChain QuestionGenerator { get; } = questionGenerator;

/// <summary>
/// The output key to return the final answer of this chain in.
/// </summary>
public string OutputKey { get; set; } = "answer";

/// <summary>
/// Whether or not to pass the new generated question to the combine_docs_chain.
///
/// If True, will pass the new generated question along.
/// If False, will only use the new generated question for retrieval and pass the
/// original question along to the <see cref="CombineDocsChain"/>.
/// </summary>
public bool RephraseQuestion { get; set; } = true;

/// <summary>
/// Return the retrieved source documents as part of the final result.
/// </summary>
public bool ReturnSourceDocuments { get; set; }

/// <summary>
/// Return the generated question as part of the final result.
/// </summary>
public bool ReturnGeneratedQuestion { get; set; }

/// <summary>
/// An optional function to get a string of the chat history.
/// If None is provided, will use a default.
/// </summary>
public Func<IReadOnlyList<Message>, string?> GetChatHistory { get; set; } =
ChatTurnTypeHelper.GetChatHistory;

/// <summary>
/// If specified, the chain will return a fixed response if no docs
/// are found for the question.
/// </summary>
public string? ResponseIfNoDocsFound { get; set; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using System.Text;
using LangChain.Providers;

namespace LangChain.Chains.ConversationalRetrieval;

public static class ChatTurnTypeHelper
{
public static string GetChatHistory(IReadOnlyList<Message> chatHistory)
{
var buffer = new StringBuilder();

foreach (var message in chatHistory)
{
var rolePrefix = message.Role switch
{
MessageRole.Human => "Human: ",
MessageRole.Ai => "Assistant: ",
_ => $"{message.Role}: "
};

buffer.AppendLine($"{rolePrefix}{message.Content}");

Check warning on line 21 in src/libs/LangChain.Core/Chains/ConversationalRetrieval/ChatTurnTypeHelper.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

The behavior of 'StringBuilder.AppendLine(ref StringBuilder.AppendInterpolatedStringHandler)' could vary based on the current user's locale settings. Replace this call in 'ChatTurnTypeHelper.GetChatHistory(IReadOnlyList<Message>)' with a call to 'StringBuilder.AppendLine(IFormatProvider, ref StringBuilder.AppendInterpolatedStringHandler)'. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1305)

Check warning on line 21 in src/libs/LangChain.Core/Chains/ConversationalRetrieval/ChatTurnTypeHelper.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

The behavior of 'StringBuilder.AppendLine(ref StringBuilder.AppendInterpolatedStringHandler)' could vary based on the current user's locale settings. Replace this call in 'ChatTurnTypeHelper.GetChatHistory(IReadOnlyList<Message>)' with a call to 'StringBuilder.AppendLine(IFormatProvider, ref StringBuilder.AppendInterpolatedStringHandler)'. (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1305)
}

return buffer.ToString();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
using LangChain.Callback;
using LangChain.Chains.CombineDocuments;
using LangChain.Docstore;
using LangChain.Providers;

namespace LangChain.Chains.ConversationalRetrieval;

/// <summary>
/// Chain for having a conversation based on retrieved documents.
///
/// This chain takes in chat history (a list of messages) and new questions,
/// and then returns an answer to that question.
/// The algorithm for this chain consists of three parts:
///
/// 1. Use the chat history and the new question to create a "standalone question".
/// This is done so that this question can be passed into the retrieval step to fetch
/// relevant documents. If only the new question was passed in, then relevant context
/// may be lacking. If the whole conversation was passed into retrieval, there may
/// be unnecessary information there that would distract from retrieval.
///
/// 2. This new question is passed to the retriever and relevant documents are
/// returned.
///
/// 3. The retrieved documents are passed to an LLM along with either the new question
/// (default behavior) or the original question and chat history to generate a final
/// response.
/// </summary>
public class ConversationalRetrievalChain(ConversationalRetrievalChainInput fields)
: BaseConversationalRetrievalChain(fields)
{
private readonly ConversationalRetrievalChainInput _fields = fields;

public override string ChainType() => "conversational_retrieval";

protected override async Task<List<Document>> GetDocsAsync(
string question,
Dictionary<string, object> inputs,
CallbackManagerForChainRun? runManager = null)
{
var docs = await _fields.Retriever.GetRelevantDocumentsAsync(
question,
callbacks: runManager?.ToCallbacks());

return ReduceTokensBelowLimit(docs);
}

public List<Document> ReduceTokensBelowLimit(IEnumerable<Document> docs)
{
var docsList = docs.ToList();
var numDocs = docsList.Count;

if (_fields.MaxTokensLimit != null &&
_fields.CombineDocsChain is StuffDocumentsChain stuffDocumentsChain &&
stuffDocumentsChain.LlmChain.Llm is ISupportsCountTokens counter)
{
var tokens = docsList.Select(doc => counter.CountTokens(doc.PageContent)).ToArray();
var tokenCount = tokens.Sum();

while (tokenCount > _fields.MaxTokensLimit)
{
numDocs -= 1;
tokenCount -= tokens[numDocs];
}
}

return docsList.Take(numDocs).ToList();
}
}
Loading

0 comments on commit f30ef1d

Please sign in to comment.