From e9feaa6174ada9242b42a66564607d86671c6bab Mon Sep 17 00:00:00 2001 From: hanhxiao Date: Wed, 4 Sep 2019 11:16:34 +0800 Subject: [PATCH] refactor(score_fn): use post_init instead of property --- gnes/base/__init__.py | 2 +- gnes/score_fn/__init__.py | 32 ++++++++++++++++++++ gnes/score_fn/base.py | 62 +++++++++++++++++++++++++------------- gnes/score_fn/chunk.py | 15 +++++++++ gnes/score_fn/doc.py | 15 +++++++++ gnes/score_fn/normalize.py | 31 +++++++++++-------- 6 files changed, 123 insertions(+), 34 deletions(-) diff --git a/gnes/base/__init__.py b/gnes/base/__init__.py index e8681d9b..b930fe84 100644 --- a/gnes/base/__init__.py +++ b/gnes/base/__init__.py @@ -50,7 +50,7 @@ def _import(module_name, class_name): if class_name in cls2file: return getattr(importlib.import_module('gnes.%s.%s' % (module_name, cls2file[class_name])), class_name) - search_modules = ['encoder', 'indexer', 'preprocessor', 'router'] + search_modules = ['encoder', 'indexer', 'preprocessor', 'router', 'score_fn'] for m in search_modules: r = _import(m, name) diff --git a/gnes/score_fn/__init__.py b/gnes/score_fn/__init__.py index e69de29b..9bbcffbd 100644 --- a/gnes/score_fn/__init__.py +++ b/gnes/score_fn/__init__.py @@ -0,0 +1,32 @@ +# Tencent is pleased to support the open source community by making GNES available. +# +# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# A key-value map for Class to the (module)file it located in +from ..base import register_all_class + +_cls2file_map = { + 'BaseScoreFn': 'base', + 'ScoreCombinedFn': 'base', + 'ModifierFn': 'base', + 'WeightedChunkScoreFn': 'chunk', + 'WeightedDocScoreFn': 'doc', + 'Normalizer1': 'normalize', + 'Normalizer2': 'normalize', + 'Normalizer3': 'normalize', + 'Normalizer4': 'normalize', + 'Normalizer5': 'normalize', +} + +register_all_class(_cls2file_map, 'score_fn') diff --git a/gnes/score_fn/base.py b/gnes/score_fn/base.py index cc281208..26a84905 100644 --- a/gnes/score_fn/base.py +++ b/gnes/score_fn/base.py @@ -1,3 +1,18 @@ +# Tencent is pleased to support the open source community by making GNES available. +# +# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json from functools import reduce from math import log, log1p, log10, sqrt @@ -18,6 +33,8 @@ def get_unary_score(value: float, **kwargs): class BaseScoreFn(TrainableBase): + """Base score function. A score function must implement __call__ method""" + warn_unnamed = False def __call__(self, *args, **kwargs) -> 'gnes_pb2.Response.QueryResponse.ScoredResult.Score': @@ -32,9 +49,6 @@ def new_score(self, *, operands: Sequence['gnes_pb2.Response.QueryResponse.Score operands=[json.loads(s.explained) for s in operands], **kwargs) - def op(self, *args, **kwargs) -> float: - raise NotImplementedError - class ScoreCombinedFn(BaseScoreFn): """Combine multiple scores into one score, defaults to 'multiply'""" @@ -44,24 +58,29 @@ def __init__(self, score_mode: str = 'multiply', *args, **kwargs): :param score_mode: specifies how the computed scores are combined """ super().__init__(*args, **kwargs) - if score_mode not in {'multiply', 'sum', 'avg', 'max', 'min'}: - raise AttributeError('score_mode=%s is not supported!' % score_mode) + if score_mode not in self.supported_ops: + raise AttributeError( + 'score_mode=%s is not supported! must be one of %s' % (score_mode, self.supported_ops.keys())) self.score_mode = score_mode - def __call__(self, *last_scores) -> 'gnes_pb2.Response.QueryResponse.ScoredResult.Score': - return self.new_score( - value=self.op(s.value for s in last_scores), - operands=last_scores, - score_mode=self.score_mode) - - def op(self, *args, **kwargs) -> float: + @property + def supported_ops(self): return { 'multiply': lambda v: reduce(mul, v), 'sum': lambda v: reduce(add, v), 'max': lambda v: reduce(max, v), 'min': lambda v: reduce(min, v), 'avg': lambda v: reduce(add, v) / len(v), - }[self.score_mode](*args, **kwargs) + } + + def post_init(self): + self.op = self.supported_ops[self.score_mode] + + def __call__(self, *last_scores) -> 'gnes_pb2.Response.QueryResponse.ScoredResult.Score': + return self.new_score( + value=self.op(s.value for s in last_scores), + operands=last_scores, + score_mode=self.score_mode) class ModifierFn(BaseScoreFn): @@ -72,18 +91,15 @@ class ModifierFn(BaseScoreFn): def __init__(self, modifier: str = 'none', factor: float = 1.0, factor_name: str = 'GivenConstant', *args, **kwargs): super().__init__(*args, **kwargs) - if modifier not in {'none', 'log', 'log1p', 'log2p', 'ln', 'ln1p', 'ln2p', 'square', 'sqrt', 'reciprocal', - 'reciprocal1p', 'abs'}: - raise AttributeError('modifier=%s is not supported!' % modifier) + if modifier not in self.supported_ops: + raise AttributeError( + 'modifier=%s is not supported! must be one of %s' % (modifier, self.supported_ops.keys())) self._modifier = modifier self._factor = factor self._factor_name = factor_name @property - def factor(self): - return get_unary_score(value=self._factor, name=self._factor_name) - - def op(self, *args, **kwargs) -> float: + def supported_ops(self): return { 'none': lambda x: x, 'log': log10, @@ -99,7 +115,11 @@ def op(self, *args, **kwargs) -> float: 'abs': abs, 'invert': lambda x: - x, 'invert1p': lambda x: 1 - x - }[self._modifier](*args, **kwargs) + } + + def post_init(self): + self.factor = get_unary_score(value=self._factor, name=self._factor_name) + self.op = self.supported_ops[self._modifier] def __call__(self, last_score: 'gnes_pb2.Response.QueryResponse.ScoredResult.Score', diff --git a/gnes/score_fn/chunk.py b/gnes/score_fn/chunk.py index 744e98e4..13992bde 100644 --- a/gnes/score_fn/chunk.py +++ b/gnes/score_fn/chunk.py @@ -1,3 +1,18 @@ +# Tencent is pleased to support the open source community by making GNES available. +# +# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .base import get_unary_score, ScoreCombinedFn diff --git a/gnes/score_fn/doc.py b/gnes/score_fn/doc.py index 751aff38..d1915819 100644 --- a/gnes/score_fn/doc.py +++ b/gnes/score_fn/doc.py @@ -1,3 +1,18 @@ +# Tencent is pleased to support the open source community by making GNES available. +# +# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .base import get_unary_score, ScoreCombinedFn diff --git a/gnes/score_fn/normalize.py b/gnes/score_fn/normalize.py index e3f156f2..449931ec 100644 --- a/gnes/score_fn/normalize.py +++ b/gnes/score_fn/normalize.py @@ -1,3 +1,18 @@ +# Tencent is pleased to support the open source community by making GNES available. +# +# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .base import ModifierFn, ScoreOps as so @@ -5,8 +20,7 @@ class Normalizer1(ModifierFn): """Do normalizing: score = 1 / (1 + sqrt(score))""" def __init__(self): - super().__init__() - self._modifier = 'reciprocal1p' + super().__init__(modifier='reciprocal1p') def __call__(self, last_score, *args, **kwargs): return super().__call__(so.sqrt(last_score)) @@ -16,10 +30,7 @@ class Normalizer2(ModifierFn): """Do normalizing: score = 1 / (1 + score / num_dim)""" def __init__(self, num_dim: int): - super().__init__() - self._modifier = 'reciprocal1p' - self._factor = 1.0 / num_dim - self._factor_name = '1/num_dim' + super().__init__(modifier='reciprocal1p', factor=1.0 / num_dim, factor_name='1/num_dim') class Normalizer3(Normalizer2): @@ -33,18 +44,14 @@ class Normalizer4(ModifierFn): """Do normalizing: score = 1 - score / num_bytes """ def __init__(self, num_bytes: int): - super().__init__() - self._modifier = 'invert1p' - self._factor = 1.0 / num_bytes - self._factor_name = '1/num_bytes' + super().__init__(modifier='invert1p', factor=1.0 / num_bytes, factor_name='1/num_bytes') class Normalizer5(ModifierFn): """Do normalizing: score = 1 / (1 + sqrt(abs(score)))""" def __init__(self): - super().__init__() - self._modifier = 'reciprocal1p' + super().__init__(modifier='reciprocal1p') def __call__(self, last_score, *args, **kwargs): return super().__call__(so.sqrt(so.abs(last_score)))