diff --git a/gnes/encoder/numeric/vlad.py b/gnes/encoder/numeric/vlad.py index d64aad5b..d8d1cb0b 100644 --- a/gnes/encoder/numeric/vlad.py +++ b/gnes/encoder/numeric/vlad.py @@ -14,8 +14,6 @@ # limitations under the License. -import copy - import numpy as np from ..base import BaseNumericEncoder @@ -39,6 +37,11 @@ def kmeans_train(self, vecs): kmeans = faiss.Kmeans(vecs.shape[1], self.num_clusters, niter=5, verbose=False) kmeans.train(vecs) self.centroids = kmeans.centroids + if self.using_faiss_pred: + self.faiss_index() + + def faiss_index(self): + import faiss self.index_flat = faiss.IndexFlatL2(self.centroids.shape[1]) self.index_flat.add(self.centroids) @@ -53,10 +56,9 @@ def kmeans_pred(self, vecs): @batching def train(self, vecs: np.ndarray, *args, **kwargs): + vecs = vecs.reshape([-1, vecs.shape[-1]]) assert len(vecs) > self.num_clusters, 'number of data should be larger than number of clusters' - vecs_ = copy.deepcopy(vecs) - vecs_ = np.concatenate((list(vecs_[i] for i in range(len(vecs_)))), axis=0) - self.kmeans_train(vecs_) + self.kmeans_train(vecs) @train_required @batching @@ -79,7 +81,14 @@ def _copy_from(self, x: 'VladEncoder') -> None: self.centroids = x.centroids self.using_faiss_pred = x.using_faiss_pred if self.using_faiss_pred: - import faiss - self.index_flat = faiss.IndexFlatL2(self.centroids.shape[1]) - self.index_flat.add(self.centroids) + self.faiss_index() + + def __setstate__(self, state): + super().__setstate__(state) + if self.using_faiss_pred: + self.faiss_index() + def __getstate__(self): + state = super().__getstate__() + del state['index_flat'] + return state diff --git a/tests/test_vlad.py b/tests/test_vlad.py index 814ea866..66132e6d 100644 --- a/tests/test_vlad.py +++ b/tests/test_vlad.py @@ -6,8 +6,8 @@ class TestVladEncoder(unittest.TestCase): def setUp(self): - self.mock_train_data = np.random.random([200, 128]) - self.mock_eval_data = np.random.random([2, 2, 128]) + self.mock_train_data = np.random.random([1, 200, 128]).astype(np.float32) + self.mock_eval_data = np.random.random([2, 2, 128]).astype(np.float32) self.dump_path = os.path.join(os.path.dirname(__file__), 'vlad.bin') def tearDown(self):