From 66aec9c94ae44486a56fbc9c8667e18c24e01c51 Mon Sep 17 00:00:00 2001 From: hanhxiao Date: Thu, 25 Jul 2019 19:11:44 +0800 Subject: [PATCH] feat(grpc): add StreamCall and decouple send and receive --- gnes/client/cli.py | 4 ++-- gnes/client/http.py | 2 +- gnes/proto/gnes.proto | 2 +- gnes/proto/gnes_pb2.py | 8 ++++---- gnes/proto/gnes_pb2_grpc.py | 10 +++++----- gnes/service/grpc.py | 12 ++++++++---- tests/test_stream_grpc.py | 29 +++++++++++++++++------------ 7 files changed, 38 insertions(+), 29 deletions(-) diff --git a/gnes/client/cli.py b/gnes/client/cli.py index 2b6a8b34..1c0d638c 100644 --- a/gnes/client/cli.py +++ b/gnes/client/cli.py @@ -41,10 +41,10 @@ def __init__(self, args): stub = gnes_pb2_grpc.GnesRPCStub(channel) if args.mode == 'train': - resp = stub.RequestStreamCall(RequestGenerator.train(all_bytes, args.batch_size)) + resp = list(stub.StreamCall(RequestGenerator.train(all_bytes, args.batch_size)))[-1] print(resp) elif args.mode == 'index': - resp = stub.RequestStreamCall(RequestGenerator.train(all_bytes, args.batch_size)) + resp = list(stub.StreamCall(RequestGenerator.train(all_bytes, args.batch_size)))[-1] print(resp) elif args.mode == 'query': for idx, q in enumerate(all_bytes): diff --git a/gnes/client/http.py b/gnes/client/http.py index a66c05a0..d7059d23 100644 --- a/gnes/client/http.py +++ b/gnes/client/http.py @@ -91,7 +91,7 @@ async def init(loop): return srv def stub_call(req): - res_f = stub.RequestStreamCall(req) + res_f = list(stub.StreamCall(req))[-1] return json.loads(MessageToJson(res_f)) with grpc.insecure_channel( diff --git a/gnes/proto/gnes.proto b/gnes/proto/gnes.proto index 1ee72b89..94475df0 100644 --- a/gnes/proto/gnes.proto +++ b/gnes/proto/gnes.proto @@ -207,7 +207,7 @@ service GnesRPC { } rpc Call (Request) returns (Response) { } - rpc RequestStreamCall (stream Request) returns (Response) { + rpc StreamCall (stream Request) returns (stream Response) { } } diff --git a/gnes/proto/gnes_pb2.py b/gnes/proto/gnes_pb2.py index b50bfb3d..f8af7170 100644 --- a/gnes/proto/gnes_pb2.py +++ b/gnes/proto/gnes_pb2.py @@ -21,7 +21,7 @@ package='gnes', syntax='proto3', serialized_options=None, - serialized_pb=_b('\n\ngnes.proto\x12\x04gnes\x1a\x1fgoogle/protobuf/timestamp.proto\"9\n\x07NdArray\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\x11\n\x05shape\x18\x02 \x03(\rB\x02\x10\x01\x12\r\n\x05\x64type\x18\x03 \x01(\t\"\xbc\x01\n\x05\x43hunk\x12\x0e\n\x06\x64oc_id\x18\x01 \x01(\x04\x12\x0e\n\x04text\x18\x02 \x01(\tH\x00\x12\x1d\n\x04\x62lob\x18\x03 \x01(\x0b\x32\r.gnes.NdArrayH\x00\x12\x11\n\toffset_1d\x18\x04 \x01(\r\x12)\n\toffset_nd\x18\x05 \x01(\x0b\x32\x16.gnes.Chunk.Coordinate\x12\x0e\n\x06weight\x18\x06 \x01(\x02\x1a\x1b\n\nCoordinate\x12\r\n\x01x\x18\x01 \x03(\rB\x02\x10\x01\x42\t\n\x07\x63ontent\"\xe2\x02\n\x08\x44ocument\x12\x0e\n\x06\x64oc_id\x18\x01 \x01(\x04\x12\x1b\n\x06\x63hunks\x18\x02 \x03(\x0b\x32\x0b.gnes.Chunk\x12\'\n\x10\x63hunk_embeddings\x18\x03 \x01(\x0b\x32\r.gnes.NdArray\x12(\n\x08\x64oc_type\x18\x04 \x01(\x0e\x32\x16.gnes.Document.DocType\x12\x11\n\tmeta_info\x18\x05 \x01(\x0c\x12\x12\n\x08raw_text\x18\x06 \x01(\tH\x00\x12\"\n\traw_image\x18\x07 \x01(\x0b\x32\r.gnes.NdArrayH\x00\x12\"\n\traw_video\x18\x08 \x01(\x0b\x32\r.gnes.NdArrayH\x00\x12\x13\n\traw_bytes\x18\t \x01(\x0cH\x00\x12\x0e\n\x06weight\x18\n \x01(\x02\"6\n\x07\x44ocType\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x08\n\x04TEXT\x10\x01\x12\t\n\x05IMAGE\x10\x02\x12\t\n\x05VIDEO\x10\x03\x42\n\n\x08raw_data\"\xd4\x01\n\x08\x45nvelope\x12\x11\n\tclient_id\x18\x01 \x01(\t\x12\x12\n\nrequest_id\x18\x02 \x01(\t\x12\x0f\n\x07part_id\x18\x03 \x01(\r\x12\x10\n\x08num_part\x18\x04 \x01(\r\x12\x0f\n\x07timeout\x18\x05 \x01(\r\x12$\n\x06routes\x18\x06 \x03(\x0b\x32\x14.gnes.Envelope.route\x1aG\n\x05route\x12\x0f\n\x07service\x18\x01 \x01(\t\x12-\n\ttimestamp\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\"y\n\x07Message\x12 \n\x08\x65nvelope\x18\x01 \x01(\x0b\x32\x0e.gnes.Envelope\x12 \n\x07request\x18\x02 \x01(\x0b\x32\r.gnes.RequestH\x00\x12\"\n\x08response\x18\x03 \x01(\x0b\x32\x0e.gnes.ResponseH\x00\x42\x06\n\x04\x62ody\"\xf6\x03\n\x07Request\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12+\n\x05train\x18\x02 \x01(\x0b\x32\x1a.gnes.Request.TrainRequestH\x00\x12+\n\x05index\x18\x03 \x01(\x0b\x32\x1a.gnes.Request.IndexRequestH\x00\x12,\n\x06search\x18\x04 \x01(\x0b\x32\x1a.gnes.Request.QueryRequestH\x00\x12/\n\x07\x63ontrol\x18\x05 \x01(\x0b\x32\x1c.gnes.Request.ControlRequestH\x00\x1a;\n\x0cTrainRequest\x12\x1c\n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x0e.gnes.Document\x12\r\n\x05\x66lush\x18\x02 \x01(\x08\x1a,\n\x0cIndexRequest\x12\x1c\n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x0e.gnes.Document\x1a<\n\x0cQueryRequest\x12\x1d\n\x05query\x18\x01 \x01(\x0b\x32\x0e.gnes.Document\x12\r\n\x05top_k\x18\x02 \x01(\r\x1am\n\x0e\x43ontrolRequest\x12\x35\n\x07\x63ommand\x18\x01 \x01(\x0e\x32$.gnes.Request.ControlRequest.Command\"$\n\x07\x43ommand\x12\r\n\tTERMINATE\x10\x00\x12\n\n\x06STATUS\x10\x01\x42\x06\n\x04\x62ody\"\xc4\x06\n\x08Response\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12-\n\x05train\x18\x02 \x01(\x0b\x32\x1c.gnes.Response.TrainResponseH\x00\x12-\n\x05index\x18\x03 \x01(\x0b\x32\x1c.gnes.Response.IndexResponseH\x00\x12.\n\x06search\x18\x04 \x01(\x0b\x32\x1c.gnes.Response.QueryResponseH\x00\x12\x31\n\x07\x63ontrol\x18\x05 \x01(\x0b\x32\x1e.gnes.Response.ControlResponseH\x00\x1a\x36\n\rTrainResponse\x12%\n\x06status\x18\x01 \x01(\x0e\x32\x15.gnes.Response.Status\x1a\x36\n\rIndexResponse\x12%\n\x06status\x18\x01 \x01(\x0e\x32\x15.gnes.Response.Status\x1a\x38\n\x0f\x43ontrolResponse\x12%\n\x06status\x18\x01 \x01(\x0e\x32\x15.gnes.Response.Status\x1a\x81\x03\n\rQueryResponse\x12%\n\x06status\x18\x01 \x01(\x0e\x32\x15.gnes.Response.Status\x12\r\n\x05top_k\x18\x02 \x01(\r\x12?\n\x0ctopk_results\x18\x03 \x03(\x0b\x32).gnes.Response.QueryResponse.ScoredResult\x12\x39\n\x05level\x18\x04 \x01(\x0e\x32*.gnes.Response.QueryResponse.ResponseLevel\x1a{\n\x0cScoredResult\x12\x1c\n\x05\x63hunk\x18\x01 \x01(\x0b\x32\x0b.gnes.ChunkH\x00\x12\x1d\n\x03\x64oc\x18\x02 \x01(\x0b\x32\x0e.gnes.DocumentH\x00\x12\r\n\x05score\x18\x03 \x01(\x02\x12\x17\n\x0fscore_explained\x18\x04 \x01(\tB\x06\n\x04\x62ody\"A\n\rResponseLevel\x12\t\n\x05\x43HUNK\x10\x00\x12\x17\n\x13\x44OCUMENT_NOT_FILLED\x10\x01\x12\x0c\n\x08\x44OCUMENT\x10\x02\"-\n\x06Status\x12\x0b\n\x07SUCCESS\x10\x00\x12\t\n\x05\x45RROR\x10\x01\x12\x0b\n\x07PENDING\x10\x02\x42\x06\n\x04\x62ody2\xe8\x01\n\x07GnesRPC\x12(\n\x05Train\x12\r.gnes.Request\x1a\x0e.gnes.Response\"\x00\x12(\n\x05Index\x12\r.gnes.Request\x1a\x0e.gnes.Response\"\x00\x12(\n\x05Query\x12\r.gnes.Request\x1a\x0e.gnes.Response\"\x00\x12\'\n\x04\x43\x61ll\x12\r.gnes.Request\x1a\x0e.gnes.Response\"\x00\x12\x36\n\x11RequestStreamCall\x12\r.gnes.Request\x1a\x0e.gnes.Response\"\x00(\x01\x62\x06proto3') + serialized_pb=_b('\n\ngnes.proto\x12\x04gnes\x1a\x1fgoogle/protobuf/timestamp.proto\"9\n\x07NdArray\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\x11\n\x05shape\x18\x02 \x03(\rB\x02\x10\x01\x12\r\n\x05\x64type\x18\x03 \x01(\t\"\xbc\x01\n\x05\x43hunk\x12\x0e\n\x06\x64oc_id\x18\x01 \x01(\x04\x12\x0e\n\x04text\x18\x02 \x01(\tH\x00\x12\x1d\n\x04\x62lob\x18\x03 \x01(\x0b\x32\r.gnes.NdArrayH\x00\x12\x11\n\toffset_1d\x18\x04 \x01(\r\x12)\n\toffset_nd\x18\x05 \x01(\x0b\x32\x16.gnes.Chunk.Coordinate\x12\x0e\n\x06weight\x18\x06 \x01(\x02\x1a\x1b\n\nCoordinate\x12\r\n\x01x\x18\x01 \x03(\rB\x02\x10\x01\x42\t\n\x07\x63ontent\"\xe2\x02\n\x08\x44ocument\x12\x0e\n\x06\x64oc_id\x18\x01 \x01(\x04\x12\x1b\n\x06\x63hunks\x18\x02 \x03(\x0b\x32\x0b.gnes.Chunk\x12\'\n\x10\x63hunk_embeddings\x18\x03 \x01(\x0b\x32\r.gnes.NdArray\x12(\n\x08\x64oc_type\x18\x04 \x01(\x0e\x32\x16.gnes.Document.DocType\x12\x11\n\tmeta_info\x18\x05 \x01(\x0c\x12\x12\n\x08raw_text\x18\x06 \x01(\tH\x00\x12\"\n\traw_image\x18\x07 \x01(\x0b\x32\r.gnes.NdArrayH\x00\x12\"\n\traw_video\x18\x08 \x01(\x0b\x32\r.gnes.NdArrayH\x00\x12\x13\n\traw_bytes\x18\t \x01(\x0cH\x00\x12\x0e\n\x06weight\x18\n \x01(\x02\"6\n\x07\x44ocType\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x08\n\x04TEXT\x10\x01\x12\t\n\x05IMAGE\x10\x02\x12\t\n\x05VIDEO\x10\x03\x42\n\n\x08raw_data\"\xd4\x01\n\x08\x45nvelope\x12\x11\n\tclient_id\x18\x01 \x01(\t\x12\x12\n\nrequest_id\x18\x02 \x01(\t\x12\x0f\n\x07part_id\x18\x03 \x01(\r\x12\x10\n\x08num_part\x18\x04 \x01(\r\x12\x0f\n\x07timeout\x18\x05 \x01(\r\x12$\n\x06routes\x18\x06 \x03(\x0b\x32\x14.gnes.Envelope.route\x1aG\n\x05route\x12\x0f\n\x07service\x18\x01 \x01(\t\x12-\n\ttimestamp\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\"y\n\x07Message\x12 \n\x08\x65nvelope\x18\x01 \x01(\x0b\x32\x0e.gnes.Envelope\x12 \n\x07request\x18\x02 \x01(\x0b\x32\r.gnes.RequestH\x00\x12\"\n\x08response\x18\x03 \x01(\x0b\x32\x0e.gnes.ResponseH\x00\x42\x06\n\x04\x62ody\"\xf6\x03\n\x07Request\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12+\n\x05train\x18\x02 \x01(\x0b\x32\x1a.gnes.Request.TrainRequestH\x00\x12+\n\x05index\x18\x03 \x01(\x0b\x32\x1a.gnes.Request.IndexRequestH\x00\x12,\n\x06search\x18\x04 \x01(\x0b\x32\x1a.gnes.Request.QueryRequestH\x00\x12/\n\x07\x63ontrol\x18\x05 \x01(\x0b\x32\x1c.gnes.Request.ControlRequestH\x00\x1a;\n\x0cTrainRequest\x12\x1c\n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x0e.gnes.Document\x12\r\n\x05\x66lush\x18\x02 \x01(\x08\x1a,\n\x0cIndexRequest\x12\x1c\n\x04\x64ocs\x18\x01 \x03(\x0b\x32\x0e.gnes.Document\x1a<\n\x0cQueryRequest\x12\x1d\n\x05query\x18\x01 \x01(\x0b\x32\x0e.gnes.Document\x12\r\n\x05top_k\x18\x02 \x01(\r\x1am\n\x0e\x43ontrolRequest\x12\x35\n\x07\x63ommand\x18\x01 \x01(\x0e\x32$.gnes.Request.ControlRequest.Command\"$\n\x07\x43ommand\x12\r\n\tTERMINATE\x10\x00\x12\n\n\x06STATUS\x10\x01\x42\x06\n\x04\x62ody\"\xc4\x06\n\x08Response\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12-\n\x05train\x18\x02 \x01(\x0b\x32\x1c.gnes.Response.TrainResponseH\x00\x12-\n\x05index\x18\x03 \x01(\x0b\x32\x1c.gnes.Response.IndexResponseH\x00\x12.\n\x06search\x18\x04 \x01(\x0b\x32\x1c.gnes.Response.QueryResponseH\x00\x12\x31\n\x07\x63ontrol\x18\x05 \x01(\x0b\x32\x1e.gnes.Response.ControlResponseH\x00\x1a\x36\n\rTrainResponse\x12%\n\x06status\x18\x01 \x01(\x0e\x32\x15.gnes.Response.Status\x1a\x36\n\rIndexResponse\x12%\n\x06status\x18\x01 \x01(\x0e\x32\x15.gnes.Response.Status\x1a\x38\n\x0f\x43ontrolResponse\x12%\n\x06status\x18\x01 \x01(\x0e\x32\x15.gnes.Response.Status\x1a\x81\x03\n\rQueryResponse\x12%\n\x06status\x18\x01 \x01(\x0e\x32\x15.gnes.Response.Status\x12\r\n\x05top_k\x18\x02 \x01(\r\x12?\n\x0ctopk_results\x18\x03 \x03(\x0b\x32).gnes.Response.QueryResponse.ScoredResult\x12\x39\n\x05level\x18\x04 \x01(\x0e\x32*.gnes.Response.QueryResponse.ResponseLevel\x1a{\n\x0cScoredResult\x12\x1c\n\x05\x63hunk\x18\x01 \x01(\x0b\x32\x0b.gnes.ChunkH\x00\x12\x1d\n\x03\x64oc\x18\x02 \x01(\x0b\x32\x0e.gnes.DocumentH\x00\x12\r\n\x05score\x18\x03 \x01(\x02\x12\x17\n\x0fscore_explained\x18\x04 \x01(\tB\x06\n\x04\x62ody\"A\n\rResponseLevel\x12\t\n\x05\x43HUNK\x10\x00\x12\x17\n\x13\x44OCUMENT_NOT_FILLED\x10\x01\x12\x0c\n\x08\x44OCUMENT\x10\x02\"-\n\x06Status\x12\x0b\n\x07SUCCESS\x10\x00\x12\t\n\x05\x45RROR\x10\x01\x12\x0b\n\x07PENDING\x10\x02\x42\x06\n\x04\x62ody2\xe3\x01\n\x07GnesRPC\x12(\n\x05Train\x12\r.gnes.Request\x1a\x0e.gnes.Response\"\x00\x12(\n\x05Index\x12\r.gnes.Request\x1a\x0e.gnes.Response\"\x00\x12(\n\x05Query\x12\r.gnes.Request\x1a\x0e.gnes.Response\"\x00\x12\'\n\x04\x43\x61ll\x12\r.gnes.Request\x1a\x0e.gnes.Response\"\x00\x12\x31\n\nStreamCall\x12\r.gnes.Request\x1a\x0e.gnes.Response\"\x00(\x01\x30\x01\x62\x06proto3') , dependencies=[google_dot_protobuf_dot_timestamp__pb2.DESCRIPTOR,]) @@ -1238,7 +1238,7 @@ index=0, serialized_options=None, serialized_start=2343, - serialized_end=2575, + serialized_end=2570, methods=[ _descriptor.MethodDescriptor( name='Train', @@ -1277,8 +1277,8 @@ serialized_options=None, ), _descriptor.MethodDescriptor( - name='RequestStreamCall', - full_name='gnes.GnesRPC.RequestStreamCall', + name='StreamCall', + full_name='gnes.GnesRPC.StreamCall', index=4, containing_service=None, input_type=_REQUEST, diff --git a/gnes/proto/gnes_pb2_grpc.py b/gnes/proto/gnes_pb2_grpc.py index 5d6b519e..613c646a 100644 --- a/gnes/proto/gnes_pb2_grpc.py +++ b/gnes/proto/gnes_pb2_grpc.py @@ -34,8 +34,8 @@ def __init__(self, channel): request_serializer=gnes__pb2.Request.SerializeToString, response_deserializer=gnes__pb2.Response.FromString, ) - self.RequestStreamCall = channel.stream_unary( - '/gnes.GnesRPC/RequestStreamCall', + self.StreamCall = channel.stream_stream( + '/gnes.GnesRPC/StreamCall', request_serializer=gnes__pb2.Request.SerializeToString, response_deserializer=gnes__pb2.Response.FromString, ) @@ -74,7 +74,7 @@ def Call(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') - def RequestStreamCall(self, request_iterator, context): + def StreamCall(self, request_iterator, context): # missing associated documentation comment in .proto file pass context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -104,8 +104,8 @@ def add_GnesRPCServicer_to_server(servicer, server): request_deserializer=gnes__pb2.Request.FromString, response_serializer=gnes__pb2.Response.SerializeToString, ), - 'RequestStreamCall': grpc.stream_unary_rpc_method_handler( - servicer.RequestStreamCall, + 'StreamCall': grpc.stream_stream_rpc_method_handler( + servicer.StreamCall, request_deserializer=gnes__pb2.Request.FromString, response_serializer=gnes__pb2.Response.SerializeToString, ), diff --git a/gnes/service/grpc.py b/gnes/service/grpc.py index d5df3001..8f234501 100644 --- a/gnes/service/grpc.py +++ b/gnes/service/grpc.py @@ -131,10 +131,14 @@ def Index(self, request, context): def Search(self, request, context): return self.Call(request, context) - def RequestStreamCall(self, request_iterator, context): - for request in request_iterator: - ret = self.Call(request, context) - return ret + def StreamCall(self, request_iterator, context): + num_result = 0 + with self.zmq_context as zmq_client: + for request in request_iterator: + zmq_client.send_message(self.add_envelope(request, zmq_client), self.args.timeout) + num_result += 1 + for _ in range(num_result): + yield self.remove_envelope(zmq_client.recv_message(self.args.timeout)) class GRPCFrontend: diff --git a/tests/test_stream_grpc.py b/tests/test_stream_grpc.py index 949df74a..1278534c 100644 --- a/tests/test_stream_grpc.py +++ b/tests/test_stream_grpc.py @@ -57,15 +57,10 @@ def test_grpc_frontend(self): ('grpc.max_receive_message_length', 70 * 1024 * 1024)]) as channel: stub = gnes_pb2_grpc.GnesRPCStub(channel) with TimeContext('sync call'): # about 5s - resp = stub.RequestStreamCall(RequestGenerator.train(self.all_bytes, 1)) + resp = list(stub.StreamCall(RequestGenerator.train(self.all_bytes, 1)))[-1] self.assertEqual(resp.request_id, str(len(self.all_bytes))) # idx start with 0, but +1 for final FLUSH - # test async calls - with TimeContext('async call'): # immeidiately returns 0.001 s - resp = stub.RequestStreamCall.future(RequestGenerator.train(self.all_bytes, 1)) - self.assertEqual(resp.result().request_id, str(len(self.all_bytes))) - @unittest.mock.patch.dict(os.environ, {'http_proxy': '', 'https_proxy': ''}) def test_async_block(self): args = set_grpc_frontend_parser().parse_args([ @@ -91,9 +86,19 @@ def test_async_block(self): options=[('grpc.max_send_message_length', 70 * 1024 * 1024), ('grpc.max_receive_message_length', 70 * 1024 * 1024)]) as channel: stub = gnes_pb2_grpc.GnesRPCStub(channel) - with TimeContext('sync call'): # about 5s - resp = stub.RequestStreamCall.future(RequestGenerator.train(self.all_bytes, 1)) - - self.assertEqual(resp.result().request_id, str(len(self.all_bytes))) - - self.assertEqual(resp.request_id, str(len(self.all_bytes2))) # idx start with 0, but +1 for final FLUSH + id = 0 + with TimeContext('non-blocking call'): # about 26s = 32s (total) - 3*2s (overlap) + resp = stub.StreamCall(RequestGenerator.train(self.all_bytes2, 1)) + for r in resp: + self.assertEqual(r.request_id, str(id)) + id += 1 + + id = 0 + with TimeContext('blocking call'): # should be 32 s + for r in RequestGenerator.train(self.all_bytes2, 1): + resp = stub.Call(r) + self.assertEqual(resp.request_id, str(id)) + id += 1 + # self.assertEqual(resp.result().request_id, str(len(self.all_bytes))) + + # self.assertEqual(resp.request_id, str(len(self.all_bytes2))) # idx start with 0, but +1 for final FLUSH