Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
fix(encoder): fix vlad unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
Larryjianfeng committed Sep 9, 2019
1 parent ddf13ff commit ffc822b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
25 changes: 17 additions & 8 deletions gnes/encoder/numeric/vlad.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# limitations under the License.


import copy

import numpy as np

from ..base import BaseNumericEncoder
Expand All @@ -39,6 +37,11 @@ def kmeans_train(self, vecs):
kmeans = faiss.Kmeans(vecs.shape[1], self.num_clusters, niter=5, verbose=False)
kmeans.train(vecs)
self.centroids = kmeans.centroids
if self.using_faiss_pred:
self.faiss_index()

def faiss_index(self):
import faiss
self.index_flat = faiss.IndexFlatL2(self.centroids.shape[1])
self.index_flat.add(self.centroids)

Expand All @@ -53,10 +56,9 @@ def kmeans_pred(self, vecs):

@batching
def train(self, vecs: np.ndarray, *args, **kwargs):
vecs = vecs.reshape([-1, vecs.shape[-1]])
assert len(vecs) > self.num_clusters, 'number of data should be larger than number of clusters'
vecs_ = copy.deepcopy(vecs)
vecs_ = np.concatenate((list(vecs_[i] for i in range(len(vecs_)))), axis=0)
self.kmeans_train(vecs_)
self.kmeans_train(vecs)

@train_required
@batching
Expand All @@ -79,7 +81,14 @@ def _copy_from(self, x: 'VladEncoder') -> None:
self.centroids = x.centroids
self.using_faiss_pred = x.using_faiss_pred
if self.using_faiss_pred:
import faiss
self.index_flat = faiss.IndexFlatL2(self.centroids.shape[1])
self.index_flat.add(self.centroids)
self.faiss_index()

def __setstate__(self, state):
super().__setstate__(state)
if self.using_faiss_pred:
self.faiss_index()

def __getstate__(self):
state = super().__getstate__()
del state['index_flat']
return state
4 changes: 2 additions & 2 deletions tests/test_vlad.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

class TestVladEncoder(unittest.TestCase):
def setUp(self):
self.mock_train_data = np.random.random([200, 128])
self.mock_eval_data = np.random.random([2, 2, 128])
self.mock_train_data = np.random.random([1, 200, 128]).astype(np.float32)
self.mock_eval_data = np.random.random([2, 2, 128]).astype(np.float32)
self.dump_path = os.path.join(os.path.dirname(__file__), 'vlad.bin')

def tearDown(self):
Expand Down

0 comments on commit ffc822b

Please sign in to comment.