Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
feat(flow): add client to flow
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Oct 9, 2019
1 parent 990593b commit 1739c7b
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 19 deletions.
20 changes: 14 additions & 6 deletions gnes/cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ def resolve_py_path(path):
return path


def random_port(port):
if not port or int(port) <= 0:
import random
min_port, max_port = 49152, 65536
return random.randrange(min_port, max_port)
else:
return int(port)


def resolve_yaml_path(path):
# priority, filepath > classname > default
import os
Expand Down Expand Up @@ -139,14 +148,14 @@ def set_composer_flask_parser(parser=None):

def set_service_parser(parser=None):
from ..service.base import SocketType, BaseService, ParallelType
import random

import os
if not parser:
parser = set_base_parser()
min_port, max_port = 49152, 65536
parser.add_argument('--port_in', type=int, default=random.randrange(min_port, max_port),

parser.add_argument('--port_in', type=int, default=random_port(-1),
help='port for input data, default a random port between [49152, 65536]')
parser.add_argument('--port_out', type=int, default=random.randrange(min_port, max_port),
parser.add_argument('--port_out', type=int, default=random_port(-1),
help='port for output data, default a random port between [49152, 65536]')
parser.add_argument('--host_in', type=str, default=BaseService.default_host,
help='host address for input')
Expand All @@ -158,8 +167,7 @@ def set_service_parser(parser=None):
parser.add_argument('--socket_out', type=SocketType.from_string, choices=list(SocketType),
default=SocketType.PUSH_BIND,
help='socket type for output port')
parser.add_argument('--port_ctrl', type=int,
default=int(os.environ.get('GNES_CONTROL_PORT', random.randrange(min_port, max_port))),
parser.add_argument('--port_ctrl', type=int, default=os.environ.get('GNES_CONTROL_PORT', random_port(-1)),
help='port for controlling the service, default a random port between [49152, 65536]')
parser.add_argument('--timeout', type=int, default=-1,
help='timeout (ms) of all communication, -1 for waiting forever')
Expand Down
45 changes: 45 additions & 0 deletions gnes/encoder/text/char.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Tencent is pleased to support the open source community by making GNES available.
#
# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import List

import numpy as np

from ..base import BaseTextEncoder
from ...helper import batching, as_numpy_array


class CharEmbeddingEncoder(BaseTextEncoder):
"""A random character embedding model. Only useful for testing"""
is_trained = True

def __init__(self, dim: int = 128, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dim = dim
self.offset = 32
self.unknown_idx = 96
# in total 96 printable chars and 2 special chars = 98
self._char_embedding = np.random.random([97, dim])

@batching
@as_numpy_array
def encode(self, text: List[str], *args, **kwargs) -> List[np.ndarray]:
# tokenize text
sent_embed = []
for sent in text:
ids = [ord(c) - 32 if 32 <= ord(c) <= 127 else self.unknown_idx for c in sent]
sent_embed.append(np.mean(self._char_embedding[ids], axis=0))
return sent_embed
26 changes: 15 additions & 11 deletions gnes/flow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def arg_wrapper(self, *args, **kwargs):
class Flow:
_supported_orch = {'swarm', 'k8s'}
_service2parser = {
Service.Encoder: set_encoder_parser(),
Service.Router: set_router_parser(),
Service.Indexer: set_indexer_parser(),
Service.Frontend: set_frontend_parser(),
Service.Preprocessor: set_preprocessor_parser(),
Service.Encoder: set_encoder_parser,
Service.Router: set_router_parser,
Service.Indexer: set_indexer_parser,
Service.Frontend: set_frontend_parser,
Service.Preprocessor: set_preprocessor_parser,
}
_service2builder = {
Service.Encoder: lambda x: ServiceManager(EncoderService, x),
Expand Down Expand Up @@ -137,14 +137,13 @@ def query(self, bytes_gen: Generator[bytes, None, None] = None, **kwargs):

@_build_level(BuildLevel.RUNTIME)
def _call_client(self, bytes_gen: Generator[bytes, None, None] = None, **kwargs):
args, p_args = self._get_parsed_args(set_client_cli_parser(), kwargs)
args, p_args = self._get_parsed_args(set_client_cli_parser, kwargs)
p_args.grpc_port = self._service_nodes[self._frontend]['parsed_args'].grpc_port
p_args.grpc_host = self._service_nodes[self._frontend]['parsed_args'].grpc_host
c = CLIClient(p_args, start_at_init=False)
if bytes_gen:
c.bytes_generator = bytes_gen
with c:
pass
c.start()

def add_frontend(self, *args, **kwargs) -> 'Flow':
"""Add a frontend to the current flow, a shortcut of add(Service.Frontend)
Expand Down Expand Up @@ -222,6 +221,10 @@ def add(self, service: 'Service',

self._last_add_service = name

# graph is now changed so we need to
# reset the build level to the lowest
self._build_level = Flow.BuildLevel.EMPTY

return self

def _parse_service_endpoints(self, cur_service_name, service_endpoint, connect_to_last_service=False):
Expand Down Expand Up @@ -263,7 +266,7 @@ def _get_parsed_args(self, service_arg_parser, kwargs):
else:
args.extend(['--%s' % k, str(v)])
try:
p_args, unknown_args = service_arg_parser.parse_known_args(args)
p_args, unknown_args = service_arg_parser().parse_known_args(args)
if unknown_args:
self.logger.warning('not sure what these arguments are: %s' % unknown_args)
except SystemExit:
Expand Down Expand Up @@ -329,7 +332,7 @@ def _build_graph(self) -> 'Flow':
s_pargs.socket_out = SocketType.PUSH_CONNECT
e_pargs.socket_in = SocketType.PULL_BIND
else:
e_pargs.socket_in = s_pargs.socket_out.complement
e_pargs.socket_in = s_pargs.socket_out.paired

if s_pargs.socket_out.is_bind:
s_pargs.host_out = BaseService.default_host
Expand Down Expand Up @@ -386,4 +389,5 @@ def __enter__(self):

def __exit__(self, exc_type, exc_val, exc_tb):
if hasattr(self, '_service_stack'):
self._service_stack.close()
self._service_stack.close()
self._build_level = Flow.BuildLevel.GRAPH
15 changes: 13 additions & 2 deletions gnes/service/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,19 @@ def is_bind(self):
return self.value % 2 == 0

@property
def complement(self):
return SocketType(self.value + (1 if (self.value % 2 == 0) else -1))
def paired(self):
return {
SocketType.PULL_BIND: SocketType.PUSH_CONNECT,
SocketType.PULL_CONNECT: SocketType.PUSH_BIND,
SocketType.SUB_BIND: SocketType.PUB_CONNECT,
SocketType.SUB_CONNECT: SocketType.PUB_BIND,
SocketType.PAIR_BIND: SocketType.PAIR_CONNECT,
SocketType.PUSH_CONNECT: SocketType.PULL_BIND,
SocketType.PUSH_BIND: SocketType.PULL_CONNECT,
SocketType.PUB_CONNECT: SocketType.SUB_BIND,
SocketType.PUB_BIND: SocketType.SUB_CONNECT,
SocketType.PAIR_CONNECT: SocketType.PAIR_BIND
}[self]


class BlockMessage(Exception):
Expand Down
50 changes: 50 additions & 0 deletions tests/test_gnes_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,28 @@ class TestGNESFlow(unittest.TestCase):
def setUp(self):
self.dirname = os.path.dirname(__file__)
self.test_file = os.path.join(self.dirname, 'sonnets_small.txt')
self.yamldir = os.path.join(self.dirname, 'yaml')
self.index_args = set_client_cli_parser().parse_args([
'--mode', 'index',
'--txt_file', self.test_file,
'--batch_size', '4'
])
os.unsetenv('http_proxy')
os.unsetenv('https_proxy')
self.test_dir = os.path.join(self.dirname, 'test_flow')
self.indexer1_bin = os.path.join(self.test_dir, 'my_faiss_indexer.bin')
self.indexer2_bin = os.path.join(self.test_dir, 'my_fulltext_indexer.bin')
self.encoder_bin = os.path.join(self.test_dir, 'my_transformer.bin')

os.mkdir(self.test_dir)

os.environ['TEST_WORKDIR'] = self.test_dir

def tearDown(self):
for k in [self.indexer1_bin, self.indexer2_bin, self.encoder_bin]:
if os.path.exists(k):
os.remove(k)
os.rmdir(self.test_dir)

def test_flow1(self):
f = (Flow(check_version=False, route_table=True)
Expand Down Expand Up @@ -80,3 +95,38 @@ def test_flow5(self):
.build(backend=None))
print(f._service_edges)
print(f.to_mermaid())

def _test_index_flow(self):
for k in [self.indexer1_bin, self.indexer2_bin, self.encoder_bin]:
self.assertFalse(os.path.exists(k))

flow = (Flow(check_version=False, route_table=True)
.add(gfs.Preprocessor, name='prep', yaml_path='SentSplitPreprocessor')
.add(gfs.Encoder, yaml_path='yaml/flow-transformer.yml')
.add(gfs.Indexer, name='vec_idx', yaml_path='yaml/flow-vecindex.yml')
.add(gfs.Indexer, name='doc_idx', yaml_path='yaml/flow-dictindex.yml',
service_in='prep')
.add(gfs.Router, name='sync_barrier', yaml_path='BaseReduceRouter',
num_part=2, service_in=['vec_idx', 'doc_idx']))

with flow.build(backend='thread') as f:
f.index(txt_file=self.test_file, batch_size=4)

for k in [self.indexer1_bin, self.indexer2_bin, self.encoder_bin]:
self.assertTrue(os.path.exists(k))

def _test_query_flow(self):
flow = (Flow(check_version=False, route_table=True)
.add(gfs.Preprocessor, name='prep', yaml_path='SentSplitPreprocessor')
.add(gfs.Encoder, yaml_path='yaml/flow-transformer.yml')
.add(gfs.Indexer, name='vec_idx', yaml_path='yaml/flow-vecindex.yml')
.add(gfs.Router, name='scorer', yaml_path='yaml/flow-score.yml')
.add(gfs.Indexer, name='doc_idx', yaml_path='yaml/flow-dictindex.yml'))

with flow.build(backend='thread') as f:
f.query(txt_file=self.test_file)

def test_index_query_flow(self):
self._test_index_flow()
print('indexing finished')
self._test_query_flow()
4 changes: 4 additions & 0 deletions tests/yaml/flow-dictindex.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
!DictIndexer
gnes_config:
name: my_fulltext_indexer # a customized name
work_dir: $TEST_WORKDIR
3 changes: 3 additions & 0 deletions tests/yaml/flow-score.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
!Chunk2DocTopkReducer
parameters:
reduce_op: avg
15 changes: 15 additions & 0 deletions tests/yaml/flow-transformer.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
!PipelineEncoder
components:
- !PyTorchTransformers
parameters:
model_dir: $TORCH_TRANSFORMERS_MODEL
model_name: bert-base-uncased
- !PoolingEncoder
parameters:
pooling_strategy: REDUCE_MEAN
backend: torch
gnes_config:
name: my_transformer # a customized name
is_trained: true # indicate the model has been trained
work_dir: $TEST_WORKDIR
batch_size: 128
13 changes: 13 additions & 0 deletions tests/yaml/flow-vecindex.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
!NumpyIndexer # just for testing
gnes_config:
name: my_faiss_indexer # a customized name
work_dir: $TEST_WORKDIR

#!FaissIndexer
#parameters:
# num_dim: -1 # delay the spec on num_dim on first add
# index_key: HNSW32
# data_path: /workspace/idx.faiss
#gnes_config:
# name: my_faiss_indexer # a customized name
# work_dir: /workspace

0 comments on commit 1739c7b

Please sign in to comment.