From 1739c7b6c248e5961d67311331f723af7d9aa479 Mon Sep 17 00:00:00 2001 From: hanhxiao Date: Wed, 9 Oct 2019 13:23:56 +0800 Subject: [PATCH] feat(flow): add client to flow --- gnes/cli/parser.py | 20 +++++++++---- gnes/encoder/text/char.py | 45 +++++++++++++++++++++++++++++ gnes/flow/__init__.py | 26 +++++++++-------- gnes/service/base.py | 15 ++++++++-- tests/test_gnes_flow.py | 50 +++++++++++++++++++++++++++++++++ tests/yaml/flow-dictindex.yml | 4 +++ tests/yaml/flow-score.yml | 3 ++ tests/yaml/flow-transformer.yml | 15 ++++++++++ tests/yaml/flow-vecindex.yml | 13 +++++++++ 9 files changed, 172 insertions(+), 19 deletions(-) create mode 100644 gnes/encoder/text/char.py create mode 100644 tests/yaml/flow-dictindex.yml create mode 100644 tests/yaml/flow-score.yml create mode 100644 tests/yaml/flow-transformer.yml create mode 100644 tests/yaml/flow-vecindex.yml diff --git a/gnes/cli/parser.py b/gnes/cli/parser.py index 11604305..65ee0b28 100644 --- a/gnes/cli/parser.py +++ b/gnes/cli/parser.py @@ -47,6 +47,15 @@ def resolve_py_path(path): return path +def random_port(port): + if not port or int(port) <= 0: + import random + min_port, max_port = 49152, 65536 + return random.randrange(min_port, max_port) + else: + return int(port) + + def resolve_yaml_path(path): # priority, filepath > classname > default import os @@ -139,14 +148,14 @@ def set_composer_flask_parser(parser=None): def set_service_parser(parser=None): from ..service.base import SocketType, BaseService, ParallelType - import random + import os if not parser: parser = set_base_parser() - min_port, max_port = 49152, 65536 - parser.add_argument('--port_in', type=int, default=random.randrange(min_port, max_port), + + parser.add_argument('--port_in', type=int, default=random_port(-1), help='port for input data, default a random port between [49152, 65536]') - parser.add_argument('--port_out', type=int, default=random.randrange(min_port, max_port), + parser.add_argument('--port_out', type=int, default=random_port(-1), help='port for output data, default a random port between [49152, 65536]') parser.add_argument('--host_in', type=str, default=BaseService.default_host, help='host address for input') @@ -158,8 +167,7 @@ def set_service_parser(parser=None): parser.add_argument('--socket_out', type=SocketType.from_string, choices=list(SocketType), default=SocketType.PUSH_BIND, help='socket type for output port') - parser.add_argument('--port_ctrl', type=int, - default=int(os.environ.get('GNES_CONTROL_PORT', random.randrange(min_port, max_port))), + parser.add_argument('--port_ctrl', type=int, default=os.environ.get('GNES_CONTROL_PORT', random_port(-1)), help='port for controlling the service, default a random port between [49152, 65536]') parser.add_argument('--timeout', type=int, default=-1, help='timeout (ms) of all communication, -1 for waiting forever') diff --git a/gnes/encoder/text/char.py b/gnes/encoder/text/char.py new file mode 100644 index 00000000..5db73772 --- /dev/null +++ b/gnes/encoder/text/char.py @@ -0,0 +1,45 @@ +# Tencent is pleased to support the open source community by making GNES available. +# +# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List + +import numpy as np + +from ..base import BaseTextEncoder +from ...helper import batching, as_numpy_array + + +class CharEmbeddingEncoder(BaseTextEncoder): + """A random character embedding model. Only useful for testing""" + is_trained = True + + def __init__(self, dim: int = 128, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dim = dim + self.offset = 32 + self.unknown_idx = 96 + # in total 96 printable chars and 2 special chars = 98 + self._char_embedding = np.random.random([97, dim]) + + @batching + @as_numpy_array + def encode(self, text: List[str], *args, **kwargs) -> List[np.ndarray]: + # tokenize text + sent_embed = [] + for sent in text: + ids = [ord(c) - 32 if 32 <= ord(c) <= 127 else self.unknown_idx for c in sent] + sent_embed.append(np.mean(self._char_embedding[ids], axis=0)) + return sent_embed diff --git a/gnes/flow/__init__.py b/gnes/flow/__init__.py index 5c15f5ce..95c2de55 100644 --- a/gnes/flow/__init__.py +++ b/gnes/flow/__init__.py @@ -45,11 +45,11 @@ def arg_wrapper(self, *args, **kwargs): class Flow: _supported_orch = {'swarm', 'k8s'} _service2parser = { - Service.Encoder: set_encoder_parser(), - Service.Router: set_router_parser(), - Service.Indexer: set_indexer_parser(), - Service.Frontend: set_frontend_parser(), - Service.Preprocessor: set_preprocessor_parser(), + Service.Encoder: set_encoder_parser, + Service.Router: set_router_parser, + Service.Indexer: set_indexer_parser, + Service.Frontend: set_frontend_parser, + Service.Preprocessor: set_preprocessor_parser, } _service2builder = { Service.Encoder: lambda x: ServiceManager(EncoderService, x), @@ -137,14 +137,13 @@ def query(self, bytes_gen: Generator[bytes, None, None] = None, **kwargs): @_build_level(BuildLevel.RUNTIME) def _call_client(self, bytes_gen: Generator[bytes, None, None] = None, **kwargs): - args, p_args = self._get_parsed_args(set_client_cli_parser(), kwargs) + args, p_args = self._get_parsed_args(set_client_cli_parser, kwargs) p_args.grpc_port = self._service_nodes[self._frontend]['parsed_args'].grpc_port p_args.grpc_host = self._service_nodes[self._frontend]['parsed_args'].grpc_host c = CLIClient(p_args, start_at_init=False) if bytes_gen: c.bytes_generator = bytes_gen - with c: - pass + c.start() def add_frontend(self, *args, **kwargs) -> 'Flow': """Add a frontend to the current flow, a shortcut of add(Service.Frontend) @@ -222,6 +221,10 @@ def add(self, service: 'Service', self._last_add_service = name + # graph is now changed so we need to + # reset the build level to the lowest + self._build_level = Flow.BuildLevel.EMPTY + return self def _parse_service_endpoints(self, cur_service_name, service_endpoint, connect_to_last_service=False): @@ -263,7 +266,7 @@ def _get_parsed_args(self, service_arg_parser, kwargs): else: args.extend(['--%s' % k, str(v)]) try: - p_args, unknown_args = service_arg_parser.parse_known_args(args) + p_args, unknown_args = service_arg_parser().parse_known_args(args) if unknown_args: self.logger.warning('not sure what these arguments are: %s' % unknown_args) except SystemExit: @@ -329,7 +332,7 @@ def _build_graph(self) -> 'Flow': s_pargs.socket_out = SocketType.PUSH_CONNECT e_pargs.socket_in = SocketType.PULL_BIND else: - e_pargs.socket_in = s_pargs.socket_out.complement + e_pargs.socket_in = s_pargs.socket_out.paired if s_pargs.socket_out.is_bind: s_pargs.host_out = BaseService.default_host @@ -386,4 +389,5 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): if hasattr(self, '_service_stack'): - self._service_stack.close() \ No newline at end of file + self._service_stack.close() + self._build_level = Flow.BuildLevel.GRAPH diff --git a/gnes/service/base.py b/gnes/service/base.py index 82e81206..cb5361a7 100644 --- a/gnes/service/base.py +++ b/gnes/service/base.py @@ -82,8 +82,19 @@ def is_bind(self): return self.value % 2 == 0 @property - def complement(self): - return SocketType(self.value + (1 if (self.value % 2 == 0) else -1)) + def paired(self): + return { + SocketType.PULL_BIND: SocketType.PUSH_CONNECT, + SocketType.PULL_CONNECT: SocketType.PUSH_BIND, + SocketType.SUB_BIND: SocketType.PUB_CONNECT, + SocketType.SUB_CONNECT: SocketType.PUB_BIND, + SocketType.PAIR_BIND: SocketType.PAIR_CONNECT, + SocketType.PUSH_CONNECT: SocketType.PULL_BIND, + SocketType.PUSH_BIND: SocketType.PULL_CONNECT, + SocketType.PUB_CONNECT: SocketType.SUB_BIND, + SocketType.PUB_BIND: SocketType.SUB_CONNECT, + SocketType.PAIR_CONNECT: SocketType.PAIR_BIND + }[self] class BlockMessage(Exception): diff --git a/tests/test_gnes_flow.py b/tests/test_gnes_flow.py index ab797fbf..49ddd4ff 100644 --- a/tests/test_gnes_flow.py +++ b/tests/test_gnes_flow.py @@ -10,6 +10,7 @@ class TestGNESFlow(unittest.TestCase): def setUp(self): self.dirname = os.path.dirname(__file__) self.test_file = os.path.join(self.dirname, 'sonnets_small.txt') + self.yamldir = os.path.join(self.dirname, 'yaml') self.index_args = set_client_cli_parser().parse_args([ '--mode', 'index', '--txt_file', self.test_file, @@ -17,6 +18,20 @@ def setUp(self): ]) os.unsetenv('http_proxy') os.unsetenv('https_proxy') + self.test_dir = os.path.join(self.dirname, 'test_flow') + self.indexer1_bin = os.path.join(self.test_dir, 'my_faiss_indexer.bin') + self.indexer2_bin = os.path.join(self.test_dir, 'my_fulltext_indexer.bin') + self.encoder_bin = os.path.join(self.test_dir, 'my_transformer.bin') + + os.mkdir(self.test_dir) + + os.environ['TEST_WORKDIR'] = self.test_dir + + def tearDown(self): + for k in [self.indexer1_bin, self.indexer2_bin, self.encoder_bin]: + if os.path.exists(k): + os.remove(k) + os.rmdir(self.test_dir) def test_flow1(self): f = (Flow(check_version=False, route_table=True) @@ -80,3 +95,38 @@ def test_flow5(self): .build(backend=None)) print(f._service_edges) print(f.to_mermaid()) + + def _test_index_flow(self): + for k in [self.indexer1_bin, self.indexer2_bin, self.encoder_bin]: + self.assertFalse(os.path.exists(k)) + + flow = (Flow(check_version=False, route_table=True) + .add(gfs.Preprocessor, name='prep', yaml_path='SentSplitPreprocessor') + .add(gfs.Encoder, yaml_path='yaml/flow-transformer.yml') + .add(gfs.Indexer, name='vec_idx', yaml_path='yaml/flow-vecindex.yml') + .add(gfs.Indexer, name='doc_idx', yaml_path='yaml/flow-dictindex.yml', + service_in='prep') + .add(gfs.Router, name='sync_barrier', yaml_path='BaseReduceRouter', + num_part=2, service_in=['vec_idx', 'doc_idx'])) + + with flow.build(backend='thread') as f: + f.index(txt_file=self.test_file, batch_size=4) + + for k in [self.indexer1_bin, self.indexer2_bin, self.encoder_bin]: + self.assertTrue(os.path.exists(k)) + + def _test_query_flow(self): + flow = (Flow(check_version=False, route_table=True) + .add(gfs.Preprocessor, name='prep', yaml_path='SentSplitPreprocessor') + .add(gfs.Encoder, yaml_path='yaml/flow-transformer.yml') + .add(gfs.Indexer, name='vec_idx', yaml_path='yaml/flow-vecindex.yml') + .add(gfs.Router, name='scorer', yaml_path='yaml/flow-score.yml') + .add(gfs.Indexer, name='doc_idx', yaml_path='yaml/flow-dictindex.yml')) + + with flow.build(backend='thread') as f: + f.query(txt_file=self.test_file) + + def test_index_query_flow(self): + self._test_index_flow() + print('indexing finished') + self._test_query_flow() diff --git a/tests/yaml/flow-dictindex.yml b/tests/yaml/flow-dictindex.yml new file mode 100644 index 00000000..943be51c --- /dev/null +++ b/tests/yaml/flow-dictindex.yml @@ -0,0 +1,4 @@ +!DictIndexer +gnes_config: + name: my_fulltext_indexer # a customized name + work_dir: $TEST_WORKDIR \ No newline at end of file diff --git a/tests/yaml/flow-score.yml b/tests/yaml/flow-score.yml new file mode 100644 index 00000000..9a982db1 --- /dev/null +++ b/tests/yaml/flow-score.yml @@ -0,0 +1,3 @@ +!Chunk2DocTopkReducer +parameters: + reduce_op: avg \ No newline at end of file diff --git a/tests/yaml/flow-transformer.yml b/tests/yaml/flow-transformer.yml new file mode 100644 index 00000000..f32b1d56 --- /dev/null +++ b/tests/yaml/flow-transformer.yml @@ -0,0 +1,15 @@ +!PipelineEncoder +components: + - !PyTorchTransformers + parameters: + model_dir: $TORCH_TRANSFORMERS_MODEL + model_name: bert-base-uncased + - !PoolingEncoder + parameters: + pooling_strategy: REDUCE_MEAN + backend: torch +gnes_config: + name: my_transformer # a customized name + is_trained: true # indicate the model has been trained + work_dir: $TEST_WORKDIR + batch_size: 128 \ No newline at end of file diff --git a/tests/yaml/flow-vecindex.yml b/tests/yaml/flow-vecindex.yml new file mode 100644 index 00000000..f2183ac3 --- /dev/null +++ b/tests/yaml/flow-vecindex.yml @@ -0,0 +1,13 @@ +!NumpyIndexer # just for testing +gnes_config: + name: my_faiss_indexer # a customized name + work_dir: $TEST_WORKDIR + +#!FaissIndexer +#parameters: +# num_dim: -1 # delay the spec on num_dim on first add +# index_key: HNSW32 +# data_path: /workspace/idx.faiss +#gnes_config: +# name: my_faiss_indexer # a customized name +# work_dir: /workspace \ No newline at end of file