Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

fix(service): make service handler thread-safe #318

Merged
merged 4 commits into from
Oct 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions gnes/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def train(self, *args, **kwargs):
def dump(self, filename: str = None) -> None:
"""
Serialize the object to a binary file
:param filename: file path of the serialized file, if not given then `self.dump_full_path` is used
:param filename: file path of the serialized file, if not given then :py:attr:`dump_full_path` is used
"""
f = filename or self.dump_full_path
if not f:
Expand All @@ -260,7 +260,7 @@ def dump(self, filename: str = None) -> None:
def dump_yaml(self, filename: str = None) -> None:
"""
Serialize the object to a yaml file
:param filename: file path of the yaml file, if not given then `self.dump_yaml_path` is used
:param filename: file path of the yaml file, if not given then :py:attr:`dump_yaml_path` is used
"""
f = filename or self.yaml_full_path
if not f:
Expand Down
3 changes: 3 additions & 0 deletions gnes/cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,9 @@ def set_indexer_parser(parser=None):
if not parser:
parser = set_base_parser()
_set_sortable_service_parser(parser)
parser.add_argument('--as_response', type=ActionNoYes, default=True,
help='convert the message type from request to response after indexing. '
'turn it off if you want to chain other services after this index service.')

return parser

Expand Down
10 changes: 7 additions & 3 deletions gnes/flow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,14 @@ class Flow:
"""
GNES Flow: an intuitive way to build workflow for GNES.

You can use `.add()` then `.build()` to customize your own workflow.
You can use :py:meth:`.add()` then :py:meth:`.build()` to customize your own workflow.
For example:

.. highlight:: python
.. code-block:: python

from gnes.flow import Flow, Service as gfs

f = (Flow(check_version=False, route_table=True)
.add(gfs.Preprocessor, yaml_path='BasePreprocessor')
.add(gfs.Encoder, yaml_path='BaseEncoder')
Expand All @@ -76,9 +78,11 @@ class Flow:
flow.index()
...

You can also use the shortcuts, e.g. :py:meth:add_encoder , :py:meth:add_preprocessor

It is recommend to use flow in the context manner as showed above.
Note the different default copy behaviors in `.add()` and `.build()`:
`.add()` always copy the flow by default, whereas `.build()` modify the flow in place.
Note the different default copy behaviors in :py:meth:`.add()` and :py:meth:`.build()`:
:py:meth:`.add()` always copy the flow by default, whereas :py:meth:`.build()` modify the flow in place.
You can change this behavior by giving an argument `copy_flow=False`.

"""
Expand Down
5 changes: 3 additions & 2 deletions gnes/indexer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from functools import wraps
from typing import List, Any, Union, Callable, Tuple
from collections import defaultdict

import numpy as np

