Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
Merge pull request #302 from gnes-ai/rm-benchmark-client
Browse files Browse the repository at this point in the history
refactor(client): remove benchmark client
  • Loading branch information
mergify[bot] authored Sep 29, 2019
2 parents a087626 + e588c94 commit cb4e46a
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 96 deletions.
7 changes: 0 additions & 7 deletions gnes/cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ def client(args):
return _client_http(args)
elif args.client == 'cli':
return _client_cli(args)
elif args.client == 'benchmark':
return _client_bm(args)
else:
raise ValueError('gnes client must follow with a client type from {http, cli, benchmark...}\n'
'see "gnes client --help" for details')
Expand Down Expand Up @@ -94,11 +92,6 @@ def _client_cli(args):
CLIClient(args)


def _client_bm(args):
from ..client.benchmark import BenchmarkClient
BenchmarkClient(args)


def compose(args):
from ..composer.base import YamlComposer
from ..composer.flask import YamlComposerFlask
Expand Down
17 changes: 0 additions & 17 deletions gnes/cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,21 +365,6 @@ def set_client_cli_parser(parser=None):
return parser


def set_client_benchmark_parser(parser=None):
if not parser:
parser = set_base_parser()
_set_grpc_parser(parser)
parser.add_argument('--batch_size', type=int, default=64,
help='the size of the request to split')
parser.add_argument('--request_length', type=int,
default=1024,
help='binary string length of each request')
parser.add_argument('--num_requests', type=int,
default=128,
help='number of total requests')
return parser


def set_client_http_parser(parser=None):
if not parser:
parser = set_base_parser()
Expand Down Expand Up @@ -422,8 +407,6 @@ def get_main_parser():
set_client_http_parser(
spp.add_parser('http', help='start a client that allows HTTP requests as input', formatter_class=adf))
set_client_cli_parser(spp.add_parser('cli', help='start a client that allows stdin as input', formatter_class=adf))
set_client_benchmark_parser(
spp.add_parser('benchmark', help='start a client for benchmark and unittest', formatter_class=adf))

# others
set_composer_flask_parser(
Expand Down
52 changes: 0 additions & 52 deletions gnes/client/benchmark.py

This file was deleted.

1 change: 1 addition & 0 deletions gnes/encoder/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def encode(self, text: List[str], *args, **kwargs) -> Union[Tuple, np.ndarray]:


class BaseNumericEncoder(BaseEncoder):
"""Note that all NumericEncoder can not be used as the first encoder of the pipeline"""

def encode(self, data: np.ndarray, *args, **kwargs) -> np.ndarray:
pass
Expand Down
16 changes: 11 additions & 5 deletions gnes/encoder/text/flair.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.


from typing import List
from typing import List, Tuple

import numpy as np

Expand All @@ -25,16 +25,22 @@
class FlairEncoder(BaseTextEncoder):
is_trained = True

def __init__(self, pooling_strategy: str = 'mean', *args, **kwargs):
def __init__(self,
word_embedding: str = 'glove',
flair_embeddings: Tuple[str] = ('news-forward', 'news-backward'),
pooling_strategy: str = 'mean', *args, **kwargs):
super().__init__(*args, **kwargs)

self.word_embedding = word_embedding
self.flair_embeddings = flair_embeddings
self.pooling_strategy = pooling_strategy

def post_init(self):
from flair.embeddings import DocumentPoolEmbeddings, WordEmbeddings, FlairEmbeddings
self._flair = DocumentPoolEmbeddings(
[WordEmbeddings('glove'),
FlairEmbeddings('news-forward'),
FlairEmbeddings('news-backward')],
[WordEmbeddings(self.word_embedding),
FlairEmbeddings(self.flair_embeddings[0]),
FlairEmbeddings(self.flair_embeddings[1])],
pooling=self.pooling_strategy)

@batching
Expand Down
11 changes: 5 additions & 6 deletions tests/test_flair_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@ def setUp(self):
if line:
self.test_str.append(line)

self.flair_encoder = FlairEncoder(
model_name=os.environ.get('FLAIR_CI_MODEL'),
pooling_strategy="REDUCE_MEAN")
self.flair_encoder = FlairEncoder(model_name=os.environ.get('FLAIR_CI_MODEL'))

@unittest.SkipTest
def test_encoding(self):
vec = self.flair_encoder.encode(self.test_str)
self.assertEqual(vec.shape[0], len(self.test_str))
self.assertEqual(vec.shape[1], 512)
vec = self.flair_encoder.encode(self.test_str[:2])
print(vec.shape)
self.assertEqual(vec.shape[0], 2)
self.assertEqual(vec.shape[1], 4196)

@unittest.SkipTest
def test_dump_load(self):
Expand Down
10 changes: 1 addition & 9 deletions tests/test_stream_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

import grpc

from gnes.cli.parser import set_frontend_parser, set_router_parser, set_client_benchmark_parser
from gnes.client.benchmark import BenchmarkClient
from gnes.cli.parser import set_frontend_parser, set_router_parser
from gnes.helper import TimeContext
from gnes.proto import RequestGenerator, gnes_pb2_grpc
from gnes.service.base import SocketType, MessageHandler, BaseService as BS
Expand Down Expand Up @@ -55,13 +54,6 @@ def test_bm_frontend(self):
'--yaml_path', 'BaseRouter'
])

b_args = set_client_benchmark_parser().parse_args([
'--num_requests', '10',
'--request_length', '65536'
])
with RouterService(p_args), FrontendService(args):
BenchmarkClient(b_args)

def test_grpc_frontend(self):
args = set_frontend_parser().parse_args([
'--grpc_host', '127.0.0.1',
Expand Down

0 comments on commit cb4e46a

Please sign in to comment.