diff --git a/core_api/src/dependencies.py b/core_api/src/dependencies.py index 2bda0e913..d55107a7a 100644 --- a/core_api/src/dependencies.py +++ b/core_api/src/dependencies.py @@ -96,6 +96,7 @@ def get_parameterised_retriever( index_name=f"{env.elastic_root_index}-chunk", params=default_params, embedding_model=get_embedding_model(env), + embedding_field_name=env.embedding_document_field_name, ).configurable_fields( params=ConfigurableField( id="params", name="Retriever parameters", description="A dictionary of parameters to use for the retriever." diff --git a/core_api/src/retriever/queries.py b/core_api/src/retriever/queries.py index 0e20b779a..4608e6381 100644 --- a/core_api/src/retriever/queries.py +++ b/core_api/src/retriever/queries.py @@ -46,12 +46,14 @@ def get_all(query: ESQuery) -> dict[str, Any]: } ) return { - "_source": {"excludes": ["embedding"]}, + "_source": {"excludes": ["*embedding"]}, "query": {"bool": {"must": {"match_all": {}}, "filter": query_filter}}, } -def get_some(embedding_model: Embeddings, params: ESParams, query: ESQuery) -> dict[str, Any]: +def get_some( + embedding_model: Embeddings, params: ESParams, embedding_field_name: str, query: ESQuery +) -> dict[str, Any]: vector = embedding_model.embed_query(query["question"]) query_filter = [ @@ -92,7 +94,7 @@ def get_some(embedding_model: Embeddings, params: ESParams, query: ESQuery) -> d }, { "knn": { - "field": "embedding", + "field": embedding_field_name, "query_vector": vector, "num_candidates": params["num_candidates"], "filter": query_filter, diff --git a/core_api/src/retriever/retrievers.py b/core_api/src/retriever/retrievers.py index ef8ea4cbe..a3df811cb 100644 --- a/core_api/src/retriever/retrievers.py +++ b/core_api/src/retriever/retrievers.py @@ -26,6 +26,7 @@ def hit_to_doc(hit: dict[str, Any]) -> Document: class ParameterisedElasticsearchRetriever(ElasticsearchRetriever): params: ESParams embedding_model: Embeddings + embedding_field_name: str = "embedding" def __init__(self, **kwargs: Any) -> None: # Hack to pass validation before overwrite @@ -33,7 +34,7 @@ def __init__(self, **kwargs: Any) -> None: kwargs["body_func"] = get_some kwargs["document_mapper"] = hit_to_doc super().__init__(**kwargs) - self.body_func = partial(get_some, self.embedding_model, self.params) + self.body_func = partial(get_some, self.embedding_model, self.params, self.embedding_field_name) class AllElasticsearchRetriever(ElasticsearchRetriever): diff --git a/core_api/tests/conftest.py b/core_api/tests/conftest.py index b3ecf380d..db1be0fd1 100644 --- a/core_api/tests/conftest.py +++ b/core_api/tests/conftest.py @@ -237,6 +237,7 @@ def parameterised_retriever(env, es_client, embedding_model_dim): index_name=f"{env.elastic_root_index}-chunk", params=default_params, embedding_model=FakeEmbeddings(size=embedding_model_dim), + embedding_field_name=env.embedding_document_field_name, ).configurable_fields( params=ConfigurableField( id="params", name="Retriever parameters", description="A dictionary of parameters to use for the retriever."