diff --git a/gnes/service/base.py b/gnes/service/base.py index 34d696ea..144f32fc 100644 --- a/gnes/service/base.py +++ b/gnes/service/base.py @@ -135,8 +135,9 @@ def build_socket(ctx: 'zmq.Context', host: str, port: int, socket_type: 'SocketT class MessageHandler: def __init__(self, mh: 'MessageHandler' = None): self.routes = {k: v for k, v in mh.routes.items()} if mh else {} - self.hook_fns = [] + self.hooks = {k: v for k, v in mh.hooks.items()} if mh else {'pre': [], 'post': []} self.logger = set_logger(self.__class__.__name__) + self.service_context = None def register(self, msg_type: Union[List, Tuple, type]): def decorator(f): @@ -149,6 +150,53 @@ def decorator(f): return decorator + def register_hook(self, hook_type: Union[str, Tuple[str]], only_when_verbose: bool = False): + """ + Register a function as a pre/post hook + + :param only_when_verbose: only call the hook when verbose is true + :param hook_type: possible values 'pre' or 'post' or ('pre', 'post') + """ + + def decorator(f): + if isinstance(hook_type, str) and hook_type in self.hooks: + self.hooks[hook_type].append((f, only_when_verbose)) + elif isinstance(hook_type, list) or isinstance(hook_type, tuple): + for h in set(hook_type): + if h in self.hooks: + self.hooks[h].append((f, only_when_verbose)) + else: + raise AttributeError('hook type: %s is not supported' % h) + return f + else: + raise TypeError('hook_type is in bad type: %s' % type(hook_type)) + + return decorator + + def call_hooks(self, msg: 'gnes_pb2.Message', hook_type: Union[str, Tuple[str]], verbose: bool, + *args, **kwargs): + """ + All post handler hooks are called after the handler is done but before + sending out the message to the next service. + All pre handler hooks are called after the service received a message + and before calling the message handler + """ + hooks = [] + if isinstance(hook_type, str) and hook_type in self.hooks: + hooks.extend(self.hooks[hook_type]) + elif isinstance(hook_type, list) or isinstance(hook_type, tuple): + for h in set(hook_type): + if h in self.hooks: + hooks.extend(self.hooks[h]) + else: + raise AttributeError('hook type: %s is not supported' % h) + else: + raise TypeError('hook_type is in bad type: %s' % type(hook_type)) + + for fn, only_verbose in hooks: + if (only_verbose and verbose) or (not only_verbose): + fn(self.service_context, msg, *args, **kwargs) + def get_serve_fn(self, msg: 'gnes_pb2.Message'): def get_default_fn(m_type): self.logger.warning('cant find handler for message type: %s, fall back to the default handler' % m_type) @@ -225,7 +273,7 @@ def __init__(self, args): self.identity = args.identity if 'identity' in args else None self.use_event_loop = True self.ctrl_addr = 'tcp://%s:%d' % (self.default_host, self.args.port_ctrl) - self.handler.hook_fns.extend([self._hook_warn_body_type_change, self._hook_sort_response]) + self.handler.service_context = self def run(self): try: @@ -259,16 +307,13 @@ def dump(self): else: self.logger.info('no dumping as "read_only" set to true.') - def post_handler(self, msg: 'gnes_pb2.Message', *args, **kwargs): - for fn in self.handler.hook_fns: - fn(msg, *args, **kwargs) - self.logger.info('hook handler %s is done' % fn.__name__) - + @handler.register_hook(hook_type='post') def _hook_warn_body_type_change(self, msg: 'gnes_pb2.Message', old_body_type: str, *args, **kwargs): new_type = msg.WhichOneof('body') if new_type != old_body_type: - self.logger.warning('message body is changed from %s to %s' % (old_body_type, new_type)) + self.logger.warning('message body type has changed from %s to %s' % (old_body_type, new_type)) + @handler.register_hook(hook_type='post') def _hook_sort_response(self, msg: 'gnes_pb2.Message', *args, **kwargs): if 'sorted_response' in self.args and self.args.sorted_response and msg.response.search.topk_results: msg.response.search.topk_results.sort(key=lambda x: x.score.value, @@ -279,6 +324,10 @@ def _hook_sort_response(self, msg: 'gnes_pb2.Message', *args, **kwargs): (len(msg.response.search.topk_results), 'descending' if msg.response.search.is_big_score_similar else 'ascending')) + @handler.register_hook(hook_type=('pre', 'post'), only_when_verbose=True) + def _hook_logging_msg(self, msg: 'gnes_pb2.Message', *args, **kwargs): + pass + def message_handler(self, msg: 'gnes_pb2.Message', out_sck, ctrl_sck): try: fn = self.handler.get_serve_fn(msg) @@ -297,11 +346,13 @@ def message_handler(self, msg: 'gnes_pb2.Message', out_sck, ctrl_sck): self.logger.info('handler %s is done' % fn.__name__) if ret is None: # assume 'msg' is modified inside fn() - self.post_handler(msg, old_body_type=old_type) + self.handler.call_hooks(msg, hook_type='post', verbose=self.args.verbose, + old_body_type=old_type) send_message(out_sock, msg, timeout=self.args.timeout) elif isinstance(ret, types.GeneratorType): for r_msg in ret: - self.post_handler(msg, old_body_type=old_type) + self.handler.call_hooks(msg, hook_type='post', verbose=self.args.verbose, + old_body_type=old_type) send_message(out_sock, r_msg, timeout=self.args.timeout) else: raise ServiceError('unknown return type from the handler: %s' % fn) @@ -353,6 +404,7 @@ def _run(self, ctx): with TimeContext('handling message', self.logger): self.is_handler_done.clear() msg = recv_message(pull_sock) + self.handler.call_hooks(msg, hook_type='pre', verbose=self.args.verbose) self.message_handler(msg, out_sock, ctrl_sock) self.is_handler_done.set() else: