Skip to content

Commit

Permalink
feature/added MessagesState (#1166)
Browse files Browse the repository at this point in the history
  • Loading branch information
gecBurton authored Nov 12, 2024
1 parent a5106ae commit b8a5806
Show file tree
Hide file tree
Showing 12 changed files with 60 additions and 82 deletions.
26 changes: 7 additions & 19 deletions redbox-core/redbox/chains/runnables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,13 @@
import re
from typing import Any, Callable, Iterable, Iterator

from langchain_core.callbacks.manager import (
CallbackManagerForLLMRun,
dispatch_custom_event,
)
from langchain_core.callbacks.manager import CallbackManagerForLLMRun, dispatch_custom_event
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import (
Runnable,
RunnableGenerator,
RunnableLambda,
RunnablePassthrough,
chain,
)
from langchain_core.runnables import Runnable, RunnableGenerator, RunnableLambda, RunnablePassthrough, chain
from tiktoken import Encoding

from redbox.api.format import format_documents, format_toolstate
Expand All @@ -26,11 +17,7 @@
from redbox.models.chain import ChainChatMessage, PromptSet, RedboxState, get_prompts
from redbox.models.errors import QuestionLengthError
from redbox.models.graph import RedboxEventType
from redbox.transform import (
flatten_document_state,
tool_calls_to_toolstate,
get_all_metadata,
)
from redbox.transform import flatten_document_state, get_all_metadata, tool_calls_to_toolstate

log = logging.getLogger()
re_string_pattern = re.compile(r"(\S+)")
Expand Down Expand Up @@ -79,6 +66,7 @@ def _chat_prompt_from_messages(state: RedboxState) -> Runnable:

prompt_template_context = (
state["request"].model_dump()
| {"messages": state.get("messages")}
| {
"text": state.get("text"),
"formatted_documents": format_documents(flatten_document_state(state.get("documents"))),
Expand Down Expand Up @@ -190,7 +178,7 @@ class CannedChatLLM(BaseChatModel):
Based on https://python.langchain.com/v0.2/docs/how_to/custom_chat_model/
"""

text: str
messages: list[AIMessage]

def _generate(
self,
Expand All @@ -211,7 +199,7 @@ def _generate(
downstream and understand why generation stopped.
run_manager: A run manager with callbacks for the LLM.
"""
message = AIMessage(content=self.text)
message = AIMessage(content=self.messages[-1].content)

generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
Expand All @@ -235,7 +223,7 @@ def _stream(
downstream and understand why generation stopped.
run_manager: A run manager with callbacks for the LLM.
"""
for token in re_string_pattern.split(self.text):
for token in re_string_pattern.split(self.messages[-1].content):
chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))

if run_manager:
Expand Down
2 changes: 1 addition & 1 deletion redbox-core/redbox/graph/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def build_strings_end_text_conditional(*strings: str) -> Runnable:
regex = re.compile(pattern, re.IGNORECASE)

def _strings_end_text_conditional(state: RedboxState) -> str:
matches = regex.findall(state["text"][-100:]) # padding for waffle
matches = regex.findall(state["messages"][-1].content[-100:]) # padding for waffle
unique_matches = set(matches)

if len(unique_matches) == 1:
Expand Down
31 changes: 11 additions & 20 deletions redbox-core/redbox/graph/nodes/processes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import re
import textwrap
from collections.abc import Callable
from functools import reduce
from typing import Any, Iterable
Expand All @@ -9,6 +10,7 @@
from langchain.schema import StrOutputParser
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.documents import Document
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import Runnable, RunnableLambda, RunnableParallel
from langchain_core.tools import StructuredTool
from langchain_core.vectorstores import VectorStoreRetriever
Expand All @@ -18,19 +20,8 @@
from redbox.chains.runnables import CannedChatLLM, build_llm_chain
from redbox.graph.nodes.tools import get_log_formatter_for_retrieval_tool, has_injected_state, is_valid_tool
from redbox.models import ChatRoute
from redbox.models.chain import (
DocumentState,
PromptSet,
RedboxState,
RequestMetadata,
merge_redbox_state_updates,
)
from redbox.models.graph import (
ROUTE_NAME_TAG,
SOURCE_DOCUMENTS_TAG,
RedboxActivityEvent,
RedboxEventType,
)
from redbox.models.chain import DocumentState, PromptSet, RedboxState, RequestMetadata, merge_redbox_state_updates
from redbox.models.graph import ROUTE_NAME_TAG, SOURCE_DOCUMENTS_TAG, RedboxActivityEvent, RedboxEventType
from redbox.transform import combine_documents, flatten_document_state

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -119,7 +110,7 @@ def _merge(state: RedboxState) -> dict[str, Any]:
prompt_set=prompt_set, llm=llm, final_response_chain=final_response_chain
).invoke(merge_state)

merged_document.page_content = merge_response["text"]
merged_document.page_content = merge_response["messages"][-1].content
request_metadata = merge_response["metadata"]
merged_document.metadata["token_count"] = len(tokeniser.encode(merged_document.page_content))

Expand Down Expand Up @@ -193,7 +184,7 @@ def build_set_self_route_from_llm_answer(

@RunnableLambda
def _set_self_route_from_llm_answer(state: RedboxState):
llm_response = state["text"]
llm_response = state["messages"][-1].content
if conditional(llm_response):
return true_condition_state_update
else:
Expand All @@ -211,22 +202,22 @@ def build_passthrough_pattern() -> Runnable[RedboxState, dict[str, Any]]:
@RunnableLambda
def _passthrough(state: RedboxState) -> dict[str, Any]:
return {
"text": state["request"].question,
"messages": [HumanMessage(content=state["request"].question)],
}

return _passthrough


def build_set_text_pattern(text: str, final_response_chain: bool = False) -> Runnable[RedboxState, dict[str, Any]]:
"""Returns a Runnable that can arbitrarily set state["text"] to a value."""
llm = CannedChatLLM(text=text)
"""Returns a Runnable that can arbitrarily set state["messages"] to a value."""
llm = CannedChatLLM(messages=[AIMessage(content=text)])
_llm = llm.with_config(tags=["response_flag"]) if final_response_chain else llm

@RunnableLambda
def _set_text(state: RedboxState) -> dict[str, Any]:
set_text_chain = _llm | StrOutputParser()

return {"text": set_text_chain.invoke(text)}
return {"messages": state.get("messages", []) + [HumanMessage(content=set_text_chain.invoke(text))]}

return _set_text

Expand Down Expand Up @@ -356,7 +347,7 @@ def _log_node(state: RedboxState):
group_id: {doc_id: d.metadata for doc_id, d in group_documents.items()}
for group_id, group_documents in state["documents"]
},
"text": (state["text"] if len(state["text"]) < 32 else f"{state['text'][:29]}..."),
"messages": (textwrap.shorten(state["messages"][-1].content, width=32, placeholder="...")),
"route": state["route_name"],
"message": message,
}
Expand Down
2 changes: 0 additions & 2 deletions redbox-core/redbox/graph/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@
from redbox.models.graph import ROUTABLE_KEYWORDS, RedboxActivityEvent
from redbox.transform import structure_documents_by_file_name, structure_documents_by_group_and_indices

# Subgraphs


def get_self_route_graph(retriever: VectorStoreRetriever, prompt_set: PromptSet, debug: bool = False):
builder = StateGraph(RedboxState)
Expand Down
16 changes: 5 additions & 11 deletions redbox-core/redbox/models/chain.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,12 @@
from datetime import UTC, datetime
from enum import StrEnum
from functools import reduce
from typing import (
Annotated,
Literal,
NotRequired,
Required,
TypedDict,
get_args,
get_origin,
)
from typing import Annotated, Literal, NotRequired, Required, TypedDict, get_args, get_origin
from uuid import UUID, uuid4

from langchain_core.documents import Document
from langchain_core.messages import ToolCall
from langgraph.graph import MessagesState
from langgraph.managed.is_last_step import RemainingStepsManager
from pydantic import BaseModel, Field

Expand Down Expand Up @@ -268,10 +261,9 @@ def tool_calls_reducer(current: ToolState, update: ToolState | None) -> ToolStat
return reduced


class RedboxState(TypedDict):
class RedboxState(MessagesState):
request: Required[RedboxQuery]
documents: Annotated[NotRequired[DocumentState], document_reducer]
text: NotRequired[str | None]
route_name: NotRequired[str | None]
tool_calls: Annotated[NotRequired[ToolState], tool_calls_reducer]
metadata: Annotated[NotRequired[RequestMetadata], metadata_reducer]
Expand Down Expand Up @@ -374,6 +366,8 @@ def merge_redbox_state_updates(current: RedboxState, update: RedboxState) -> Red
if is_dict_type(annotation):
# If it's annotated and a subclass of dict, apply a custom reducer function
merged_state[update_key] = dict_reducer(current=current_value or {}, update=update_value or {})
elif current_value is None:
merged_state[update_key] = update_value
else:
# If it's annotated and not a dict, apply its reducer function
_, reducer_func = get_args(annotation)
Expand Down
2 changes: 1 addition & 1 deletion redbox-core/redbox/models/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@
"The following context and previous actions are provided to assist you. \n\n"
"Previous tool calls: \n\n <ToolCalls> \n\n {tool_calls} </ToolCalls> \n\n "
"Document snippets: \n\n <Documents> \n\n {formatted_documents} </Documents> \n\n "
"Previous agent's response: \n\n <AIResponse> \n\n {text} \n\n </AIResponse> \n\n "
"Previous agent's response: \n\n <AIResponse> \n\n {messages} \n\n </AIResponse> \n\n "
"User question: \n\n {question}"
)

Expand Down
2 changes: 1 addition & 1 deletion redbox-core/redbox/retriever/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class ParameterisedElasticsearchRetriever(BaseRetriever):
def _get_relevant_documents(
self, query: RedboxState, *, run_manager: CallbackManagerForRetrieverRun
) -> list[Document]:
query_text = query["text"]
query_text = query["messages"][-1].content
query_vector = self.embedding_model.embed_query(query_text)
selected_files = query["request"].s3_keys
permitted_files = query["request"].permitted_s3_keys
Expand Down
4 changes: 2 additions & 2 deletions redbox-core/redbox/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tiktoken
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.documents import Document
from langchain_core.messages import ToolCall, AnyMessage
from langchain_core.messages import ToolCall, AnyMessage, AIMessage
from langchain_core.runnables import RunnableLambda

from redbox.models.chain import (
Expand Down Expand Up @@ -169,9 +169,9 @@ def get_all_metadata(obj: dict):
citations = []

out = {
"messages": [AIMessage(content=text)],
"tool_calls": text_and_tools["tool_calls"],
"metadata": to_request_metadata(obj),
"text": text,
"citations": citations,
}
return out
Expand Down
8 changes: 4 additions & 4 deletions redbox-core/tests/graph/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,11 +582,11 @@ async def documents_response_handler(documents: list[Document]):
expected_text = expected_text.content if isinstance(expected_text, AIMessage) else expected_text

assert (
final_state["text"] == llm_response
), f"Text response from streaming: '{llm_response}' did not match final state text '{final_state["text"]}'"
final_state["messages"][-1].content == llm_response
), f"Text response from streaming: '{llm_response}' did not match final state text '{final_state["messages"]}'"
assert (
final_state["text"] == expected_text
), f"Expected text: '{expected_text}' did not match received text '{final_state["text"]}'"
final_state["messages"][-1].content == expected_text
), f"Expected text: '{expected_text}' did not match received text '{final_state["messages"]}'"

assert (
final_state.get("route_name") == test_case.test_data.expected_route
Expand Down
Loading

0 comments on commit b8a5806

Please sign in to comment.