diff --git a/tests/test_vlad.py b/tests/test_vlad.py new file mode 100644 index 00000000..732f8db2 --- /dev/null +++ b/tests/test_vlad.py @@ -0,0 +1,26 @@ +import os +import unittest +import numpy as np +from gnes.encoder.numeric.vlad import VladEncoder + + +class TestVladEncoder(unittest.TestCase): + def setUp(self): + self.mock_train_data = np.random.random([200, 128]) + self.mock_eval_data = np.random.random([2, 2, 128]) + self.dump_path = os.path.join(os.path.dirname(__file__), 'vlad.bin') + + def tearDown(self): + if os.path.exists(self.dump_path): + os.remove(self.dump_path) + + def test_vlad_train(self): + model = VladEncoder(20) + model.train(self.mock_train_data) + self.assertEqual(model.centroids.shape, (20, 128)) + model.dump(self.dump_path) + + def test_vlad_encode(self): + model = VladEncoder.load(self.dump_path) + v = model.encode(self.mock_eval_data) + self.assertEqual(v.shape, (2, 2560))