-
-
Notifications
You must be signed in to change notification settings - Fork 92
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: ConversationalRetrievalChain (#85)
Co-authored-by: Evgenii Khoroshev <[email protected]> Closes #13
- Loading branch information
1 parent
55b5db9
commit f30ef1d
Showing
16 changed files
with
519 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
113 changes: 113 additions & 0 deletions
113
src/libs/LangChain.Core/Chains/ConversationalRetrieval/BaseConversationalRetrievalChain.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
using LangChain.Abstractions.Schema; | ||
using LangChain.Base; | ||
using LangChain.Callback; | ||
using LangChain.Common; | ||
using LangChain.Docstore; | ||
using LangChain.Providers; | ||
using LangChain.Schema; | ||
|
||
namespace LangChain.Chains.ConversationalRetrieval; | ||
|
||
/// <summary> | ||
/// Chain for chatting with an index. | ||
/// </summary> | ||
public abstract class BaseConversationalRetrievalChain(BaseConversationalRetrievalChainInput fields) : BaseChain(fields) | ||
{ | ||
/// <summary> Chain input fields </summary> | ||
private readonly BaseConversationalRetrievalChainInput _fields = fields; | ||
|
||
public override string[] InputKeys => new[] { "question", "chat_history" }; | ||
|
||
public override string[] OutputKeys | ||
{ | ||
get | ||
{ | ||
var outputKeys = new List<string> { _fields.OutputKey }; | ||
if (_fields.ReturnSourceDocuments) | ||
{ | ||
outputKeys.Add("source_documents"); | ||
} | ||
|
||
if (_fields.ReturnGeneratedQuestion) | ||
{ | ||
outputKeys.Add("generated_question"); | ||
} | ||
|
||
return outputKeys.ToArray(); | ||
} | ||
} | ||
|
||
protected override async Task<IChainValues> CallAsync(IChainValues values, CallbackManagerForChainRun? runManager) | ||
{ | ||
runManager ??= BaseRunManager.GetNoopManager<CallbackManagerForChainRun>(); | ||
|
||
var question = values.Value["question"].ToString(); | ||
|
||
var getChatHistory = _fields.GetChatHistory; | ||
var chatHistoryStr = getChatHistory(values.Value["chat_history"] as List<Message>); | ||
|
||
string? newQuestion; | ||
if (chatHistoryStr != null) | ||
{ | ||
var callbacks = runManager.GetChild(); | ||
newQuestion = await _fields.QuestionGenerator.Run( | ||
new Dictionary<string, object> | ||
{ | ||
["question"] = question, | ||
["chat_history"] = chatHistoryStr | ||
}, | ||
callbacks: new ManagerCallbacks(callbacks)); | ||
} | ||
else | ||
{ | ||
newQuestion = question; | ||
} | ||
|
||
var docs = await GetDocsAsync(newQuestion, values.Value); | ||
var newInputs = new Dictionary<string, object> | ||
{ | ||
["chat_history"] = chatHistoryStr, | ||
["input_documents"] = docs | ||
}; | ||
|
||
if (_fields.RephraseQuestion) | ||
{ | ||
newInputs["question"] = newQuestion; | ||
} | ||
|
||
newInputs.TryAddKeyValues(values.Value); | ||
|
||
var answer = await _fields.CombineDocsChain.Run( | ||
input: newInputs, | ||
callbacks: new ManagerCallbacks(runManager.GetChild())); | ||
|
||
var output = new Dictionary<string, object> | ||
{ | ||
[_fields.OutputKey] = answer | ||
}; | ||
|
||
if (_fields.ReturnSourceDocuments) | ||
{ | ||
output["source_documents"] = docs; | ||
} | ||
|
||
if (_fields.ReturnGeneratedQuestion) | ||
{ | ||
output["generated_question"] = newQuestion; | ||
} | ||
|
||
return new ChainValues(output); | ||
} | ||
|
||
/// <summary> | ||
/// Get docs. | ||
/// </summary> | ||
/// <param name="question"></param> | ||
/// <param name="inputs"></param> | ||
/// <param name="runManager"></param> | ||
/// <returns></returns> | ||
protected abstract Task<List<Document>> GetDocsAsync( | ||
string question, | ||
Dictionary<string, object> inputs, | ||
CallbackManagerForChainRun? runManager = null); | ||
} |
63 changes: 63 additions & 0 deletions
63
...bs/LangChain.Core/Chains/ConversationalRetrieval/BaseConversationalRetrievalChainInput.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
using LangChain.Base; | ||
using LangChain.Chains.CombineDocuments; | ||
using LangChain.Chains.LLM; | ||
using LangChain.Providers; | ||
|
||
namespace LangChain.Chains.ConversationalRetrieval; | ||
|
||
public class BaseConversationalRetrievalChainInput( | ||
BaseCombineDocumentsChain combineDocsChain, | ||
ILlmChain questionGenerator) | ||
: ChainInputs | ||
{ | ||
/// <summary> | ||
/// The chain used to combine any retrieved documents. | ||
/// </summary> | ||
public BaseCombineDocumentsChain CombineDocsChain { get; } = combineDocsChain; | ||
|
||
/// <summary> | ||
/// The chain used to generate a new question for the sake of retrieval. | ||
/// | ||
/// This chain will take in the current question (with variable `question`) | ||
/// and any chat history (with variable `chat_history`) and will produce | ||
/// a new standalone question to be used later on. | ||
/// </summary> | ||
public ILlmChain QuestionGenerator { get; } = questionGenerator; | ||
|
||
/// <summary> | ||
/// The output key to return the final answer of this chain in. | ||
/// </summary> | ||
public string OutputKey { get; set; } = "answer"; | ||
|
||
/// <summary> | ||
/// Whether or not to pass the new generated question to the combine_docs_chain. | ||
/// | ||
/// If True, will pass the new generated question along. | ||
/// If False, will only use the new generated question for retrieval and pass the | ||
/// original question along to the <see cref="CombineDocsChain"/>. | ||
/// </summary> | ||
public bool RephraseQuestion { get; set; } = true; | ||
|
||
/// <summary> | ||
/// Return the retrieved source documents as part of the final result. | ||
/// </summary> | ||
public bool ReturnSourceDocuments { get; set; } | ||
|
||
/// <summary> | ||
/// Return the generated question as part of the final result. | ||
/// </summary> | ||
public bool ReturnGeneratedQuestion { get; set; } | ||
|
||
/// <summary> | ||
/// An optional function to get a string of the chat history. | ||
/// If None is provided, will use a default. | ||
/// </summary> | ||
public Func<IReadOnlyList<Message>, string?> GetChatHistory { get; set; } = | ||
ChatTurnTypeHelper.GetChatHistory; | ||
|
||
/// <summary> | ||
/// If specified, the chain will return a fixed response if no docs | ||
/// are found for the question. | ||
/// </summary> | ||
public string? ResponseIfNoDocsFound { get; set; } | ||
} |
26 changes: 26 additions & 0 deletions
26
src/libs/LangChain.Core/Chains/ConversationalRetrieval/ChatTurnTypeHelper.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
using System.Text; | ||
using LangChain.Providers; | ||
|
||
namespace LangChain.Chains.ConversationalRetrieval; | ||
|
||
public static class ChatTurnTypeHelper | ||
{ | ||
public static string GetChatHistory(IReadOnlyList<Message> chatHistory) | ||
{ | ||
var buffer = new StringBuilder(); | ||
|
||
foreach (var message in chatHistory) | ||
{ | ||
var rolePrefix = message.Role switch | ||
{ | ||
MessageRole.Human => "Human: ", | ||
MessageRole.Ai => "Assistant: ", | ||
_ => $"{message.Role}: " | ||
}; | ||
|
||
buffer.AppendLine($"{rolePrefix}{message.Content}"); | ||
Check warning on line 21 in src/libs/LangChain.Core/Chains/ConversationalRetrieval/ChatTurnTypeHelper.cs GitHub Actions / Build, test and publish / Build, test and publish
Check warning on line 21 in src/libs/LangChain.Core/Chains/ConversationalRetrieval/ChatTurnTypeHelper.cs GitHub Actions / Build, test and publish / Build, test and publish
|
||
} | ||
|
||
return buffer.ToString(); | ||
} | ||
} |
68 changes: 68 additions & 0 deletions
68
src/libs/LangChain.Core/Chains/ConversationalRetrieval/ConversationalRetrievalChain.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
using LangChain.Callback; | ||
using LangChain.Chains.CombineDocuments; | ||
using LangChain.Docstore; | ||
using LangChain.Providers; | ||
|
||
namespace LangChain.Chains.ConversationalRetrieval; | ||
|
||
/// <summary> | ||
/// Chain for having a conversation based on retrieved documents. | ||
/// | ||
/// This chain takes in chat history (a list of messages) and new questions, | ||
/// and then returns an answer to that question. | ||
/// The algorithm for this chain consists of three parts: | ||
/// | ||
/// 1. Use the chat history and the new question to create a "standalone question". | ||
/// This is done so that this question can be passed into the retrieval step to fetch | ||
/// relevant documents. If only the new question was passed in, then relevant context | ||
/// may be lacking. If the whole conversation was passed into retrieval, there may | ||
/// be unnecessary information there that would distract from retrieval. | ||
/// | ||
/// 2. This new question is passed to the retriever and relevant documents are | ||
/// returned. | ||
/// | ||
/// 3. The retrieved documents are passed to an LLM along with either the new question | ||
/// (default behavior) or the original question and chat history to generate a final | ||
/// response. | ||
/// </summary> | ||
public class ConversationalRetrievalChain(ConversationalRetrievalChainInput fields) | ||
: BaseConversationalRetrievalChain(fields) | ||
{ | ||
private readonly ConversationalRetrievalChainInput _fields = fields; | ||
|
||
public override string ChainType() => "conversational_retrieval"; | ||
|
||
protected override async Task<List<Document>> GetDocsAsync( | ||
string question, | ||
Dictionary<string, object> inputs, | ||
CallbackManagerForChainRun? runManager = null) | ||
{ | ||
var docs = await _fields.Retriever.GetRelevantDocumentsAsync( | ||
question, | ||
callbacks: runManager?.ToCallbacks()); | ||
|
||
return ReduceTokensBelowLimit(docs); | ||
} | ||
|
||
public List<Document> ReduceTokensBelowLimit(IEnumerable<Document> docs) | ||
{ | ||
var docsList = docs.ToList(); | ||
var numDocs = docsList.Count; | ||
|
||
if (_fields.MaxTokensLimit != null && | ||
_fields.CombineDocsChain is StuffDocumentsChain stuffDocumentsChain && | ||
stuffDocumentsChain.LlmChain.Llm is ISupportsCountTokens counter) | ||
{ | ||
var tokens = docsList.Select(doc => counter.CountTokens(doc.PageContent)).ToArray(); | ||
var tokenCount = tokens.Sum(); | ||
|
||
while (tokenCount > _fields.MaxTokensLimit) | ||
{ | ||
numDocs -= 1; | ||
tokenCount -= tokens[numDocs]; | ||
} | ||
} | ||
|
||
return docsList.Take(numDocs).ToList(); | ||
} | ||
} |
Oops, something went wrong.