Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
feat(grpc): add StreamCall and decouple send and receive
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Jul 25, 2019
1 parent 9973f60 commit 66aec9c
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 29 deletions.
4 changes: 2 additions & 2 deletions gnes/client/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def __init__(self, args):
stub = gnes_pb2_grpc.GnesRPCStub(channel)

if args.mode == 'train':
resp = stub.RequestStreamCall(RequestGenerator.train(all_bytes, args.batch_size))
resp = list(stub.StreamCall(RequestGenerator.train(all_bytes, args.batch_size)))[-1]
print(resp)
elif args.mode == 'index':
resp = stub.RequestStreamCall(RequestGenerator.train(all_bytes, args.batch_size))
resp = list(stub.StreamCall(RequestGenerator.train(all_bytes, args.batch_size)))[-1]
print(resp)
elif args.mode == 'query':
for idx, q in enumerate(all_bytes):
Expand Down
2 changes: 1 addition & 1 deletion gnes/client/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async def init(loop):
return srv

def stub_call(req):
res_f = stub.RequestStreamCall(req)
res_f = list(stub.StreamCall(req))[-1]
return json.loads(MessageToJson(res_f))

with grpc.insecure_channel(
Expand Down
2 changes: 1 addition & 1 deletion gnes/proto/gnes.proto
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ service GnesRPC {
}
rpc Call (Request) returns (Response) {
}
rpc RequestStreamCall (stream Request) returns (Response) {
rpc StreamCall (stream Request) returns (stream Response) {
}
}

8 changes: 4 additions & 4 deletions gnes/proto/gnes_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions gnes/proto/gnes_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def __init__(self, channel):
request_serializer=gnes__pb2.Request.SerializeToString,
response_deserializer=gnes__pb2.Response.FromString,
)
self.RequestStreamCall = channel.stream_unary(
'/gnes.GnesRPC/RequestStreamCall',
self.StreamCall = channel.stream_stream(
'/gnes.GnesRPC/StreamCall',
request_serializer=gnes__pb2.Request.SerializeToString,
response_deserializer=gnes__pb2.Response.FromString,
)
Expand Down Expand Up @@ -74,7 +74,7 @@ def Call(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def RequestStreamCall(self, request_iterator, context):
def StreamCall(self, request_iterator, context):
# missing associated documentation comment in .proto file
pass
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
Expand Down Expand Up @@ -104,8 +104,8 @@ def add_GnesRPCServicer_to_server(servicer, server):
request_deserializer=gnes__pb2.Request.FromString,
response_serializer=gnes__pb2.Response.SerializeToString,
),
'RequestStreamCall': grpc.stream_unary_rpc_method_handler(
servicer.RequestStreamCall,
'StreamCall': grpc.stream_stream_rpc_method_handler(
servicer.StreamCall,
request_deserializer=gnes__pb2.Request.FromString,
response_serializer=gnes__pb2.Response.SerializeToString,
),
Expand Down
12 changes: 8 additions & 4 deletions gnes/service/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,14 @@ def Index(self, request, context):
def Search(self, request, context):
return self.Call(request, context)

def RequestStreamCall(self, request_iterator, context):
for request in request_iterator:
ret = self.Call(request, context)
return ret
def StreamCall(self, request_iterator, context):
num_result = 0
with self.zmq_context as zmq_client:
for request in request_iterator:
zmq_client.send_message(self.add_envelope(request, zmq_client), self.args.timeout)
num_result += 1
for _ in range(num_result):
yield self.remove_envelope(zmq_client.recv_message(self.args.timeout))


class GRPCFrontend:
Expand Down
29 changes: 17 additions & 12 deletions tests/test_stream_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,10 @@ def test_grpc_frontend(self):
('grpc.max_receive_message_length', 70 * 1024 * 1024)]) as channel:
stub = gnes_pb2_grpc.GnesRPCStub(channel)
with TimeContext('sync call'): # about 5s
resp = stub.RequestStreamCall(RequestGenerator.train(self.all_bytes, 1))
resp = list(stub.StreamCall(RequestGenerator.train(self.all_bytes, 1)))[-1]

self.assertEqual(resp.request_id, str(len(self.all_bytes))) # idx start with 0, but +1 for final FLUSH

# test async calls
with TimeContext('async call'): # immeidiately returns 0.001 s
resp = stub.RequestStreamCall.future(RequestGenerator.train(self.all_bytes, 1))
self.assertEqual(resp.result().request_id, str(len(self.all_bytes)))

@unittest.mock.patch.dict(os.environ, {'http_proxy': '', 'https_proxy': ''})
def test_async_block(self):
args = set_grpc_frontend_parser().parse_args([
Expand All @@ -91,9 +86,19 @@ def test_async_block(self):
options=[('grpc.max_send_message_length', 70 * 1024 * 1024),
('grpc.max_receive_message_length', 70 * 1024 * 1024)]) as channel:
stub = gnes_pb2_grpc.GnesRPCStub(channel)
with TimeContext('sync call'): # about 5s
resp = stub.RequestStreamCall.future(RequestGenerator.train(self.all_bytes, 1))

self.assertEqual(resp.result().request_id, str(len(self.all_bytes)))

self.assertEqual(resp.request_id, str(len(self.all_bytes2))) # idx start with 0, but +1 for final FLUSH
id = 0
with TimeContext('non-blocking call'): # about 26s = 32s (total) - 3*2s (overlap)
resp = stub.StreamCall(RequestGenerator.train(self.all_bytes2, 1))
for r in resp:
self.assertEqual(r.request_id, str(id))
id += 1

id = 0
with TimeContext('blocking call'): # should be 32 s
for r in RequestGenerator.train(self.all_bytes2, 1):
resp = stub.Call(r)
self.assertEqual(resp.request_id, str(id))
id += 1
# self.assertEqual(resp.result().request_id, str(len(self.all_bytes)))

# self.assertEqual(resp.request_id, str(len(self.all_bytes2))) # idx start with 0, but +1 for final FLUSH

0 comments on commit 66aec9c

Please sign in to comment.