diff --git a/src/libs/LangChain.Core/Chains/Chain.cs b/src/libs/LangChain.Core/Chains/Chain.cs index 921314f6..64aa00c1 100644 --- a/src/libs/LangChain.Core/Chains/Chain.cs +++ b/src/libs/LangChain.Core/Chains/Chain.cs @@ -1,6 +1,8 @@ using LangChain.Abstractions.Chains.Base; using LangChain.Chains.HelperChains; using LangChain.Chains.StackableChains; +using LangChain.Chains.StackableChains.Agents; +using LangChain.Chains.StackableChains.ReAct; using LangChain.Indexes; using LangChain.Memory; using LangChain.Providers; @@ -130,4 +132,23 @@ public static STTChain STT(ISpeechToTextModel model, { return new STTChain(model, settings, inputKey, outputKey); } + + public static ReActAgentExecutorChain ReActAgentExecutor(IChatModel model, string reActPrompt = null, + int maxActions = 5, string inputKey = "input", + string outputKey = "final_answer") + { + return new ReActAgentExecutorChain(model, reActPrompt, maxActions, inputKey, outputKey); + } + + public static ReActParserChain ReActParser( + string inputKey = "text", string outputKey = "answer") + { + return new ReActParserChain(inputKey, outputKey); + } + + public static GroupChat GroupChat( + IList agents, string? stopPhrase = null, int messagesLimit = 10, string inputKey = "input", string outputKey = "output") + { + return new GroupChat(agents, stopPhrase, messagesLimit, inputKey, outputKey); + } } \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/StackableChains/Agents/AgentExecutorChain.cs b/src/libs/LangChain.Core/Chains/StackableChains/Agents/AgentExecutorChain.cs new file mode 100644 index 00000000..8ef13afa --- /dev/null +++ b/src/libs/LangChain.Core/Chains/StackableChains/Agents/AgentExecutorChain.cs @@ -0,0 +1,50 @@ +using LangChain.Abstractions.Chains.Base; +using LangChain.Abstractions.Schema; +using LangChain.Chains.HelperChains; +using LangChain.Memory; +using LangChain.Providers; +using LangChain.Schema; + +namespace LangChain.Chains.StackableChains.Agents; + +public class AgentExecutorChain: BaseStackableChain +{ + public string HistoryKey { get; } + private readonly BaseStackableChain _originalChain; + + private BaseStackableChain _chainWithHistory; + + public string Name { get; private set; } + + /// + /// Messages of this agent will not be added to the history + /// + public bool IsObserver { get; set; } = false; + + public AgentExecutorChain(BaseStackableChain originalChain, string name, string historyKey="history", + string outputKey = "final_answer") + { + Name = name; + HistoryKey = historyKey; + _originalChain = originalChain; + + InputKeys = new[] { historyKey}; + OutputKeys = new[] { outputKey }; + + SetHistory(""); + } + + public void SetHistory(string history) + { + + _chainWithHistory = + Chain.Set(history, HistoryKey) + |_originalChain; + } + + protected override async Task InternalCall(IChainValues values) + { + var res=await _chainWithHistory.CallAsync(values); + return res; + } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/StackableChains/Agents/GroupChat.cs b/src/libs/LangChain.Core/Chains/StackableChains/Agents/GroupChat.cs new file mode 100644 index 00000000..4a592b4c --- /dev/null +++ b/src/libs/LangChain.Core/Chains/StackableChains/Agents/GroupChat.cs @@ -0,0 +1,93 @@ +using LangChain.Abstractions.Schema; +using LangChain.Chains.HelperChains; +using LangChain.Chains.HelperChains.Exceptions; +using LangChain.Memory; +using LangChain.Providers; + +namespace LangChain.Chains.StackableChains.Agents; + +public class GroupChat:BaseStackableChain +{ + private readonly IList _agents; + + private readonly string _stopPhrase; + private readonly int _messagesLimit; + private readonly string _inputKey; + private readonly string _outputKey; + + + int _currentAgentId=0; + private readonly ConversationBufferMemory _conversationBufferMemory; + + + public bool ThrowOnLimit { get; set; } = false; + public GroupChat(IList agents, string? stopPhrase=null, int messagesLimit=10, string inputKey="input", string outputKey="output") + { + _agents = agents; + + _stopPhrase = stopPhrase; + _messagesLimit = messagesLimit; + _inputKey = inputKey; + _outputKey = outputKey; + _conversationBufferMemory = new ConversationBufferMemory(new ChatMessageHistory()) { AiPrefix = "", HumanPrefix = "", SystemPrefix = "", SaveHumanMessages = false }; + InputKeys = new[] { inputKey }; + OutputKeys = new[] { outputKey }; + + } + + public IReadOnlyList GetHistory() + { + return _conversationBufferMemory.ChatHistory.Messages; + } + + + protected override async Task InternalCall(IChainValues values) + { + + await _conversationBufferMemory.Clear().ConfigureAwait(false); + foreach (var agent in _agents) + { + agent.SetHistory(""); + } + var firstAgent = _agents[0]; + var firstAgentMessage = (string)values.Value[_inputKey]; + await _conversationBufferMemory.ChatHistory.AddMessage(new Message($"{firstAgent.Name}: {firstAgentMessage}", + MessageRole.System)).ConfigureAwait(false); + int messagesCount = 1; + while (messagesCount<_messagesLimit) + { + var agent = GetNextAgent(); + agent.SetHistory(_conversationBufferMemory.BufferAsString+"\n"+$"{agent.Name}:"); + var res = await agent.CallAsync(values).ConfigureAwait(false); + var message = (string)res.Value[agent.OutputKeys[0]]; + if (message.Contains(_stopPhrase)) + { + break; + } + + if (!agent.IsObserver) + { + await _conversationBufferMemory.ChatHistory.AddMessage(new Message($"{agent.Name}: {message}", + MessageRole.System)).ConfigureAwait(false); + } + } + + var result = _conversationBufferMemory.ChatHistory.Messages.Last(); + messagesCount = _conversationBufferMemory.ChatHistory.Messages.Count; + if (ThrowOnLimit && messagesCount >= _messagesLimit) + { + throw new InvalidOperationException($"Message limit reached:{_messagesLimit}"); + } + values.Value.Add(_outputKey, result); + return values; + + } + + AgentExecutorChain GetNextAgent() + { + _currentAgentId++; + if (_currentAgentId >= _agents.Count) + _currentAgentId = 0; + return _agents[_currentAgentId]; + } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/StackableChains/Agents/PromptedAgent.cs b/src/libs/LangChain.Core/Chains/StackableChains/Agents/PromptedAgent.cs new file mode 100644 index 00000000..7a8769d6 --- /dev/null +++ b/src/libs/LangChain.Core/Chains/StackableChains/Agents/PromptedAgent.cs @@ -0,0 +1,24 @@ +using LangChain.Chains.HelperChains; +using LangChain.Providers; + +namespace LangChain.Chains.StackableChains.Agents; + +public class PromptedAgent: AgentExecutorChain +{ + public const string Template = + @"{system} +{history}"; + + private static BaseStackableChain MakeChain(string name, string system, IChatModel model, string outputKey) + { + return Chain.Set(system, "system") + | Chain.Template(Template) + | Chain.LLM(model,outputKey: outputKey); + } + + + public PromptedAgent(string name, string prompt, IChatModel model, string outputKey = "final_answer") : base(MakeChain(name,prompt,model, outputKey),name, "history", outputKey) + { + + } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/StackableChains/Agents/ReActAgentExecutorChain.cs b/src/libs/LangChain.Core/Chains/StackableChains/Agents/ReActAgentExecutorChain.cs new file mode 100644 index 00000000..3ed4e1c0 --- /dev/null +++ b/src/libs/LangChain.Core/Chains/StackableChains/Agents/ReActAgentExecutorChain.cs @@ -0,0 +1,134 @@ +using LangChain.Abstractions.Chains.Base; +using LangChain.Abstractions.Schema; +using LangChain.Chains.HelperChains; +using LangChain.Chains.StackableChains.ReAct; +using LangChain.Memory; +using LangChain.Providers; +using LangChain.Schema; +using System.Reflection; +using static LangChain.Chains.Chain; + +namespace LangChain.Chains.StackableChains.Agents; + +public class ReActAgentExecutorChain : BaseStackableChain +{ + public const string DefaultPrompt = + @"Answer the following questions as best you can. You have access to the following tools: + +{tools} + +Use the following format: + +Question: the input question you must answer +Thought: you should always think about what to do +Action: the action to take, should be one of [{tool_names}] +Action Input: the input to the action +Observation: the result of the action +(this Thought/Action/Action Input/Observation can repeat multiple times) +Thought: I now know the final answer +Final Answer: the final answer to the original input question +Always add [END] after final answer + +Begin! + +Question: {input} +Thought:{history}"; + + private IChain? _chain = null; + private bool _useCache; + Dictionary _tools = new(); + private readonly IChatModel _model; + private readonly string _reActPrompt; + private readonly int _maxActions; + private readonly ConversationBufferMemory _conversationBufferMemory; + + + public ReActAgentExecutorChain(IChatModel model, string reActPrompt = null, int maxActions = 5, string inputKey = "answer", + string outputKey = "final_answer") + { + reActPrompt ??= DefaultPrompt; + _model = model; + _reActPrompt = reActPrompt; + _maxActions = maxActions; + + InputKeys = new[] { inputKey }; + OutputKeys = new[] { outputKey }; + + _conversationBufferMemory = new ConversationBufferMemory(new ChatMessageHistory()) { AiPrefix = "", HumanPrefix = "", SystemPrefix = "", SaveHumanMessages = false }; + + } + + private string? _userInput = null; + private const string ReActAnswer = "answer"; + private void InitializeChain() + { + string tool_names = string.Join(",", _tools.Select(x => x.Key)); + string tools = string.Join("\n", _tools.Select(x => $"{x.Value.Name}, {x.Value.Description}")); + + var chain = + Set(() => _userInput, "input") + | Set(tools, "tools") + | Set(tool_names, "tool_names") + | Set(() => _conversationBufferMemory.BufferAsString, "history") + | Template(_reActPrompt) + | Chain.LLM(_model).UseCache(_useCache) + | UpdateMemory(_conversationBufferMemory, requestKey: "input", responseKey: "text") + | ReActParser(inputKey: "text", outputKey: ReActAnswer); + + _chain = chain; + } + + protected override async Task InternalCall(IChainValues values) + { + + var input = (string)values.Value[InputKeys[0]]; + var values_chain = new ChainValues(); + + _userInput = input; + + + if (_chain == null) + { + InitializeChain(); + } + + for (int i = 0; i < _maxActions; i++) + { + var res = await _chain.CallAsync(values_chain); + if (res.Value[ReActAnswer] is AgentAction) + { + var action = (AgentAction)res.Value[ReActAnswer]; + var tool = _tools[action.Action]; + var tool_res = tool.ToolCall(action.ActionInput); + await _conversationBufferMemory.ChatHistory.AddMessage(new Message("Observation: " + tool_res, MessageRole.System)) + .ConfigureAwait(false); + await _conversationBufferMemory.ChatHistory.AddMessage(new Message("Thought:", MessageRole.System)) + .ConfigureAwait(false); + continue; + } + else if (res.Value[ReActAnswer] is AgentFinish) + { + var finish = (AgentFinish)res.Value[ReActAnswer]; + values.Value.Add(OutputKeys[0], finish.Output); + return values; + } + } + + + + return values; + } + + public ReActAgentExecutorChain UseCache(bool enabled = true) + { + _useCache = enabled; + return this; + } + + + public ReActAgentExecutorChain UseTool(ReActAgentTool tool) + { + _tools.Add(tool.Name, tool); + return this; + } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/StackableChains/BaseStackableChain.cs b/src/libs/LangChain.Core/Chains/StackableChains/BaseStackableChain.cs index 968f337b..483970e7 100644 --- a/src/libs/LangChain.Core/Chains/StackableChains/BaseStackableChain.cs +++ b/src/libs/LangChain.Core/Chains/StackableChains/BaseStackableChain.cs @@ -140,6 +140,12 @@ public async Task Run() return res.Value[resultKey].ToString(); } + public async Task Run(string resultKey) + { + var res = await CallAsync(new ChainValues()).ConfigureAwait(false); + return (T)res.Value[resultKey]; + } + /// /// /// diff --git a/src/libs/LangChain.Core/Chains/StackableChains/ReAct/ReActAgentTool.cs b/src/libs/LangChain.Core/Chains/StackableChains/ReAct/ReActAgentTool.cs new file mode 100644 index 00000000..971f1163 --- /dev/null +++ b/src/libs/LangChain.Core/Chains/StackableChains/ReAct/ReActAgentTool.cs @@ -0,0 +1,17 @@ +namespace LangChain.Chains.StackableChains.ReAct; + +public class ReActAgentTool +{ + public ReActAgentTool(string name, string description, Func func) + { + Name = name; + Description = description; + ToolCall = func; + } + + public string Name { get; set; } + public string Description { get; set; } + + public Func ToolCall { get; set; } + +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/StackableChains/ReAct/ReActParserChain.cs b/src/libs/LangChain.Core/Chains/StackableChains/ReAct/ReActParserChain.cs new file mode 100644 index 00000000..5a8f55ed --- /dev/null +++ b/src/libs/LangChain.Core/Chains/StackableChains/ReAct/ReActParserChain.cs @@ -0,0 +1,91 @@ +using System.Text.RegularExpressions; +using LangChain.Abstractions.Schema; +using LangChain.Chains.HelperChains; +using LangChain.Schema; + +namespace LangChain.Chains.StackableChains.ReAct; + +public class ReActParserChain : BaseStackableChain +{ + + + + public ReActParserChain(string inputKey="text",string outputText="answer") + { + InputKeys = new[] { inputKey }; + OutputKeys = new[] { outputText }; + } + + private const string FinalAnswerAction = "Final Answer:"; + private const string MissingActionAfterThoughtErrorMessage = "Invalid Format: Missing 'Action:' after 'Thought:"; + private const string MissingActionInputAfterActionErrorMessage = "Invalid Format: Missing 'Action Input:' after 'Action:'"; + private const string FinalAnswerAndParsableActionErrorMessage = "Parsing LLM output produced both a final answer and a parse-able action:"; + + + + public object Parse(string text) + { + bool includesAnswer = text.Contains(FinalAnswerAction); + string regex = @"Action\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"; + Match actionMatch = Regex.Match(text, regex, RegexOptions.Singleline); + + if (actionMatch.Success) + { + if (includesAnswer) + { + throw new OutputParserException($"{FinalAnswerAndParsableActionErrorMessage}: {text}"); + } + string action = actionMatch.Groups[1].Value.Trim(); + string actionInput = actionMatch.Groups[2].Value.Trim().Trim('\"'); + + return new AgentAction(action, actionInput, text); + } + else if (includesAnswer) + { + return new AgentFinish(text.Split(new[] { FinalAnswerAction }, StringSplitOptions.None)[^1].Trim(), text); + } + + if (!Regex.IsMatch(text, @"Action\s*\d*\s*:[\s]*(.*?)", RegexOptions.Singleline)) + { + throw new OutputParserException($"Could not parse LLM output: `{text}`", MissingActionAfterThoughtErrorMessage); + } + else if (!Regex.IsMatch(text, @"[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)", RegexOptions.Singleline)) + { + throw new OutputParserException($"Could not parse LLM output: `{text}`", MissingActionInputAfterActionErrorMessage); + } + else + { + throw new OutputParserException($"Could not parse LLM output: `{text}`"); + } + } + + protected override Task InternalCall(IChainValues values) + { + values.Value[this.OutputKeys[0]] = Parse(values.Value[this.InputKeys[0]].ToString()); + return Task.FromResult(values); + } +} + +public class AgentAction(string action, string actionInput, string text) +{ + public string Action => action; + public string ActionInput => actionInput; + public string Text => text; + + public override string ToString() + { + return $"Action: {action}, Action Input: {actionInput}"; + } +} + +public class AgentFinish(string output, string text) +{ + public string Output => output; + public string Text => text; + + public override string ToString() + { + return $"Final Answer: {output}"; + } +} + diff --git a/src/libs/LangChain.Core/LangChain.Core.csproj b/src/libs/LangChain.Core/LangChain.Core.csproj index 7f5e8dda..e5e62311 100644 --- a/src/libs/LangChain.Core/LangChain.Core.csproj +++ b/src/libs/LangChain.Core/LangChain.Core.csproj @@ -36,7 +36,6 @@ - diff --git a/src/libs/LangChain.Core/Memory/BaseChatMemory.cs b/src/libs/LangChain.Core/Memory/BaseChatMemory.cs index f6b72e13..2d36887a 100644 --- a/src/libs/LangChain.Core/Memory/BaseChatMemory.cs +++ b/src/libs/LangChain.Core/Memory/BaseChatMemory.cs @@ -10,7 +10,7 @@ public abstract class BaseChatMemory( /// /// /// - protected BaseChatMessageHistory ChatHistory { get; set; } = chatHistory; + public BaseChatMessageHistory ChatHistory { get; set; } = chatHistory; /// /// diff --git a/src/libs/LangChain.Core/Memory/ConversationBufferMemory.cs b/src/libs/LangChain.Core/Memory/ConversationBufferMemory.cs index b98d936c..f9cd1377 100644 --- a/src/libs/LangChain.Core/Memory/ConversationBufferMemory.cs +++ b/src/libs/LangChain.Core/Memory/ConversationBufferMemory.cs @@ -9,18 +9,25 @@ public class ConversationBufferMemory : BaseChatMemory /// /// /// - public string HumanPrefix { get; set; } = "Human"; + public string HumanPrefix { get; set; } = "Human: "; /// /// /// - public string AiPrefix { get; set; } = "AI"; + public string AiPrefix { get; set; } = "AI: "; + + public string SystemPrefix { get; set; } = "System: "; /// /// /// public string MemoryKey { get; set; } = "history"; + public bool SaveHumanMessages { get; set; } = true; + + + + /// public ConversationBufferMemory(BaseChatMessageHistory chatHistory) : base(chatHistory) { @@ -48,16 +55,22 @@ private string GetBufferString(IEnumerable messages) foreach (var m in messages) { + + if (m.Role==MessageRole.Human&&!SaveHumanMessages) + { + continue; + } + string role = m.Role switch { MessageRole.Human => HumanPrefix, MessageRole.Ai => AiPrefix, - MessageRole.System => "System", + MessageRole.System => SystemPrefix, MessageRole.FunctionCall => "Function", _ => throw new ArgumentException($"Unsupported message type: {m.GetType().Name}") }; - string message = $"{role}: {m.Content}"; + string message = $"{role}{m.Content}"; // TODO: Add special case for a function call stringMessages.Add(message); diff --git a/src/libs/Providers/LangChain.Providers.LLamaSharp/LLamaSharpConfiguration.cs b/src/libs/Providers/LangChain.Providers.LLamaSharp/LLamaSharpConfiguration.cs index fb978027..8aae2aa2 100644 --- a/src/libs/Providers/LangChain.Providers.LLamaSharp/LLamaSharpConfiguration.cs +++ b/src/libs/Providers/LangChain.Providers.LLamaSharp/LLamaSharpConfiguration.cs @@ -33,9 +33,10 @@ public class LLamaSharpConfiguration /// public int MaxTokens { get; set; } = 600; - /// - /// - /// - public IReadOnlyList AntiPrompts { get; set; } = new[] { ">", "Human: "}; + public float RepeatPenalty { get; set; } = 1; + + public List AntiPrompts { get; set; } = new() { ">", "Human: "}; + + } \ No newline at end of file diff --git a/src/libs/Providers/LangChain.Providers.LLamaSharp/LLamaSharpModelBase.cs b/src/libs/Providers/LangChain.Providers.LLamaSharp/LLamaSharpModelBase.cs index b705cd84..c685ea99 100644 --- a/src/libs/Providers/LangChain.Providers.LLamaSharp/LLamaSharpModelBase.cs +++ b/src/libs/Providers/LangChain.Providers.LLamaSharp/LLamaSharpModelBase.cs @@ -37,7 +37,7 @@ public abstract class LLamaSharpModelBase : IChatModel /// /// /// - protected ModelParams Parameters { get; } + public ModelParams Parameters { get; } /// /// diff --git a/src/libs/Providers/LangChain.Providers.LLamaSharp/LLamaSharpModelInstruction.cs b/src/libs/Providers/LangChain.Providers.LLamaSharp/LLamaSharpModelInstruction.cs index 9b6cec7a..9e3814be 100644 --- a/src/libs/Providers/LangChain.Providers.LLamaSharp/LLamaSharpModelInstruction.cs +++ b/src/libs/Providers/LangChain.Providers.LLamaSharp/LLamaSharpModelInstruction.cs @@ -10,6 +10,8 @@ namespace LangChain.Providers.LLamaSharp; [CLSCompliant(false)] public class LLamaSharpModelInstruction : LLamaSharpModelBase { + + /// /// /// @@ -51,6 +53,11 @@ private static string SanitizeOutput(string res) /// public event Action TokenGenerated = delegate { }; + /// + /// Occurs before prompt is sent to the model. + /// + public event Action PromptSent = delegate { }; + /// /// /// @@ -59,7 +66,7 @@ private static string SanitizeOutput(string res) /// public override async Task GenerateAsync(ChatRequest request, CancellationToken cancellationToken = default) { - var prompt = ToPrompt(request.Messages) + "\n"; + var prompt = ToPrompt(request.Messages); var watch = Stopwatch.StartNew(); @@ -72,18 +79,29 @@ public override async Task GenerateAsync(ChatRequest request, Canc Temperature = Configuration.Temperature, AntiPrompts = Configuration.AntiPrompts, MaxTokens = Configuration.MaxTokens, - + RepeatPenalty = Configuration.RepeatPenalty }; - + PromptSent(prompt); var buf = ""; await foreach (var text in ex.InferAsync(prompt, inferenceParams, cancellationToken)) { buf += text; + foreach (string antiPrompt in Configuration.AntiPrompts) + { + if (buf.EndsWith(antiPrompt)) + { + buf = buf.Substring(0, buf.Length - antiPrompt.Length); + break; + } + } + TokenGenerated(text); } + + buf = LLamaSharpModelInstruction.SanitizeOutput(buf); var result = request.Messages.ToList(); result.Add(buf.AsAiMessage());