diff --git a/django_app/redbox_app/redbox_core/client.py b/django_app/redbox_app/redbox_core/client.py index d5cd28ddf..6933ace7f 100644 --- a/django_app/redbox_app/redbox_core/client.py +++ b/django_app/redbox_app/redbox_core/client.py @@ -7,9 +7,10 @@ from botocore.exceptions import ClientError from dataclasses_json import Undefined, dataclass_json from django.conf import settings +from django.forms.models import model_to_dict from yarl import URL -from redbox_app.redbox_core.models import User +from redbox_app.redbox_core.models import AISettings, User logger = logging.getLogger(__name__) @@ -85,12 +86,22 @@ def upload_file(self, name: str, user: User) -> FileOperation: response.raise_for_status() return FileOperation.schema().loads(response.content) + def get_ai_settings(self, user: User) -> AISettings: + return model_to_dict( + user.ai_settings, + fields=[field.name for field in user.ai_settings._meta.fields if field.name != "label"], # noqa: SLF001 + ) + def rag_chat( self, message_history: list[dict[str, str]], selected_files: list[dict[str, str]], user: User ) -> CoreChatResponse: response = requests.post( self.url / "chat/rag", - json={"message_history": message_history, "selected_files": selected_files}, + json={ + "message_history": message_history, + "selected_files": selected_files, + "ai_settings": self.get_ai_settings(user), + }, headers={"Authorization": user.get_bearer_token()}, timeout=60, ) diff --git a/django_app/redbox_app/redbox_core/consumers.py b/django_app/redbox_app/redbox_core/consumers.py index 5b77e0a12..679d65d5e 100644 --- a/django_app/redbox_app/redbox_core/consumers.py +++ b/django_app/redbox_app/redbox_core/consumers.py @@ -10,13 +10,14 @@ from channels.generic.websocket import AsyncWebsocketConsumer from dataclasses_json import Undefined, dataclass_json from django.conf import settings +from django.forms.models import model_to_dict from django.utils import timezone from websockets import ConnectionClosedError, WebSocketClientProtocol from websockets.client import connect from yarl import URL from redbox_app.redbox_core import error_messages -from redbox_app.redbox_core.models import ChatHistory, ChatMessage, ChatRoleEnum, Citation, File, User +from redbox_app.redbox_core.models import AISettings, ChatHistory, ChatMessage, ChatRoleEnum, Citation, File, User OptFileSeq = Sequence[File] | None logger = logging.getLogger(__name__) @@ -79,6 +80,7 @@ async def llm_conversation( message = { "message_history": message_history, "selected_files": [{"uuid": f.core_file_uuid} for f in selected_files], + "ai_settings": await self.get_ai_settings(user), } await self.send_to_server(core_websocket, message) await self.send_to_client("session-id", session.id) @@ -219,6 +221,14 @@ def get_sources_with_files( def file_save(file): return file.save() + @staticmethod + @database_sync_to_async + def get_ai_settings(user: User) -> AISettings: + return model_to_dict( + user.ai_settings, + fields=[field.name for field in user.ai_settings._meta.fields if field.name != "label"], # noqa: SLF001 + ) + class CoreError(Exception): message: str diff --git a/django_app/redbox_app/redbox_core/migrations/0028_aisettings.py b/django_app/redbox_app/redbox_core/migrations/0028_aisettings.py new file mode 100644 index 000000000..c047c187b --- /dev/null +++ b/django_app/redbox_app/redbox_core/migrations/0028_aisettings.py @@ -0,0 +1,131 @@ +# Generated by Django 5.0.7 on 2024-08-07 12:33 + +import uuid +from django.db import migrations, models +import django.db.models.deletion + + +def create_default_ai_settings(apps, schema_editor): + AISettings = apps.get_model("redbox_core", "AISettings") + AISettings.objects.create(label="default") + + +class Migration(migrations.Migration): + dependencies = [ + ("redbox_core", "0027_alter_file_status"), + ] + + operations = [ + migrations.CreateModel( + name="AISettings", + fields=[ + ( + "id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("modified_at", models.DateTimeField(auto_now=True)), + ("label", models.CharField(max_length=50, unique=True)), + ("context_window_size", models.PositiveIntegerField(default=8000)), + ("rag_k", models.PositiveIntegerField(default=30)), + ("rag_num_candidates", models.PositiveIntegerField(default=10)), + ("rag_desired_chunk_size", models.PositiveIntegerField(default=300)), + ("elbow_filter_enabled", models.BooleanField(default=False)), + ( + "chat_system_prompt", + models.TextField( + default="You are an AI assistant called Redbox tasked with answering questions and providing information objectively." + ), + ), + ( + "chat_question_prompt", + models.TextField(default="{question}\n=========\n Response: "), + ), + ("stuff_chunk_context_ratio", models.FloatField(default=0.75)), + ( + "chat_with_docs_system_prompt", + models.TextField( + default="You are an AI assistant called Redbox tasked with answering questions on user provided documents and providing information objectively." + ), + ), + ( + "chat_with_docs_question_prompt", + models.TextField( + default="Question: {question}. \n\n Documents: \n\n {formatted_documents} \n\n Answer: " + ), + ), + ( + "chat_with_docs_reduce_system_prompt", + models.TextField( + default="You are an AI assistant tasked with answering questions on user provided documents. Your goal is to answer the user question based on list of summaries in a coherent manner.Please follow these guidelines while answering the question: \n1) Identify and highlight key points,\n2) Avoid repetition,\n3) Ensure the answer is easy to understand,\n4) Maintain the original context and meaning.\n" + ), + ), + ( + "retrieval_system_prompt", + models.TextField( + default="Given the following conversation and extracted parts of a long document and a question, create a final answer. \nIf you don't know the answer, just say that you don't know. Don't try to make up an answer. If a user asks for a particular format to be returned, such as bullet points, then please use that format. If a user asks for bullet points you MUST give bullet points. If the user asks for a specific number or range of bullet points you MUST give that number of bullet points. \nUse **bold** to highlight the most question relevant parts in your response. If dealing dealing with lots of data return it in markdown table format. " + ), + ), + ( + "retrieval_question_prompt", + models.TextField( + default="{question} \n=========\n{formatted_documents}\n=========\nFINAL ANSWER: " + ), + ), + ( + "condense_system_prompt", + models.TextField( + default="Given the following conversation and a follow up question, generate a follow up question to be a standalone question. You are only allowed to generate one question in response. Include sources from the chat history in the standalone question created, when they are available. If you don't know the answer, just say that you don't know, don't try to make up an answer. \n" + ), + ), + ( + "condense_question_prompt", + models.TextField( + default="{question}\n=========\n Standalone question: " + ), + ), + ("map_max_concurrency", models.PositiveIntegerField(default=128)), + ( + "chat_map_system_prompt", + models.TextField( + default="You are an AI assistant tasked with summarizing documents. Your goal is to extract the most important information and present it in a concise and coherent manner. Please follow these guidelines while summarizing: \n1) Identify and highlight key points,\n2) Avoid repetition,\n3) Ensure the summary is easy to understand,\n4) Maintain the original context and meaning.\n" + ), + ), + ( + "chat_map_question_prompt", + models.TextField( + default="Question: {question}. \n Documents: \n {formatted_documents} \n\n Answer: " + ), + ), + ( + "reduce_system_prompt", + models.TextField( + default="You are an AI assistant tasked with summarizing documents. Your goal is to write a concise summary of list of summaries from a list of summaries in a concise and coherent manner. Please follow these guidelines while summarizing: \n1) Identify and highlight key points,\n2) Avoid repetition,\n3) Ensure the summary is easy to understand,\n4) Maintain the original context and meaning.\n" + ), + ), + ("llm_max_tokens", models.PositiveIntegerField(default=1024)), + ("match_boost", models.PositiveIntegerField(default=1)), + ("knn_boost", models.PositiveIntegerField(default=1)), + ("similarity_threshold", models.PositiveIntegerField(default=0)), + ], + options={ + "abstract": False, + }, + ), + migrations.RunPython(create_default_ai_settings, migrations.RunPython.noop), + migrations.AddField( + model_name="user", + name="ai_settings", + field=models.ForeignKey( + default="default", + on_delete=django.db.models.deletion.SET_DEFAULT, + to="redbox_core.aisettings", + to_field="label", + ), + ), + ] diff --git a/django_app/redbox_app/redbox_core/models.py b/django_app/redbox_app/redbox_core/models.py index 69a24d0d1..daba69bf7 100644 --- a/django_app/redbox_app/redbox_core/models.py +++ b/django_app/redbox_app/redbox_core/models.py @@ -17,6 +17,7 @@ from jose import jwt from yarl import URL +from redbox_app.redbox_core import prompts from redbox_app.redbox_core.utils import get_date_group logger = logging.getLogger(__name__) @@ -44,6 +45,36 @@ def sanitise_string(string: str | None) -> str | None: return string.replace("\x00", "\ufffd") if string else string +class AISettings(UUIDPrimaryKeyBase, TimeStampedModel): + label = models.CharField(max_length=50, unique=True) + context_window_size = models.PositiveIntegerField(default=8_000) + rag_k = models.PositiveIntegerField(default=30) + rag_num_candidates = models.PositiveIntegerField(default=10) + rag_desired_chunk_size = models.PositiveIntegerField(default=300) + elbow_filter_enabled = models.BooleanField(default=False) + chat_system_prompt = models.TextField(default=prompts.CHAT_SYSTEM_PROMPT) + chat_question_prompt = models.TextField(default=prompts.CHAT_QUESTION_PROMPT) + stuff_chunk_context_ratio = models.FloatField(default=0.75) + chat_with_docs_system_prompt = models.TextField(default=prompts.CHAT_WITH_DOCS_SYSTEM_PROMPT) + chat_with_docs_question_prompt = models.TextField(default=prompts.CHAT_WITH_DOCS_QUESTION_PROMPT) + chat_with_docs_reduce_system_prompt = models.TextField(default=prompts.CHAT_WITH_DOCS_REDUCE_SYSTEM_PROMPT) + retrieval_system_prompt = models.TextField(default=prompts.RETRIEVAL_SYSTEM_PROMPT) + retrieval_question_prompt = models.TextField(default=prompts.RETRIEVAL_QUESTION_PROMPT) + condense_system_prompt = models.TextField(default=prompts.CONDENSE_SYSTEM_PROMPT) + condense_question_prompt = models.TextField(default=prompts.CONDENSE_QUESTION_PROMPT) + map_max_concurrency = models.PositiveIntegerField(default=128) + chat_map_system_prompt = models.TextField(default=prompts.CHAT_MAP_SYSTEM_PROMPT) + chat_map_question_prompt = models.TextField(default=prompts.CHAT_MAP_QUESTION_PROMPT) + reduce_system_prompt = models.TextField(default=prompts.REDUCE_SYSTEM_PROMPT) + llm_max_tokens = models.PositiveIntegerField(default=1024) + match_boost = models.PositiveIntegerField(default=1) + knn_boost = models.PositiveIntegerField(default=1) + similarity_threshold = models.PositiveIntegerField(default=0) + + def __str__(self) -> str: + return str(self.label) + + class BusinessUnit(UUIDPrimaryKeyBase): name = models.TextField(max_length=64, null=False, blank=False, unique=True) @@ -129,6 +160,7 @@ class AIExperienceLevel(models.TextChoices): name = models.CharField(null=True, blank=True) ai_experience = models.CharField(null=True, blank=True, max_length=25, choices=AIExperienceLevel) profession = models.CharField(null=True, blank=True, max_length=4, choices=Profession) + ai_settings = models.ForeignKey(AISettings, on_delete=models.SET_DEFAULT, default="default", to_field="label") objects = BaseUserManager() def __str__(self) -> str: # pragma: no cover diff --git a/django_app/redbox_app/redbox_core/prompts.py b/django_app/redbox_app/redbox_core/prompts.py new file mode 100644 index 000000000..353f1ed17 --- /dev/null +++ b/django_app/redbox_app/redbox_core/prompts.py @@ -0,0 +1,68 @@ +CHAT_SYSTEM_PROMPT = ( + "You are an AI assistant called Redbox tasked with answering questions and providing information objectively." +) + +CHAT_WITH_DOCS_SYSTEM_PROMPT = ( + "You are an AI assistant called Redbox tasked with answering questions on user provided documents and " + "providing information objectively." +) + +CHAT_WITH_DOCS_REDUCE_SYSTEM_PROMPT = ( + "You are an AI assistant tasked with answering questions on user provided documents. " + "Your goal is to answer the user question based on list of summaries in a coherent manner." + "Please follow these guidelines while answering the question: \n" + "1) Identify and highlight key points,\n" + "2) Avoid repetition,\n" + "3) Ensure the answer is easy to understand,\n" + "4) Maintain the original context and meaning.\n" +) + +RETRIEVAL_SYSTEM_PROMPT = ( + "Given the following conversation and extracted parts of a long document and a question, create a final answer. \n" + "If you don't know the answer, just say that you don't know. Don't try to make up an answer. " + "If a user asks for a particular format to be returned, such as bullet points, then please use that format. " + "If a user asks for bullet points you MUST give bullet points. " + "If the user asks for a specific number or range of bullet points you MUST give that number of bullet points. \n" + "Use **bold** to highlight the most question relevant parts in your response. " + "If dealing dealing with lots of data return it in markdown table format. " +) + +CHAT_MAP_SYSTEM_PROMPT = ( + "You are an AI assistant tasked with summarizing documents. " + "Your goal is to extract the most important information and present it in " + "a concise and coherent manner. Please follow these guidelines while summarizing: \n" + "1) Identify and highlight key points,\n" + "2) Avoid repetition,\n" + "3) Ensure the summary is easy to understand,\n" + "4) Maintain the original context and meaning.\n" +) + +REDUCE_SYSTEM_PROMPT = ( + "You are an AI assistant tasked with summarizing documents. " + "Your goal is to write a concise summary of list of summaries from a list of summaries in " + "a concise and coherent manner. Please follow these guidelines while summarizing: \n" + "1) Identify and highlight key points,\n" + "2) Avoid repetition,\n" + "3) Ensure the summary is easy to understand,\n" + "4) Maintain the original context and meaning.\n" +) + +CONDENSE_SYSTEM_PROMPT = ( + "Given the following conversation and a follow up question, generate a follow " + "up question to be a standalone question. " + "You are only allowed to generate one question in response. " + "Include sources from the chat history in the standalone question created, " + "when they are available. " + "If you don't know the answer, just say that you don't know, " + "don't try to make up an answer. \n" +) + +CHAT_QUESTION_PROMPT = "{question}\n=========\n Response: " + +CHAT_WITH_DOCS_QUESTION_PROMPT = "Question: {question}. \n\n Documents: \n\n {formatted_documents} \n\n Answer: " + +RETRIEVAL_QUESTION_PROMPT = "{question} \n=========\n{formatted_documents}\n=========\nFINAL ANSWER: " + +CHAT_MAP_QUESTION_PROMPT = "Question: {question}. \n Documents: \n {formatted_documents} \n\n Answer: " + +CONDENSE_QUESTION_PROMPT = "{question}\n=========\n Standalone question: " diff --git a/django_app/tests/conftest.py b/django_app/tests/conftest.py index d0328ddf7..bf19c5da6 100644 --- a/django_app/tests/conftest.py +++ b/django_app/tests/conftest.py @@ -12,6 +12,7 @@ from redbox_app.redbox_core import client from redbox_app.redbox_core.models import ( + AISettings, BusinessUnit, ChatHistory, ChatMessage, @@ -34,6 +35,8 @@ def _collect_static(): @pytest.fixture() def create_user(): + AISettings.objects.get_or_create(label="default") + def _create_user(email, date_joined_iso, is_staff=False): date_joined = datetime.fromisoformat(date_joined_iso).astimezone(UTC) return User.objects.create_user(email=email, date_joined=date_joined, is_staff=is_staff) diff --git a/django_app/tests/test_consumers.py b/django_app/tests/test_consumers.py index 2091b6e49..d1c18496e 100644 --- a/django_app/tests/test_consumers.py +++ b/django_app/tests/test_consumers.py @@ -15,6 +15,7 @@ from redbox_app.redbox_core import error_messages from redbox_app.redbox_core.consumers import ChatConsumer from redbox_app.redbox_core.models import ChatHistory, ChatMessage, ChatRoleEnum, File, User +from redbox_app.redbox_core.prompts import CHAT_MAP_QUESTION_PROMPT logger = logging.getLogger(__name__) @@ -257,6 +258,8 @@ async def test_chat_consumer_with_selected_files( # TODO (@brunns): Assert selected files sent to core. # Requires fix for https://github.com/django/channels/issues/1091 + # fixed now merged in https://github.com/django/channels/pull/2101, but not released + # Retry this when a version of Channels after 4.1.0 is released mocked_websocket = mocked_connect_with_several_files.return_value.__aenter__.return_value expected = json.dumps( { @@ -268,6 +271,7 @@ async def test_chat_consumer_with_selected_files( {"role": "user", "text": "Third question, with selected files?"}, ], "selected_files": selected_file_core_uuids, + "ai_settings": await ChatConsumer.get_ai_settings(alice), } ) mocked_websocket.send.assert_called_with(expected) @@ -354,6 +358,27 @@ async def test_chat_consumer_with_explicit_no_document_selected_error( await communicator.disconnect() +@pytest.mark.django_db() +@pytest.mark.asyncio() +async def test_chat_consumer_get_ai_settings( + alice: User, mocked_connect_with_explicit_no_document_selected_error: Connect +): + with patch("redbox_app.redbox_core.consumers.connect", new=mocked_connect_with_explicit_no_document_selected_error): + communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/") + communicator.scope["user"] = alice + connected, _ = await communicator.connect() + assert connected + + ai_settings = await ChatConsumer.get_ai_settings(alice) + + assert ai_settings["chat_map_question_prompt"] == CHAT_MAP_QUESTION_PROMPT + with pytest.raises(KeyError): + ai_settings["label"] + + # Close + await communicator.disconnect() + + @database_sync_to_async def get_chat_messages(user: User) -> Sequence[ChatMessage]: return list( diff --git a/django_app/tests/test_migrations.py b/django_app/tests/test_migrations.py index 699ded1a7..cbdad40b7 100644 --- a/django_app/tests/test_migrations.py +++ b/django_app/tests/test_migrations.py @@ -128,3 +128,19 @@ def test_0027_alter_file_status(migrator): # Cleanup: migrator.reset() + + +@pytest.mark.django_db() +def test_0028_aisettings(migrator): + old_state = migrator.apply_initial_migration(("redbox_core", "0027_alter_file_status")) + + User = old_state.apps.get_model("redbox_core", "User") + User.objects.create(email="someone@example.com") + + new_state = migrator.apply_tested_migration( + ("redbox_core", "0028_aisettings"), + ) + NewUser = new_state.apps.get_model("redbox_core", "User") # noqa: N806 + + for user in NewUser.objects.all(): + assert user.ai_settings.label == "default"