diff --git a/tests/test_annoyindexer.py b/tests/test_annoyindexer.py index 66fe533b..3dde7494 100644 --- a/tests/test_annoyindexer.py +++ b/tests/test_annoyindexer.py @@ -4,6 +4,7 @@ import numpy as np from gnes.indexer.chunk.annoy import AnnoyIndexer +from gnes.indexer.chunk.numpy import NumpyIndexer class TestAnnoyIndexer(unittest.TestCase): @@ -27,3 +28,19 @@ def test_search(self): a.close() a.dump() a.dump_yaml() + + def test_numpy_indexer(self): + a = NumpyIndexer() + a.add(list(zip(list(range(10)), list(range(10)))), self.toy_data, [1.] * 10) + self.assertEqual(a.num_chunks, 10) + self.assertEqual(a.num_docs, 10) + top_1 = [i[0][0] for i in a.query(self.toy_data, top_k=1)] + self.assertEqual(top_1, list(range(10))) + a.close() + a.dump() + a.dump_yaml() + b = NumpyIndexer.load_yaml(a.yaml_full_path) + self.assertEqual(b.num_chunks, 10) + self.assertEqual(b.num_docs, 10) + top_1 = [i[0][0] for i in b.query(self.toy_data, top_k=1)] + self.assertEqual(top_1, list(range(10)))