Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
feat(score_fn): use numpy for score fn
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Sep 5, 2019
1 parent a12f8ef commit 00e6280
Showing 1 changed file with 19 additions and 20 deletions.
39 changes: 19 additions & 20 deletions gnes/score_fn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
# limitations under the License.

import json
from functools import reduce
from math import log, log1p, log10, sqrt
from operator import mul, add
from typing import Sequence

import numpy as np

from ..base import TrainableBase
from ..proto import gnes_pb2

Expand Down Expand Up @@ -66,19 +65,19 @@ def __init__(self, score_mode: str = 'multiply', *args, **kwargs):
@property
def supported_ops(self):
return {
'multiply': lambda v: reduce(mul, v),
'sum': lambda v: reduce(add, v),
'max': lambda v: reduce(max, v),
'min': lambda v: reduce(min, v),
'avg': lambda v: reduce(add, v) / len(v),
'multiply': np.prod,
'sum': np.sum,
'max': np.max,
'min': np.min,
'avg': np.mean,
}

def post_init(self):
self.op = self.supported_ops[self.score_mode]

def __call__(self, *last_scores) -> 'gnes_pb2.Response.QueryResponse.ScoredResult.Score':
return self.new_score(
value=self.op(s.value for s in last_scores),
value=self.op([s.value for s in last_scores]),
operands=last_scores,
score_mode=self.score_mode)

Expand All @@ -102,17 +101,17 @@ def __init__(self, modifier: str = 'none', factor: float = 1.0, factor_name: str
def supported_ops(self):
return {
'none': lambda x: x,
'log': log10,
'log1p': lambda x: log(x + 1, 10),
'log2p': lambda x: log(x + 2, 10),
'ln': log,
'ln1p': log1p,
'ln2p': lambda x: log(x + 2),
'square': lambda x: x * x,
'sqrt': sqrt,
'reciprocal': lambda x: 1 / x,
'reciprocal1p': lambda x: 1 / (1 + x),
'abs': abs,
'log': np.log10,
'log1p': lambda x: np.log10(x + 1),
'log2p': lambda x: np.log10(x + 2),
'ln': np.log,
'ln1p': np.log1p,
'ln2p': lambda x: np.log(x + 2),
'square': np.square,
'sqrt': np.sqrt,
'reciprocal': np.reciprocal,
'reciprocal1p': lambda x: np.reciprocal(1 + x),
'abs': np.abs,
'invert': lambda x: - x,
'invert1p': lambda x: 1 - x
}
Expand Down

0 comments on commit 00e6280

Please sign in to comment.