From 654a5ba40a30ef51d57ab6ff0942c77d68d5a102 Mon Sep 17 00:00:00 2001 From: Larry Yan Date: Mon, 9 Sep 2019 16:53:15 +0800 Subject: [PATCH] fix(encoder): fix vlad to speed up centroids calculation --- gnes/encoder/numeric/vlad.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/gnes/encoder/numeric/vlad.py b/gnes/encoder/numeric/vlad.py index 433a0c97..772c380b 100644 --- a/gnes/encoder/numeric/vlad.py +++ b/gnes/encoder/numeric/vlad.py @@ -24,7 +24,7 @@ class VladEncoder(BaseNumericEncoder): batch_size = 2048 def __init__(self, num_clusters: int, - using_faiss_pred: bool=True, + using_faiss_pred: bool = False, *args, **kwargs): super().__init__(*args, **kwargs) self.num_clusters = num_clusters @@ -37,6 +37,8 @@ def kmeans_train(self, vecs): kmeans = faiss.Kmeans(vecs.shape[1], self.num_clusters, niter=5, verbose=False) kmeans.train(vecs) self.centroids = kmeans.centroids + self.centroids_l2 = np.sum(self.centroids**2, axis=1).reshape([1, -1]) + self.centroids_trans = np.transpose(self.centroids) if self.using_faiss_pred: self.faiss_index() @@ -50,9 +52,9 @@ def kmeans_pred(self, vecs): _, pred = self.index_flat.search(vecs.astype(np.float32), 1) return np.reshape(pred, [-1]) else: - vecs = np.reshape(vecs, [vecs.shape[0], 1, 1, vecs.shape[1]]) - dist = np.sum(np.square(vecs - self.centroids), -1) - return np.argmax(-dist, axis=-1).reshape([-1]).astype(np.int32) + vecs_l2 = np.sum(vecs**2, axis=1).reshape([-1, 1]) + dist = vecs_l2 + self.centroids_l2 - 2 * np.matmul(vecs, self.centroids_trans) + return np.argmax(dist, axis=-1).reshape([-1]).astype(np.int32) @batching def train(self, vecs: np.ndarray, *args, **kwargs): @@ -79,6 +81,8 @@ def encode(self, vecs: np.ndarray, *args, **kwargs) -> np.ndarray: def _copy_from(self, x: 'VladEncoder') -> None: self.num_clusters = x.num_clusters self.centroids = x.centroids + self.centroids_l2 = x.centroids_l2 + self.centroids_trans = np.transpose(self.centroids) self.using_faiss_pred = x.using_faiss_pred if self.using_faiss_pred: self.faiss_index()