diff --git a/gnes/encoder/numeric/vlad.py b/gnes/encoder/numeric/vlad.py index 39664aeb..433a0c97 100644 --- a/gnes/encoder/numeric/vlad.py +++ b/gnes/encoder/numeric/vlad.py @@ -14,8 +14,6 @@ # limitations under the License. -import copy - import numpy as np from ..base import BaseNumericEncoder @@ -25,51 +23,72 @@ class VladEncoder(BaseNumericEncoder): batch_size = 2048 - def __init__(self, num_clusters: int, *args, **kwargs): + def __init__(self, num_clusters: int, + using_faiss_pred: bool=True, + *args, **kwargs): super().__init__(*args, **kwargs) self.num_clusters = num_clusters + self.using_faiss_pred = using_faiss_pred self.centroids = None + self.index_flat = None def kmeans_train(self, vecs): import faiss - kmeans = faiss.Kmeans(vecs.shape[1], self.num_clusters, niter=5, verbose=False) kmeans.train(vecs) self.centroids = kmeans.centroids + if self.using_faiss_pred: + self.faiss_index() + + def faiss_index(self): + import faiss + self.index_flat = faiss.IndexFlatL2(self.centroids.shape[1]) + self.index_flat.add(self.centroids) def kmeans_pred(self, vecs): - vecs = np.reshape(vecs, [vecs.shape[0], 1, 1, vecs.shape[1]]) - dist = np.sum(np.square(vecs - self.centroids), -1) - return np.argmax(-dist, axis=-1).astype(np.int64) + if self.using_faiss_pred: + _, pred = self.index_flat.search(vecs.astype(np.float32), 1) + return np.reshape(pred, [-1]) + else: + vecs = np.reshape(vecs, [vecs.shape[0], 1, 1, vecs.shape[1]]) + dist = np.sum(np.square(vecs - self.centroids), -1) + return np.argmax(-dist, axis=-1).reshape([-1]).astype(np.int32) @batching def train(self, vecs: np.ndarray, *args, **kwargs): + vecs = vecs.reshape([-1, vecs.shape[-1]]) assert len(vecs) > self.num_clusters, 'number of data should be larger than number of clusters' - vecs_ = copy.deepcopy(vecs) - vecs_ = np.concatenate((list(vecs_[i] for i in range(len(vecs_)))), axis=0) - self.kmeans_train(vecs_) + self.kmeans_train(vecs) @train_required @batching def encode(self, vecs: np.ndarray, *args, **kwargs) -> np.ndarray: - vecs_ = copy.deepcopy(vecs) - vecs_ = np.concatenate((list(vecs_[i] for i in range(len(vecs_)))), axis=0) - - knn_output = self.kmeans_pred(vecs_) - knn_output = [knn_output[i:i + vecs.shape[1]] for i in range(0, len(knn_output), vecs.shape[1])] + knn_output = [self.kmeans_pred(vecs_) for vecs_ in vecs] output = [] for chunk_count, chunk in enumerate(vecs): res = np.zeros((self.centroids.shape[0], self.centroids.shape[1])) for frame_count, frame in enumerate(chunk): - center_index = knn_output[chunk_count][frame_count][0] + center_index = knn_output[chunk_count][frame_count] res[center_index] += (frame - self.centroids[center_index]) - output.append(res) + res = res.reshape([-1]) + output.append(res / np.sum(res**2)**0.5) - output = np.array(list(map(lambda x: x.reshape(1, -1), output)), dtype=np.float32) - output = np.squeeze(output, axis=1) - return output + return np.array(output, dtype=np.float32) def _copy_from(self, x: 'VladEncoder') -> None: self.num_clusters = x.num_clusters self.centroids = x.centroids + self.using_faiss_pred = x.using_faiss_pred + if self.using_faiss_pred: + self.faiss_index() + + def __setstate__(self, state): + super().__setstate__(state) + if self.using_faiss_pred: + self.faiss_index() + + def __getstate__(self): + state = super().__getstate__() + del state['index_flat'] + return state diff --git a/tests/test_vlad.py b/tests/test_vlad.py new file mode 100644 index 00000000..66132e6d --- /dev/null +++ b/tests/test_vlad.py @@ -0,0 +1,29 @@ +import os +import unittest +import numpy as np +from gnes.encoder.numeric.vlad import VladEncoder + + +class TestVladEncoder(unittest.TestCase): + def setUp(self): + self.mock_train_data = np.random.random([1, 200, 128]).astype(np.float32) + self.mock_eval_data = np.random.random([2, 2, 128]).astype(np.float32) + self.dump_path = os.path.join(os.path.dirname(__file__), 'vlad.bin') + + def tearDown(self): + if os.path.exists(self.dump_path): + os.remove(self.dump_path) + + def test_vlad_train(self): + model = VladEncoder(20) + model.train(self.mock_train_data) + self.assertEqual(model.centroids.shape, (20, 128)) + v = model.encode(self.mock_eval_data) + self.assertEqual(v.shape, (2, 2560)) + + def test_vlad_dump_load(self): + model = VladEncoder(20) + model.train(self.mock_train_data) + model.dump(self.dump_path) + model_new = VladEncoder.load(self.dump_path) + self.assertEqual(model_new.centroids.shape, (20, 128))