Skip to content

Commit

Permalink
Feature/advise changing model (#1362)
Browse files Browse the repository at this point in the history
* feature/advise-changing-model

* rename responses

* added tests

* formatting
  • Loading branch information
gecBurton authored Feb 4, 2025
1 parent e91e194 commit 9f600c1
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 80 deletions.
17 changes: 13 additions & 4 deletions django_app/redbox_app/redbox_core/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ async def receive(self, text_data=None, _bytes_data=None):
user: User = self.scope["user"]
chat_id = self.scope["url_route"]["kwargs"]["chat_id"]
except KeyError:
self.close()
self.send_to_client("error", error_messages.CORE_ERROR_MESSAGE)
await self.close()
await self.send_to_client("error", error_messages.CORE_ERROR_MESSAGE)
raise

chat = await Chat.objects.aget(id=chat_id)
Expand All @@ -81,12 +81,21 @@ async def receive(self, text_data=None, _bytes_data=None):

async def llm_conversation(self, session: Chat) -> None:
"""Initiate & close websocket conversation with the core-api message endpoint."""
await self.send_to_client("session-id", session.id)

token_count = await sync_to_async(session.token_count)()

active_context_window_sizes = await sync_to_async(ChatLLMBackend.active_context_window_sizes)()

if token_count > max(active_context_window_sizes.values()):
await self.send_to_client("error", error_messages.FILES_TOO_LARGE)
return

if token_count > await sync_to_async(session.context_window_size)():
await self.send_to_client("error", "The attached files are too large to work with")
details = "\n".join(
f"* `{k}`: {v} tokens" for k, v in active_context_window_sizes.items() if v >= token_count
)
msg = f"{error_messages.FILES_TOO_LARGE}.\nTry one of the following models:\n{details}"
await self.send_to_client("error", msg)
return

self.route = "chat_with_docs" # if selected_files else "chat"
Expand Down
1 change: 1 addition & 0 deletions django_app/redbox_app/redbox_core/error_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
"Redbox has temporarily exceeded its usage allowance with the AI server. "
'Please try again in a few minutes, and contact <a href="/support/">support</a> if the problem persists.'
)
FILES_TOO_LARGE = "The attached files are too large to work with"
4 changes: 4 additions & 0 deletions django_app/redbox_app/redbox_core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ def save(self, *args, **kwargs):
ChatLLMBackend.objects.filter(is_default=True).update(is_default=False)
super().save(*args, **kwargs)

@classmethod
def active_context_window_sizes(cls) -> dict[str, int]:
return {str(o): o.context_window_size for o in cls.objects.filter(enabled=True)}


class User(BaseUser, UUIDPrimaryKeyBase):
class UserGrade(models.TextChoices):
Expand Down
25 changes: 25 additions & 0 deletions django_app/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,17 @@ def llm_backend(db): # noqa: ARG001
return gpt_4o


@pytest.fixture()
def big_llm_backend():
big_llm, _ = ChatLLMBackend.objects.get_or_create(
name="big-llm",
provider="azure_openai",
is_default=False,
context_window_size=1_000_000,
)
return big_llm


