From 1ba4e11cb7f18b97cb35faed61b7d82fb512cd84 Mon Sep 17 00:00:00 2001 From: Larry Yan Date: Mon, 9 Sep 2019 12:25:20 +0800 Subject: [PATCH] fix(encoder): fix vald encoder and add unittest --- gnes/encoder/numeric/vlad.py | 40 ++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/gnes/encoder/numeric/vlad.py b/gnes/encoder/numeric/vlad.py index 39664aeb..9a6cf63c 100644 --- a/gnes/encoder/numeric/vlad.py +++ b/gnes/encoder/numeric/vlad.py @@ -25,22 +25,31 @@ class VladEncoder(BaseNumericEncoder): batch_size = 2048 - def __init__(self, num_clusters: int, *args, **kwargs): + def __init__(self, num_clusters: int, + using_faiss_pred: True, + *args, **kwargs): super().__init__(*args, **kwargs) self.num_clusters = num_clusters + self.using_faiss_pred = using_faiss_pred self.centroids = None + self.index_flat = None def kmeans_train(self, vecs): import faiss - kmeans = faiss.Kmeans(vecs.shape[1], self.num_clusters, niter=5, verbose=False) kmeans.train(vecs) self.centroids = kmeans.centroids + self.index_flat = faiss.IndexFlatL2(self.centroids.shape[1]) + self.index_flat.add(self.centroids) def kmeans_pred(self, vecs): - vecs = np.reshape(vecs, [vecs.shape[0], 1, 1, vecs.shape[1]]) - dist = np.sum(np.square(vecs - self.centroids), -1) - return np.argmax(-dist, axis=-1).astype(np.int64) + if self.using_faiss_pred: + D, I = self.index_flat.search(vecs.astype(np.float32), 1) + return np.reshape(I, [-1]) + else: + vecs = np.reshape(vecs, [vecs.shape[0], 1, 1, vecs.shape[1]]) + dist = np.sum(np.square(vecs - self.centroids), -1) + return np.argmax(-dist, axis=-1).reshape([-1]).astype(np.int32) @batching def train(self, vecs: np.ndarray, *args, **kwargs): @@ -52,24 +61,25 @@ def train(self, vecs: np.ndarray, *args, **kwargs): @train_required @batching def encode(self, vecs: np.ndarray, *args, **kwargs) -> np.ndarray: - vecs_ = copy.deepcopy(vecs) - vecs_ = np.concatenate((list(vecs_[i] for i in range(len(vecs_)))), axis=0) - - knn_output = self.kmeans_pred(vecs_) - knn_output = [knn_output[i:i + vecs.shape[1]] for i in range(0, len(knn_output), vecs.shape[1])] + knn_output = [self.kmeans_pred(vecs_) for vecs_ in vecs] output = [] for chunk_count, chunk in enumerate(vecs): res = np.zeros((self.centroids.shape[0], self.centroids.shape[1])) for frame_count, frame in enumerate(chunk): - center_index = knn_output[chunk_count][frame_count][0] + center_index = knn_output[chunk_count][frame_count] res[center_index] += (frame - self.centroids[center_index]) - output.append(res) + res = res.reshape([-1]) + output.append(res / np.sum(res**2)**0.5) - output = np.array(list(map(lambda x: x.reshape(1, -1), output)), dtype=np.float32) - output = np.squeeze(output, axis=1) - return output + return np.array(output, dtype=np.float32) def _copy_from(self, x: 'VladEncoder') -> None: self.num_clusters = x.num_clusters self.centroids = x.centroids + self.using_faiss_pred = x.using_faiss_pred + if self.using_faiss_pred: + import faiss + self.index_flat = faiss.IndexFlatL2(self.centroids.shape[1]) + self.index_flat.add(self.centroids) +