Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
Merge pull request #227 from gnes-ai/feat-pre-post-hooks
Browse files Browse the repository at this point in the history
feat(service): add pre and post hooks to baseservice
  • Loading branch information
mergify[bot] authored Sep 6, 2019
2 parents fa6ab54 + a9d8279 commit 1b70fc7
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 47 deletions.
149 changes: 102 additions & 47 deletions gnes/service/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,9 @@ def build_socket(ctx: 'zmq.Context', host: str, port: int, socket_type: 'SocketT
class MessageHandler:
def __init__(self, mh: 'MessageHandler' = None):
self.routes = {k: v for k, v in mh.routes.items()} if mh else {}
self.hook_fns = []
self.hooks = {k: v for k, v in mh.hooks.items()} if mh else {'pre': [], 'post': []}
self.logger = set_logger(self.__class__.__name__)
self.service_context = None

def register(self, msg_type: Union[List, Tuple, type]):
def decorator(f):
Expand All @@ -149,7 +150,53 @@ def decorator(f):

return decorator

def get_serve_fn(self, msg: 'gnes_pb2.Message'):
def register_hook(self, hook_type: Union[str, Tuple[str]], only_when_verbose: bool = False):
"""
Register a function as a pre/post hook
:param only_when_verbose: only call the hook when verbose is true
:param hook_type: possible values 'pre' or 'post' or ('pre', 'post')
"""

def decorator(f):
if isinstance(hook_type, str) and hook_type in self.hooks:
self.hooks[hook_type].append((f, only_when_verbose))
elif isinstance(hook_type, list) or isinstance(hook_type, tuple):
for h in set(hook_type):
if h in self.hooks:
self.hooks[h].append((f, only_when_verbose))
else:
raise AttributeError('hook type: %s is not supported' % h)
return f
else:
raise TypeError('hook_type is in bad type: %s' % type(hook_type))

return decorator

def call_hooks(self, msg: 'gnes_pb2.Message', hook_type: Union[str, Tuple[str]], *args, **kwargs):
"""
All post handler hooks are called after the handler is done but before
sending out the message to the next service.
All pre handler hooks are called after the service received a message
and before calling the message handler
"""
hooks = []
if isinstance(hook_type, str) and hook_type in self.hooks:
hooks.extend(self.hooks[hook_type])
elif isinstance(hook_type, list) or isinstance(hook_type, tuple):
for h in set(hook_type):
if h in self.hooks:
hooks.extend(self.hooks[h])
else:
raise AttributeError('hook type: %s is not supported' % h)
else:
raise TypeError('hook_type is in bad type: %s' % type(hook_type))

for fn, only_verbose in hooks:
if (only_verbose and self.service_context.args.verbose) or (not only_verbose):
fn(self.service_context, msg, *args, **kwargs)

def call_routes(self, msg: 'gnes_pb2.Message'):
def get_default_fn(m_type):
self.logger.warning('cant find handler for message type: %s, fall back to the default handler' % m_type)
f = self.routes.get(m_type, self.routes[NotImplementedError])
Expand All @@ -168,7 +215,32 @@ def get_default_fn(m_type):
fn = get_default_fn(type(body))
else:
fn = get_default_fn(type(msg))
return fn

self.logger.info('handling message with %s' % fn.__name__)
return fn(self.service_context, msg)

def call_routes_send_back(self, msg: 'gnes_pb2.Message', out_sock):
try:
# NOTE that msg is mutable object, it may be modified in fn()
ret = self.call_routes(msg)
if ret is None:
# assume 'msg' is modified inside fn()
self.call_hooks(msg, hook_type='post', verbose=self.service_context.args.verbose)
send_message(out_sock, msg, timeout=self.service_context.args.timeout)
elif isinstance(ret, types.GeneratorType):
for r_msg in ret:
self.call_hooks(msg, hook_type='post', verbose=self.service_context.args.verbose)
send_message(out_sock, r_msg, timeout=self.service_context.args.timeout)
else:
raise ServiceError('unknown return type from the handler')

except BlockMessage:
pass
except EventLoopEnd:
send_message(out_sock, msg, timeout=self.service_context.args.timeout)
raise EventLoopEnd
except ServiceError as ex:
self.logger.error(ex, exc_info=True)


class ConcurrentService(type):
Expand Down Expand Up @@ -225,7 +297,6 @@ def __init__(self, args):
self.identity = args.identity if 'identity' in args else None
self.use_event_loop = True
self.ctrl_addr = 'tcp://%s:%d' % (self.default_host, self.args.port_ctrl)
self.handler.hook_fns.extend([self._hook_warn_body_type_change, self._hook_sort_response])

def run(self):
try:
Expand Down Expand Up @@ -259,16 +330,13 @@ def dump(self):
else:
self.logger.info('no dumping as "read_only" set to true.')

def post_handler(self, msg: 'gnes_pb2.Message', *args, **kwargs):
for fn in self.handler.hook_fns:
fn(msg, *args, **kwargs)
self.logger.info('hook handler %s is done' % fn.__name__)

def _hook_warn_body_type_change(self, msg: 'gnes_pb2.Message', old_body_type: str, *args, **kwargs):
@handler.register_hook(hook_type='post')
def _hook_warn_body_type_change(self, msg: 'gnes_pb2.Message', *args, **kwargs):
new_type = msg.WhichOneof('body')
if new_type != old_body_type:
self.logger.warning('message body is changed from %s to %s' % (old_body_type, new_type))
if new_type != self._msg_old_type:
self.logger.warning('message body type has changed from %s to %s' % (self._msg_old_type, new_type))

@handler.register_hook(hook_type='post')
def _hook_sort_response(self, msg: 'gnes_pb2.Message', *args, **kwargs):
if 'sorted_response' in self.args and self.args.sorted_response and msg.response.search.topk_results:
msg.response.search.topk_results.sort(key=lambda x: x.score.value,
Expand All @@ -279,44 +347,16 @@ def _hook_sort_response(self, msg: 'gnes_pb2.Message', *args, **kwargs):
(len(msg.response.search.topk_results),
'descending' if msg.response.search.is_big_score_similar else 'ascending'))

def message_handler(self, msg: 'gnes_pb2.Message', out_sck, ctrl_sck):
try:
fn = self.handler.get_serve_fn(msg)
if fn:
add_route(msg.envelope, self._model.__class__.__name__)
self.logger.info('handling a message with route: %s using handler %s' % (router2str(msg), fn.__name__))
old_type = msg.WhichOneof('body')
if msg.request and msg.request.WhichOneof('body') and \
type(getattr(msg.request, msg.request.WhichOneof('body'))) == gnes_pb2.Request.ControlRequest:
out_sock = ctrl_sck
else:
out_sock = out_sck
try:
# NOTE that msg is mutable object, it may be modified in fn()
ret = fn(self, msg)
self.logger.info('handler %s is done' % fn.__name__)
if ret is None:
# assume 'msg' is modified inside fn()
self.post_handler(msg, old_body_type=old_type)
send_message(out_sock, msg, timeout=self.args.timeout)
elif isinstance(ret, types.GeneratorType):
for r_msg in ret:
self.post_handler(msg, old_body_type=old_type)
send_message(out_sock, r_msg, timeout=self.args.timeout)
else:
raise ServiceError('unknown return type from the handler: %s' % fn)

except BlockMessage:
pass
except EventLoopEnd:
send_message(out_sock, msg, timeout=self.args.timeout)
raise EventLoopEnd
except ServiceError as ex:
self.logger.error(ex, exc_info=True)
@handler.register_hook(hook_type='pre')
def _hook_add_route(self, msg: 'gnes_pb2.Message', *args, **kwargs):
add_route(msg.envelope, self._model.__class__.__name__)
self._msg_old_type = msg.WhichOneof('body')
self.logger.info('a message in type: %s with route: %s' % (self._msg_old_type, router2str(msg)))

@zmqd.context()
def _run(self, ctx):
ctx.setsockopt(zmq.LINGER, 0)
self.handler.service_context = self
self.logger.info('bind sockets...')
in_sock, _ = build_socket(ctx, self.args.host_in, self.args.port_in, self.args.socket_in,
getattr(self, 'identity', None))
Expand Down Expand Up @@ -352,8 +392,23 @@ def _run(self, ctx):
if self.use_event_loop or pull_sock == ctrl_sock:
with TimeContext('handling message', self.logger):
self.is_handler_done.clear()

# receive message
msg = recv_message(pull_sock)
self.message_handler(msg, out_sock, ctrl_sock)

# choose output sock
if msg.request and msg.request.WhichOneof('body') and \
isinstance(getattr(msg.request, msg.request.WhichOneof('body')),
gnes_pb2.Request.ControlRequest):
o_sock = ctrl_sock
else:
o_sock = out_sock

# call pre-hooks
self.handler.call_hooks(msg, hook_type='pre')
# call main handler and send result back
self.handler.call_routes_send_back(msg, o_sock)

self.is_handler_done.set()
else:
self.logger.warning(
Expand Down
22 changes: 22 additions & 0 deletions gnes/service/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,25 @@ def _handler_train(self, msg: 'gnes_pb2.Message'):
@handler.register(gnes_pb2.Request.QueryRequest)
def _handler_search(self, msg: 'gnes_pb2.Message'):
self.embed_chunks_in_docs(msg.request.search.query, is_input_list=False)

@handler.register_hook(hook_type=('pre', 'post'), only_when_verbose=True)
def _hook_debug_msg(self, msg: 'gnes_pb2.Message', *args, **kwargs):
from pprint import pformat

debug_kv = {
'envelope': lambda: msg.envelope,
'num_docs': lambda: len(msg.request.index.docs),
'num_chunks in doc[0]': lambda: len(msg.request.index.docs[0].chunks),
'docs[0].chunks[0].content_type': lambda: msg.request.index.docs[0].chunks[0].WhichOneof('content'),
'docs[0].chunks[0].weight': lambda: msg.request.index.docs[0].chunks[0].weight,
'docs[0].chunks[0].embedding': lambda: blob2array(msg.request.index.docs[0].chunks[0].embedding),
'docs[0].chunks[0].embedding[0]': lambda: blob2array(msg.request.index.docs[0].chunks[0].embedding)[0]
}
debug_info = {}
for k, v in debug_kv.items():
try:
r = v()
except Exception as ex:
r = 'fail to get the value, reason: %s' % ex
debug_info[k] = r
self.logger.info(pformat(debug_info))

0 comments on commit 1b70fc7

Please sign in to comment.