From 15753927f5eb8f536a3eaf2826959c147c722f84 Mon Sep 17 00:00:00 2001 From: jamesrichards Date: Mon, 1 Jul 2024 13:30:01 +0000 Subject: [PATCH 1/3] Providing embedding field name to retriever --- core_api/src/retriever/queries.py | 6 +++--- core_api/src/retriever/retrievers.py | 3 ++- core_api/tests/conftest.py | 1 + 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/core_api/src/retriever/queries.py b/core_api/src/retriever/queries.py index 0e20b779a..9e08b2a0b 100644 --- a/core_api/src/retriever/queries.py +++ b/core_api/src/retriever/queries.py @@ -46,12 +46,12 @@ def get_all(query: ESQuery) -> dict[str, Any]: } ) return { - "_source": {"excludes": ["embedding"]}, + "_source": {"excludes": ["*embedding"]}, "query": {"bool": {"must": {"match_all": {}}, "filter": query_filter}}, } -def get_some(embedding_model: Embeddings, params: ESParams, query: ESQuery) -> dict[str, Any]: +def get_some(embedding_model: Embeddings, params: ESParams, embedding_field_name: str, query: ESQuery) -> dict[str, Any]: vector = embedding_model.embed_query(query["question"]) query_filter = [ @@ -92,7 +92,7 @@ def get_some(embedding_model: Embeddings, params: ESParams, query: ESQuery) -> d }, { "knn": { - "field": "embedding", + "field": embedding_field_name, "query_vector": vector, "num_candidates": params["num_candidates"], "filter": query_filter, diff --git a/core_api/src/retriever/retrievers.py b/core_api/src/retriever/retrievers.py index ef8ea4cbe..a3df811cb 100644 --- a/core_api/src/retriever/retrievers.py +++ b/core_api/src/retriever/retrievers.py @@ -26,6 +26,7 @@ def hit_to_doc(hit: dict[str, Any]) -> Document: class ParameterisedElasticsearchRetriever(ElasticsearchRetriever): params: ESParams embedding_model: Embeddings + embedding_field_name: str = "embedding" def __init__(self, **kwargs: Any) -> None: # Hack to pass validation before overwrite @@ -33,7 +34,7 @@ def __init__(self, **kwargs: Any) -> None: kwargs["body_func"] = get_some kwargs["document_mapper"] = hit_to_doc super().__init__(**kwargs) - self.body_func = partial(get_some, self.embedding_model, self.params) + self.body_func = partial(get_some, self.embedding_model, self.params, self.embedding_field_name) class AllElasticsearchRetriever(ElasticsearchRetriever): diff --git a/core_api/tests/conftest.py b/core_api/tests/conftest.py index b3ecf380d..f66b08ddd 100644 --- a/core_api/tests/conftest.py +++ b/core_api/tests/conftest.py @@ -237,6 +237,7 @@ def parameterised_retriever(env, es_client, embedding_model_dim): index_name=f"{env.elastic_root_index}-chunk", params=default_params, embedding_model=FakeEmbeddings(size=embedding_model_dim), + embedding_field_name=env.embedding_document_field_name ).configurable_fields( params=ConfigurableField( id="params", name="Retriever parameters", description="A dictionary of parameters to use for the retriever." From 0a37c2948ea16962a3bf31c91b21b9362f37d225 Mon Sep 17 00:00:00 2001 From: jamesrichards Date: Mon, 1 Jul 2024 13:31:07 +0000 Subject: [PATCH 2/3] Providing embedding field name to retriever --- core_api/src/dependencies.py | 1 + 1 file changed, 1 insertion(+) diff --git a/core_api/src/dependencies.py b/core_api/src/dependencies.py index 2bda0e913..bf02ece6f 100644 --- a/core_api/src/dependencies.py +++ b/core_api/src/dependencies.py @@ -96,6 +96,7 @@ def get_parameterised_retriever( index_name=f"{env.elastic_root_index}-chunk", params=default_params, embedding_model=get_embedding_model(env), + embedding_field_name=env.embedding_document_field_name ).configurable_fields( params=ConfigurableField( id="params", name="Retriever parameters", description="A dictionary of parameters to use for the retriever." From d3af31caa36ceb786babfef6833a52db7ba8e794 Mon Sep 17 00:00:00 2001 From: jamesrichards Date: Mon, 1 Jul 2024 14:32:21 +0100 Subject: [PATCH 3/3] Ruff --- core_api/src/dependencies.py | 2 +- core_api/src/retriever/queries.py | 4 +++- core_api/tests/conftest.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/core_api/src/dependencies.py b/core_api/src/dependencies.py index bf02ece6f..d55107a7a 100644 --- a/core_api/src/dependencies.py +++ b/core_api/src/dependencies.py @@ -96,7 +96,7 @@ def get_parameterised_retriever( index_name=f"{env.elastic_root_index}-chunk", params=default_params, embedding_model=get_embedding_model(env), - embedding_field_name=env.embedding_document_field_name + embedding_field_name=env.embedding_document_field_name, ).configurable_fields( params=ConfigurableField( id="params", name="Retriever parameters", description="A dictionary of parameters to use for the retriever." diff --git a/core_api/src/retriever/queries.py b/core_api/src/retriever/queries.py index 9e08b2a0b..4608e6381 100644 --- a/core_api/src/retriever/queries.py +++ b/core_api/src/retriever/queries.py @@ -51,7 +51,9 @@ def get_all(query: ESQuery) -> dict[str, Any]: } -def get_some(embedding_model: Embeddings, params: ESParams, embedding_field_name: str, query: ESQuery) -> dict[str, Any]: +def get_some( + embedding_model: Embeddings, params: ESParams, embedding_field_name: str, query: ESQuery +) -> dict[str, Any]: vector = embedding_model.embed_query(query["question"]) query_filter = [ diff --git a/core_api/tests/conftest.py b/core_api/tests/conftest.py index f66b08ddd..db1be0fd1 100644 --- a/core_api/tests/conftest.py +++ b/core_api/tests/conftest.py @@ -237,7 +237,7 @@ def parameterised_retriever(env, es_client, embedding_model_dim): index_name=f"{env.elastic_root_index}-chunk", params=default_params, embedding_model=FakeEmbeddings(size=embedding_model_dim), - embedding_field_name=env.embedding_document_field_name + embedding_field_name=env.embedding_document_field_name, ).configurable_fields( params=ConfigurableField( id="params", name="Retriever parameters", description="A dictionary of parameters to use for the retriever."