Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
fix(encoder): fix bug in vlad encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
Larryjianfeng committed Sep 9, 2019
1 parent 1ba4e11 commit ddf13ff
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
2 changes: 1 addition & 1 deletion gnes/encoder/numeric/vlad.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class VladEncoder(BaseNumericEncoder):
batch_size = 2048

def __init__(self, num_clusters: int,
using_faiss_pred: True,
using_faiss_pred: bool=True,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.num_clusters = num_clusters
Expand Down
11 changes: 7 additions & 4 deletions tests/test_vlad.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@ def test_vlad_train(self):
model = VladEncoder(20)
model.train(self.mock_train_data)
self.assertEqual(model.centroids.shape, (20, 128))
model.dump(self.dump_path)

def test_vlad_encode(self):
model = VladEncoder.load(self.dump_path)
v = model.encode(self.mock_eval_data)
self.assertEqual(v.shape, (2, 2560))

def test_vlad_dump_load(self):
model = VladEncoder(20)
model.train(self.mock_train_data)
model.dump(self.dump_path)
model_new = VladEncoder.load(self.dump_path)
self.assertEqual(model_new.centroids.shape, (20, 128))

0 comments on commit ddf13ff

Please sign in to comment.