generated from langchain-ai/integration-repo-template
-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: ChengZi <[email protected]>
- Loading branch information
1 parent
b925dac
commit 8845a95
Showing
7 changed files
with
790 additions
and
349 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import uuid | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
from pymilvus import Function, FunctionType | ||
|
||
from langchain_milvus.utils.constant import SPARSE_VECTOR_FIELD, TEXT_FIELD | ||
|
||
|
||
class BaseMilvusBuiltInFunction: | ||
""" | ||
Base class for Milvus built-in functions. | ||
See: | ||
https://milvus.io/docs/manage-collections.md#Function | ||
""" | ||
|
||
def __init__(self) -> None: | ||
self._function: Optional[Function] = None | ||
|
||
@property | ||
def function(self) -> Function: | ||
return self._function | ||
|
||
@property | ||
def input_field_names(self) -> Union[str, List[str]]: | ||
return self.function.input_field_names | ||
|
||
@property | ||
def output_field_names(self) -> Union[str, List[str]]: | ||
return self.function.output_field_names | ||
|
||
@property | ||
def type(self) -> FunctionType: | ||
return self.function.type | ||
|
||
|
||
class Bm25BuiltInFunction(BaseMilvusBuiltInFunction): | ||
""" | ||
Milvus BM25 built-in function. | ||
See: | ||
https://milvus.io/docs/full-text-search.md | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
input_field_names: str = TEXT_FIELD, | ||
output_field_names: str = SPARSE_VECTOR_FIELD, | ||
analyzer_params: Optional[Dict[Any, Any]] = None, | ||
enable_match: bool = False, | ||
function_name: Optional[str] = None, | ||
): | ||
super().__init__() | ||
if not function_name: | ||
function_name = f"bm25_function_{str(uuid.uuid4())[:8]}" | ||
self._function = Function( | ||
name=function_name, | ||
input_field_names=input_field_names, | ||
output_field_names=output_field_names, | ||
function_type=FunctionType.BM25, | ||
) | ||
self.analyzer_params: Optional[Dict[Any, Any]] = analyzer_params | ||
self.enable_match = enable_match | ||
|
||
def get_input_field_schema_kwargs(self) -> dict: | ||
field_schema_kwargs: Dict[Any, Any] = { | ||
"enable_analyzer": True, | ||
"enable_match": self.enable_match, | ||
} | ||
if self.analyzer_params is not None: | ||
field_schema_kwargs["analyzer_params"] = self.analyzer_params | ||
return field_schema_kwargs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
VECTOR_FIELD = "vector" | ||
SPARSE_VECTOR_FIELD = "sparse" | ||
TEXT_FIELD = "text" | ||
PRIMARY_FIELD = "pk" |
Oops, something went wrong.