-
Notifications
You must be signed in to change notification settings - Fork 120
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: migrate script with different destination collection name
allowing to rename collections on migration while using the same business logic as the oriignal migrate.py has - refer to: https://discord.com/channels/907569970500743200/1237049014852128858/1237057774274220063 - refer to also: https://discord.com/channels/907569970500743200/1178905971712606258/1215304613314560000
- Loading branch information
1 parent
23ef4c6
commit bd08873
Showing
2 changed files
with
207 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
## How to run migrate_different_name.py | ||
|
||
### Example | ||
```python | ||
from qdrant_client import QdrantClient | ||
from migrate_custom import migrate_new | ||
|
||
# Create Qdrant client(s) with (existing collections) | ||
source_client = QdrantClient('http://qdrant-client', prefer_grpc=True, api_key='...') | ||
|
||
# Destination client (where new collections will be created) | ||
source_client = QdrantClient('http://qdrant-client', prefer_grpc=True, api_key='...') | ||
|
||
# set below if you want to migrate in the same qdrant client. | ||
source_client = dest_client | ||
|
||
# Define your collection mapping | ||
collection_mapping = { | ||
"old_name": "new_name", | ||
# Add more mappings as needed | ||
} | ||
|
||
# print time | ||
import time | ||
print("Start time: ", time.ctime()) | ||
|
||
# Call the migrate function | ||
migrate_new( | ||
source_client=source_client, | ||
dest_client=dest_client, | ||
collection_mapping=collection_mapping, | ||
recreate_on_collision=True, | ||
batch_size=100 | ||
) | ||
|
||
print("Migration completed successfully!") | ||
print("End time: ", time.ctime()) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
from typing import Dict, Optional, Any | ||
from qdrant_client._pydantic_compat import to_dict | ||
from qdrant_client.client_base import QdrantBase | ||
from qdrant_client.http import models | ||
from qdrant_client.http.models import PointStruct | ||
from qdrant_client.migrate.migrate import ( | ||
_has_custom_shards, | ||
_find_collisions, | ||
_recreate_payload_schema, | ||
) | ||
|
||
|
||
def migrate_with_different_name( | ||
source_client: QdrantBase, | ||
dest_client: QdrantBase, | ||
collection_mapping: Dict[str, str], | ||
recreate_on_collision: bool = False, | ||
batch_size: int = 100, | ||
) -> None: | ||
""" | ||
Migrate collections from source client to destination client with optional renaming | ||
Args: | ||
source_client (QdrantBase): Source client | ||
dest_client (QdrantBase): Destination client | ||
collection_mapping (Dict[str, str]): Mapping of source collection names to destination names | ||
recreate_on_collision (bool, optional): If True - recreate collection if it exists, otherwise | ||
raise ValueError. | ||
batch_size (int, optional): Batch size for scrolling and uploading vectors. Defaults to 100. | ||
""" | ||
source_collection_names = list(collection_mapping.keys()) | ||
dest_collection_names = list(collection_mapping.values()) | ||
|
||
if any(_has_custom_shards(source_client, name) for name in source_collection_names): | ||
raise ValueError("Migration of collections with custom shards is not supported yet") | ||
|
||
collisions = _find_collisions(dest_client, dest_collection_names) | ||
absent_dest_collections = set(dest_collection_names) - set(collisions) | ||
|
||
if collisions and not recreate_on_collision: | ||
raise ValueError(f"Collections already exist in dest_client: {collisions}") | ||
|
||
for src_name, dest_name in collection_mapping.items(): | ||
if dest_name in absent_dest_collections: | ||
_recreate_collection(source_client, dest_client, src_name, dest_name) | ||
elif recreate_on_collision: | ||
dest_client.delete_collection(dest_name) | ||
_recreate_collection(source_client, dest_client, src_name, dest_name) | ||
_migrate_collection(source_client, dest_client, src_name, dest_name, batch_size) | ||
|
||
|
||
def _recreate_collection( | ||
source_client: QdrantBase, | ||
dest_client: QdrantBase, | ||
src_collection_name: str, | ||
dest_collection_name: str, | ||
) -> None: | ||
src_collection_info = source_client.get_collection(src_collection_name) | ||
src_config = src_collection_info.config | ||
src_payload_schema = src_collection_info.payload_schema | ||
|
||
dest_client.recreate_collection( | ||
dest_collection_name, | ||
vectors_config=src_config.params.vectors, | ||
sparse_vectors_config=src_config.params.sparse_vectors, | ||
shard_number=src_config.params.shard_number, | ||
replication_factor=src_config.params.replication_factor, | ||
write_consistency_factor=src_config.params.write_consistency_factor, | ||
on_disk_payload=src_config.params.on_disk_payload, | ||
hnsw_config=models.HnswConfigDiff(**to_dict(src_config.hnsw_config)), | ||
optimizers_config=models.OptimizersConfigDiff(**to_dict(src_config.optimizer_config)), | ||
wal_config=models.WalConfigDiff(**to_dict(src_config.wal_config)), | ||
quantization_config=src_config.quantization_config, | ||
) | ||
|
||
_recreate_payload_schema(dest_client, dest_collection_name, src_payload_schema) | ||
|
||
|
||
def _migrate_collection_old( | ||
source_client: QdrantBase, | ||
dest_client: QdrantBase, | ||
src_collection_name: str, | ||
dest_collection_name: str, | ||
batch_size: int = 100, | ||
) -> None: | ||
"""Migrate collection from source client to destination client | ||
Args: | ||
src_collection_name (str): Source collection name | ||
dest_collection_name (str): Destination collection name | ||
source_client (QdrantBase): Source client | ||
dest_client (QdrantBase): Destination client | ||
batch_size (int, optional): Batch size for scrolling and uploading vectors. Defaults to 100. | ||
""" | ||
offset: Optional[Dict[str, Any]] = None | ||
total_migrated = 0 | ||
|
||
while True: | ||
records, offset = source_client.scroll( | ||
src_collection_name, | ||
offset=offset, | ||
limit=batch_size, | ||
with_vectors=True, | ||
with_payload=True, | ||
) | ||
|
||
if not records: | ||
break | ||
|
||
points = [ | ||
PointStruct(id=record.id, vector=record.vector, payload=record.payload) | ||
for record in records | ||
] | ||
dest_client.upsert(dest_collection_name, points=points) | ||
|
||
total_migrated += len(records) | ||
print(f"Migrated {total_migrated} points so far...") | ||
|
||
if offset is None: | ||
break | ||
|
||
source_client_vectors_count = source_client.count(src_collection_name).count | ||
dest_client_vectors_count = dest_client.count(dest_collection_name).count | ||
|
||
print( | ||
f"Migration completed. Source count: {source_client_vectors_count}, Destination count: {dest_client_vectors_count}" | ||
) | ||
|
||
assert ( | ||
source_client_vectors_count == dest_client_vectors_count | ||
), f"Migration failed, vectors count are not equal: source vector count {source_client_vectors_count}, dest vector count {dest_client_vectors_count}" | ||
|
||
print( | ||
f"Successfully migrated {total_migrated} points from {src_collection_name} to {dest_collection_name}" | ||
) | ||
|
||
|
||
def _migrate_collection( | ||
source_client: QdrantBase, | ||
dest_client: QdrantBase, | ||
src_collection_name: str, | ||
dest_collection_name: str, | ||
batch_size: int = 100, | ||
) -> None: | ||
"""Migrate collection from source client to destination client | ||
Args: | ||
src_collection_name (str): Source collection name | ||
dest_collection_name (str): Destination collection name | ||
source_client (QdrantBase): Source client | ||
dest_client (QdrantBase): Destination client | ||
batch_size (int, optional): Batch size for scrolling and uploading vectors. Defaults to 100. | ||
""" | ||
records, next_offset = source_client.scroll(src_collection_name, limit=2, with_vectors=True) | ||
dest_client.upload_points(dest_collection_name, records, wait=True) # type: ignore | ||
|
||
# upload_records has been deprecated due to the usage of models.Record; models.Record has been deprecated as a | ||
# structure for uploading due to a `shard_key` field, and now is used only as a result structure. | ||
# since shard_keys are not supported in migration, we can safely type ignore here and use Records for uploading | ||
while next_offset is not None: | ||
records, next_offset = source_client.scroll( | ||
src_collection_name, offset=next_offset, limit=batch_size, with_vectors=True | ||
) | ||
dest_client.upload_points(dest_collection_name, records, wait=True) # type: ignore | ||
source_client_vectors_count = source_client.count(src_collection_name).count | ||
dest_client_vectors_count = dest_client.count(dest_collection_name).count | ||
assert ( | ||
source_client_vectors_count == dest_client_vectors_count | ||
), f"Migration failed, vectors count are not equal: source vector count {source_client_vectors_count}, dest vector count {dest_client_vectors_count}" |