@pytest.fixture()
def create_user():
def _create_user(
Expand Down Expand Up @@ -193,6 +204,20 @@ def uploaded_file(chat: Chat, original_file: UploadedFile, s3_client) -> File:
file.delete()


@pytest.fixture()
def large_file(chat: Chat, original_file: UploadedFile, s3_client) -> File: # noqa: ARG001
file = File.objects.create(
chat=chat,
original_file=original_file,
last_referenced=datetime.now(tz=UTC) - timedelta(days=14),
status=File.Status.processing,
token_count=150_000,
)
file.save()
yield file
file.delete()


@pytest.fixture()
def original_file() -> UploadedFile:
return SimpleUploadedFile("original_file.txt", b"Lorem Ipsum.")
Expand Down
166 changes: 90 additions & 76 deletions django_app/tests/test_consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,16 @@ async def test_chat_consumer_with_new_session(chat: Chat, mocked_connect: Connec
assert connected

await communicator.send_json_to({"message": "Hello Hal."})
response1 = await communicator.receive_json_from(timeout=5)
response2 = await communicator.receive_json_from(timeout=5)
response3 = await communicator.receive_json_from(timeout=5)
response4 = await communicator.receive_json_from(timeout=5)
# response5 = await communicator.receive_json_from(timeout=5)
response_1 = await communicator.receive_json_from(timeout=5)
response_2 = await communicator.receive_json_from(timeout=5)
response_3 = await communicator.receive_json_from(timeout=5)

# Then
assert response1["type"] == "session-id"
assert response2["type"] == "text"
assert response2["data"] == "Good afternoon, "
assert response3["type"] == "text"
assert response3["data"] == "Mr. Amor."
assert response4["type"] == "end"
assert response_1["type"] == "text"
assert response_1["data"] == "Good afternoon, "
assert response_2["type"] == "text"
assert response_2["data"] == "Mr. Amor."
assert response_3["type"] == "end"
# Close
await communicator.disconnect()

Expand All @@ -83,19 +80,16 @@ async def test_chat_consumer_staff_user(staff_user: User, chat: Chat, mocked_con
assert connected

await communicator.send_json_to({"message": "Hello Hal.", "output_text": "hello"})
response1 = await communicator.receive_json_from(timeout=5)
response2 = await communicator.receive_json_from(timeout=5)
response3 = await communicator.receive_json_from(timeout=5)
response4 = await communicator.receive_json_from(timeout=5)
# _response5 = await communicator.receive_json_from(timeout=5)
response_1 = await communicator.receive_json_from(timeout=5)
response_2 = await communicator.receive_json_from(timeout=5)
response_3 = await communicator.receive_json_from(timeout=5)

# Then
assert response1["type"] == "session-id"
assert response2["type"] == "text"
assert response2["data"] == "Good afternoon, "
assert response3["type"] == "text"
assert response3["data"] == "Mr. Amor."
assert response4["type"] == "end"
assert response_1["type"] == "text"
assert response_1["data"] == "Good afternoon, "
assert response_2["type"] == "text"
assert response_2["data"] == "Mr. Amor."
assert response_3["type"] == "end"
# Close
await communicator.disconnect()

Expand All @@ -114,11 +108,6 @@ async def test_chat_consumer_with_existing_session(chat: Chat, mocked_connect: C
assert connected

await communicator.send_json_to({"message": "Hello Hal."})
response1 = await communicator.receive_json_from(timeout=5)

# Then
assert response1["type"] == "session-id"
assert response1["data"] == str(chat.id)

# Close
await communicator.disconnect()
Expand All @@ -141,18 +130,16 @@ async def test_chat_consumer_with_naughty_question(chat: Chat, mocked_connect: C
assert connected

await communicator.send_json_to({"message": "Hello Hal. \x00"})
response1 = await communicator.receive_json_from(timeout=5)
response2 = await communicator.receive_json_from(timeout=5)
response3 = await communicator.receive_json_from(timeout=5)
response4 = await communicator.receive_json_from(timeout=5)
response_1 = await communicator.receive_json_from(timeout=5)
response_2 = await communicator.receive_json_from(timeout=5)
response_3 = await communicator.receive_json_from(timeout=5)

# Then
assert response1["type"] == "session-id"
assert response2["type"] == "text"
assert response2["data"] == "Good afternoon, "
assert response3["type"] == "text"
assert response3["data"] == "Mr. Amor."
assert response4["type"] == "end"
assert response_1["type"] == "text"
assert response_1["data"] == "Good afternoon, "
assert response_2["type"] == "text"
assert response_2["data"] == "Mr. Amor."
assert response_3["type"] == "end"
# Close
await communicator.disconnect()

Expand Down Expand Up @@ -197,11 +184,6 @@ async def test_chat_consumer_with_selected_files(
"message": "Third question, with selected files?",
}
)
response1 = await communicator.receive_json_from(timeout=5)

# Then
assert response1["type"] == "session-id"
assert response1["data"] == str(chat_with_files.id)

# Close
await communicator.disconnect()
Expand Down Expand Up @@ -251,11 +233,10 @@ async def test_chat_consumer_with_connection_error(chat: Chat, mocked_breaking_c
assert connected

await communicator.send_json_to({"message": "Hello Hal."})
await communicator.receive_json_from(timeout=5)
response2 = await communicator.receive_json_from(timeout=5)
response_1 = await communicator.receive_json_from(timeout=5)

# Then
assert response2["type"] == "error"
assert response_1["type"] == "error"


@pytest.mark.django_db(transaction=True)
Expand All @@ -277,16 +258,14 @@ async def test_chat_consumer_with_explicit_unhandled_error(
assert connected

await communicator.send_json_to({"message": "Hello Hal."})
response1 = await communicator.receive_json_from(timeout=5)
response2 = await communicator.receive_json_from(timeout=5)
response3 = await communicator.receive_json_from(timeout=5)
response_1 = await communicator.receive_json_from(timeout=5)
response_2 = await communicator.receive_json_from(timeout=5)

# Then
assert response1["type"] == "session-id"
assert response2["type"] == "text"
assert response2["data"] == "Good afternoon, "
assert response3["type"] == "text"
assert response3["data"] == error_messages.CORE_ERROR_MESSAGE
assert response_1["type"] == "text"
assert response_1["data"] == "Good afternoon, "
assert response_2["type"] == "text"
assert response_2["data"] == error_messages.CORE_ERROR_MESSAGE
# Close
await communicator.disconnect()

Expand All @@ -308,16 +287,14 @@ async def test_chat_consumer_with_rate_limited_error(chat: Chat, mocked_connect_
assert connected

await communicator.send_json_to({"message": "Hello Hal."})
response1 = await communicator.receive_json_from(timeout=5)
response2 = await communicator.receive_json_from(timeout=5)
response3 = await communicator.receive_json_from(timeout=5)
response_1 = await communicator.receive_json_from(timeout=5)
response_2 = await communicator.receive_json_from(timeout=5)

# Then
assert response1["type"] == "session-id"
assert response2["type"] == "text"
assert response2["data"] == "Good afternoon, "
assert response3["type"] == "text"
assert response3["data"] == error_messages.RATE_LIMITED
assert response_1["type"] == "text"
assert response_1["data"] == "Good afternoon, "
assert response_2["type"] == "text"
assert response_2["data"] == error_messages.RATE_LIMITED
# Close
await communicator.disconnect()

Expand All @@ -341,38 +318,29 @@ async def test_chat_consumer_with_explicit_no_document_selected_error(
assert connected

await communicator.send_json_to({"message": "Hello Hal."})
response1 = await communicator.receive_json_from(timeout=5)
response2 = await communicator.receive_json_from(timeout=5)
response_1 = await communicator.receive_json_from(timeout=5)

# Then
assert response1["type"] == "session-id"
assert response2["type"] == "text"
assert response2["data"] == error_messages.SELECT_DOCUMENT
assert response_1["type"] == "text"
assert response_1["data"] == error_messages.SELECT_DOCUMENT
# Close
await communicator.disconnect()


@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio()
async def test_chat_consumer_redbox_state(
chat: Chat, several_files: Sequence[File], chat_with_files: Chat, llm_backend
):
async def test_chat_consumer_redbox_state(several_files: Sequence[File], chat_with_files: Chat, llm_backend):
# Given

# When
with patch("redbox_app.redbox_core.consumers.ChatConsumer.redbox.run") as mock_run:
communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/")
communicator.scope["user"] = chat.user
communicator.scope["url_route"] = {"kwargs": {"chat_id": chat.id}}
communicator.scope["user"] = chat_with_files.user
communicator.scope["url_route"] = {"kwargs": {"chat_id": chat_with_files.id}}
connected, _ = await communicator.connect()
assert connected

await communicator.send_json_to({"message": "Third question, with selected files?"})
response1 = await communicator.receive_json_from(timeout=5)

# Then
assert response1["type"] == "session-id"
assert response1["data"] == str(chat_with_files.id)

# Close
await communicator.disconnect()
Expand Down Expand Up @@ -569,3 +537,49 @@ def mocked_connect_with_several_files(several_files: Sequence[File]) -> Connect:
@database_sync_to_async
def refresh_from_db(obj: Model) -> None:
obj.refresh_from_db()


@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio()
async def test_chat_consumer_with_context_window_error(large_file: File):
# Given large_file

# When
communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/")
communicator.scope["user"] = large_file.chat.user
communicator.scope["url_route"] = {"kwargs": {"chat_id": large_file.chat.id}}
connected, _ = await communicator.connect()
assert connected

await communicator.send_json_to({"message": "Hello Hal."})
response_1 = await communicator.receive_json_from(timeout=5)

# Then
assert response_1["type"] == "error"
assert response_1["data"] == error_messages.FILES_TOO_LARGE
# Close
await communicator.disconnect()


@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio()
async def test_chat_consumer_with_context_window_error_with_suggestion(large_file: File, big_llm_backend):
# Given large_file

# When
communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/")
communicator.scope["user"] = large_file.chat.user
communicator.scope["url_route"] = {"kwargs": {"chat_id": large_file.chat.id}}
connected, _ = await communicator.connect()
assert connected

await communicator.send_json_to({"message": "Hello Hal."})
response_1 = await communicator.receive_json_from(timeout=5)

# Then
assert response_1["type"] == "error"
assert response_1["data"].startswith(error_messages.FILES_TOO_LARGE)
assert response_1["data"].endswith(f"`{big_llm_backend}`: {big_llm_backend.context_window_size} tokens")

# Close
await communicator.disconnect()

0 comments on commit 9f600c1

Please sign in to comment.