-
Notifications
You must be signed in to change notification settings - Fork 2
/
5-weaviate.py
104 lines (85 loc) · 3.09 KB
/
5-weaviate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import struct
import time
import weaviate
from weaviate.classes.init import AdditionalConfig, Timeout
import hashlib
from weaviate.classes.config import Configure, Property, DataType, VectorDistances
from weaviate.classes.query import MetadataQuery
from sentence_transformers import SentenceTransformer
from common import read_verses
md5_hash = hashlib.md5()
client = weaviate.connect_to_local(
port=8080,
grpc_port=50051,
additional_config=AdditionalConfig(
timeout=Timeout(init=30, query=60, insert=120) # Values in seconds
)
) # Connect with default parameters
collection_name = "collection_768"
try:
collection = client.collections.get(collection_name)
except weaviate.exceptions.UnexpectedStatusCodeError as e:
collection = client.collections.create(
name=collection_name,
vector_index_config=Configure.VectorIndex.hnsw(
distance_metric=VectorDistances.COSINE
),
properties=[
Property(
name="text",
data_type=DataType.TEXT,
),
Property(
name="meta",
data_type=DataType.OBJECT,
index_filterable=True,
index_searchable=True,
),
]
)
def weaviate_inserts(chunk):
# collection = client.collections.get(collection_name)
data = []
for id, text, meta, embedding in chunk:
# md5_hash.update(id.encode('utf-8'))
# id = md5_hash.hexdigest()
# Ensure embedding is a list of floats
if isinstance(embedding, bytes):
embedding = list(struct.unpack(f'{len(embedding) // 4}f', embedding))
data.append(
{"embedding": embedding, "text": text, "meta": meta}
)
start_time = time.perf_counter()
with collection.batch.dynamic() as batch:
for data_row in data:
batch.add_object(
properties={"text": data_row["text"], "meta": data_row["meta"]},
vector=data_row["embedding"],
)
end_time = time.perf_counter()
elapsed_time = end_time - start_time
print(f"batch insert: {elapsed_time} sec")
print(collection.batch.failed_objects)
return elapsed_time
def weaviate_search(text):
response = collection.query.near_vector(
near_vector=embeddings,
limit=10,
return_metadata=MetadataQuery(distance=True)
)
for o in response.objects:
print(f"Text: {o.properties['text']}; Similarity: {1-o.metadata.distance}")
read_verses(weaviate_inserts, max_items=1400000, minibatch_size=1000)
aggregation = collection.aggregate.over_all(total_count=True)
print(aggregation.total_count)
model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-mpnet-base-v2')
embeddings = model.encode("воскресил из мертвых")
start_time = time.perf_counter()
weaviate_search(embeddings)
weaviate_search(embeddings)
weaviate_search(embeddings)
weaviate_search(embeddings)
weaviate_search(embeddings)
end_time = time.perf_counter()
elapsed_time = end_time - start_time
print(f"Search time: {elapsed_time/5} sec")