Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
fix(encoder): fix vlad to speed up centroids calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
Larryjianfeng committed Sep 9, 2019
1 parent c62fa3f commit 654a5ba
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions gnes/encoder/numeric/vlad.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class VladEncoder(BaseNumericEncoder):
batch_size = 2048

def __init__(self, num_clusters: int,
using_faiss_pred: bool=True,
using_faiss_pred: bool = False,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.num_clusters = num_clusters
Expand All @@ -37,6 +37,8 @@ def kmeans_train(self, vecs):
kmeans = faiss.Kmeans(vecs.shape[1], self.num_clusters, niter=5, verbose=False)
kmeans.train(vecs)
self.centroids = kmeans.centroids
self.centroids_l2 = np.sum(self.centroids**2, axis=1).reshape([1, -1])
self.centroids_trans = np.transpose(self.centroids)
if self.using_faiss_pred:
self.faiss_index()

Expand All @@ -50,9 +52,9 @@ def kmeans_pred(self, vecs):
_, pred = self.index_flat.search(vecs.astype(np.float32), 1)
return np.reshape(pred, [-1])
else:
vecs = np.reshape(vecs, [vecs.shape[0], 1, 1, vecs.shape[1]])
dist = np.sum(np.square(vecs - self.centroids), -1)
return np.argmax(-dist, axis=-1).reshape([-1]).astype(np.int32)
vecs_l2 = np.sum(vecs**2, axis=1).reshape([-1, 1])
dist = vecs_l2 + self.centroids_l2 - 2 * np.matmul(vecs, self.centroids_trans)
return np.argmax(dist, axis=-1).reshape([-1]).astype(np.int32)

@batching
def train(self, vecs: np.ndarray, *args, **kwargs):
Expand All @@ -79,6 +81,8 @@ def encode(self, vecs: np.ndarray, *args, **kwargs) -> np.ndarray:
def _copy_from(self, x: 'VladEncoder') -> None:
self.num_clusters = x.num_clusters
self.centroids = x.centroids
self.centroids_l2 = x.centroids_l2
self.centroids_trans = np.transpose(self.centroids)
self.using_faiss_pred = x.using_faiss_pred
if self.using_faiss_pred:
self.faiss_index()
Expand Down

0 comments on commit 654a5ba

Please sign in to comment.