From 1c13e434b7e44a34dd4a4d2c4ee66095ee6f8b37 Mon Sep 17 00:00:00 2001 From: Cheney Zhang Date: Fri, 10 Jan 2025 18:11:17 +0800 Subject: [PATCH] feat: support bm25 milvus function (#33) This PR introduced some major refactors: - Introduce the abstract class `BaseMilvusBuiltInFunction`, which is a light wrapper of [Milvus Function](https://milvus.io/docs/manage-collections.md#Function). - Introduce `Bm25BuiltInFunction` extended from `BaseMilvusBuiltInFunction` , which includes the Milvus `FunctionType.BM25` settings and the configs of Milvus analyzer. We can use this `Bm25BuiltInFunction` to implement [Full text search](https://milvus.io/docs/full-text-search.md) in Milvus - In the future, Milvus will support more built-in Functions which support text-in(instead of vector-in) abilities, without transporting text to embedding on the user's end because it does this on the server's end automatically (here is a `FunctionType.TEXTEMBEDDING` [example](https://github.com/milvus-io/pymilvus/blob/master/examples/text_embedding.py)). So in the future we can implement more subclass from `BaseMilvusBuiltInFunction` to support the text-in functions in Milvus. - The how-to-use introduction is on the way, and there are some use case examples in the unittest `test_builtin_bm25_function()`. Simply speaking, we can pass in any customized Langchain embedding functions or milvus built-in functions to the Milvus class initialization function to build multi index fields in Milvus. Some use case examples will be like these: ```python from langchain_milvus import Milvus, BM25BuiltInFunction from langchain_openai import OpenAIEmbeddings embedding = OpenAIEmbeddings() vectorstore = Milvus.from_documents( documents=docs, embedding=embedding, builtin_function=BM25BuiltInFunction( output_field_names="sparse" ), #"dense" field is used for similarity search for OpenAI dense embedding, "sparse" field is used for BM25 full-text search vector_field=["dense", "sparse"], connection_args={ "uri": URI, }, drop_old=True, ) ``` or with multi embedding fields and bm25 function: ```python from langchain_voyageai import VoyageAIEmbeddings embedding = OpenAIEmbeddings() embedding2 = VoyageAIEmbeddings(model="voyage-3") vectorstore = Milvus.from_documents( documents=docs, embedding=[embedding, embedding2], builtin_function=BM25BuiltInFunction( input_field_names="text", output_field_names="sparse" ), text_field="text", vector_field=["dense", "dense2", "sparse"], connection_args={ "uri": URI, }, drop_old=True, ) ``` --------- Signed-off-by: ChengZi --- libs/milvus/langchain_milvus/__init__.py | 6 + libs/milvus/langchain_milvus/function.py | 74 ++ .../milvus/langchain_milvus/utils/constant.py | 4 + .../langchain_milvus/vectorstores/milvus.py | 915 +++++++++++++----- .../langchain_milvus/vectorstores/zilliz.py | 78 +- libs/milvus/poetry.lock | 55 +- libs/milvus/pyproject.toml | 2 +- .../vectorstores/test_milvus.py | 131 ++- libs/milvus/tests/unit_tests/test_imports.py | 2 + 9 files changed, 883 insertions(+), 384 deletions(-) create mode 100644 libs/milvus/langchain_milvus/function.py create mode 100644 libs/milvus/langchain_milvus/utils/constant.py diff --git a/libs/milvus/langchain_milvus/__init__.py b/libs/milvus/langchain_milvus/__init__.py index b19bc1d..1bc46b8 100644 --- a/libs/milvus/langchain_milvus/__init__.py +++ b/libs/milvus/langchain_milvus/__init__.py @@ -1,3 +1,7 @@ +from langchain_milvus.function import ( + BaseMilvusBuiltInFunction, + BM25BuiltInFunction, +) from langchain_milvus.retrievers import ( MilvusCollectionHybridSearchRetriever, ZillizCloudPipelineRetriever, @@ -9,4 +13,6 @@ "Zilliz", "ZillizCloudPipelineRetriever", "MilvusCollectionHybridSearchRetriever", + "BaseMilvusBuiltInFunction", + "BM25BuiltInFunction", ] diff --git a/libs/milvus/langchain_milvus/function.py b/libs/milvus/langchain_milvus/function.py new file mode 100644 index 0000000..4b47682 --- /dev/null +++ b/libs/milvus/langchain_milvus/function.py @@ -0,0 +1,74 @@ +import uuid +from abc import ABC +from typing import Any, Dict, List, Optional, Union + +from pymilvus import Function, FunctionType + +from langchain_milvus.utils.constant import SPARSE_VECTOR_FIELD, TEXT_FIELD + + +class BaseMilvusBuiltInFunction(ABC): + """ + Base class for Milvus built-in functions. + + See: + https://milvus.io/docs/manage-collections.md#Function + """ + + def __init__(self) -> None: + self._function: Optional[Function] = None + + @property + def function(self) -> Function: + return self._function + + @property + def input_field_names(self) -> Union[str, List[str]]: + return self.function.input_field_names + + @property + def output_field_names(self) -> Union[str, List[str]]: + return self.function.output_field_names + + @property + def type(self) -> FunctionType: + return self.function.type + + +class BM25BuiltInFunction(BaseMilvusBuiltInFunction): + """ + Milvus BM25 built-in function. + + See: + https://milvus.io/docs/full-text-search.md + """ + + def __init__( + self, + *, + input_field_names: str = TEXT_FIELD, + output_field_names: str = SPARSE_VECTOR_FIELD, + analyzer_params: Optional[Dict[Any, Any]] = None, + enable_match: bool = False, + function_name: Optional[str] = None, + ): + super().__init__() + if not function_name: + function_name = f"bm25_function_{str(uuid.uuid4())[:8]}" + self._function = Function( + name=function_name, + input_field_names=input_field_names, + output_field_names=output_field_names, + function_type=FunctionType.BM25, + ) + self.analyzer_params: Optional[Dict[Any, Any]] = analyzer_params + self.enable_match = enable_match + + def get_input_field_schema_kwargs(self) -> dict: + field_schema_kwargs: Dict[Any, Any] = { + "enable_analyzer": True, + "enable_match": self.enable_match, + } + if self.analyzer_params is not None: + field_schema_kwargs["analyzer_params"] = self.analyzer_params + return field_schema_kwargs diff --git a/libs/milvus/langchain_milvus/utils/constant.py b/libs/milvus/langchain_milvus/utils/constant.py new file mode 100644 index 0000000..d6985a8 --- /dev/null +++ b/libs/milvus/langchain_milvus/utils/constant.py @@ -0,0 +1,4 @@ +VECTOR_FIELD = "vector" +SPARSE_VECTOR_FIELD = "sparse" +TEXT_FIELD = "text" +PRIMARY_FIELD = "pk" diff --git a/libs/milvus/langchain_milvus/vectorstores/milvus.py b/libs/milvus/langchain_milvus/vectorstores/milvus.py index bbcef1d..33ab4c4 100644 --- a/libs/milvus/langchain_milvus/vectorstores/milvus.py +++ b/libs/milvus/langchain_milvus/vectorstores/milvus.py @@ -20,20 +20,24 @@ from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore from pymilvus import ( + AnnSearchRequest, Collection, CollectionSchema, DataType, FieldSchema, + FunctionType, MilvusClient, MilvusException, RRFRanker, + SearchResult, WeightedRanker, utility, ) from pymilvus.client.types import LoadState # type: ignore from pymilvus.orm.types import infer_dtype_bydata # type: ignore -from langchain_milvus import MilvusCollectionHybridSearchRetriever +from langchain_milvus.function import BaseMilvusBuiltInFunction, BM25BuiltInFunction +from langchain_milvus.utils.constant import PRIMARY_FIELD, TEXT_FIELD, VECTOR_FIELD from langchain_milvus.utils.sparse import BaseSparseEmbedding logger = logging.getLogger(__name__) @@ -267,7 +271,7 @@ class Milvus(VectorStore): def __init__( self, - embedding_function: Union[EmbeddingType, List[EmbeddingType]], # type: ignore + embedding_function: Optional[Union[EmbeddingType, List[EmbeddingType]]], collection_name: str = "LangChainCollection", collection_description: str = "", collection_properties: Optional[dict[str, Any]] = None, @@ -278,9 +282,9 @@ def __init__( drop_old: Optional[bool] = False, auto_id: bool = False, *, - primary_field: str = "pk", - text_field: str = "text", - vector_field: Union[str, List[str]] = "vector", + primary_field: str = PRIMARY_FIELD, + text_field: str = TEXT_FIELD, + vector_field: Union[str, List[str]] = VECTOR_FIELD, enable_dynamic_field: bool = False, metadata_field: Optional[str] = None, partition_key_field: Optional[str] = None, @@ -290,6 +294,9 @@ def __init__( num_shards: Optional[int] = None, vector_schema: Optional[Union[dict[str, Any], List[dict[str, Any]]]] = None, metadata_schema: Optional[dict[str, Any]] = None, + builtin_function: Optional[ + Union[BaseMilvusBuiltInFunction, List[BaseMilvusBuiltInFunction]] + ] = None, ): """Initialize the Milvus vector store.""" # Default search params when one is not provided. @@ -325,7 +332,17 @@ def __init__( "SPARSE_WAND": {"metric_type": "IP", "params": {"drop_ratio_build": 0.2}}, } - self.embedding_func = embedding_function + if not embedding_function and not builtin_function: + raise ValueError( + "Either `embedding_function` or `builtin_function` should be provided." + ) + + self.embedding_func: Optional[ + Union[EmbeddingType, List[EmbeddingType]] + ] = self._from_list(embedding_function) + self.builtin_func: Optional[ + Union[BaseMilvusBuiltInFunction, List[BaseMilvusBuiltInFunction]] + ] = self._from_list(builtin_function) self.collection_name = collection_name self.collection_description = collection_description self.collection_properties = collection_properties @@ -339,24 +356,7 @@ def __init__( # In order for compatibility, the text field will need to be called "text" self._text_field = text_field - if isinstance(self.embedding_func, list): - if len(self.embedding_func) == 1: - self.embedding_func = self.embedding_func[0] - else: - self.embedding_func = cast(List[EmbeddingType], self.embedding_func) - if not isinstance(vector_field, list): - vector_field = [ - f"vector_{i + 1}" for i, e in enumerate(self.embedding_func) - ] - logger.warning( - "When multiple embeddings function are used, one should provide" - "matching `vector_field` names. " - "Using generated vector names %s", - vector_field, - ) - - # In order for compatibility, the vector field needs to be called "vector" - self._vector_field = vector_field + self._check_vector_field(vector_field, vector_schema) if metadata_field: logger.warning( "DeprecationWarning: `metadata_field` is about to be deprecated, " @@ -376,7 +376,6 @@ def __init__( self.timeout = timeout self.num_shards = num_shards self.metadata_schema = metadata_schema - self.vector_schema = vector_schema # Create the connection to the server if connection_args is None: @@ -408,8 +407,155 @@ def __init__( timeout=timeout, ) + def _check_vector_field( + self, + vector_field: Union[str, List[str]], + vector_schema: Optional[Union[dict[str, Any], List[dict[str, Any]]]] = None, + ) -> None: + """ + Check the validity of vector_field and vector_schema, + as well as the relationships with embedding_func and builtin_func. + """ + assert len(self._as_list(vector_field)) == len( + set(self._as_list(vector_field)) + ), "Vector field names should be unique." + + vector_fields_from_function = [] + for builtin_function in self._as_list(self.builtin_func): + vector_fields_from_function.extend( + self._as_list(builtin_function.output_field_names) + ) + # Check there are not overlapping fields + assert len(vector_fields_from_function) == len( + set(vector_fields_from_function) + ), "When using builtin functions, there should be no overlapping fields." + + embedding_fields_expected = [] + for field in self._as_list(vector_field): + if field not in vector_fields_from_function: + embedding_fields_expected.append(field) + + # Number of customized fields <= number of embedding functions + if len(embedding_fields_expected) <= len(self._as_list(self.embedding_func)): # type: ignore[arg-type] + vector_fields_from_embedding = embedding_fields_expected + appending_fields = [] + for i in range( + len(embedding_fields_expected), + len(self._as_list(self.embedding_func)), # type: ignore[arg-type] + ): + appending_fields.append(f"vector_{i + 1}") + vector_fields_from_embedding.extend(appending_fields) + if len(appending_fields) > 0: + logger.warning( + "When multiple embeddings function are used, one should provide " + "matching `vector_field` names. " + "Using generated vector names %s", + appending_fields, + ) + # Number of customized fields > number of embedding functions + else: + raise ValueError( + f"Too many custom fields: {embedding_fields_expected}." + f" They cannot be mapped to a limited number of embedding functions," + f" nor do they belong to any build-in function." + ) + + assert ( + len(set(vector_fields_from_function) & set(vector_fields_from_embedding)) + == 0 + ), ( + "Vector fields from embeddings and vector fields from builtin functions " + "should not overlap." + ) + all_vector_fields = vector_fields_from_embedding + vector_fields_from_function + # For backward compatibility, the vector field needs to be called "vector", + # and it can be either a list or a string. + self._vector_field: Union[str, List[str]] = cast( + Union[str, List[str]], self._from_list(all_vector_fields) + ) + self._vector_fields_from_embedding: List[str] = vector_fields_from_embedding + self._vector_fields_from_function: List[str] = vector_fields_from_function + + # Check vector schema and prepare vector schema map + self.vector_schema = vector_schema + self._vector_schema_map: Dict[str, dict] = {} + if self.vector_schema: + if len(self._as_list(self.vector_schema)) == 1: + assert len(self._as_list(self._vector_field)) == 1, ( + "When only one custom vector_schema is provided, " + "it should keep the vector store has only one vector field." + ) + vector_field_ = cast(str, self._from_list(self._vector_field)) + vector_schema_ = cast(dict, self._from_list(self.vector_schema)) + self._vector_schema_map[vector_field_] = vector_schema_ + else: + if self._is_embedding_only or self._is_function_only: + assert len(self._as_list(self._vector_field)) == len( + self._as_list(self.vector_schema) + ), ( + "You should provide the same number of custom `vector_schema`s " + "as the number of corresponding `vector_field`s." + ) + else: + # If both embedding and builtin functions are provided, + # it must specify vector_schema for each vector field. + assert len(self._as_list(vector_field)) == len( + self._as_list(self.vector_schema) + ), ( + "When multiple custom `vector_schema`s are provided, " + "you should provide the same number of corresponding " + "`vector_field`s." + ) + for field, vector_schema in zip( + self._as_list(vector_field), self._as_list(self.vector_schema) + ): + self._vector_schema_map[field] = vector_schema + else: + self._vector_schema_map = { + field: {} for field in self._as_list(self._vector_field) + } + + # Check index param and prepare index param map + self._index_param_map: Dict[str, dict] = {} + if self.index_params: + if len(self._as_list(self.index_params)) == 1: + assert len(self._as_list(self._vector_field)) == 1, ( + "When only one custom index_params is provided, " + "it should keep the vector store has only one vector field." + ) + vector_field_ = cast(str, self._from_list(self._vector_field)) + index_params_ = cast(dict, self._from_list(self.index_params)) + self._index_param_map[vector_field_] = index_params_ + else: + if self._is_embedding_only or self._is_function_only: + assert len(self._as_list(self._vector_field)) == len( + self._as_list(self.index_params) + ), ( + "You should provide the same number of custom `index_params`s " + "as the number of corresponding `vector_field`s." + ) + else: + # If both embedding and builtin functions are provided, + # it must specify index_params for each vector field. + assert len(self._as_list(vector_field)) == len( + self._as_list(self.index_params) + ), ( + "When multiple custom `index_params`s are provided, " + "you should provide the same number of corresponding " + "`vector_field`s." + ) + for field, index_params in zip( + self._as_list(vector_field), self._as_list(self.index_params) + ): + self._index_param_map[field] = index_params + else: + self._index_param_map = { + field: {} for field in self._as_list(self._vector_field) + } + @property - def embeddings(self) -> Union[EmbeddingType, List[EmbeddingType]]: # type: ignore + def embeddings(self) -> Optional[Union[EmbeddingType, List[EmbeddingType]]]: # type: ignore + """Get embedding function(s).""" return self.embedding_func @property @@ -417,17 +563,58 @@ def client(self) -> MilvusClient: """Get client.""" return self._milvus_client + @property + def vector_fields(self) -> List[str]: + """Get vector field(s).""" + return self._as_list(self._vector_field) + @property def _is_multi_vector(self) -> bool: - return isinstance(self.embedding_func, list) + """Whether the sum of embedding functions and builtin functions is multi.""" + return isinstance(self._vector_field, list) and len(self._vector_field) > 1 + + @property + def _is_multi_embedding(self) -> bool: + """Whether there are multi embedding functions in this instance.""" + return isinstance(self.embedding_func, list) and len(self.embedding_func) > 1 + + @property + def _is_multi_function(self) -> bool: + """Whether there are multi builtin functions in this instance.""" + return isinstance(self.builtin_func, list) and len(self.builtin_func) > 1 + + @property + def _is_embedding_only(self) -> bool: + """Whether there are only embedding function(s) but no builtin function(s).""" + return ( + len(self._as_list(self.embedding_func)) > 0 # type: ignore[arg-type] + and len(self._as_list(self.builtin_func)) == 0 + ) + + @property + def _is_function_only(self) -> bool: + """Whether there are only builtin function(s) but no embedding function(s).""" + return ( + len(self._as_list(self.embedding_func)) == 0 # type: ignore[arg-type] + and len(self._as_list(self.builtin_func)) > 0 + ) @property def _is_sparse(self) -> bool: - embedding_func: List[EmbeddingType] = self._as_list(self.embedding_func) - if self._is_sparse_embedding(embedding_func[0]): - return True - else: - return False + """Detect whether there is only one sparse embedding/builtin function""" + if self._is_embedding_only: + embedding_func = self._as_list(self.embedding_func) # type: ignore[arg-type] + if len(embedding_func) == 1 and self._is_sparse_embedding( + embedding_func[0] # type: ignore[arg-type] + ): + return True + if self._is_function_only: + builtin_func = self._as_list(self.builtin_func) + if len(builtin_func) == 1 and isinstance( + builtin_func[0], BM25BuiltInFunction + ): + return True + return False @staticmethod def _is_sparse_embedding(embeddings_function: EmbeddingType) -> bool: @@ -455,16 +642,55 @@ def _init( def _create_collection( self, embeddings: List[list], metadatas: Optional[list[dict]] = None ) -> None: + metadata_fields = self._prepare_metadata_fields(metadatas) + text_fields = self._prepare_text_fields() + primary_key_fields = self._prepare_primary_key_fields() + vector_fields = self._prepare_vector_fields(embeddings) + + fields = text_fields + primary_key_fields + vector_fields + metadata_fields + + # Create the schema for the collection + schema = CollectionSchema( + fields, + description=self.collection_description, + partition_key_field=self._partition_key_field, + enable_dynamic_field=self.enable_dynamic_field, + functions=[func.function for func in self._as_list(self.builtin_func)], + ) + + # Create the collection + try: + if self.num_shards is not None: + # Issue with defaults: + # https://github.com/milvus-io/pymilvus/blob/59bf5e811ad56e20946559317fed855330758d9c/pymilvus/client/prepare.py#L82-L85 + self.col = Collection( + name=self.collection_name, + schema=schema, + consistency_level=self.consistency_level, + using=self.alias, + num_shards=self.num_shards, + ) + else: + self.col = Collection( + name=self.collection_name, + schema=schema, + consistency_level=self.consistency_level, + using=self.alias, + ) + # Set the collection properties if they exist + if self.collection_properties is not None: + self.col.set_properties(self.collection_properties) + except MilvusException as e: + logger.error( + "Failed to create collection: %s error: %s", self.collection_name, e + ) + raise e + + def _prepare_metadata_fields( + self, metadatas: Optional[list[dict]] = None + ) -> List[FieldSchema]: fields = [] - vector_fields: List[str] = self._as_list(self._vector_field) # If enable_dynamic_field, we don't need to create fields, and just pass it. - # In the future, when metadata_field is deprecated, - # This logical structure will be simplified like this: - # ``` - # if not self.enable_dynamic_field and metadatas: - # for key, value in metadatas[0].items(): - # ... - # ``` if self.enable_dynamic_field: # If both dynamic fields and partition key field are enabled if self._partition_key_field is not None: @@ -480,7 +706,9 @@ def _create_collection( # Determine metadata schema if metadatas: # Create FieldSchema for each entry in metadata. + vector_fields: List[str] = self._as_list(self._vector_field) for key, value in metadatas[0].items(): + # Check if the key is reserved if ( key in [ @@ -522,8 +750,19 @@ def _create_collection( raise ValueError(f"Unrecognized datatype for {key}.") # Datatype is a string/varchar equivalent elif dtype == DataType.VARCHAR: + kwargs = {} + for function in self._as_list(self.builtin_func): + if isinstance(function, BM25BuiltInFunction): + if function.input_field_names == self._text_field: + kwargs = ( + function.get_input_field_schema_kwargs() + ) + break + fields.append( - FieldSchema(key, DataType.VARCHAR, max_length=65_535) + FieldSchema( + key, DataType.VARCHAR, max_length=65_535, **kwargs + ) ) # infer_dtype_bydata currently can't recognize array type, # so this line can not be accessed. @@ -537,12 +776,24 @@ def _create_collection( ) else: fields.append(FieldSchema(key, dtype)) + return fields + + def _prepare_text_fields(self) -> List[FieldSchema]: + fields = [] + kwargs = {} + for function in self._as_list(self.builtin_func): + if isinstance(function, BM25BuiltInFunction): + if self._from_list(function.input_field_names) == self._text_field: + kwargs = function.get_input_field_schema_kwargs() + break - # Create the text field fields.append( - FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535) + FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535, **kwargs) ) - # Create the primary key field + return fields + + def _prepare_primary_key_fields(self) -> List[FieldSchema]: + fields = [] if self.auto_id: fields.append( FieldSchema( @@ -559,18 +810,27 @@ def _create_collection( max_length=65_535, ) ) + return fields + def _prepare_vector_fields(self, embeddings: List[list]) -> List[FieldSchema]: + fields = [] embeddings_functions: List[EmbeddingType] = self._as_list(self.embedding_func) - vector_schemas: List[dict[str, Any]] = ( - self._as_list(self.vector_schema) - if self.vector_schema - else [{} for _ in range(len(embeddings_functions))] + + assert ( + len(self._vector_fields_from_embedding) + == len(embeddings_functions) + == len(embeddings) + ), ( + "The number of `self._vector_fields_from_embedding`, " + "`embeddings_functions`, and `embeddings` should be the same." + f"Got {len(self._vector_fields_from_embedding)}, " + f"{len(embeddings_functions)}, and {len(embeddings)}." ) - for vector_field, vector_schema, embedding_func, vector_field_embeddings in zip( - vector_fields, vector_schemas, embeddings_functions, embeddings + # Loop through the embedding functions + for vector_field, embedding_func, embedding in zip( + self._vector_fields_from_embedding, embeddings_functions, embeddings ): - dim = len(vector_field_embeddings[0]) - # Create the vector field + vector_schema = self._vector_schema_map.get(vector_field, None) if vector_schema and "dtype" in vector_schema: fields.append( self._get_field_schema_from_dict(vector_field, vector_schema) @@ -585,48 +845,31 @@ def _create_collection( fields.append( FieldSchema( vector_field, - infer_dtype_bydata(vector_field_embeddings[0]), - dim=dim, + infer_dtype_bydata(embedding[0]), + dim=len(embedding[0]), ) ) - - # Create the schema for the collection - schema = CollectionSchema( - fields, - description=self.collection_description, - partition_key_field=self._partition_key_field, - enable_dynamic_field=self.enable_dynamic_field, - ) - - # Create the collection - try: - if self.num_shards is not None: - # Issue with defaults: - # https://github.com/milvus-io/pymilvus/blob/59bf5e811ad56e20946559317fed855330758d9c/pymilvus/client/prepare.py#L82-L85 - self.col = Collection( - name=self.collection_name, - schema=schema, - consistency_level=self.consistency_level, - using=self.alias, - num_shards=self.num_shards, - ) + # Loop through the built-in functions + for vector_field, builtin_function in zip( + self._vector_fields_from_function, self._as_list(self.builtin_func) + ): + vector_schema = self._vector_schema_map.get(vector_field, None) + if vector_schema and "dtype" in vector_schema: + field = self._get_field_schema_from_dict(vector_field, vector_schema) + elif isinstance(builtin_function, BM25BuiltInFunction): + field = FieldSchema(vector_field, DataType.SPARSE_FLOAT_VECTOR) else: - self.col = Collection( - name=self.collection_name, - schema=schema, - consistency_level=self.consistency_level, - using=self.alias, + raise ValueError( + "Unsupported embedding function type: " + f"{type(builtin_function)} for field: {vector_field}." ) - # Set the collection properties if they exist - if self.collection_properties is not None: - self.col.set_properties(self.collection_properties) - except MilvusException as e: - logger.error( - "Failed to create collection: %s error: %s", self.collection_name, e - ) - raise e + field.is_function_output = True + fields.append(field) + return fields - def _get_field_schema_from_dict(self, field_name: str, schema_dict: dict): # type: ignore[no-untyped-def] + def _get_field_schema_from_dict( + self, field_name: str, schema_dict: dict + ) -> FieldSchema: assert "dtype" in schema_dict, ( f"Please provide `dtype` in the schema dict. " f"Existing keys are: {schema_dict.keys()}" @@ -654,62 +897,103 @@ def _get_index(self, field_name: Optional[str] = None) -> Optional[dict[str, Any return x.to_dict() return None + def _get_indexes( + self, field_names: Optional[List[str]] = None + ) -> List[dict[str, Any]]: + """Return the list of vector index information""" + index_list = [] + if not field_names: + field_names = self._as_list(self._vector_field) + for field_name in field_names: + index = self._get_index(field_name) + if index is not None: + index_list.append(index) + return index_list + def _create_index(self) -> None: """Create an index on the collection""" - if isinstance(self.col, Collection) and self._get_index() is None: + if isinstance(self.col, Collection): embeddings_functions: List[EmbeddingType] = self._as_list( self.embedding_func ) - vector_fields: List[str] = self._as_list(self._vector_field) - if self.index_params is None: - indexes_params: List[dict] = [ - {} for _ in range(len(embeddings_functions)) - ] - else: - indexes_params = self._as_list(self.index_params) - for i, embeddings_func in enumerate(embeddings_functions): - if not self._get_index(vector_fields[i]): + # For backward compatibility + if type(self) is Milvus: + default_index_params = { + "metric_type": "L2", + "index_type": "HNSW", + "params": {"M": 8, "efConstruction": 64}, + } + else: # Zilliz, which is subclass of Milvus + default_index_params = { + "metric_type": "L2", + "index_type": "AUTOINDEX", + "params": {}, + } + for vector_field, embeddings_func in zip( + self._vector_fields_from_embedding, embeddings_functions + ): + if not self._get_index(vector_field): try: - # If no index params, use a default HNSW based one - if not indexes_params[i]: + if not self._index_param_map.get(vector_field, None): if self._is_sparse_embedding(embeddings_func): - indexes_params[i] = { + index_params = { "metric_type": "IP", "index_type": "SPARSE_INVERTED_INDEX", "params": {"drop_ratio_build": 0.2}, } else: - indexes_params[i] = { - "metric_type": "L2", - "index_type": "HNSW", - "params": {"M": 8, "efConstruction": 64}, + index_params = default_index_params + self._index_param_map[vector_field] = index_params + else: + index_params = self._index_param_map[vector_field] + self.col.create_index( + vector_field, + index_params=index_params, + using=self.alias, + ) + logger.debug( + "Successfully created an index" + "on %s field on collection: %s", + vector_field, + self.collection_name, + ) + except MilvusException as e: + logger.error( + "Failed to create an index on collection: %s", + self.collection_name, + ) + raise e + for vector_field, builtin_function in zip( + self._vector_fields_from_function, self._as_list(self.builtin_func) + ): + if not self._get_index(vector_field): + try: + if not self._index_param_map.get(vector_field, None): + if builtin_function.type == FunctionType.BM25: + index_params = { + "metric_type": "BM25", + "index_type": "AUTOINDEX", + "params": {}, } - - try: - self.col.create_index( - vector_fields[i], - index_params=indexes_params[i], - using=self.alias, - ) - - # If default did not work, most likely on Zilliz Cloud - except MilvusException: - # Use AUTOINDEX based index - index_params = { - "metric_type": "L2", - "index_type": "AUTOINDEX", - "params": {}, - } - self.col.create_index( - vector_fields[i], - index_params=index_params, - using=self.alias, - ) + else: + raise ValueError( + "Unsupported built-in function type: " + f"{builtin_function.type} for field: " + f"{vector_field}." + ) + self._index_param_map[vector_field] = index_params + else: + index_params = self._index_param_map[vector_field] + self.col.create_index( + vector_field, + index_params=index_params, + using=self.alias, + ) logger.debug( "Successfully created an index" "on %s field on collection: %s", - vector_fields[i], + vector_field, self.collection_name, ) except MilvusException as e: @@ -718,10 +1002,10 @@ def _create_index(self) -> None: self.collection_name, ) raise e - if self._is_multi_vector: - self.index_params = indexes_params - else: - self.index_params = indexes_params[0] + index_params_list: List[dict] = [] + for field in self._as_list(self._vector_field): + index_params_list.append(self._index_param_map.get(field, {})) + self.index_params = self._from_list(index_params_list) def _create_search_params(self) -> None: """Generate search params based on the current index type""" @@ -741,10 +1025,7 @@ def _create_search_params(self) -> None: ) search_params["metric_type"] = metric_type search_params_list.append(search_params) - if self._is_multi_vector: - self.search_params = search_params_list - else: - self.search_params = search_params_list[0] + self.search_params = self._from_list(search_params_list) def _load( self, @@ -756,7 +1037,7 @@ def _load( timeout = self.timeout or timeout if ( isinstance(self.col, Collection) - and self._get_index() is not None + and self._get_indexes() and utility.load_state(self.collection_name, using=self.alias) == LoadState.NotLoad ): @@ -833,7 +1114,7 @@ def add_texts( embeddings.append(embedding_func.embed_documents(texts)) except NotImplementedError: embeddings.append([embedding_func.embed_query(x) for x in texts]) - + # Currently, it is field-wise # assuming [f1, f2] embeddings functions and [a, b, c] as texts: # embeddings = [ # [f1(a), f1(b), f1(c)], @@ -844,11 +1125,12 @@ def add_texts( # [f1(a), f1(b), f1(c)] # ] - if len(embeddings) == 0: + if len(texts) == 0: logger.debug("Nothing to insert, skipping.") return [] - if self._is_multi_vector: + # Transpose it into row-wise + if self._is_multi_embedding: # transposed_embeddings = [ # [f1(a), f2(a)], # [f1(b), f2(b)], @@ -864,7 +1146,7 @@ def add_texts( # f1(b), # f1(c) # ] - transposed_embeddings = embeddings[0] + transposed_embeddings = embeddings[0] if len(embeddings) > 0 else [] return self.add_embeddings( texts=texts, @@ -898,7 +1180,7 @@ def add_embeddings( Args: texts (List[str]): the texts to insert - embeddings (List[List[Union[float, List[float]]]]): + embeddings (List[List[float]] | List[List[List[float]]]): A vector embeddings for each text (in case of a single vector) or list of vectors for each text (in case of multi-vector) metadatas (Optional[List[dict]]): Metadata dicts attached to each of @@ -917,15 +1199,24 @@ def add_embeddings( List[str]: The resulting keys for each inserted element. """ - if not self._is_multi_vector: - embeddings = [[embedding] for embedding in embeddings] # type: ignore - # Transpose embeddings to make it a list of embeddings of each type. - embeddings = [ # type: ignore - [embeddings[j][i] for j in range(len(embeddings))] - for i in range(len(embeddings[0])) - ] - - vector_fields: List[str] = self._as_list(self._vector_field) + if embeddings: + # row-wise -> field-wise + if not self._is_multi_embedding: + embeddings = [[embedding] for embedding in embeddings] # type: ignore + # transposed_embeddings = [ + # [f1(a), f2(a)], + # [f1(b), f2(b)], + # [f1(c), f2(c)] + # ] + # Transpose embeddings to make it a list of embeddings of each type. + embeddings = [ # type: ignore + [embeddings[j][i] for j in range(len(embeddings))] + for i in range(len(embeddings[0])) + ] + # embeddings = [ + # [f1(a), f1(b), f1(c)], + # [f2(a), f2(b), f2(c)] + # ] # If the collection hasn't been initialized yet, perform all steps to do so if not isinstance(self.col, Collection): @@ -958,7 +1249,9 @@ def add_embeddings( entity_dict[self._text_field] = text - for vector_field, vector_field_embeddings in zip(vector_fields, embeddings): # type: ignore + for vector_field, vector_field_embeddings in zip( # type: ignore + self._vector_fields_from_embedding, embeddings + ): entity_dict[vector_field] = vector_field_embeddings[i] if self._metadata_field and not self.enable_dynamic_field: @@ -998,22 +1291,23 @@ def add_embeddings( def _collection_search( self, - embedding: List[float] | Dict[int, float], + embedding_or_text: List[float] | Dict[int, float] | str, k: int = 4, param: Optional[dict] = None, expr: Optional[str] = None, timeout: Optional[float] = None, **kwargs: Any, - ) -> "pymilvus.client.abstract.SearchResult | None": # type: ignore[name-defined] # noqa: F821 - """Perform a search on an embedding and return milvus search results. + ) -> Optional[SearchResult]: + """Perform a search on an embedding or a query text and return milvus search + results. For more information about the search parameters, take a look at the pymilvus documentation found here: - https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/Collection/search.md + https://milvus.io/api-reference/pymilvus/v2.5.x/ORM/Collection/search.md Args: - embedding (List[float] | Dict[int, float]): The embedding vector being - searched. + embedding_or_text (List[float] | Dict[int, float] | str): The embedding + vector or query text being searched. k (int, optional): The amount of results to return. Defaults to 4. param (dict): The search params for the specified index. Defaults to None. @@ -1029,38 +1323,129 @@ def _collection_search( logger.debug("No existing collection to search.") return None - assert not isinstance(self.search_params, list) and not isinstance( - self._vector_field, list - ), "_collection_search does not support multi-vector search." + assert not self._is_multi_vector, ( + "_collection_search does not support multi-vector search. " + "You can use _collection_hybrid_search instead." + ) if param is None: - param = self.search_params + assert len(self._as_list(self.search_params)) == 1, ( + "The number of search params is larger than 1, " + "please check the search_params in this Milvus instance." + ) + param = self._as_list(self.search_params)[0] - # Determine result metadata fields with PK. if self.enable_dynamic_field: output_fields = ["*"] else: - output_fields = self.fields[:] - output_fields.remove(self._vector_field) - timeout = self.timeout or timeout - # Perform the search. - res = self.col.search( - data=[embedding], + output_fields = self._remove_forbidden_fields(self.fields[:]) + col_search_res = self.col.search( + data=[embedding_or_text], anns_field=self._vector_field, param=param, limit=k, expr=expr, output_fields=output_fields, - timeout=timeout, + timeout=self.timeout or timeout, **kwargs, ) - return res + return col_search_res + + def _collection_hybrid_search( + self, + query: str, + k: int = 4, + param: Optional[dict | list[dict]] = None, + expr: Optional[str] = None, + fetch_k: Optional[int] = 4, + ranker_type: Optional[Literal["rrf", "weighted"]] = None, + ranker_params: Optional[dict] = None, + timeout: Optional[float] = None, + **kwargs: Any, + ) -> Optional[SearchResult]: + """ + Perform a hybrid search on a query string and return milvus search results. + + For more information about the search parameters, take a look at the pymilvus + documentation found here: + https://milvus.io/api-reference/pymilvus/v2.5.x/ORM/Collection/hybrid_search.md + + Args: + query (str): The text being searched. + k (int, optional): The amount of results to return. Defaults to 4. + param (dict | list[dict], optional): The search params for the specified + index. Defaults to None. + expr (str, optional): Filtering expression. Defaults to None. + fetch_k (int, optional): The amount of pre-fetching results for each query. + Defaults to 4. + ranker_type (str, optional): The type of ranker to use. Defaults to None. + ranker_params (dict, optional): The parameters for the ranker. + Defaults to None. + timeout (float, optional): How long to wait before timeout error. + Defaults to None. + kwargs: Collection.hybrid_search() keyword arguments. + + Returns: + pymilvus.client.abstract.SearchResult: Milvus search result. + """ + if self.col is None: + logger.debug("No existing collection to search.") + return None + + search_requests = [] + reranker = self._create_ranker( + ranker_type=ranker_type, + ranker_params=ranker_params or {}, + ) + if not param: + param_list = self._as_list(self.search_params) + else: + assert len(self._as_list(param)) == len( + self._as_list(self.search_params) + ), ( + f"The number of search params ({len(self._as_list(param))})" + f" does not match the number of vector fields " + f"({len(self._as_list(self._vector_field))})." + f" All vector fields are: {(self._as_list(self._vector_field))}," + " please provide a list of search params for each vector field." + ) + param_list = self._as_list(param) + for field, param_dict in zip(self._vector_field, param_list): + search_data: List[float] | Dict[int, float] | str + if field in self._vector_fields_from_embedding: + embedding_func: EmbeddingType = self._as_list(self.embedding_func)[ # type: ignore + self._vector_fields_from_embedding.index(field) + ] + search_data = embedding_func.embed_query(query) + else: + search_data = query + request = AnnSearchRequest( + data=[search_data], + anns_field=field, + param=param_dict, + limit=fetch_k, + expr=expr, + ) + search_requests.append(request) + if self.enable_dynamic_field: + output_fields = ["*"] + else: + output_fields = self._remove_forbidden_fields(self.fields[:]) + col_search_res = self.col.hybrid_search( + reqs=search_requests, + rerank=reranker, + limit=k, + output_fields=output_fields, + timeout=self.timeout or timeout, + **kwargs, + ) + return col_search_res def similarity_search( self, query: str, k: int = 4, - param: Optional[dict] = None, + param: Optional[dict | list[dict]] = None, expr: Optional[str] = None, timeout: Optional[float] = None, **kwargs: Any, @@ -1070,7 +1455,7 @@ def similarity_search( Args: query (str): The text to search. k (int, optional): How many results to return. Defaults to 4. - param (dict, optional): The search params for the index type. + param (dict | list[dict], optional): The search params for the index type. Defaults to None. expr (str, optional): Filtering expression. Defaults to None. timeout (int, optional): How long to wait before timeout error. @@ -1126,7 +1511,7 @@ def similarity_search_with_score( self, query: str, k: int = 4, - param: Optional[dict] = None, + param: Optional[dict | list[dict]] = None, expr: Optional[str] = None, timeout: Optional[float] = None, **kwargs: Any, @@ -1135,59 +1520,68 @@ def similarity_search_with_score( For more information about the search parameters, take a look at the pymilvus documentation found here: - https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/Collection/search.md + https://milvus.io/api-reference/pymilvus/v2.5.x/ORM/Collection/search.md Args: query (str): The text being searched. k (int, optional): The amount of results to return. Defaults to 4. - param (dict): The search params for the specified index. - Defaults to None. + param (dict | list[dict], optional): The search params for the specified + index. Defaults to None. expr (str, optional): Filtering expression. Defaults to None. timeout (float, optional): How long to wait before timeout error. Defaults to None. - kwargs: Collection.search() keyword arguments. + kwargs: Collection.search() or hybrid_search() keyword arguments. Returns: - List[float], List[Tuple[Document, any, any]]: + List[Tuple[Document, float]]: List of result doc and score. """ if self.col is None: logger.debug("No existing collection to search.") return [] - if isinstance(self.embedding_func, list): # is multi-vector - ranker = self._create_ranker( - kwargs.pop("ranker_type", None), kwargs.pop("ranker_params", {}) - ) - hybrid_retriever = MilvusCollectionHybridSearchRetriever( - collection=self.col, - rerank=ranker, - anns_fields=self._vector_field, - field_embeddings=self.embeddings, - field_search_params=param or self.search_params, - field_exprs=expr, - top_k=k, - text_field=self._text_field, - timeout=timeout, - **kwargs, + + if self._is_multi_vector: + col_search_res = self._collection_hybrid_search( + query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs ) - res = [] - col_search_res = hybrid_retriever.hybrid_search(query) - for result in col_search_res[0]: - data = {x: result.entity.get(x) for x in result.entity.fields} - doc = self._parse_document(data) - res.append((doc, result.score)) + else: - # Embed the query text. - embedding = self.embedding_func.embed_query(query) - timeout = self.timeout or timeout - res = self.similarity_search_with_score_by_vector( - embedding=embedding, - k=k, - param=param, - expr=expr, - timeout=timeout, - **kwargs, + assert len(self._as_list(param)) <= 1, ( + "When there is only one vector field, you can not provide multiple " + "search param dicts." ) - return res + param = cast(Optional[dict], self._from_list(param)) + if ( + len(self._as_list(self.embedding_func)) == 1 # type: ignore[arg-type] + and len(self._as_list(self.builtin_func)) == 0 + ): + embedding = self._as_list(self.embedding_func)[0].embed_query(query) # type: ignore + col_search_res = self._collection_search( + embedding_or_text=embedding, + k=k, + param=param, + expr=expr, + timeout=timeout, + **kwargs, + ) + elif ( + len(self._as_list(self.embedding_func)) == 0 # type: ignore[arg-type] + and len(self._as_list(self.builtin_func)) == 1 + ): + col_search_res = self._collection_search( + embedding_or_text=query, + k=k, + param=param, + expr=expr, + timeout=timeout, + **kwargs, + ) + else: + raise RuntimeError( + "Check either it's multi vectors or single vector with " + "only one embedding/builtin function." + ) + + return self._parse_documents_from_search_results(col_search_res) def similarity_search_with_score_by_vector( self, @@ -1202,7 +1596,7 @@ def similarity_search_with_score_by_vector( For more information about the search parameters, take a look at the pymilvus documentation found here: - https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/Collection/search.md + https://milvus.io/api-reference/pymilvus/v2.5.x/ORM/Collection/search.md Args: embedding (List[float] | Dict[int, float]): The embedding vector being @@ -1219,18 +1613,14 @@ def similarity_search_with_score_by_vector( List[Tuple[Document, float]]: Result doc and score. """ col_search_res = self._collection_search( - embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs + embedding_or_text=embedding, + k=k, + param=param, + expr=expr, + timeout=timeout, + **kwargs, ) - if col_search_res is None: - return [] - ret = [] - for result in col_search_res[0]: - data = {x: result.entity.get(x) for x in result.entity.fields} - doc = self._parse_document(data) - pair = (doc, result.score) - ret.append(pair) - - return ret + return self._parse_documents_from_search_results(col_search_res) def max_marginal_relevance_search( self, @@ -1269,10 +1659,16 @@ def max_marginal_relevance_search( logger.debug("No existing collection to search.") return [] - assert not isinstance( - self.embedding_func, list - ), "MMR is not-suported in multi-vector settings" - embedding = self.embedding_func.embed_query(query) + assert ( + len(self._as_list(self.embedding_func)) == 1 # type: ignore[arg-type] + ), "You must set only one embedding function for MMR search." + if len(self._vector_fields_from_function) > 0: + logger.warning( + "MMR search will only use the embedding function, " + "without the built-in functions." + ) + + embedding = self._as_list(self.embedding_func)[0].embed_query(query) # type: ignore timeout = self.timeout or timeout return self.max_marginal_relevance_search_by_vector( embedding=embedding, @@ -1319,7 +1715,7 @@ def max_marginal_relevance_search_by_vector( List[Document]: Document results for search. """ col_search_res = self._collection_search( - embedding=embedding, + embedding_or_text=embedding, k=fetch_k, param=param, expr=expr, @@ -1377,7 +1773,7 @@ def _select_relevance_score_fn(self) -> Callable[[float], float]: raise ValueError( "No index params provided. Could not determine relevance function." ) - if self._is_multi_vector: + if self._is_multi_embedding or self._is_multi_function: raise ValueError( "No supported normalization function for multi vectors. " "Could not determine relevance function." @@ -1408,13 +1804,19 @@ def _map_ip_to_similarity(ip_score: float) -> float: """ return (ip_score + 1) / 2.0 - if self.index_params is None: + if not self.index_params: logger.warning( "No index params provided. Could not determine relevance function. " "Use L2 distance as default." ) return _map_l2_to_similarity indexes_params = self._as_list(self.index_params) + if len(indexes_params) > 1: + raise ValueError( + "No supported normalization function for multi vectors. " + "Could not determine relevance function." + ) + # In the left case, the len of indexes_params is 1. metric_type = indexes_params[0]["metric_type"] if metric_type == "L2": return _map_l2_to_similarity @@ -1454,7 +1856,7 @@ def delete( # type: ignore[no-untyped-def] def from_texts( cls, texts: List[str], - embedding: Union[EmbeddingType, List[EmbeddingType]], # type: ignore + embedding: Optional[Union[EmbeddingType, List[EmbeddingType]]], metadatas: Optional[List[dict]] = None, collection_name: str = "LangChainCollection", connection_args: Optional[Dict[str, Any]] = None, @@ -1465,13 +1867,17 @@ def from_texts( *, ids: Optional[List[str]] = None, auto_id: bool = False, + builtin_function: Optional[ + Union[BaseMilvusBuiltInFunction, List[BaseMilvusBuiltInFunction]] + ] = None, **kwargs: Any, ) -> Milvus: """Create a Milvus collection, indexes it with HNSW, and insert data. Args: texts (List[str]): Text data. - embedding (Union[Embeddings, BaseSparseEmbedding]): Embedding function. + embedding (Optional[Union[Embeddings, BaseSparseEmbedding]]): Embedding + function. metadatas (Optional[List[dict]]): Metadata for each text if it exists. Defaults to None. collection_name (str, optional): Collection name to use. Defaults to @@ -1490,7 +1896,10 @@ def from_texts( auto_id (bool): Whether to enable auto id for primary key. Defaults to False. If False, you need to provide text ids (string less than 65535 bytes). If True, Milvus will generate unique integers as primary keys. - + builtin_function (Optional[Union[BaseMilvusBuiltInFunction, + List[BaseMilvusBuiltInFunction]]]): + Built-in function to use. Defaults to None. + **kwargs: Other parameters in Milvus Collection. Returns: Milvus: Milvus Vector Store """ @@ -1512,6 +1921,7 @@ def from_texts( search_params=search_params, drop_old=drop_old, auto_id=auto_id, + builtin_function=builtin_function, **kwargs, ) vector_db.add_texts(texts=texts, metadatas=metadatas, ids=ids) @@ -1527,6 +1937,20 @@ def _parse_document(self, data: dict) -> Document: metadata=data.pop(self._metadata_field) if self._metadata_field else data, ) + def _parse_documents_from_search_results( + self, + col_search_res: SearchResult, + ) -> List[Tuple[Document, float]]: + if not col_search_res: + return [] + ret = [] + for result in col_search_res[0]: + data = {x: result.entity.get(x) for x in result.entity.fields} + doc = self._parse_document(data) + pair = (doc, result.score) + ret.append(pair) + return ret + def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: """Run more documents through the embeddings and add to the vectorstore. @@ -1599,17 +2023,26 @@ def upsert( # type: ignore raise exc @staticmethod - def _as_list(value: Union[T, List[T]]) -> List[T]: + def _as_list(value: Optional[Union[T, List[T]]]) -> List[T]: + """Try to cast a value to a list""" + if not value: + return [] return [value] if not isinstance(value, list) else value + @staticmethod + def _from_list(value: Optional[Union[T, List[T]]]) -> Optional[Union[T, List[T]]]: + """Try to cast a list to a single value""" + if isinstance(value, list) and len(value) == 1: + return value[0] + return value + def _create_ranker( self, ranker_type: Optional[Literal["rrf", "weighted"]], ranker_params: dict, ) -> Union[WeightedRanker, RRFRanker]: """A Ranker factory method""" - embeddings_functions: List[EmbeddingType] = self._as_list(self.embedding_func) - default_weights = [1.0] * len(embeddings_functions) + default_weights = [1.0] * len(self._as_list(self._vector_field)) if not ranker_type: return WeightedRanker(*default_weights) @@ -1630,3 +2063,13 @@ def _create_ranker( "rrf", ) raise ValueError("Unrecognized ranker of type %s", ranker_type) + + def _remove_forbidden_fields(self, fields: List[str]) -> List[str]: + """Bm25 function fields are not allowed as output fields in Milvus.""" + forbidden_fields = [] + for builtin_function in self._as_list(self.builtin_func): + if builtin_function.type == FunctionType.BM25: + forbidden_fields.extend( + self._as_list(builtin_function.output_field_names) + ) + return [field for field in fields if field not in forbidden_fields] diff --git a/libs/milvus/langchain_milvus/vectorstores/zilliz.py b/libs/milvus/langchain_milvus/vectorstores/zilliz.py index f68d5d3..865f898 100644 --- a/libs/milvus/langchain_milvus/vectorstores/zilliz.py +++ b/libs/milvus/langchain_milvus/vectorstores/zilliz.py @@ -1,11 +1,8 @@ from __future__ import annotations import logging -from typing import List, Optional, Union, cast -from pymilvus import Collection, MilvusException - -from langchain_milvus.vectorstores.milvus import EmbeddingType, Milvus +from langchain_milvus.vectorstores.milvus import Milvus logger = logging.getLogger(__name__) @@ -73,74 +70,5 @@ class Zilliz(Milvus): ValueError: If the pymilvus python package is not installed. """ - def _create_index(self) -> None: - """Create an index on the collection""" - - self.index_params = cast(Optional[Union[dict, List[dict]]], self.index_params) # type: ignore - - if isinstance(self.col, Collection) and self._get_index() is None: - embeddings_functions: List[EmbeddingType] = self._as_list( - self.embedding_func - ) - vector_fields: List[str] = self._as_list(self._vector_field) - if self.index_params is None: - indexes_params: List[dict] = [ - {} for _ in range(len(embeddings_functions)) - ] - else: - indexes_params = self._as_list(self.index_params) - - for i, embeddings_func in enumerate(embeddings_functions): - if not self._get_index(vector_fields[i]): - try: - # If no index params, use a default *AutoIndex* based one - if not indexes_params[i]: - if self._is_sparse_embedding(embeddings_func): - indexes_params[i] = { - "metric_type": "IP", - "index_type": "SPARSE_INVERTED_INDEX", - "params": {"drop_ratio_build": 0.2}, - } - else: - indexes_params[i] = { - "metric_type": "L2", - "index_type": "AUTOINDEX", - "params": {}, - } - - try: - self.col.create_index( - vector_fields[i], - index_params=indexes_params[i], - using=self.alias, - ) - - # If default did not work, most likely Milvus self-hosted - except MilvusException: - # Use HNSW based index - index_params = { - "metric_type": "L2", - "index_type": "HNSW", - "params": {"M": 8, "efConstruction": 64}, - } - self.col.create_index( - vector_fields[i], - index_params=index_params, - using=self.alias, - ) - logger.debug( - "Successfully created an index" - "on %s field on collection: %s", - vector_fields[i], - self.collection_name, - ) - except MilvusException as e: - logger.error( - "Failed to create an index on collection: %s", - self.collection_name, - ) - raise e - if self._is_multi_vector: - self.index_params = indexes_params - else: - self.index_params = indexes_params[0] + # For backwards compatibility, this class is preserved. + # But it is recommended to use Milvus instead. diff --git a/libs/milvus/poetry.lock b/libs/milvus/poetry.lock index 981cce9..4e7bb8c 100644 --- a/libs/milvus/poetry.lock +++ b/libs/milvus/poetry.lock @@ -188,27 +188,6 @@ humanfriendly = ">=9.1" [package.extras] cron = ["capturer (>=2.4)"] -[[package]] -name = "environs" -version = "9.5.0" -description = "simplified environment variable parsing" -optional = false -python-versions = ">=3.6" -files = [ - {file = "environs-9.5.0-py2.py3-none-any.whl", hash = "sha256:1e549569a3de49c05f856f40bce86979e7d5ffbbc4398e7f338574c220189124"}, - {file = "environs-9.5.0.tar.gz", hash = "sha256:a76307b36fbe856bdca7ee9161e6c466fd7fcffc297109a118c59b54e27e30c9"}, -] - -[package.dependencies] -marshmallow = ">=3.0.0" -python-dotenv = "*" - -[package.extras] -dev = ["dj-database-url", "dj-email-url", "django-cache-url", "flake8 (==4.0.1)", "flake8-bugbear (==21.9.2)", "mypy (==0.910)", "pre-commit (>=2.4,<3.0)", "pytest", "tox"] -django = ["dj-database-url", "dj-email-url", "django-cache-url"] -lint = ["flake8 (==4.0.1)", "flake8-bugbear (==21.9.2)", "mypy (==0.910)", "pre-commit (>=2.4,<3.0)"] -tests = ["dj-database-url", "dj-email-url", "django-cache-url", "pytest"] - [[package]] name = "exceptiongroup" version = "1.2.2" @@ -564,25 +543,6 @@ pydantic = [ ] requests = ">=2,<3" -[[package]] -name = "marshmallow" -version = "3.22.0" -description = "A lightweight library for converting complex datatypes to and from native Python datatypes." -optional = false -python-versions = ">=3.8" -files = [ - {file = "marshmallow-3.22.0-py3-none-any.whl", hash = "sha256:71a2dce49ef901c3f97ed296ae5051135fd3febd2bf43afe0ae9a82143a494d9"}, - {file = "marshmallow-3.22.0.tar.gz", hash = "sha256:4972f529104a220bb8637d595aa4c9762afbe7f7a77d82dc58c1615d70c5823e"}, -] - -[package.dependencies] -packaging = ">=17.0" - -[package.extras] -dev = ["marshmallow[tests]", "pre-commit (>=3.5,<4.0)", "tox"] -docs = ["alabaster (==1.0.0)", "autodocsumm (==0.2.13)", "sphinx (==8.0.2)", "sphinx-issues (==4.1.0)", "sphinx-version-warning (==1.1.2)"] -tests = ["pytest", "pytz", "simplejson"] - [[package]] name = "milvus-lite" version = "2.4.10" @@ -592,7 +552,6 @@ python-versions = ">=3.7" files = [ {file = "milvus_lite-2.4.10-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:fc4246d3ed7d1910847afce0c9ba18212e93a6e9b8406048436940578dfad5cb"}, {file = "milvus_lite-2.4.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:74a8e07c5e3b057df17fbb46913388e84df1dc403a200f4e423799a58184c800"}, - {file = "milvus_lite-2.4.10-py3-none-manylinux2014_aarch64.whl", hash = "sha256:240c7386b747bad696ecb5bd1f58d491e86b9d4b92dccee3315ed7256256eddc"}, {file = "milvus_lite-2.4.10-py3-none-manylinux2014_x86_64.whl", hash = "sha256:211d2e334a043f9282bdd9755f76b9b2d93b23bffa7af240919ffce6a8dfe325"}, ] @@ -1115,21 +1074,21 @@ typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" [[package]] name = "pymilvus" -version = "2.4.7" +version = "2.5.0" description = "Python Sdk for Milvus" optional = false python-versions = ">=3.8" files = [ - {file = "pymilvus-2.4.7-py3-none-any.whl", hash = "sha256:1e5d377bd40fa7eb459d3958dbd96201758f5cf997d41eb3d2d169d0b7fa462e"}, - {file = "pymilvus-2.4.7.tar.gz", hash = "sha256:9ef460b940782a42e1b7b8ae0da03d8cc02d9d80044d13f4b689a7c935ec7aa7"}, + {file = "pymilvus-2.5.0-py3-none-any.whl", hash = "sha256:a0e8653d8fe78019abfda79b3404ef7423f312501e8cbd7dc728051ce8732652"}, + {file = "pymilvus-2.5.0.tar.gz", hash = "sha256:4da14a3bd957a4921166f9355fd1f1ac5c5e4e80b46f12f64d9c9a6dcb8cb395"}, ] [package.dependencies] -environs = "<=9.5.0" -grpcio = ">=1.49.1" -milvus-lite = {version = ">=2.4.0,<2.5.0", markers = "sys_platform != \"win32\""} +grpcio = ">=1.49.1,<=1.67.1" +milvus-lite = {version = ">=2.4.0", markers = "sys_platform != \"win32\""} pandas = ">=1.2.4" protobuf = ">=3.20.0" +python-dotenv = ">=1.0.1,<2.0.0" setuptools = ">69" ujson = ">=2.0.0" @@ -2241,4 +2200,4 @@ watchmedo = ["PyYAML (>=3.10)"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "cf3207ab2986e65ab302be4cc6f3bc35f1bca94beb62fa1fd223605ef04cb327" +content-hash = "14a3ea86ebbedf5795c02a65bcd208248cba201ace5761794f3df53330e8b3f2" diff --git a/libs/milvus/pyproject.toml b/libs/milvus/pyproject.toml index cd18c95..d9a2406 100644 --- a/libs/milvus/pyproject.toml +++ b/libs/milvus/pyproject.toml @@ -26,7 +26,7 @@ ignore_missing_imports = "True" [tool.poetry.dependencies] python = ">=3.9,<4.0" -pymilvus = "^2.4.3" +pymilvus = "^2.5.0" langchain-core = ">=0.2.38,<0.4" [tool.coverage.run] diff --git a/libs/milvus/tests/integration_tests/vectorstores/test_milvus.py b/libs/milvus/tests/integration_tests/vectorstores/test_milvus.py index 5580aac..55d6e2c 100644 --- a/libs/milvus/tests/integration_tests/vectorstores/test_milvus.py +++ b/libs/milvus/tests/integration_tests/vectorstores/test_milvus.py @@ -1,11 +1,11 @@ """Test Milvus functionality.""" - import tempfile from typing import Any, List, Optional import pytest from langchain_core.documents import Document +from langchain_milvus.function import BM25BuiltInFunction from langchain_milvus.utils.sparse import BM25SparseEmbedding from langchain_milvus.vectorstores import Milvus from tests.integration_tests.utils import ( @@ -35,11 +35,15 @@ def temp_milvus_db() -> Any: yield temp_file.name +TEST_URI = "./milvus_demo.db" +# TEST_TOKEN = "" + + def _milvus_from_texts( metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, drop: bool = True, - db_path: str = "./milvus_demo.db", + db_path: str = TEST_URI, **kwargs: Any, ) -> Milvus: return Milvus.from_texts( @@ -61,7 +65,7 @@ def _get_pks(expr: str, docsearch: Milvus) -> List[Any]: def test_milvus(temp_milvus_db: Any) -> None: """Test end to end construction and search.""" - docsearch = _milvus_from_texts(db_path=temp_milvus_db) + docsearch = _milvus_from_texts(db_path=TEST_URI) output = docsearch.similarity_search("foo", k=1) assert_docs_equal_without_pk(output, [Document(page_content="foo")]) @@ -86,7 +90,7 @@ def test_milvus_add_embeddings_search(temp_milvus_db: Any) -> None: def test_milvus_vector_search(temp_milvus_db: Any) -> None: """Test end to end construction and search by vector.""" - docsearch = _milvus_from_texts(db_path=temp_milvus_db) + docsearch = _milvus_from_texts(db_path=TEST_URI) output = docsearch.similarity_search_by_vector( FakeEmbeddings().embed_query("foo"), k=1 ) @@ -96,7 +100,7 @@ def test_milvus_vector_search(temp_milvus_db: Any) -> None: def test_milvus_with_metadata(temp_milvus_db: Any) -> None: """Test with metadata""" docsearch = _milvus_from_texts( - metadatas=[{"label": "test"}] * len(fake_texts), db_path=temp_milvus_db + metadatas=[{"label": "test"}] * len(fake_texts), db_path=TEST_URI ) output = docsearch.similarity_search("foo", k=1) assert_docs_equal_without_pk( @@ -107,7 +111,7 @@ def test_milvus_with_metadata(temp_milvus_db: Any) -> None: def test_milvus_with_id(temp_milvus_db: Any) -> None: """Test with ids""" ids = ["id_" + str(i) for i in range(len(fake_texts))] - docsearch = _milvus_from_texts(ids=ids, db_path=temp_milvus_db) + docsearch = _milvus_from_texts(ids=ids, db_path=TEST_URI) output = docsearch.similarity_search("foo", k=1) assert_docs_equal_without_pk(output, [Document(page_content="foo")]) @@ -116,7 +120,7 @@ def test_milvus_with_id(temp_milvus_db: Any) -> None: try: ids = ["dup_id" for _ in fake_texts] - _milvus_from_texts(ids=ids, db_path=temp_milvus_db) + _milvus_from_texts(ids=ids, db_path=TEST_URI) except Exception as e: assert isinstance(e, AssertionError) @@ -125,7 +129,7 @@ def test_milvus_with_score(temp_milvus_db: Any) -> None: """Test end to end construction and search with scores and IDs.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": i} for i in range(len(texts))] - docsearch = _milvus_from_texts(metadatas=metadatas, db_path=temp_milvus_db) + docsearch = _milvus_from_texts(metadatas=metadatas, db_path=TEST_URI) output = docsearch.similarity_search_with_score("foo", k=3) docs = [o[0] for o in output] scores = [o[1] for o in output] @@ -144,7 +148,7 @@ def test_milvus_max_marginal_relevance_search(temp_milvus_db: Any) -> None: """Test end to end construction and MRR search.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": i} for i in range(len(texts))] - docsearch = _milvus_from_texts(metadatas=metadatas, db_path=temp_milvus_db) + docsearch = _milvus_from_texts(metadatas=metadatas, db_path=TEST_URI) output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3) assert_docs_equal_without_pk( output, @@ -162,7 +166,7 @@ def test_milvus_max_marginal_relevance_search_with_dynamic_field( texts = ["foo", "bar", "baz"] metadatas = [{"page": i} for i in range(len(texts))] docsearch = _milvus_from_texts( - metadatas=metadatas, enable_dynamic_field=True, db_path=temp_milvus_db + metadatas=metadatas, enable_dynamic_field=True, db_path=TEST_URI ) output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3) assert_docs_equal_without_pk( @@ -178,7 +182,7 @@ def test_milvus_add_extra(temp_milvus_db: Any) -> None: """Test end to end construction and MRR search.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": i} for i in range(len(texts))] - docsearch = _milvus_from_texts(metadatas=metadatas, db_path=temp_milvus_db) + docsearch = _milvus_from_texts(metadatas=metadatas, db_path=TEST_URI) docsearch.add_texts(texts, metadatas) @@ -190,12 +194,10 @@ def test_milvus_no_drop(temp_milvus_db: Any) -> None: """Test construction without dropping old data.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": i} for i in range(len(texts))] - docsearch = _milvus_from_texts(metadatas=metadatas, db_path=temp_milvus_db) + docsearch = _milvus_from_texts(metadatas=metadatas, db_path=TEST_URI) del docsearch - docsearch = _milvus_from_texts( - metadatas=metadatas, drop=False, db_path=temp_milvus_db - ) + docsearch = _milvus_from_texts(metadatas=metadatas, drop=False, db_path=TEST_URI) output = docsearch.similarity_search("foo", k=10) assert len(output) == 6 @@ -205,7 +207,7 @@ def test_milvus_get_pks(temp_milvus_db: Any) -> None: """Test end to end construction and get pks with expr""" texts = ["foo", "bar", "baz"] metadatas = [{"id": i} for i in range(len(texts))] - docsearch = _milvus_from_texts(metadatas=metadatas, db_path=temp_milvus_db) + docsearch = _milvus_from_texts(metadatas=metadatas, db_path=TEST_URI) expr = "id in [1,2]" output = _get_pks(expr, docsearch) assert len(output) == 2 @@ -215,7 +217,7 @@ def test_milvus_delete_entities(temp_milvus_db: Any) -> None: """Test end to end construction and delete entities""" texts = ["foo", "bar", "baz"] metadatas = [{"id": i} for i in range(len(texts))] - docsearch = _milvus_from_texts(metadatas=metadatas, db_path=temp_milvus_db) + docsearch = _milvus_from_texts(metadatas=metadatas, db_path=TEST_URI) expr = "id in [1,2]" pks = _get_pks(expr, docsearch) result = docsearch.delete(pks) @@ -226,7 +228,7 @@ def test_milvus_upsert_entities(temp_milvus_db: Any) -> None: """Test end to end construction and upsert entities""" texts = ["foo", "bar", "baz"] metadatas = [{"id": i} for i in range(len(texts))] - docsearch = _milvus_from_texts(metadatas=metadatas, db_path=temp_milvus_db) + docsearch = _milvus_from_texts(metadatas=metadatas, db_path=TEST_URI) expr = "id in [1,2]" pks = _get_pks(expr, docsearch) documents = [ @@ -242,7 +244,7 @@ def test_milvus_enable_dynamic_field(temp_milvus_db: Any) -> None: texts = ["foo", "bar", "baz"] metadatas = [{"id": i} for i in range(len(texts))] docsearch = _milvus_from_texts( - metadatas=metadatas, enable_dynamic_field=True, db_path=temp_milvus_db + metadatas=metadatas, enable_dynamic_field=True, db_path=TEST_URI ) output = docsearch.similarity_search("foo", k=10) assert len(output) == 3 @@ -266,7 +268,7 @@ def test_milvus_disable_dynamic_field(temp_milvus_db: Any) -> None: texts = ["foo", "bar", "baz"] metadatas = [{"id": i} for i in range(len(texts))] docsearch = _milvus_from_texts( - metadatas=metadatas, enable_dynamic_field=False, db_path=temp_milvus_db + metadatas=metadatas, enable_dynamic_field=False, db_path=TEST_URI ) output = docsearch.similarity_search("foo", k=10) assert len(output) == 3 @@ -300,7 +302,7 @@ def test_milvus_metadata_field(temp_milvus_db: Any) -> None: texts = ["foo", "bar", "baz"] metadatas = [{"id": i} for i in range(len(texts))] docsearch = _milvus_from_texts( - metadatas=metadatas, metadata_field="metadata", db_path=temp_milvus_db + metadatas=metadatas, metadata_field="metadata", db_path=TEST_URI ) output = docsearch.similarity_search("foo", k=10) assert len(output) == 3 @@ -331,7 +333,7 @@ def test_milvus_enable_dynamic_field_with_partition_key(temp_milvus_db: Any) -> metadatas=metadatas, enable_dynamic_field=True, partition_key_field="namespace", - db_path=temp_milvus_db, + db_path=TEST_URI, ) # filter on a single namespace @@ -408,7 +410,7 @@ def test_milvus_array_field(temp_milvus_db: Any) -> None: # "dtype": DataType.INT64, # } }, - db_path=temp_milvus_db, + db_path=TEST_URI, ) output = docsearch.similarity_search("foo", k=10, expr="array_field[0] < 2") assert len(output) == 2 @@ -422,7 +424,7 @@ def test_milvus_array_field(temp_milvus_db: Any) -> None: docsearch = _milvus_from_texts( enable_dynamic_field=True, metadatas=metadatas, - db_path=temp_milvus_db, + db_path=TEST_URI, ) output = docsearch.similarity_search("foo", k=10, expr="array_field[0] < 2") assert len(output) == 2 @@ -601,6 +603,86 @@ def test_milvus_similarity_search_with_relevance_scores( assert retrival_output[0].page_content == "down" +@pytest.mark.parametrize("enable_dynamic_field", [True, False]) +def test_milvus_builtin_bm25_function(enable_dynamic_field: bool) -> None: + """ + Test builtin BM25 function + + NOTE: The full text search feature is so far not supported in Milvus-Lite and Zilliz + To run this unittest successfully, we can only use Milvus Docker Standalone service. + """ + + def _add_and_assert(docsearch: Milvus) -> None: + if enable_dynamic_field: + metadatas = [{"page": i} for i in range(len(fake_texts))] + else: + metadatas = None + docsearch.add_texts(fake_texts, metadatas=metadatas) + output = docsearch.similarity_search("foo", k=1) + if enable_dynamic_field: + assert_docs_equal_without_pk( + output, [Document(page_content=fake_texts[0], metadata={"page": 0})] + ) + else: + assert_docs_equal_without_pk(output, [Document(page_content=fake_texts[0])]) + + # BM25 only + docsearch1 = Milvus( + embedding_function=[], + builtin_function=[BM25BuiltInFunction()], + connection_args={"uri": TEST_URI}, + auto_id=True, + drop_old=True, + consistency_level="Strong", + vector_field="sparse", + enable_dynamic_field=enable_dynamic_field, + ) + _add_and_assert(docsearch1) + + # Dense embedding + BM25 + docsearch2 = Milvus( + embedding_function=FakeEmbeddings(), + builtin_function=[BM25BuiltInFunction()], + connection_args={"uri": TEST_URI}, + auto_id=True, + drop_old=True, + consistency_level="Strong", + vector_field="sparse", + enable_dynamic_field=enable_dynamic_field, + ) + _add_and_assert(docsearch2) + + # Dense embedding + BM25 + custom index params + index_param_1 = { + "metric_type": "COSINE", + "index_type": "HNSW", + } + index_param_2 = { + "metric_type": "BM25", + "index_type": "AUTOINDEX", + } + docsearch3 = Milvus( + embedding_function=[ + FakeEmbeddings(), + ], + builtin_function=[ + BM25BuiltInFunction( + input_field_names="text00", + output_field_names="sparse00", + ) + ], + index_params=[index_param_1, index_param_2], + connection_args={"uri": TEST_URI}, + auto_id=True, + drop_old=True, + consistency_level="Strong", + text_field="text00", + vector_field=["dense00", "sparse00"], + enable_dynamic_field=enable_dynamic_field, + ) + _add_and_assert(docsearch3) + + # if __name__ == "__main__": # test_milvus() # test_milvus_vector_search() @@ -625,3 +707,4 @@ def test_milvus_similarity_search_with_relevance_scores( # test_milvus_multi_vector_with_index_params() # test_milvus_multi_vector_search_with_ranker() # test_milvus_similarity_search_with_relevance_scores() +# test_milvus_builtin_bm25_function() diff --git a/libs/milvus/tests/unit_tests/test_imports.py b/libs/milvus/tests/unit_tests/test_imports.py index 8be170e..601efa4 100644 --- a/libs/milvus/tests/unit_tests/test_imports.py +++ b/libs/milvus/tests/unit_tests/test_imports.py @@ -5,6 +5,8 @@ "MilvusCollectionHybridSearchRetriever", "Zilliz", "ZillizCloudPipelineRetriever", + "BaseMilvusBuiltInFunction", + "BM25BuiltInFunction", ]