Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 8.x] support new dense vector quantization in 8.16 #1951

Merged
merged 1 commit into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ jobs:
"3.12",
"3.13",
]
es-version: [8.0.0, 8.15.0]
es-version: [8.0.0, 8.16.0]

steps:
- name: Remove irrelevant software to free up disk space
Expand Down
14 changes: 12 additions & 2 deletions elasticsearch_dsl/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,13 +389,23 @@ def _deserialize(self, data: Any) -> float:
return float(data)


class DenseVector(Float):
class DenseVector(Field):
name = "dense_vector"
_coerce = True

def __init__(self, **kwargs: Any):
kwargs["multi"] = True
self._element_type = kwargs.get("element_type", "float")
if self._element_type in ["float", "byte"]:
kwargs["multi"] = True
super().__init__(**kwargs)

def _deserialize(self, data: Any) -> Any:
if self._element_type == "float":
return float(data)
elif self._element_type == "byte":
return int(data)
return data


class SparseVector(Field):
name = "sparse_vector"
Expand Down
57 changes: 56 additions & 1 deletion tests/test_integration/_async/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from datetime import datetime
from ipaddress import ip_address
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Union
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Tuple, Union

import pytest
from elasticsearch import AsyncElasticsearch, ConflictError, NotFoundError
Expand All @@ -37,6 +37,7 @@
Binary,
Boolean,
Date,
DenseVector,
Double,
InnerDoc,
Ip,
Expand Down Expand Up @@ -795,3 +796,57 @@ async def gen3() -> AsyncIterator[Union[Doc, Dict[str, Any]]]:
"age": 45,
"languages": ["es"],
}


@pytest.mark.asyncio
async def test_legacy_dense_vector(
async_client: AsyncElasticsearch, es_version: Tuple[int, ...]
) -> None:
if es_version >= (8, 16):
pytest.skip("this test is a legacy version for Elasticsearch 8.15 or older")

class Doc(AsyncDocument):
float_vector: List[float] = mapped_field(DenseVector(dims=3))

class Index:
name = "vectors"

await Doc._index.delete(ignore_unavailable=True)
await Doc.init()

doc = Doc(float_vector=[1.0, 1.2, 2.3])
await doc.save(refresh=True)

docs = await Doc.search().execute()
assert len(docs) == 1
assert docs[0].float_vector == doc.float_vector


@pytest.mark.asyncio
async def test_dense_vector(
async_client: AsyncElasticsearch, es_version: Tuple[int, ...]
) -> None:
if es_version < (8, 16):
pytest.skip("this test requires Elasticsearch 8.16 or newer")

class Doc(AsyncDocument):
float_vector: List[float] = mapped_field(DenseVector())
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
bit_vector: str = mapped_field(DenseVector(element_type="bit"))

class Index:
name = "vectors"

await Doc._index.delete(ignore_unavailable=True)
await Doc.init()

doc = Doc(
float_vector=[1.0, 1.2, 2.3], byte_vector=[12, 23, 34, 45], bit_vector="12abf0"
)
await doc.save(refresh=True)

docs = await Doc.search().execute()
assert len(docs) == 1
assert docs[0].float_vector == doc.float_vector
assert docs[0].byte_vector == doc.byte_vector
assert docs[0].bit_vector == doc.bit_vector
55 changes: 54 additions & 1 deletion tests/test_integration/_sync/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from datetime import datetime
from ipaddress import ip_address
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Union
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Tuple, Union

import pytest
from elasticsearch import ConflictError, Elasticsearch, NotFoundError
Expand All @@ -35,6 +35,7 @@
Binary,
Boolean,
Date,
DenseVector,
Document,
Double,
InnerDoc,
Expand Down Expand Up @@ -789,3 +790,55 @@ def gen3() -> Iterator[Union[Doc, Dict[str, Any]]]:
"age": 45,
"languages": ["es"],
}


@pytest.mark.sync
def test_legacy_dense_vector(
client: Elasticsearch, es_version: Tuple[int, ...]
) -> None:
if es_version >= (8, 16):
pytest.skip("this test is a legacy version for Elasticsearch 8.15 or older")

class Doc(Document):
float_vector: List[float] = mapped_field(DenseVector(dims=3))

class Index:
name = "vectors"

Doc._index.delete(ignore_unavailable=True)
Doc.init()

doc = Doc(float_vector=[1.0, 1.2, 2.3])
doc.save(refresh=True)

docs = Doc.search().execute()
assert len(docs) == 1
assert docs[0].float_vector == doc.float_vector


@pytest.mark.sync
def test_dense_vector(client: Elasticsearch, es_version: Tuple[int, ...]) -> None:
if es_version < (8, 16):
pytest.skip("this test requires Elasticsearch 8.16 or newer")

class Doc(Document):
float_vector: List[float] = mapped_field(DenseVector())
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
bit_vector: str = mapped_field(DenseVector(element_type="bit"))

class Index:
name = "vectors"

Doc._index.delete(ignore_unavailable=True)
Doc.init()

doc = Doc(
float_vector=[1.0, 1.2, 2.3], byte_vector=[12, 23, 34, 45], bit_vector="12abf0"
)
doc.save(refresh=True)

docs = Doc.search().execute()
assert len(docs) == 1
assert docs[0].float_vector == doc.float_vector
assert docs[0].byte_vector == doc.byte_vector
assert docs[0].bit_vector == doc.bit_vector
Loading