From 00e6280d4b44689f6f7706c925612abe87fe9cd6 Mon Sep 17 00:00:00 2001 From: hanhxiao Date: Thu, 5 Sep 2019 11:51:39 +0800 Subject: [PATCH] feat(score_fn): use numpy for score fn --- gnes/score_fn/base.py | 39 +++++++++++++++++++-------------------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/gnes/score_fn/base.py b/gnes/score_fn/base.py index 424c519d..6a4fddbb 100644 --- a/gnes/score_fn/base.py +++ b/gnes/score_fn/base.py @@ -14,11 +14,10 @@ # limitations under the License. import json -from functools import reduce -from math import log, log1p, log10, sqrt -from operator import mul, add from typing import Sequence +import numpy as np + from ..base import TrainableBase from ..proto import gnes_pb2 @@ -66,11 +65,11 @@ def __init__(self, score_mode: str = 'multiply', *args, **kwargs): @property def supported_ops(self): return { - 'multiply': lambda v: reduce(mul, v), - 'sum': lambda v: reduce(add, v), - 'max': lambda v: reduce(max, v), - 'min': lambda v: reduce(min, v), - 'avg': lambda v: reduce(add, v) / len(v), + 'multiply': np.prod, + 'sum': np.sum, + 'max': np.max, + 'min': np.min, + 'avg': np.mean, } def post_init(self): @@ -78,7 +77,7 @@ def post_init(self): def __call__(self, *last_scores) -> 'gnes_pb2.Response.QueryResponse.ScoredResult.Score': return self.new_score( - value=self.op(s.value for s in last_scores), + value=self.op([s.value for s in last_scores]), operands=last_scores, score_mode=self.score_mode) @@ -102,17 +101,17 @@ def __init__(self, modifier: str = 'none', factor: float = 1.0, factor_name: str def supported_ops(self): return { 'none': lambda x: x, - 'log': log10, - 'log1p': lambda x: log(x + 1, 10), - 'log2p': lambda x: log(x + 2, 10), - 'ln': log, - 'ln1p': log1p, - 'ln2p': lambda x: log(x + 2), - 'square': lambda x: x * x, - 'sqrt': sqrt, - 'reciprocal': lambda x: 1 / x, - 'reciprocal1p': lambda x: 1 / (1 + x), - 'abs': abs, + 'log': np.log10, + 'log1p': lambda x: np.log10(x + 1), + 'log2p': lambda x: np.log10(x + 2), + 'ln': np.log, + 'ln1p': np.log1p, + 'ln2p': lambda x: np.log(x + 2), + 'square': np.square, + 'sqrt': np.sqrt, + 'reciprocal': np.reciprocal, + 'reciprocal1p': lambda x: np.reciprocal(1 + x), + 'abs': np.abs, 'invert': lambda x: - x, 'invert1p': lambda x: 1 - x }