Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REDBOX-337 - Tests for chat file selection #565

Merged
merged 13 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion django_app/redbox_app/redbox_core/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@
class UserResource(admin.ModelAdmin):
fields = ["email", "is_superuser", "is_staff", "last_login"]
list_display = ["email", "is_superuser", "is_staff", "last_login"]
date_hierarchy = "last_login"


class FileResource(admin.ModelAdmin):
list_display = ["original_file_name", "user", "status"]
list_display = ["original_file_name", "user", "status", "created_at"]
list_filter = ["user", "status"]
date_hierarchy = "created_at"


class ChatMessageResource(admin.ModelAdmin):
Expand Down
11 changes: 7 additions & 4 deletions django_app/redbox_app/redbox_core/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@


class ChatConsumer(AsyncWebsocketConsumer):
async def receive(self, text_data):
data = json.loads(text_data)
async def receive(self, text_data=None, bytes_data=None):
data = json.loads(text_data or bytes_data)
logger.debug("received %s from browser", data)
user_message_text: str = data.get("message", "")
session_id: str | None = data.get("sessionId", None)
selected_file_uuids: Sequence[UUID] = [UUID(u) for u in data.get("selectedFiles", [])]
Expand All @@ -32,6 +33,7 @@ async def receive(self, text_data):
await self.save_message(session, user_message_text, ChatRoleEnum.user, selected_files=selected_files)

await self.llm_conversation(selected_files, session, user)
await self.close()

async def llm_conversation(self, selected_files: Sequence[File], session: ChatHistory, user: User) -> None:
session_messages = await self.get_messages(session)
Expand All @@ -44,7 +46,6 @@ async def llm_conversation(self, selected_files: Sequence[File], session: ChatHi
"message_history": message_history,
"selected_files": [{"uuid": f.core_file_uuid} for f in selected_files],
}
logger.debug("sending to core-api: %s", message)
await self.send_to_server(core_websocket, message)
await self.send_to_client({"type": "session-id", "data": str(session.id)})
reply, source_files = await self.receive_llm_responses(user, core_websocket)
Expand All @@ -57,7 +58,7 @@ async def receive_llm_responses(
source_files: MutableSequence[File] = []
async for raw_message in core_websocket:
message = json.loads(raw_message, object_hook=lambda d: SimpleNamespace(**d))
logger.debug("Received: %s", message)
logger.debug("received %s from core-api", message)
if message.resource_type == "text":
full_reply.append(await self.handle_text(message))
elif message.resource_type == "documents":
Expand All @@ -81,10 +82,12 @@ async def handle_text(self, message: SimpleNamespace) -> str:
return message.data

async def send_to_client(self, data):
logger.debug("sending %s to browser", data)
await self.send(json.dumps(data, default=str))

@staticmethod
async def send_to_server(websocket, data):
logger.debug("sending %s to core-api", data)
return await websocket.send(json.dumps(data, default=str))

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion django_app/redbox_app/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

SECRET_KEY = env.str("DJANGO_SECRET_KEY")
ENVIRONMENT = Environment[env.str("ENVIRONMENT").upper()]
WEBSOCKET_SCHEME = "ws" if ENVIRONMENT.is_local else "wss"
WEBSOCKET_SCHEME = "ws" if ENVIRONMENT.is_test else "wss"

# SECURITY WARNING: don't run with debug turned on in production!
DEBUG = env.bool("DEBUG")
Expand Down
4 changes: 2 additions & 2 deletions django_app/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ def chat_history_with_files(chat_history: ChatHistory, several_files: Sequence[F
chat_message = ChatMessage.objects.create(chat_history=chat_history, text="An answer.", role=ChatRoleEnum.ai)
chat_message.source_files.set(several_files[0::2])
chat_message = ChatMessage.objects.create(
chat_history=chat_history, text="Another question?", role=ChatRoleEnum.user
chat_history=chat_history, text="A second question?", role=ChatRoleEnum.user
)
chat_message.selected_files.set(several_files[0:2])
chat_message = ChatMessage.objects.create(chat_history=chat_history, text="Another answer.", role=ChatRoleEnum.ai)
chat_message = ChatMessage.objects.create(chat_history=chat_history, text="A second answer.", role=ChatRoleEnum.ai)
chat_message.source_files.set([several_files[2]])
return chat_history

Expand Down
47 changes: 39 additions & 8 deletions django_app/tests/test_consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def get_chat_message_text(user: User, role: ChatRoleEnum) -> Sequence[str]:
return [m.text for m in ChatMessage.objects.filter(chat_history__users=user, role=role)]


@pytest.mark.xfail()
@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio()
async def test_chat_consumer_with_selected_files(
Expand All @@ -87,6 +88,7 @@ async def test_chat_consumer_with_selected_files(
mocked_connect_with_several_files: Connect,
):
# Given
selected_files: Sequence[File] = several_files[2:]

# When
with patch("redbox_app.redbox_core.consumers.connect", new=mocked_connect_with_several_files):
Expand All @@ -95,11 +97,12 @@ async def test_chat_consumer_with_selected_files(
connected, _ = await communicator.connect()
assert connected

selected_file_core_uuids: Sequence[str] = [str(f.core_file_uuid) for f in selected_files]
await communicator.send_json_to(
{
"message": "Third question, with selected files?",
"sessionId": str(chat_history_with_files.id),
"selectedFiles": [str(f.core_file_uuid) for f in several_files[2:]],
"selectedFiles": selected_file_core_uuids,
}
)
response1 = await communicator.receive_json_from(timeout=5)
Expand All @@ -111,13 +114,41 @@ async def test_chat_consumer_with_selected_files(
# Close
await communicator.disconnect()

assert await get_chat_message_text(alice, ChatRoleEnum.user) == [
"A question?",
"Another question?",
"Third question, with selected files?",
]
assert await get_chat_message_text(alice, ChatRoleEnum.ai) == ["An answer.", "Another answer.", "Third answer."]
# TODO (@brunns): Assert selected files sent to core, and saved to model.
# Then

# TODO (@brunns): Assert selected files sent to core.
# Requires fix for https://github.com/django/channels/issues/1091
mocked_websocket = mocked_connect_with_several_files.return_value.__aenter__.return_value
expected = json.dumps(
{
"message_history": [
{"role": "user", "text": "A question?"},
{"role": "ai", "text": "An answer."},
{"role": "user", "text": "A second question?"},
{"role": "ai", "text": "A second answer."},
{"role": "user", "text": "Third question, with selected files?"},
],
"selected_files": selected_file_core_uuids,
}
)
mocked_websocket.send.assert_called_with(expected)

# TODO (@brunns): Assert selected files saved to model.
# Requires fix for https://github.com/django/channels/issues/1091
all_messages = get_chat_messages(alice)
last_user_message = [m for m in all_messages if m.rule == ChatRoleEnum.user][-1]
assert last_user_message.selected_files == selected_files


@database_sync_to_async
def get_chat_messages(user: User) -> Sequence[ChatMessage]:
return list(
ChatMessage.objects.filter(chat_history__users=user)
.order_by("created_at")
.prefetch_related("chat_history")
.prefetch_related("source_files")
.prefetch_related("selected_files")
)


@pytest.fixture()
Expand Down
Loading
Loading