Skip to content

Commit

Permalink
support new dense vector quantization in 8.16 (#1948) (#1951)
Browse files Browse the repository at this point in the history
* support new dense vector quantization in 8.16

* use 8.16 in CI builds

(cherry picked from commit 5de355e)

Co-authored-by: Miguel Grinberg <[email protected]>
  • Loading branch information
github-actions[bot] and miguelgrinberg authored Dec 12, 2024
1 parent c0a4871 commit ea260bd
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ jobs:
"3.12",
"3.13",
]
es-version: [8.0.0, 8.15.0]
es-version: [8.0.0, 8.16.0]

steps:
- name: Remove irrelevant software to free up disk space
Expand Down
14 changes: 12 additions & 2 deletions elasticsearch_dsl/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,13 +389,23 @@ def _deserialize(self, data: Any) -> float:
return float(data)


class DenseVector(Float):
class DenseVector(Field):
name = "dense_vector"
_coerce = True

def __init__(self, **kwargs: Any):
kwargs["multi"] = True
self._element_type = kwargs.get("element_type", "float")
if self._element_type in ["float", "byte"]:
kwargs["multi"] = True
super().__init__(**kwargs)

def _deserialize(self, data: Any) -> Any:
if self._element_type == "float":
return float(data)
elif self._element_type == "byte":
return int(data)
return data


class SparseVector(Field):
name = "sparse_vector"
Expand Down
57 changes: 56 additions & 1 deletion tests/test_integration/_async/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from datetime import datetime
from ipaddress import ip_address
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Union
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Tuple, Union

import pytest
from elasticsearch import AsyncElasticsearch, ConflictError, NotFoundError
Expand All @@ -37,6 +37,7 @@
Binary,
Boolean,
Date,
DenseVector,
Double,
InnerDoc,
Ip,
Expand Down Expand Up @@ -795,3 +796,57 @@ async def gen3() -> AsyncIterator[Union[Doc, Dict[str, Any]]]:
"age": 45,
"languages": ["es"],
}


@pytest.mark.asyncio
async def test_legacy_dense_vector(
async_client: AsyncElasticsearch, es_version: Tuple[int, ...]
) -> None:
if es_version >= (8, 16):
pytest.skip("this test is a legacy version for Elasticsearch 8.15 or older")

class Doc(AsyncDocument):
float_vector: List[float] = mapped_field(DenseVector(dims=3))

class Index:
name = "vectors"

await Doc._index.delete(ignore_unavailable=True)
await Doc.init()

doc = Doc(float_vector=[1.0, 1.2, 2.3])
await doc.save(refresh=True)

docs = await Doc.search().execute()
assert len(docs) == 1
assert docs[0].float_vector == doc.float_vector


@pytest.mark.asyncio
async def test_dense_vector(
async_client: AsyncElasticsearch, es_version: Tuple[int, ...]
) -> None:
if es_version < (8, 16):
pytest.skip("this test requires Elasticsearch 8.16 or newer")

class Doc(AsyncDocument):
float_vector: List[float] = mapped_field(DenseVector())
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
bit_vector: str = mapped_field(DenseVector(element_type="bit"))

class Index:
name = "vectors"

await Doc._index.delete(ignore_unavailable=True)
await Doc.init()

doc = Doc(
float_vector=[1.0, 1.2, 2.3], byte_vector=[12, 23, 34, 45], bit_vector="12abf0"
)
await doc.save(refresh=True)

docs = await Doc.search().execute()
assert len(docs) == 1
assert docs[0].float_vector == doc.float_vector
assert docs[0].byte_vector == doc.byte_vector
assert docs[0].bit_vector == doc.bit_vector
55 changes: 54 additions & 1 deletion tests/test_integration/_sync/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from datetime import datetime
from ipaddress import ip_address
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Union
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Tuple, Union

import pytest
from elasticsearch import ConflictError, Elasticsearch, NotFoundError
Expand All @@ -35,6 +35,7 @@
Binary,
Boolean,
Date,
DenseVector,
Document,
Double,
InnerDoc,
Expand Down Expand Up @@ -789,3 +790,55 @@ def gen3() -> Iterator[Union[Doc, Dict[str, Any]]]:
"age": 45,
"languages": ["es"],
}


@pytest.mark.sync
def test_legacy_dense_vector(
client: Elasticsearch, es_version: Tuple[int, ...]
) -> None:
if es_version >= (8, 16):
pytest.skip("this test is a legacy version for Elasticsearch 8.15 or older")

class Doc(Document):
float_vector: List[float] = mapped_field(DenseVector(dims=3))

class Index:
name = "vectors"

Doc._index.delete(ignore_unavailable=True)
Doc.init()

doc = Doc(float_vector=[1.0, 1.2, 2.3])
doc.save(refresh=True)

docs = Doc.search().execute()
assert len(docs) == 1
assert docs[0].float_vector == doc.float_vector


@pytest.mark.sync
def test_dense_vector(client: Elasticsearch, es_version: Tuple[int, ...]) -> None:
if es_version < (8, 16):
pytest.skip("this test requires Elasticsearch 8.16 or newer")

class Doc(Document):
float_vector: List[float] = mapped_field(DenseVector())
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
bit_vector: str = mapped_field(DenseVector(element_type="bit"))

class Index:
name = "vectors"

Doc._index.delete(ignore_unavailable=True)
Doc.init()

doc = Doc(
float_vector=[1.0, 1.2, 2.3], byte_vector=[12, 23, 34, 45], bit_vector="12abf0"
)
doc.save(refresh=True)

docs = Doc.search().execute()
assert len(docs) == 1
assert docs[0].float_vector == doc.float_vector
assert docs[0].byte_vector == doc.byte_vector
assert docs[0].bit_vector == doc.bit_vector

0 comments on commit ea260bd

Please sign in to comment.