Skip to content

Commit

Permalink
feat(rag): expose similarity_top_k and similarity_score to settings (#…
Browse files Browse the repository at this point in the history
…1771)

* Added RAG settings to settings.py, vector_store and chat_service to add similarity_top_k and similarity_score

* Updated settings in vector and chat service per Ivans request

* Updated code for mypy
  • Loading branch information
icsy7867 authored Mar 20, 2024
1 parent 774e256 commit 087cb0b
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 1 deletion.
16 changes: 15 additions & 1 deletion private_gpt/server/chat/chat_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from llama_index.core.indices import VectorStoreIndex
from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor
from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.core.postprocessor import (
SimilarityPostprocessor,
)
from llama_index.core.storage import StorageContext
from llama_index.core.types import TokenGen
from pydantic import BaseModel
Expand All @@ -20,6 +23,7 @@
)
from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.server.chunks.chunks_service import Chunk
from private_gpt.settings.settings import Settings


class Completion(BaseModel):
Expand Down Expand Up @@ -68,14 +72,18 @@ def from_messages(cls, messages: list[ChatMessage]) -> "ChatEngineInput":

@singleton
class ChatService:
settings: Settings

@inject
def __init__(
self,
settings: Settings,
llm_component: LLMComponent,
vector_store_component: VectorStoreComponent,
embedding_component: EmbeddingComponent,
node_store_component: NodeStoreComponent,
) -> None:
self.settings = settings
self.llm_component = llm_component
self.embedding_component = embedding_component
self.vector_store_component = vector_store_component
Expand All @@ -98,16 +106,22 @@ def _chat_engine(
use_context: bool = False,
context_filter: ContextFilter | None = None,
) -> BaseChatEngine:
settings = self.settings
if use_context:
vector_index_retriever = self.vector_store_component.get_retriever(
index=self.index, context_filter=context_filter
index=self.index,
context_filter=context_filter,
similarity_top_k=self.settings.rag.similarity_top_k,
)
return ContextChatEngine.from_defaults(
system_prompt=system_prompt,
retriever=vector_index_retriever,
llm=self.llm_component.llm, # Takes no effect at the moment
node_postprocessors=[
MetadataReplacementPostProcessor(target_metadata_key="window"),
SimilarityPostprocessor(
similarity_cutoff=settings.rag.similarity_value
),
],
)
else:
Expand Down
12 changes: 12 additions & 0 deletions private_gpt/settings/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,17 @@ class UISettings(BaseModel):
)


class RagSettings(BaseModel):
similarity_top_k: int = Field(
2,
description="This value controls the number of documents returned by the RAG pipeline",
)
similarity_value: float = Field(
None,
description="If set, any documents retrieved from the RAG must meet a certain match score. Acceptable values are between 0 and 1.",
)


class PostgresSettings(BaseModel):
host: str = Field(
"localhost",
Expand Down Expand Up @@ -379,6 +390,7 @@ class Settings(BaseModel):
azopenai: AzureOpenAISettings
vectorstore: VectorstoreSettings
nodestore: NodeStoreSettings
rag: RagSettings
qdrant: QdrantSettings | None = None
postgres: PostgresSettings | None = None

Expand Down
6 changes: 6 additions & 0 deletions settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ llm:
tokenizer: mistralai/Mistral-7B-Instruct-v0.2
temperature: 0.1 # The temperature of the model. Increasing the temperature will make the model answer more creatively. A value of 0.1 would be more factual. (Default: 0.1)

rag:
similarity_top_k: 2
#This value controls how many "top" documents the RAG returns to use in the context.
#similarity_value: 0.45
#This value is disabled by default. If you enable this settings, the RAG will only use articles that meet a certain percentage score.

llamacpp:
prompt_style: "mistral"
llm_hf_repo_id: TheBloke/Mistral-7B-Instruct-v0.2-GGUF
Expand Down

0 comments on commit 087cb0b

Please sign in to comment.