Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
fix(encoder): fix vald in numeric encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
Larryjianfeng committed Sep 9, 2019
1 parent 2c9d5ce commit f8e18d0
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions tests/test_vlad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
import unittest
import numpy as np
from gnes.encoder.numeric.vlad import VladEncoder


class TestVladEncoder(unittest.TestCase):
def setUp(self):
self.mock_train_data = np.random.random([200, 128])
self.mock_eval_data = np.random.random([2, 2, 128])
self.dump_path = os.path.join(os.path.dirname(__file__), 'vlad.bin')

def tearDown(self):
if os.path.exists(self.dump_path):
os.remove(self.dump_path)

def test_vlad_train(self):
model = VladEncoder(20)
model.train(self.mock_train_data)
self.assertEqual(model.centroids.shape, (20, 128))
model.dump(self.dump_path)

def test_vlad_encode(self):
model = VladEncoder.load(self.dump_path)
v = model.encode(self.mock_eval_data)
self.assertEqual(v.shape, (2, 2560))

0 comments on commit f8e18d0

Please sign in to comment.