Skip to content

Commit

Permalink
feat: redis message history (#86)
Browse files Browse the repository at this point in the history
Co-authored-by: Evgenii Khoroshev <[email protected]>
  • Loading branch information
khoroshevj and Evgenii Khoroshev authored Dec 4, 2023
1 parent f30ef1d commit e14cdb0
Show file tree
Hide file tree
Showing 11 changed files with 236 additions and 37 deletions.
7 changes: 7 additions & 0 deletions LangChain.sln
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "LangChain.Utilities.Postgre
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "LangChain.Utilities.Postgres.IntegrationTests", "src\tests\LangChain.Utilities.Postgres.IntegrationTests\LangChain.Utilities.Postgres.IntegrationTests.csproj", "{A652E4C6-6988-40BD-A726-2F5A3783C129}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "LangChain.Databases.Redis.IntegrationTests", "src\tests\LangChain.Databases.Redis.IntegrationTests\LangChain.Databases.Redis.IntegrationTests.csproj", "{E19562A0-9AAA-4C75-BE78-648E7148A4CD}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -348,6 +350,10 @@ Global
{A652E4C6-6988-40BD-A726-2F5A3783C129}.Debug|Any CPU.Build.0 = Debug|Any CPU
{A652E4C6-6988-40BD-A726-2F5A3783C129}.Release|Any CPU.ActiveCfg = Release|Any CPU
{A652E4C6-6988-40BD-A726-2F5A3783C129}.Release|Any CPU.Build.0 = Release|Any CPU
{E19562A0-9AAA-4C75-BE78-648E7148A4CD}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{E19562A0-9AAA-4C75-BE78-648E7148A4CD}.Debug|Any CPU.Build.0 = Debug|Any CPU
{E19562A0-9AAA-4C75-BE78-648E7148A4CD}.Release|Any CPU.ActiveCfg = Release|Any CPU
{E19562A0-9AAA-4C75-BE78-648E7148A4CD}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -407,6 +413,7 @@ Global
{7D47EC2D-2F03-4284-A07D-E56486B885C6} = {788567AF-444A-488F-BCED-C3B9F03CC38D}
{2A01AC56-7850-48FD-B32F-A7AAF0E86F84} = {788567AF-444A-488F-BCED-C3B9F03CC38D}
{A652E4C6-6988-40BD-A726-2F5A3783C129} = {FDEE2E22-C239-4921-83B2-9797F765FD6A}
{E19562A0-9AAA-4C75-BE78-648E7148A4CD} = {FDEE2E22-C239-4921-83B2-9797F765FD6A}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {5C00D0F1-6138-4ED9-846B-97E43D6DFF1C}
Expand Down
1 change: 1 addition & 0 deletions src/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
<PackageVersion Include="NUnit3TestAdapter" Version="4.5.0" />
<PackageVersion Include="PdfPig" Version="0.1.9-alpha-20231119-4537e" />
<PackageVersion Include="PolySharp" Version="1.13.2" />
<PackageVersion Include="StackExchange.Redis" Version="2.7.4" />
<PackageVersion Include="System.Net.Http" Version="4.3.4" />
<PackageVersion Include="System.Text.Json" Version="8.0.0" />
<PackageVersion Include="Tiktoken" Version="1.1.3" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@

<ItemGroup>
<PackageReference Include="Microsoft.SemanticKernel.Connectors.Memory.Redis" />
<PackageReference Include="StackExchange.Redis" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\LangChain.Core\LangChain.Core.csproj" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
using System.Text.Json;
using LangChain.Memory;
using LangChain.Providers;
using StackExchange.Redis;

namespace LangChain.Databases;

/// <summary>
/// Chat message history stored in a Redis database.
/// </summary>
public class RedisChatMessageHistory : BaseChatMessageHistory
{
private readonly string _sessionId;
private readonly string _keyPrefix;
private readonly TimeSpan? _ttl;
private readonly Lazy<ConnectionMultiplexer> _multiplexer;

/// <inheritdoc />
public RedisChatMessageHistory(
string sessionId,
string connectionString,
string keyPrefix = "message_store:",
TimeSpan? ttl = null)
{
_sessionId = sessionId;
_keyPrefix = keyPrefix;
_ttl = ttl;

_multiplexer = new Lazy<ConnectionMultiplexer>(
() =>
{
var multiplexer = ConnectionMultiplexer.Connect(connectionString);

return multiplexer;
},
LazyThreadSafetyMode.ExecutionAndPublication);
}

/// <summary>
/// Construct the record key to use
/// </summary>
private string Key => _keyPrefix + _sessionId;

/// <summary>
/// Retrieve the messages from Redis
/// TODO: use async methods
/// </summary>
public override IReadOnlyList<Message> Messages
{
get
{
var database = _multiplexer.Value.GetDatabase();
var values = database.ListRange(Key, start: 0, stop: -1);
var messages = values.Select(v => JsonSerializer.Deserialize<Message>(v.ToString())).Reverse();

return messages.ToList();
}
}

/// <summary>
/// Append the message to the record in Redis
/// </summary>
public override async Task AddMessage(Message message)
{
var database = _multiplexer.Value.GetDatabase();
await database.ListLeftPushAsync(Key, JsonSerializer.Serialize(message)).ConfigureAwait(false);
if (_ttl.HasValue)
{
await database.KeyExpireAsync(Key, _ttl).ConfigureAwait(false);
}
}

/// <summary>
/// Clear session memory from Redis
/// </summary>
public override async Task Clear()
{
var database = _multiplexer.Value.GetDatabase();
await database.KeyDeleteAsync(Key).ConfigureAwait(false);
}
}
3 changes: 1 addition & 2 deletions src/libs/LangChain.Core/Memory/BaseChatMemory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ public override async Task SaveContext(InputValues inputValues, OutputValues out

public override Task Clear()
{
ChatHistory.Clear();
return Task.CompletedTask;
return ChatHistory.Clear();
}
}
5 changes: 2 additions & 3 deletions src/libs/LangChain.Core/Memory/BaseChatMessageHistory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,18 @@ namespace LangChain.Memory;

public abstract class BaseChatMessageHistory
{
public IList<Message> Messages { get; set; } = new List<Message>();

public async Task AddUserMessage(string message)
{
await AddMessage(message.AsHumanMessage());

}

public async Task AddAiMessage(string message)
{
await AddMessage(message.AsAiMessage());
}

public abstract IReadOnlyList<Message> Messages { get; }

public abstract Task AddMessage(Message message);

public abstract Task Clear();
Expand Down
7 changes: 5 additions & 2 deletions src/libs/LangChain.Core/Memory/ChatMessageHistory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@ namespace LangChain.Memory;

public class ChatMessageHistory : BaseChatMessageHistory
{
private readonly List<Message> _messages = new List<Message>();
public override IReadOnlyList<Message> Messages => _messages;

public override Task AddMessage(Message message)
{
Messages.Add(message);
_messages.Add(message);
return Task.CompletedTask;
}

public override Task Clear()
{
Messages.Clear();
_messages.Clear();
return Task.CompletedTask;
}
}
38 changes: 12 additions & 26 deletions src/libs/LangChain.Core/Memory/ConversationBufferMemory.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using LangChain.Providers;
using System.Net.Mail;
using LangChain.Schema;

namespace LangChain.Memory;
Expand All @@ -16,40 +15,28 @@ public ConversationBufferMemory(BaseChatMessageHistory chatHistory) : base(chatH
ChatHistory = chatHistory;
}


// note: buffer property can't be implemented because of Any type as return type

public string BufferAsString => GetBufferString(BufferAsMessages);

public IList<Message> BufferAsMessages => ChatHistory.Messages;
public IReadOnlyList<Message> BufferAsMessages => ChatHistory.Messages;

public override List<string> MemoryVariables => new List<string> {MemoryKey};
public override List<string> MemoryVariables => new List<string> { MemoryKey };

private string GetBufferString(
IEnumerable<Message> messages)
private string GetBufferString(IEnumerable<Message> messages)
{
List<string> stringMessages = new List<string>();
var stringMessages = new List<string>();

foreach (var m in messages)
{
string role;
switch (m.Role)
string role = m.Role switch
{
case MessageRole.Human:
role = HumanPrefix;
break;
case MessageRole.Ai:
role = AiPrefix;
break;
case MessageRole.System:
role = "System";
break;
case MessageRole.FunctionCall:
role = "Function";
break;
default:
throw new ArgumentException($"Unsupported message type: {m.GetType().Name}");
}
MessageRole.Human => HumanPrefix,
MessageRole.Ai => AiPrefix,
MessageRole.System => "System",
MessageRole.FunctionCall => "Function",
_ => throw new ArgumentException($"Unsupported message type: {m.GetType().Name}")
};

string message = $"{role}: {m.Content}";
// TODO: Add special case for a function call
Expand All @@ -60,9 +47,8 @@ private string GetBufferString(
return string.Join("\n", stringMessages);
}


public override OutputValues LoadMemoryVariables(InputValues? inputValues)
{
return new OutputValues(new Dictionary<string, object> {{MemoryKey, BufferAsString}});
return new OutputValues(new Dictionary<string, object> { { MemoryKey, BufferAsString } });
}
}
27 changes: 23 additions & 4 deletions src/libs/LangChain.Core/Memory/MemoryExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ namespace LangChain.Memory;

public static class MemoryExtensions
{
public static IReadOnlyCollection<Message> WithHistory(this IReadOnlyCollection<Message> messages, BaseMemory? memory)
public static IReadOnlyCollection<Message> WithHistory(
this IReadOnlyCollection<Message> messages,
BaseMemory? memory)
{
if (memory == null)
{
Expand All @@ -22,9 +24,26 @@ public static IReadOnlyCollection<Message> WithHistory(this IReadOnlyCollection<
}
}

return new[]
var result = new Message[messages.Count + 1];
result[0] = history.AsHumanMessage();
messages.CopyTo(result, startIndex: 1);

return result;
}

private static void CopyTo<T>(this IReadOnlyCollection<T> source, T[] destination, int startIndex)
{
if (destination.Length > source.Count + startIndex)
{
history.AsHumanMessage(),
}.Concat(messages).ToArray();
throw new ArgumentException(
$"{nameof(destination)} required to have min length of {source.Count + startIndex}, but was {destination.Length}");
}

var i = 0;
foreach (var item in source)
{
destination[startIndex + i] = item;
i++;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>net8.0</TargetFramework>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\..\libs\Databases\LangChain.Databases.Redis\LangChain.Databases.Redis.csproj" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
using LangChain.Providers;

namespace LangChain.Databases.Redis.IntegrationTests;

/// <summary>
/// In order to run tests please run redis locally, e.g. with docker
/// docker run -p 6379:6379 redis
/// </summary>
[TestFixture]
[Explicit]
public class RedisChatMessageHistoryTests
{
private readonly string _connectionString = "127.0.0.1:6379";

[Test]
public void GetMessages_EmptyHistory_Ok()
{
var sessionId = "GetMessages_EmptyHistory_Ok";
var history = new RedisChatMessageHistory(
sessionId,
_connectionString,
ttl: TimeSpan.FromSeconds(30));

var existing = history.Messages;

existing.Should().BeEmpty();
}

[Test]
public async Task AddMessage_Ok()
{
var sessionId = "RedisChatMessageHistoryTests_AddMessage_Ok";
var history = new RedisChatMessageHistory(
sessionId,
_connectionString,
ttl: TimeSpan.FromSeconds(30));

var humanMessage = Message.Human("Hi, AI");
await history.AddMessage(humanMessage);
var aiMessage = Message.Ai("Hi, human");
await history.AddMessage(aiMessage);

var actual = history.Messages;

actual.Should().HaveCount(2);

actual[0].Role.Should().Be(humanMessage.Role);
actual[0].Content.Should().BeEquivalentTo(humanMessage.Content);

actual[1].Role.Should().Be(aiMessage.Role);
actual[1].Content.Should().BeEquivalentTo(aiMessage.Content);
}

[Test]
public async Task Ttl_Ok()
{
var sessionId = "Ttl_Ok";
var history = new RedisChatMessageHistory(
sessionId,
_connectionString,
ttl: TimeSpan.FromSeconds(2));

var humanMessage = Message.Human("Hi, AI");
await history.AddMessage(humanMessage);

await Task.Delay(2_500);

var existing = history.Messages;

existing.Should().BeEmpty();
}

[Test]
public async Task Clear_Ok()
{
var sessionId = "Ttl_Ok";
var history = new RedisChatMessageHistory(
sessionId,
_connectionString,
ttl: TimeSpan.FromSeconds(30));

await history.Clear();

var existing = history.Messages;

existing.Should().BeEmpty();
}
}

0 comments on commit e14cdb0

Please sign in to comment.