Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
Merge pull request #284 from gnes-ai/fix-frontend-3
Browse files Browse the repository at this point in the history
fix(frontend): use poll for better efficiency
  • Loading branch information
mergify[bot] authored Sep 25, 2019
2 parents 1ce6350 + 7fdbbb4 commit c954e25
Show file tree
Hide file tree
Showing 6 changed files with 355 additions and 70 deletions.
9 changes: 6 additions & 3 deletions gnes/cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ def __init__(self, option_strings, dest, default=None, required=False, help=None
raise ValueError('yes/no arguments must be prefixed with --')

opt = opt[2:]
opts = ['--' + opt, '--no-' + opt]
opts = ['--' + opt, '--no-' + opt, '--no_' + opt]
super(ActionNoYes, self).__init__(opts, dest, nargs=0, const=None,
default=default, required=required, help=help)

def __call__(self, parser, namespace, values, option_strings=None):
if option_strings.startswith('--no-'):
if option_strings.startswith('--no-') or option_strings.startswith('--no_'):
setattr(namespace, self.dest, False)
else:
setattr(namespace, self.dest, True)
Expand Down Expand Up @@ -180,8 +180,11 @@ def set_service_parser(parser=None):
'mismatch raise an exception')
parser.add_argument('--identity', type=str, default='',
help='identity of the service, empty by default')
parser.add_argument('--route_table', action=ActionNoYes, default=True,
parser.add_argument('--route_table', action=ActionNoYes, default=False,
help='showing a route table with time cost after receiving the result')
parser.add_argument('--squeeze_pb', action=ActionNoYes, default=True,
help='sending bytes and ndarray separately apart from the protobuf message, '
'usually yields better network efficiency')
return parser


Expand Down
16 changes: 7 additions & 9 deletions gnes/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple, List, Union

import grpc
import zmq
from termcolor import colored
from typing import Tuple, List, Union, Type

from ..helper import set_logger
from ..proto import gnes_pb2_grpc
from ..proto import send_message, gnes_pb2, recv_message
from ..proto import send_message as _send_message, gnes_pb2, recv_message as _recv_message
from ..service.base import build_socket


Expand Down Expand Up @@ -97,15 +98,12 @@ def close(self):
self.receiver.close()
self.ctx.term()

def send_message(self, message: "gnes_pb2.Message", timeout: int = -1):
def send_message(self, message: "gnes_pb2.Message", **kwargs):
self.logger.debug('send message: %s' % message.envelope)
send_message(self.sender, message, timeout=timeout)
_send_message(self.sender, message, **kwargs)

def recv_message(self, timeout: int = -1) -> gnes_pb2.Message:
r = recv_message(
self.receiver,
timeout=timeout,
check_version=self.args.check_version)
def recv_message(self, **kwargs) -> gnes_pb2.Message:
r = _recv_message(self.receiver, **kwargs)
self.logger.debug('recv a message: %s' % r.envelope)
return r

Expand Down
190 changes: 157 additions & 33 deletions gnes/proto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import ctypes
import random
from typing import List, Iterator
from typing import List, Iterator, Tuple
from typing import Optional

import numpy as np
Expand Down Expand Up @@ -121,14 +121,156 @@ def merge_routes(msg: 'gnes_pb2.Message', prev_msgs: List['gnes_pb2.Message']):
msg.envelope.routes.extend(sorted(routes.values(), key=lambda x: (x.start_time.seconds, x.start_time.nanos)))


def send_message(sock: 'zmq.Socket', msg: 'gnes_pb2.Message', timeout: int = -1) -> None:
def check_msg_version(msg: 'gnes_pb2.Message'):
from .. import __version__, __proto_version__
if hasattr(msg.envelope, 'gnes_version'):
if not msg.envelope.gnes_version:
# only happen in unittest
default_logger.warning('incoming message contains empty "gnes_version", '
'you may ignore it in debug/unittest mode. '
'otherwise please check if frontend service set correct version')
elif __version__ != msg.envelope.gnes_version:
raise AttributeError('mismatched GNES version! '
'incoming message has GNES version %s, whereas local GNES version %s' % (
msg.envelope.gnes_version, __version__))

if hasattr(msg.envelope, 'proto_version'):
if not msg.envelope.proto_version:
# only happen in unittest
default_logger.warning('incoming message contains empty "proto_version", '
'you may ignore it in debug/unittest mode. '
'otherwise please check if frontend service set correct version')
elif __proto_version__ != msg.envelope.proto_version:
raise AttributeError('mismatched protobuf version! '
'incoming message has protobuf version %s, whereas local protobuf version %s' % (
msg.envelope.proto_version, __proto_version__))

if not hasattr(msg.envelope, 'proto_version') and not hasattr(msg.envelope, 'gnes_version'):
raise AttributeError('version_check=True locally, '
'but incoming message contains no version info in its envelope. '
'the message is probably sent from a very outdated GNES version')


def extract_bytes_from_msg(msg: 'gnes_pb2.Message') -> Tuple:
doc_bytes = []
chunk_bytes = []
doc_byte_type = b''
chunk_byte_type = b''

docs = msg.request.train.docs or msg.request.index.docs or [msg.request.search.query]
# for train request
for d in docs:
# oneof raw_data {
# string raw_text = 5;
# NdArray raw_image = 6;
# NdArray raw_video = 7;
# bytes raw_bytes = 8; // for other types
# }
dtype = d.WhichOneof('raw_data') or ''
doc_byte_type = dtype.encode()
if dtype == 'raw_bytes':
doc_bytes.append(d.raw_bytes)
d.ClearField('raw_bytes')
elif dtype == 'raw_image':
doc_bytes.append(d.raw_image.data)
d.raw_image.ClearField('data')
elif dtype == 'raw_video':
doc_bytes.append(d.raw_video.data)
d.raw_video.ClearField('data')
elif dtype == 'raw_text':
doc_bytes.append(d.raw_text.encode())
d.ClearField('raw_text')

for c in d.chunks:
# oneof content {
# string text = 2;
# NdArray blob = 3;
# bytes raw = 7;
# }
chunk_bytes.append(c.embedding.data)
c.embedding.ClearField('data')

ctype = c.WhichOneof('content') or ''
chunk_byte_type = ctype.encode()
if ctype == 'raw':
chunk_bytes.append(c.raw)
c.ClearField('raw')
elif ctype == 'blob':
chunk_bytes.append(c.blob.data)
c.blob.ClearField('data')
elif ctype == 'text':
chunk_bytes.append(c.text.encode())
c.ClearField('text')

return doc_bytes, doc_byte_type, chunk_bytes, chunk_byte_type


def fill_raw_bytes_to_msg(msg: 'gnes_pb2.Message', msg_data: List[bytes]):
doc_byte_type = msg_data[3].decode()
chunk_byte_type = msg_data[4].decode()
doc_bytes_len = int(msg_data[5])
chunk_bytes_len = int(msg_data[6])

doc_bytes = msg_data[7:(7 + doc_bytes_len)]
chunk_bytes = msg_data[(7 + doc_bytes_len):]

if len(chunk_bytes) != chunk_bytes_len:
raise ValueError('"chunk_bytes_len"=%d in message, but the actual length is %d' % (
chunk_bytes_len, len(chunk_bytes)))

c_idx = 0
d_idx = 0
docs = msg.request.train.docs or msg.request.index.docs or [msg.request.search.query]
for d in docs:
if doc_bytes and doc_bytes[d_idx]:
if doc_byte_type == 'raw':
d.raw_bytes = doc_bytes[d_idx]
d_idx += 1
elif doc_byte_type == 'raw_image':
d.raw_image.data = doc_bytes[d_idx]
d_idx += 1
elif doc_byte_type == 'raw_video':
d.raw_video.data = doc_bytes[d_idx]
d_idx += 1
elif doc_byte_type == 'raw_text':
d.raw_text = doc_bytes[d_idx].decode()
d_idx += 1

for c in d.chunks:
if chunk_bytes and chunk_bytes[c_idx]:
c.embedding.data = chunk_bytes[c_idx]
c_idx += 1

if chunk_byte_type == 'raw':
c.raw = chunk_bytes[c_idx]
c_idx += 1
elif chunk_byte_type == 'blob':
c.blob.data = chunk_bytes[c_idx]
c_idx += 1
elif chunk_byte_type == 'text':
c.text = chunk_bytes[c_idx].decode()
c_idx += 1


def send_message(sock: 'zmq.Socket', msg: 'gnes_pb2.Message', timeout: int = -1,
squeeze_pb: bool = False, **kwargs) -> None:
try:
if timeout > 0:
sock.setsockopt(zmq.SNDTIMEO, timeout)
else:
sock.setsockopt(zmq.SNDTIMEO, -1)

sock.send_multipart([msg.envelope.client_id.encode(), msg.SerializeToString()])
if not squeeze_pb:
sock.send_multipart([msg.envelope.client_id.encode(), b'0', msg.SerializeToString()])
else:
doc_bytes, doc_byte_type, chunk_bytes, chunk_byte_type = extract_bytes_from_msg(msg)
# now raw_bytes are removed from message, hoping for faster de/serialization
sock.send_multipart(
[msg.envelope.client_id.encode(), # 0
b'1', msg.SerializeToString(), # 1, 2
doc_byte_type, chunk_byte_type, # 3, 4
b'%d' % len(doc_bytes), b'%d' % len(chunk_bytes), # 5, 6
*doc_bytes, *chunk_bytes]) # 7, 8
except zmq.error.Again:
raise TimeoutError(
'cannot send message to sock %s after timeout=%dms, please check the following:'
Expand All @@ -140,44 +282,26 @@ def send_message(sock: 'zmq.Socket', msg: 'gnes_pb2.Message', timeout: int = -1)
sock.setsockopt(zmq.SNDTIMEO, -1)


def recv_message(sock: 'zmq.Socket', timeout: int = -1, check_version: bool = False) -> Optional['gnes_pb2.Message']:
def recv_message(sock: 'zmq.Socket', timeout: int = -1, check_version: bool = False, **kwargs) -> Optional[
'gnes_pb2.Message']:
response = []
try:
if timeout > 0:
sock.setsockopt(zmq.RCVTIMEO, timeout)
else:
sock.setsockopt(zmq.RCVTIMEO, -1)

_, msg_data = sock.recv_multipart()
msg = gnes_pb2.Message()
msg.ParseFromString(msg_data)

if check_version and msg.envelope:
from .. import __version__, __proto_version__
if hasattr(msg.envelope, 'gnes_version'):
if not msg.envelope.gnes_version:
# only happen in unittest
default_logger.warning('incoming message contains empty "gnes_version", '
'you may ignore it in debug/unittest mode. '
'otherwise please check if frontend service set correct version')
elif __version__ != msg.envelope.gnes_version:
raise AttributeError('mismatched GNES version! '
'incoming message has GNES version %s, whereas local GNES version %s' % (
msg.envelope.gnes_version, __version__))
if hasattr(msg.envelope, 'proto_version'):
if not msg.envelope.proto_version:
# only happen in unittest
default_logger.warning('incoming message contains empty "proto_version", '
'you may ignore it in debug/unittest mode. '
'otherwise please check if frontend service set correct version')
elif __proto_version__ != msg.envelope.proto_version:
raise AttributeError('mismatched protobuf version! '
'incoming message has protobuf version %s, whereas local protobuf version %s' % (
msg.envelope.proto_version, __proto_version__))
if not hasattr(msg.envelope, 'proto_version') and not hasattr(msg.envelope, 'gnes_version'):
raise AttributeError('version_check=True locally, '
'but incoming message contains no version info in its envelope. '
'the message is probably sent from a very outdated GNES version')
msg_data = sock.recv_multipart()
squeeze_pb = (msg_data[1] == b'1')
msg.ParseFromString(msg_data[2])

if check_version:
check_msg_version(msg)

# now we have a barebone msg, we need to fill in data
if squeeze_pb:
fill_raw_bytes_to_msg(msg, msg_data)
return msg

except ValueError:
Expand Down
12 changes: 8 additions & 4 deletions gnes/service/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,18 +238,18 @@ def call_routes_send_back(self, msg: 'gnes_pb2.Message', out_sock):
if ret is None:
# assume 'msg' is modified inside fn()
self.call_hooks(msg, hook_type='post', verbose=self.service_context.args.verbose)
send_message(out_sock, msg, timeout=self.service_context.args.timeout)
send_message(out_sock, msg, **self.service_context.send_recv_kwargs)
elif isinstance(ret, types.GeneratorType):
for r_msg in ret:
self.call_hooks(msg, hook_type='post', verbose=self.service_context.args.verbose)
send_message(out_sock, r_msg, timeout=self.service_context.args.timeout)
send_message(out_sock, r_msg, **self.service_context.send_recv_kwargs)
else:
raise ServiceError('unknown return type from the handler')

except BlockMessage:
pass
except EventLoopEnd:
send_message(out_sock, msg, timeout=self.service_context.args.timeout)
send_message(out_sock, msg, **self.service_context.send_recv_kwargs)
raise EventLoopEnd
except ServiceError as ex:
self.logger.error(ex, exc_info=True)
Expand Down Expand Up @@ -308,6 +308,10 @@ def __init__(self, args):
self._model = None
self.use_event_loop = True
self.ctrl_addr = 'tcp://%s:%d' % (self.default_host, self.args.port_ctrl)
self.send_recv_kwargs = dict(
check_version=self.args.check_version,
timeout=self.args.timeout,
squeeze_pb=self.args.squeeze_pb)

def run(self):
try:
Expand Down Expand Up @@ -410,7 +414,7 @@ def _run(self, ctx):
self.is_handler_done.clear()

# receive message
msg = recv_message(pull_sock, check_version=self.args.check_version)
msg = recv_message(pull_sock, **self.send_recv_kwargs)

# choose output sock
if msg.request and msg.request.WhichOneof('body') and \
Expand Down
34 changes: 13 additions & 21 deletions gnes/service/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def __init__(self, args):
self.logger = set_logger(FrontendService.__name__, args.verbose)
self.zmq_context = self.ZmqContext(args)
self.request_id_cnt = 0
self.send_recv_kwargs = dict(
check_version=self.args.check_version,
timeout=self.args.timeout,
squeeze_pb=self.args.squeeze_pb)

def add_envelope(self, body: 'gnes_pb2.Request', zmq_client: 'ZmqClient'):
msg = gnes_pb2.Message()
Expand Down Expand Up @@ -88,8 +92,8 @@ def remove_envelope(self, m: 'gnes_pb2.Message'):

def Call(self, request, context):
with self.zmq_context as zmq_client:
zmq_client.send_message(self.add_envelope(request, zmq_client), self.args.timeout)
return self.remove_envelope(zmq_client.recv_message(self.args.timeout))
zmq_client.send_message(self.add_envelope(request, zmq_client), **self.send_recv_kwargs)
return self.remove_envelope(zmq_client.recv_message(**self.send_recv_kwargs))

def Train(self, request, context):
return self.Call(request, context)
Expand All @@ -102,31 +106,19 @@ def Search(self, request, context):

def StreamCall(self, request_iterator, context):
with self.zmq_context as zmq_client:
# network traffic control
num_request = 0
max_outstanding = 500

for request in request_iterator:
timeout = 25
if self.args.timeout > 0:
timeout = min(0.5 * self.args.timeout, 50)

while num_request > 10:
try:
msg = zmq_client.recv_message(timeout)
yield self.remove_envelope(msg)
num_request -= 1
except TimeoutError:
if num_request > max_outstanding:
self.logger.warning("the network traffic exceed max outstanding (%d > %d)" % (
num_request, max_outstanding))
continue
break
zmq_client.send_message(self.add_envelope(request, zmq_client), -1)
zmq_client.send_message(self.add_envelope(request, zmq_client), **self.send_recv_kwargs)
num_request += 1

if zmq_client.receiver.poll(1):
msg = zmq_client.recv_message(**self.send_recv_kwargs)
num_request -= 1
yield self.remove_envelope(msg)

for _ in range(num_request):
msg = zmq_client.recv_message(self.args.timeout)
msg = zmq_client.recv_message(**self.send_recv_kwargs)
yield self.remove_envelope(msg)

class ZmqContext:
Expand Down
Loading

0 comments on commit c954e25

Please sign in to comment.