diff --git a/gnes/indexer/base.py b/gnes/indexer/base.py index 8fd89876..b7dac4c4 100644 --- a/gnes/indexer/base.py +++ b/gnes/indexer/base.py @@ -36,8 +36,10 @@ def __init__(self, :type is_big_score_similar: when set to true, then larger score means more similar """ super().__init__(*args, **kwargs) - self.normalize_fn = normalize_fn(context=self) if normalize_fn else ModifierScoreFn(context=self) - self.score_fn = score_fn(context=self) if score_fn else ModifierScoreFn(context=self) + self.normalize_fn = normalize_fn if normalize_fn else ModifierScoreFn(context=self) + self.score_fn = score_fn if score_fn else ModifierScoreFn(context=self) + self.normalize_fn._context = self + self.score_fn._context = self self.is_big_score_similar = is_big_score_similar self._num_docs = 0 self._num_chunks = 0 diff --git a/tests/test_score_fn.py b/tests/test_score_fn.py index adb7697f..936bf244 100644 --- a/tests/test_score_fn.py +++ b/tests/test_score_fn.py @@ -40,7 +40,7 @@ def test_combine_score_fn(self): q_chunk.embedding.CopyFrom(array2blob(np.array([3, 3, 3]))) for _fn in [WeightedChunkOffsetScoreFn, CoordChunkScoreFn, TFIDFChunkScoreFn, BM25ChunkScoreFn]: - indexer = NumpyIndexer(helper_indexer=ListKeyIndexer(), score_fn=_fn) + indexer = NumpyIndexer(helper_indexer=ListKeyIndexer(), score_fn=_fn()) indexer.add(keys=[(0, 1), (1, 2)], vectors=np.array([[1, 1, 1], [2, 2, 2]]), weights=[0.5, 0.8]) queried_result = indexer.query_and_score(q_chunks=[q_chunk], top_k=2)