Expand All @@ -30,7 +30,8 @@ def __init__(self,
is_big_score_similar: bool = False,
*args, **kwargs):
"""
Base indexer, a valid indexer must implement `add` and `query` methods
Base indexer, a valid indexer must implement :py:meth:`add` and :py:meth:`query` methods

:type score_fn: advanced score function
:type normalize_fn: normalizing score function
:type is_big_score_similar: when set to true, then larger score means more similar
Expand Down
15 changes: 8 additions & 7 deletions gnes/service/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,9 @@ def register(self, msg_type: Union[List, Tuple, type]):
def decorator(f):
if isinstance(msg_type, list) or isinstance(msg_type, tuple):
for m in msg_type:
self.routes[m] = f
self.routes[m] = f.__name__
else:
self.routes[msg_type] = f
self.routes[msg_type] = f.__name__
return f

return decorator
Expand All @@ -187,11 +187,12 @@ def register_hook(self, hook_type: Union[str, Tuple[str]], only_when_verbose: bo

def decorator(f):
if isinstance(hook_type, str) and hook_type in self.hooks:
self.hooks[hook_type].append((f, only_when_verbose))
self.hooks[hook_type].append((f.__name__, only_when_verbose))
return f
elif isinstance(hook_type, list) or isinstance(hook_type, tuple):
for h in set(hook_type):
if h in self.hooks:
self.hooks[h].append((f, only_when_verbose))
self.hooks[h].append((f.__name__, only_when_verbose))
else:
raise AttributeError('hook type: %s is not supported' % h)
return f
Expand Down Expand Up @@ -222,7 +223,7 @@ def call_hooks(self, msg: 'gnes_pb2.Message', hook_type: Union[str, Tuple[str]],
for fn, only_verbose in hooks:
if (only_verbose and self.service_context.args.verbose) or (not only_verbose):
try:
fn(self.service_context, msg, *args, **kwargs)
fn(msg, *args, **kwargs)
except Exception as ex:
self.logger.warning('hook %s throws an exception, '
'this wont affect the server but you may want to pay attention' % fn)
Expand All @@ -249,7 +250,7 @@ def get_default_fn(m_type):
fn = get_default_fn(type(msg))

self.logger.info('handling message with %s' % fn.__name__)
return fn(self.service_context, msg)
return fn(msg)

def call_routes_send_back(self, msg: 'gnes_pb2.Message', out_sock):
try:
Expand Down Expand Up @@ -334,7 +335,7 @@ def __init__(self, args):
check_version=self.args.check_version,
timeout=self.args.timeout,
squeeze_pb=self.args.squeeze_pb)
# self._override_handler()
self._override_handler()

def _override_handler(self):
# replace the function name by the function itself
Expand Down
32 changes: 21 additions & 11 deletions gnes/service/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,23 @@ def post_init(self):
def _handler_index(self, msg: 'gnes_pb2.Message'):
# print('tid: %s, model: %r, self._tmp_a: %r' % (threading.get_ident(), self._model, self._tmp_a))
# if self._tmp_a != threading.get_ident():
# print('tid: %s, tmp_a: %r !!! %r' % (threading.get_ident(), self._tmp_a, self._handler_index))
# print('!!! tid: %s, tmp_a: %r %r' % (threading.get_ident(), self._tmp_a, self._handler_index))
from ..indexer.base import BaseChunkIndexer, BaseDocIndexer
if isinstance(self._model, BaseChunkIndexer):
self._handler_chunk_index(msg)
is_changed = self._handler_chunk_index(msg)
elif isinstance(self._model, BaseDocIndexer):
self._handler_doc_index(msg)
is_changed = self._handler_doc_index(msg)
else:
raise ServiceError(
'unsupported indexer, dont know how to use %s to handle this message' % self._model.__bases__)
msg.response.index.status = gnes_pb2.Response.SUCCESS
self.is_model_changed.set()

def _handler_chunk_index(self, msg: 'gnes_pb2.Message'):
if self.args.as_response:
msg.response.index.status = gnes_pb2.Response.SUCCESS

if is_changed:
self.is_model_changed.set()

def _handler_chunk_index(self, msg: 'gnes_pb2.Message') -> bool:
embed_info = []

for d in msg.request.index.docs:
Expand All @@ -59,13 +63,19 @@ def _handler_chunk_index(self, msg: 'gnes_pb2.Message'):
if embed_info:
vecs, doc_ids, offsets, weights = zip(*embed_info)
self._model.add(list(zip(doc_ids, offsets)), np.stack(vecs), weights)
return True
else:
self.logger.warning('chunks contain no embedded vectors, the indexer will do nothing')

def _handler_doc_index(self, msg: 'gnes_pb2.Message'):
self._model.add([d.doc_id for d in msg.request.index.docs],
[d for d in msg.request.index.docs],
[d.weight for d in msg.request.index.docs])
return False

def _handler_doc_index(self, msg: 'gnes_pb2.Message') -> bool:
if msg.request.index.docs:
self._model.add([d.doc_id for d in msg.request.index.docs],
[d for d in msg.request.index.docs],
[d.weight for d in msg.request.index.docs])
return True
else:
return False

def _put_result_into_message(self, results, msg: 'gnes_pb2.Message'):
msg.response.search.ClearField('topk_results')
Expand Down
21 changes: 12 additions & 9 deletions tests/test_gnes_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_flow5(self):
print(f._service_edges)
print(f.to_mermaid())

def _test_index_flow(self):
def _test_index_flow(self, backend):
for k in [self.indexer1_bin, self.indexer2_bin, self.encoder_bin]:
self.assertFalse(os.path.exists(k))

Expand All @@ -127,25 +127,28 @@ def _test_index_flow(self):
.add(gfs.Router, name='sync_barrier', yaml_path='BaseReduceRouter',
num_part=2, service_in=['vec_idx', 'doc_idx']))

with flow.build(backend='process') as f:
with flow.build(backend=backend) as f:
f.index(txt_file=self.test_file, batch_size=20)

for k in [self.indexer1_bin, self.indexer2_bin]:
self.assertTrue(os.path.exists(k))

def _test_query_flow(self):
def _test_query_flow(self, backend):
flow = (Flow(check_version=False, route_table=False)
.add(gfs.Preprocessor, name='prep', yaml_path='SentSplitPreprocessor')
.add(gfs.Encoder, yaml_path=os.path.join(self.dirname, 'yaml/flow-transformer.yml'))
.add(gfs.Indexer, name='vec_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-vecindex.yml'))
.add(gfs.Router, name='scorer', yaml_path=os.path.join(self.dirname, 'yaml/flow-score.yml'))
.add(gfs.Indexer, name='doc_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-dictindex.yml')))

with flow.build(backend='process') as f, open(self.test_file, encoding='utf8') as fp:
f.query(bytes_gen=[v.encode() for v in fp][:10])
with flow.build(backend=backend) as f, open(self.test_file, encoding='utf8') as fp:
f.query(bytes_gen=[v.encode() for v in fp][:3])

@unittest.SkipTest
# @unittest.SkipTest
def test_index_query_flow(self):
self._test_index_flow()
print('indexing finished')
self._test_query_flow()
self._test_index_flow('thread')
self._test_query_flow('thread')

def test_indexe_query_flow_proc(self):
self._test_index_flow('process')
self._test_query_flow('process')
9 changes: 1 addition & 8 deletions tests/yaml/flow-transformer.yml
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
!PipelineEncoder
components:
- !PyTorchTransformers
parameters:
model_dir: $TORCH_TRANSFORMERS_MODEL
model_name: bert-base-uncased
- !PoolingEncoder
parameters:
pooling_strategy: REDUCE_MEAN
backend: torch
- !CharEmbeddingEncoder {}
gnes_config:
name: my_transformer # a customized name
is_trained: true # indicate the model has been trained
Expand Down