Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
refactor(proto): refactor request stream call
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Jul 24, 2019
1 parent 216cecc commit a1a2b02
Show file tree
Hide file tree
Showing 10 changed files with 99 additions and 78 deletions.
3 changes: 3 additions & 0 deletions gnes/cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,13 @@ def _set_grpc_parser(parser=None):


def set_grpc_frontend_parser(parser=None):
from ..service.base import SocketType
if not parser:
parser = set_base_parser()
_set_client_parser(parser)
_set_grpc_parser(parser)
parser.set_defaults(socket_in=SocketType.PULL_BIND,
socket_out=SocketType.PUSH_BIND)
parser.add_argument('--max_concurrency', type=int, default=10,
help='maximum concurrent client allowed')
parser.add_argument('--max_send_size', type=int, default=100,
Expand Down
12 changes: 5 additions & 7 deletions gnes/client/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,15 @@ def __init__(self, args):
stub = gnes_pb2_grpc.GnesRPCStub(channel)

if args.mode == 'train':
for req in RequestGenerator.train(all_bytes, args.batch_size):
resp = stub._Call(req)
print(resp)
resp = stub.RequestStreamCall(RequestGenerator.train(all_bytes, args.batch_size))
print(resp)
elif args.mode == 'index':
for req in RequestGenerator.index(all_bytes, args.batch_size):
resp = stub._Call(req)
print(resp)
resp = stub.RequestStreamCall(RequestGenerator.train(all_bytes, args.batch_size))
print(resp)
elif args.mode == 'query':
for idx, q in enumerate(all_bytes):
for req in RequestGenerator.query(q, args.top_k):
resp = stub._Call(req)
resp = stub.Call(req)
print(resp)
print('query %d result: %s' % (idx, resp))
input('press any key to continue...')
4 changes: 1 addition & 3 deletions gnes/client/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,7 @@ async def init(loop):
return srv

def stub_call(req):
res_f = None
for r in req:
res_f = stub._Call(r)
res_f = stub.RequestStreamCall(req)
return json.loads(MessageToJson(res_f))

with grpc.insecure_channel(
Expand Down
14 changes: 11 additions & 3 deletions gnes/proto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,33 +28,41 @@

class RequestGenerator:
@staticmethod
def index(data: List[bytes], batch_size: int = 0, *args, **kwargs):
def index(data: List[bytes], batch_size: int = 0, start_id: int = 0, *args, **kwargs):

for pi in batch_iterator(data, batch_size):
req = gnes_pb2.Request()
req.request_id = start_id
for raw_bytes in pi:
d = req.index.docs.add()
d.raw_bytes = raw_bytes
d.weight = 1.0
yield req
start_id += 1

@staticmethod
def train(data: List[bytes], batch_size: int = 0, *args, **kwargs):
def train(data: List[bytes], batch_size: int = 0, start_id: int = 0, *args, **kwargs):
for pi in batch_iterator(data, batch_size):
req = gnes_pb2.Request()
req.request_id = str(start_id)
for raw_bytes in pi:
d = req.train.docs.add()
d.raw_bytes = raw_bytes
yield req
start_id += 1
req = gnes_pb2.Request()
req.request_id = str(start_id)
req.train.flush = True
yield req
start_id += 1

@staticmethod
def query(query: bytes, top_k: int, *args, **kwargs):
def query(query: bytes, top_k: int, start_id: int = 0, *args, **kwargs):
if top_k <= 0:
raise ValueError('"top_k: %d" is not a valid number' % top_k)

req = gnes_pb2.Request()
req.request_id = start_id
req.search.query.raw_bytes = query
req.search.top_k = top_k
yield req
Expand Down
6 changes: 2 additions & 4 deletions gnes/proto/gnes.proto
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,9 @@ service GnesRPC {
}
rpc Query (Request) returns (Response) {
}
rpc _Call (Request) returns (Response) {
rpc Call (Request) returns (Response) {
}
rpc TrainStream (stream Request) returns (Response) {
}
rpc IndexStream (stream Request) returns (Response) {
rpc RequestStreamCall (stream Request) returns (Response) {
}
}

21 changes: 6 additions & 15 deletions gnes/proto/gnes_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

37 changes: 10 additions & 27 deletions gnes/proto/gnes_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,13 @@ def __init__(self, channel):
request_serializer=gnes__pb2.Request.SerializeToString,
response_deserializer=gnes__pb2.Response.FromString,
)
self._Call = channel.unary_unary(
'/gnes.GnesRPC/_Call',
self.Call = channel.unary_unary(
'/gnes.GnesRPC/Call',
request_serializer=gnes__pb2.Request.SerializeToString,
response_deserializer=gnes__pb2.Response.FromString,
)
self.TrainStream = channel.stream_unary(
'/gnes.GnesRPC/TrainStream',
request_serializer=gnes__pb2.Request.SerializeToString,
response_deserializer=gnes__pb2.Response.FromString,
)
self.IndexStream = channel.stream_unary(
'/gnes.GnesRPC/IndexStream',
self.RequestStreamCall = channel.stream_unary(
'/gnes.GnesRPC/RequestStreamCall',
request_serializer=gnes__pb2.Request.SerializeToString,
response_deserializer=gnes__pb2.Response.FromString,
)
Expand Down Expand Up @@ -72,21 +67,14 @@ def Query(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def _Call(self, request, context):
# missing associated documentation comment in .proto file
pass
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def TrainStream(self, request_iterator, context):
def Call(self, request, context):
# missing associated documentation comment in .proto file
pass
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def IndexStream(self, request_iterator, context):
def RequestStreamCall(self, request_iterator, context):
# missing associated documentation comment in .proto file
pass
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
Expand All @@ -111,18 +99,13 @@ def add_GnesRPCServicer_to_server(servicer, server):
request_deserializer=gnes__pb2.Request.FromString,
response_serializer=gnes__pb2.Response.SerializeToString,
),
'_Call': grpc.unary_unary_rpc_method_handler(
servicer._Call,
request_deserializer=gnes__pb2.Request.FromString,
response_serializer=gnes__pb2.Response.SerializeToString,
),
'TrainStream': grpc.stream_unary_rpc_method_handler(
servicer.TrainStream,
'Call': grpc.unary_unary_rpc_method_handler(
servicer.Call,
request_deserializer=gnes__pb2.Request.FromString,
response_serializer=gnes__pb2.Response.SerializeToString,
),
'IndexStream': grpc.stream_unary_rpc_method_handler(
servicer.IndexStream,
'RequestStreamCall': grpc.stream_unary_rpc_method_handler(
servicer.RequestStreamCall,
request_deserializer=gnes__pb2.Request.FromString,
response_serializer=gnes__pb2.Response.SerializeToString,
),
Expand Down
31 changes: 14 additions & 17 deletions gnes/service/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
__all__ = ['GRPCFrontend']


class ZmqContext(object):
class ZmqContext:
"""The zmq context class."""

def __init__(self, args):
Expand Down Expand Up @@ -111,32 +111,29 @@ def add_envelope(self, body: 'gnes_pb2.Request', zmq_client: 'ZmqClient'):
msg.request.CopyFrom(body)
return msg

def _Call(self, request, context):
def remove_envelope(self, m: 'gnes_pb2.Message'):
resp = m.response
resp.request_id = m.envelope.request_id
return resp

def Call(self, request, context):
self.logger.info('received a new request: %s' % request.request_id or 'EMPTY_REQUEST_ID')
with self.zmq_context as zmq_client:
msg = self.add_envelope(request, zmq_client)
zmq_client.send_message(msg, self.args.timeout)
resp = zmq_client.recv_message(self.args.timeout)
self.logger.info("received message done!")
return resp.response
zmq_client.send_message(self.add_envelope(request, zmq_client), self.args.timeout)
return self.remove_envelope(zmq_client.recv_message(self.args.timeout))

def Train(self, request, context):
return self._Call(request, context)
return self.Call(request, context)

def Index(self, request, context):
return self._Call(request, context)
return self.Call(request, context)

def Search(self, request, context):
return self._Call(request, context)

def TrainStream(self, request_iterator, context):
for request in request_iterator:
ret = self._Call(request, context)
return ret
return self.Call(request, context)

def IndexStream(self, request_iterator, context):
def RequestStreamCall(self, request_iterator, context):
for request in request_iterator:
ret = self._Call(request, context)
ret = self.Call(request, context)
return ret


Expand Down
4 changes: 2 additions & 2 deletions shell/make-proto.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#!/usr/bin/env bash

SRC_DIR=../gnes/proto/
#PLUGIN_PATH=/Volumes/TOSHIBA-4T/Documents/grpc/bins/opt/grpc_python_plugin
PLUGIN_PATH=/user/local/grpc/bins/opt/grpc_python_plugin
PLUGIN_PATH=/Volumes/TOSHIBA-4T/Documents/grpc/bins/opt/grpc_python_plugin
#PLUGIN_PATH=/user/local/grpc/bins/opt/grpc_python_plugin

protoc -I ${SRC_DIR} --python_out=${SRC_DIR} --grpc_python_out=${SRC_DIR} --plugin=protoc-gen-grpc_python=${PLUGIN_PATH} ${SRC_DIR}gnes.proto

Expand Down
Loading

0 comments on commit a1a2b02

Please sign in to comment.