diff --git a/gnes/client/base.py b/gnes/client/base.py index 5a09e7e9..4826e714 100644 --- a/gnes/client/base.py +++ b/gnes/client/base.py @@ -17,6 +17,7 @@ import grpc import zmq from termcolor import colored +from typing import Tuple, List, Union, Type from ..helper import set_logger from ..proto import gnes_pb2_grpc @@ -24,6 +25,43 @@ from ..service.base import build_socket +class ResponseHandler: + def __init__(self, h: 'ResponseHandler' = None): + self.routes = {k: v for k, v in h.routes.items()} if h else {} + self.logger = set_logger(self.__class__.__name__) + self._context = None + + def register(self, resp_type: Union[List, Tuple, type]): + def decorator(f): + if isinstance(resp_type, list) or isinstance(resp_type, tuple): + for t in resp_type: + self.routes[t] = f + else: + self.routes[resp_type] = f + return f + + return decorator + + def call_routes(self, resp: 'gnes_pb2.Response'): + def get_default_fn(r_type): + self.logger.warning('cant find handler for response type: %s, fall back to the default handler' % r_type) + f = self.routes.get(r_type, self.routes[NotImplementedError]) + return f + + self.logger.info('received a response for request %d' % resp.request_id) + if resp.WhichOneof('body'): + body = getattr(resp, resp.WhichOneof('body')) + resp_type = type(body) + + if resp_type in self.routes: + fn = self.routes.get(resp_type) + else: + fn = get_default_fn(type(resp)) + + self.logger.info('handling response with %s' % fn.__name__) + return fn(self._context, resp) + + class ZmqClient: def __init__(self, args): @@ -63,11 +101,18 @@ def recv_message(self, timeout: int = -1) -> gnes_pb2.Message: class GrpcClient: + """ + A Base Unary gRPC client which the other client application can build from. + + """ + + handler = ResponseHandler() def __init__(self, args): self.args = args self.logger = set_logger(self.__class__.__name__, self.args.verbose) - self.logger.info('setting up channel...') + self.logger.info('setting up grpc insecure channel...') + # A gRPC channel provides a connection to a remote gRPC server. self._channel = grpc.insecure_channel( '%s:%d' % (self.args.grpc_host, self.args.grpc_port), options={ @@ -77,19 +122,44 @@ def __init__(self, args): ) self.logger.info('waiting channel to be ready...') grpc.channel_ready_future(self._channel).result() - self.logger.info('making stub...') + self.logger.critical('gnes client ready!') + + # create new stub + self.logger.info('create new stub...') self._stub = gnes_pb2_grpc.GnesRPCStub(self._channel) - self.logger.critical('ready!') - def send_request(self, request): + # attache response handler + self.handler._context = self + + def call(self, request): + resp = self._stub.call(request) + self.handler.call_routes(resp) + return resp + + def stream_call(self, request_iterator): + response_stream = self._stub.StreamCall(request_iterator) + for resp in response_stream: + self.handler.call_routes(resp) + + @handler.register(NotImplementedError) + def _handler_default(self, msg: 'gnes_pb2.Response'): raise NotImplementedError - def close(self): - self._channel.close() - self._stub = None + @handler.register(gnes_pb2.Response) + def _handler_response_default(self, msg: 'gnes_pb2.Response'): + pass def __enter__(self): + self.open() return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() + + def open(self): + pass + + def close(self): + self._channel.close() + self._stub = None + self.total_response = 0 \ No newline at end of file diff --git a/gnes/client/stream.py b/gnes/client/stream.py index 7630b3a8..75cfb618 100644 --- a/gnes/client/stream.py +++ b/gnes/client/stream.py @@ -13,91 +13,63 @@ # See the License for the specific language governing permissions and # limitations under the License. +import time +import threading import queue from concurrent import futures -from .base import GrpcClient +from .base import GrpcClient, ResponseHandler -class _SyncStream: - - def __init__(self, stub, handle_response): - self._stub = stub - self._handle_response = handle_response - self._is_streaming = False - self._request_queue = queue.Queue() - - def send_request(self, request): - self._request_queue.put(request) - - def start(self): - self._is_streaming = True - response_stream = self._stub.StreamCall(self._request_generator()) - for resp in response_stream: - self._handle_response(self, resp) - - def stop(self): - self._is_streaming = False - - def _request_generator(self): - while self._is_streaming: - try: - request = self._request_queue.get(block=True, timeout=1.0) - yield request - except queue.Empty: - pass - - -class UnarySyncClient(GrpcClient): +class SyncClient(GrpcClient): + handler = ResponseHandler(GrpcClient.handler) def __init__(self, args): super().__init__(args) self._pool = futures.ThreadPoolExecutor( max_workers=self.args.max_concurrency) - self._response_callbacks = [] def send_request(self, request): # Send requests in seperate threads to support multiple outstanding rpcs - self._pool.submit(self._dispatch_request, request) + self._pool.submit(self.call, request) def close(self): self._pool.shutdown(wait=True) super().close() - def _dispatch_request(self, request): - resp = self._stub.Call(request) - self._handle_response(self, resp) - - def _handle_response(self, client, response): - for callback in self._response_callbacks: - callback(client, response) - - def add_response_callback(self, callback): - """callback will be invoked as callback(client, response)""" - self._response_callbacks.append(callback) - - -class StreamingClient(UnarySyncClient): +class StreamingClient(GrpcClient): + handler = ResponseHandler(GrpcClient.handler) def __init__(self, args): super().__init__(args) - self._streams = [ - _SyncStream(self._stub, self._handle_response) - for _ in range(self.args.max_concurrency) - ] - self._curr_stream = 0 + self._request_queue = queue.Queue() + self._is_streaming = threading.Event() + + self._dispatch_thread = threading.Thread(target=self._start) + self._dispatch_thread.setDaemon(1) + self._dispatch_thread.start() def send_request(self, request): - # Use a round_robin scheduler to determine what stream to send on - self._streams[self._curr_stream].send_request(request) - self._curr_stream = (self._curr_stream + 1) % len(self._streams) + self._request_queue.put(request) + + def _start(self): + self._is_streaming.set() + response_stream = self.stream_call(self._request_generator()) + + def _request_generator(self): + while self._is_streaming.is_set(): + try: + request = self._request_queue.get(block=True, timeout=1.0) + yield request + except queue.Empty: + pass - def start(self): - for stream in self._streams: - self._pool.submit(stream.start) + @handler.register(NotImplementedError) + def _handler_default(self, resp: 'gnes_pb2.Response'): + raise NotImplementedError def close(self): - for stream in self._streams: - stream.stop() + self._is_streaming.clear() + self._dispatch_thread.join() super().close()