From d2d578c400ca61820795e15678b788f60fafa03a Mon Sep 17 00:00:00 2001 From: hanhxiao Date: Wed, 4 Sep 2019 11:30:06 +0800 Subject: [PATCH] refactor(score_fn): rename score functions --- gnes/component.py | 21 ++++++++++++++++++ gnes/composer/http.py | 15 +++++++++++++ gnes/encoder/numeric/pooling.py | 15 +++++++++++++ gnes/indexer/base.py | 6 +++--- gnes/indexer/doc/dict.py | 15 +++++++++++++ gnes/router/base.py | 4 ++-- gnes/score_fn/base.py | 38 ++++++++++++++++----------------- gnes/score_fn/chunk.py | 4 ++-- gnes/score_fn/doc.py | 4 ++-- gnes/score_fn/normalize.py | 10 ++++----- tests/test_score_fn.py | 8 +++---- 11 files changed, 103 insertions(+), 37 deletions(-) diff --git a/gnes/component.py b/gnes/component.py index a4645591..4be8c12d 100644 --- a/gnes/component.py +++ b/gnes/component.py @@ -1,7 +1,23 @@ +# Tencent is pleased to support the open source community by making GNES available. +# +# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .encoder import base as encoder_base from .indexer import base as indexer_base from .preprocessor import base as prep_base from .router import base as router_base +from .score_fn import base as score_base # Encoder BaseEncoder = encoder_base.BaseEncoder @@ -35,3 +51,8 @@ BaseTopkReduceRouter = router_base.BaseTopkReduceRouter BaseMapRouter = router_base.BaseMapRouter PipelineRouter = router_base.PipelineRouter + +# Score_Fn +BaseScoreFn = score_base.BaseScoreFn +ModifierScoreFn = score_base.ModifierScoreFn +CombinedScoreFn = score_base.CombinedScoreFn diff --git a/gnes/composer/http.py b/gnes/composer/http.py index b9a39c7e..c96bf794 100644 --- a/gnes/composer/http.py +++ b/gnes/composer/http.py @@ -1,3 +1,18 @@ +# Tencent is pleased to support the open source community by making GNES available. +# +# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from http.server import BaseHTTPRequestHandler, HTTPServer from urllib.parse import parse_qs diff --git a/gnes/encoder/numeric/pooling.py b/gnes/encoder/numeric/pooling.py index bcfe1ecf..69e7bfe7 100644 --- a/gnes/encoder/numeric/pooling.py +++ b/gnes/encoder/numeric/pooling.py @@ -1,3 +1,18 @@ +# Tencent is pleased to support the open source community by making GNES available. +# +# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os from typing import Tuple diff --git a/gnes/indexer/base.py b/gnes/indexer/base.py index c7c94ec1..c00760e5 100644 --- a/gnes/indexer/base.py +++ b/gnes/indexer/base.py @@ -18,13 +18,13 @@ from ..base import TrainableBase, CompositionalTrainableBase from ..proto import gnes_pb2, blob2array -from ..score_fn.base import get_unary_score, ModifierFn +from ..score_fn.base import get_unary_score, ModifierScoreFn class BaseIndexer(TrainableBase): def __init__(self, - normalize_fn: 'BaseScoreFn' = ModifierFn(), - score_fn: 'BaseScoreFn' = ModifierFn(), *args, **kwargs): + normalize_fn: 'BaseScoreFn' = ModifierScoreFn(), + score_fn: 'BaseScoreFn' = ModifierScoreFn(), *args, **kwargs): super().__init__(*args, **kwargs) self.normalize_fn = normalize_fn self.score_fn = score_fn diff --git a/gnes/indexer/doc/dict.py b/gnes/indexer/doc/dict.py index 1a80252d..dac0256a 100644 --- a/gnes/indexer/doc/dict.py +++ b/gnes/indexer/doc/dict.py @@ -1,3 +1,18 @@ +# Tencent is pleased to support the open source community by making GNES available. +# +# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import List from google.protobuf.json_format import MessageToJson, Parse diff --git a/gnes/router/base.py b/gnes/router/base.py index 6a9b9691..a9a84827 100644 --- a/gnes/router/base.py +++ b/gnes/router/base.py @@ -15,7 +15,7 @@ from collections import defaultdict from typing import List, Generator -from gnes.score_fn.base import ScoreCombinedFn +from gnes.score_fn.base import CombinedScoreFn from ..base import TrainableBase, CompositionalTrainableBase from ..proto import gnes_pb2, merge_routes @@ -65,7 +65,7 @@ def __init__(self, reduce_op: str = 'sum', descending: bool = True, *args, **kwa self.descending = descending def post_init(self): - self.reduce_op = ScoreCombinedFn(score_mode=self._reduce_op) + self.reduce_op = CombinedScoreFn(score_mode=self._reduce_op) def get_key(self, x: 'gnes_pb2.Response.QueryResponse.ScoredResult') -> str: raise NotImplementedError diff --git a/gnes/score_fn/base.py b/gnes/score_fn/base.py index 26a84905..424c519d 100644 --- a/gnes/score_fn/base.py +++ b/gnes/score_fn/base.py @@ -50,7 +50,7 @@ def new_score(self, *, operands: Sequence['gnes_pb2.Response.QueryResponse.Score **kwargs) -class ScoreCombinedFn(BaseScoreFn): +class CombinedScoreFn(BaseScoreFn): """Combine multiple scores into one score, defaults to 'multiply'""" def __init__(self, score_mode: str = 'multiply', *args, **kwargs): @@ -83,7 +83,7 @@ def __call__(self, *last_scores) -> 'gnes_pb2.Response.QueryResponse.ScoredResul score_mode=self.score_mode) -class ModifierFn(BaseScoreFn): +class ModifierScoreFn(BaseScoreFn): """Modifier to apply to the value score = modifier(factor * value) """ @@ -136,20 +136,20 @@ def __call__(self, class ScoreOps: - multiply = ScoreCombinedFn('multiply') - sum = ScoreCombinedFn('sum') - max = ScoreCombinedFn('max') - min = ScoreCombinedFn('min') - avg = ScoreCombinedFn('avg') - none = ModifierFn('none') - log = ModifierFn('log') - log1p = ModifierFn('log1p') - log2p = ModifierFn('log2p') - ln = ModifierFn('ln') - ln1p = ModifierFn('ln1p') - ln2p = ModifierFn('ln2p') - square = ModifierFn('square') - sqrt = ModifierFn('sqrt') - abs = ModifierFn('abs') - reciprocal = ModifierFn('reciprocal') - reciprocal1p = ModifierFn('reciprocal1p') + multiply = CombinedScoreFn('multiply') + sum = CombinedScoreFn('sum') + max = CombinedScoreFn('max') + min = CombinedScoreFn('min') + avg = CombinedScoreFn('avg') + none = ModifierScoreFn('none') + log = ModifierScoreFn('log') + log1p = ModifierScoreFn('log1p') + log2p = ModifierScoreFn('log2p') + ln = ModifierScoreFn('ln') + ln1p = ModifierScoreFn('ln1p') + ln2p = ModifierScoreFn('ln2p') + square = ModifierScoreFn('square') + sqrt = ModifierScoreFn('sqrt') + abs = ModifierScoreFn('abs') + reciprocal = ModifierScoreFn('reciprocal') + reciprocal1p = ModifierScoreFn('reciprocal1p') diff --git a/gnes/score_fn/chunk.py b/gnes/score_fn/chunk.py index 13992bde..de1d4c6f 100644 --- a/gnes/score_fn/chunk.py +++ b/gnes/score_fn/chunk.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import get_unary_score, ScoreCombinedFn +from .base import get_unary_score, CombinedScoreFn -class WeightedChunkScoreFn(ScoreCombinedFn): +class WeightedChunkScoreFn(CombinedScoreFn): """score = d_chunk.weight * relevance * q_chunk.weight""" def __call__(self, last_score: 'gnes_pb2.Response.QueryResponse.ScoredResult.Score', diff --git a/gnes/score_fn/doc.py b/gnes/score_fn/doc.py index d1915819..2dcffe69 100644 --- a/gnes/score_fn/doc.py +++ b/gnes/score_fn/doc.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import get_unary_score, ScoreCombinedFn +from .base import get_unary_score, CombinedScoreFn -class WeightedDocScoreFn(ScoreCombinedFn): +class WeightedDocScoreFn(CombinedScoreFn): def __call__(self, last_score: 'gnes_pb2.Response.QueryResponse.ScoredResult.Score', doc: 'gnes_pb2.Document', *args, **kwargs): d_weight = get_unary_score(value=doc.weight, diff --git a/gnes/score_fn/normalize.py b/gnes/score_fn/normalize.py index 449931ec..65a0a96b 100644 --- a/gnes/score_fn/normalize.py +++ b/gnes/score_fn/normalize.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import ModifierFn, ScoreOps as so +from .base import ModifierScoreFn, ScoreOps as so -class Normalizer1(ModifierFn): +class Normalizer1(ModifierScoreFn): """Do normalizing: score = 1 / (1 + sqrt(score))""" def __init__(self): @@ -26,7 +26,7 @@ def __call__(self, last_score, *args, **kwargs): return super().__call__(so.sqrt(last_score)) -class Normalizer2(ModifierFn): +class Normalizer2(ModifierScoreFn): """Do normalizing: score = 1 / (1 + score / num_dim)""" def __init__(self, num_dim: int): @@ -40,14 +40,14 @@ def __call__(self, last_score, *args, **kwargs): return super().__call__(so.sqrt(last_score)) -class Normalizer4(ModifierFn): +class Normalizer4(ModifierScoreFn): """Do normalizing: score = 1 - score / num_bytes """ def __init__(self, num_bytes: int): super().__init__(modifier='invert1p', factor=1.0 / num_bytes, factor_name='1/num_bytes') -class Normalizer5(ModifierFn): +class Normalizer5(ModifierScoreFn): """Do normalizing: score = 1 / (1 + sqrt(abs(score)))""" def __init__(self): diff --git a/tests/test_score_fn.py b/tests/test_score_fn.py index 99526a1d..f91ba3ea 100644 --- a/tests/test_score_fn.py +++ b/tests/test_score_fn.py @@ -3,7 +3,7 @@ from pprint import pprint from gnes.proto import gnes_pb2 -from gnes.score_fn.base import get_unary_score, ScoreCombinedFn, ModifierFn +from gnes.score_fn.base import get_unary_score, CombinedScoreFn, ModifierScoreFn from gnes.score_fn.chunk import WeightedChunkScoreFn from gnes.score_fn.normalize import Normalizer1, Normalizer2, Normalizer3, Normalizer4 @@ -18,11 +18,11 @@ def test_basic(self): def test_op(self): a = get_unary_score(0.5) b = get_unary_score(0.7) - sum_op = ScoreCombinedFn(score_mode='sum') + sum_op = CombinedScoreFn(score_mode='sum') c = sum_op(a, b) self.assertAlmostEqual(c.value, 1.2) - sq_op = ModifierFn(modifier='square') + sq_op = ModifierScoreFn(modifier='square') c = sum_op(a, sq_op(b)) self.assertAlmostEqual(c.value, 0.99) print(c) @@ -51,7 +51,7 @@ def test_normalizer(self): pprint(json.loads(b.explained)) self.assertEqual(b.value, 0.75) - norm_op = ModifierFn('none') + norm_op = ModifierScoreFn('none') b = norm_op(a) pprint(json.loads(b.explained)) self.assertEqual(b.value, 0.5)