diff --git a/gnes/service/grpc.py b/gnes/service/grpc.py index 35fab6b9..8b5b1eff 100644 --- a/gnes/service/grpc.py +++ b/gnes/service/grpc.py @@ -28,33 +28,6 @@ __all__ = ['GRPCFrontend'] -class ZmqContext: - """The zmq context class.""" - - def __init__(self, args): - """Database connection context. - - Args: - servers: a list of config dicts for connecting to database - dbapi_name: the name of database engine - """ - self.args = args - - self.tlocal = threading.local() - self.tlocal.client = None - - def __enter__(self): - """Enter the context.""" - client = ZmqClient(self.args) - self.tlocal.client = client - return client - - def __exit__(self, exc_type, exc_value, exc_traceback): - """Exit the context.""" - self.tlocal.client.close() - self.tlocal.client = None - - class ZmqClient: def __init__(self, args): @@ -87,61 +60,8 @@ def recv_message(self, timeout: int = -1) -> gnes_pb2.Message: return recv_message(self.receiver, timeout=timeout) -class GNESServicer(gnes_pb2_grpc.GnesRPCServicer): - - def __init__(self, args): - self.args = args - self.logger = set_logger(self.__class__.__name__, args.verbose) - self.zmq_context = ZmqContext(args) - - def add_envelope(self, body: 'gnes_pb2.Request', zmq_client: 'ZmqClient'): - msg = gnes_pb2.Message() - msg.envelope.client_id = zmq_client.identity if zmq_client.identity else '' - if body.request_id: - msg.envelope.request_id = body.request_id - else: - msg.envelope.request_id = str(uuid.uuid4()) - self.logger.warning('request_id is missing, filled it with a random uuid!') - msg.envelope.part_id = 1 - msg.envelope.num_part.append(1) - msg.envelope.timeout = 5000 - r = msg.envelope.routes.add() - r.service = GRPCFrontend.__name__ - r.timestamp.GetCurrentTime() - msg.request.CopyFrom(body) - return msg - - def remove_envelope(self, m: 'gnes_pb2.Message'): - resp = m.response - resp.request_id = m.envelope.request_id - return resp - - def Call(self, request, context): - self.logger.info('received a new request: %s' % request.request_id or 'EMPTY_REQUEST_ID') - with self.zmq_context as zmq_client: - zmq_client.send_message(self.add_envelope(request, zmq_client), self.args.timeout) - return self.remove_envelope(zmq_client.recv_message(self.args.timeout)) - - def Train(self, request, context): - return self.Call(request, context) - - def Index(self, request, context): - return self.Call(request, context) - - def Search(self, request, context): - return self.Call(request, context) - - def StreamCall(self, request_iterator, context): - num_result = 0 - with self.zmq_context as zmq_client: - for request in request_iterator: - zmq_client.send_message(self.add_envelope(request, zmq_client), self.args.timeout) - num_result += 1 - for _ in range(num_result): - yield self.remove_envelope(zmq_client.recv_message(self.args.timeout)) - - class GRPCFrontend: + def __init__(self, args): self.logger = set_logger(self.__class__.__name__, args.verbose) self.server = grpc.server( @@ -149,9 +69,8 @@ def __init__(self, args): options=[('grpc.max_send_message_length', args.max_message_size * 1024 * 1024), ('grpc.max_receive_message_length', args.max_message_size * 1024 * 1024)]) self.logger.info('start a grpc server with %d workers' % args.max_concurrency) - gnes_pb2_grpc.add_GnesRPCServicer_to_server(GNESServicer(args), self.server) + gnes_pb2_grpc.add_GnesRPCServicer_to_server(self.GNESServicer(args), self.server) - # Start GRPC Server self.bind_address = '{0}:{1}'.format(args.grpc_host, args.grpc_port) self.server.add_insecure_port(self.bind_address) @@ -162,3 +81,75 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.server.stop(None) + + class GNESServicer(gnes_pb2_grpc.GnesRPCServicer): + + def __init__(self, args): + self.args = args + self.logger = set_logger(self.__class__.__name__, args.verbose) + self.zmq_context = self.ZmqContext(args) + + def add_envelope(self, body: 'gnes_pb2.Request', zmq_client: 'ZmqClient'): + msg = gnes_pb2.Message() + msg.envelope.client_id = zmq_client.identity if zmq_client.identity else '' + if body.request_id: + msg.envelope.request_id = body.request_id + else: + msg.envelope.request_id = str(uuid.uuid4()) + self.logger.warning('request_id is missing, filled it with a random uuid!') + msg.envelope.part_id = 1 + msg.envelope.num_part.append(1) + msg.envelope.timeout = 5000 + r = msg.envelope.routes.add() + r.service = GRPCFrontend.__name__ + r.timestamp.GetCurrentTime() + msg.request.CopyFrom(body) + return msg + + def remove_envelope(self, m: 'gnes_pb2.Message'): + resp = m.response + resp.request_id = m.envelope.request_id + return resp + + def Call(self, request, context): + self.logger.info('received a new request: %s' % request.request_id or 'EMPTY_REQUEST_ID') + with self.zmq_context as zmq_client: + zmq_client.send_message(self.add_envelope(request, zmq_client), self.args.timeout) + return self.remove_envelope(zmq_client.recv_message(self.args.timeout)) + + def Train(self, request, context): + return self.Call(request, context) + + def Index(self, request, context): + return self.Call(request, context) + + def Search(self, request, context): + return self.Call(request, context) + + def StreamCall(self, request_iterator, context): + num_result = 0 + with self.zmq_context as zmq_client: + for request in request_iterator: + zmq_client.send_message(self.add_envelope(request, zmq_client), self.args.timeout) + num_result += 1 + for _ in range(num_result): + yield self.remove_envelope(zmq_client.recv_message(self.args.timeout)) + + class ZmqContext: + """The zmq context class.""" + + def __init__(self, args): + self.args = args + self.tlocal = threading.local() + self.tlocal.client = None + + def __enter__(self): + """Enter the context.""" + client = ZmqClient(self.args) + self.tlocal.client = client + return client + + def __exit__(self, exc_type, exc_value, exc_traceback): + """Exit the context.""" + self.tlocal.client.close() + self.tlocal.client = None