diff --git a/gnes/client/stream.py b/gnes/client/stream.py index 4c7b423c..7adde956 100644 --- a/gnes/client/stream.py +++ b/gnes/client/stream.py @@ -44,27 +44,36 @@ class StreamingClient(GrpcClient): def __init__(self, args): super().__init__(args) - self._request_queue = queue.Queue() + self._request_queue = queue.Queue(maxsize=1000) self._is_streaming = threading.Event() self._dispatch_thread = threading.Thread(target=self._start) - self._dispatch_thread.setDaemon(1) - self._dispatch_thread.start() + self._dispatch_thread.setDaemon(True) def send_request(self, request): - self._request_queue.put(request) + self._request_queue.put(request, block=True) + + # create a new streaming call + if not self._is_streaming.is_set(): + self._dispatch_thread.start() def _start(self): self._is_streaming.set() - response_stream = self.stream_call(self._request_generator()) + self.stream_call(self._request_generator()) + self._is_streaming.clear() def _request_generator(self): while self._is_streaming.is_set(): try: - request = self._request_queue.get(block=True, timeout=1.0) + request = self._request_queue.get(block=True, timeout=5.0) + if request is None: + break yield request except queue.Empty: - pass + continue + except Exception as e: + print('exception: %s' % str(e)) + break @handler.register(NotImplementedError) def _handler_default(self, resp: 'gnes_pb2.Response'): diff --git a/gnes/preprocessor/video/shotdetect.py b/gnes/preprocessor/video/shotdetect.py index 8b8bda16..e076c8fd 100644 --- a/gnes/preprocessor/video/shotdetect.py +++ b/gnes/preprocessor/video/shotdetect.py @@ -31,7 +31,7 @@ def __init__(self, detect_method: str = 'threshold', frame_size: str = None, frame_rate: int = 10, - frame_num: int = -1, + vframes: int = -1, sframes: int = -1, drop_raw_data: bool = False, *args, @@ -42,7 +42,7 @@ def __init__(self, self.distance_metric = distance_metric self.detect_method = detect_method self.frame_rate = frame_rate - self.frame_num = frame_num + self.vframes = vframes self.sframes = sframes self.drop_raw_data = drop_raw_data self._detector_kwargs = kwargs @@ -83,11 +83,11 @@ def apply(self, doc: 'gnes_pb2.Document') -> None: input_data=doc.raw_bytes, scale=self.frame_size, fps=self.frame_rate, - vframes=self.frame_num) + vframes=self.vframes) elif raw_type == gnes_pb2.NdArray: video_frames = blob2array(doc.raw_video) - if self.frame_num > 0: - video_frames = video_frames[0:self.frame_num, :] + if self.vframes > 0: + video_frames = video_frames[0:self.vframes, :] num_frames = len(video_frames) if num_frames > 0: @@ -99,9 +99,12 @@ def apply(self, doc: 'gnes_pb2.Document') -> None: shot_len = len(frames) c.weight = shot_len / num_frames if self.sframes > 0 and shot_len > self.sframes: - start_id = int((shot_len - self.sframes) / 2) - end_id = start_id + self.sframes - frames = frames[start_id:end_id] + begin = 0 + if self.sframes < 3: + begin = (shot_len - self.sframes) // 2 + step = (shot_len) // self.sframes + frames = [frames[_] for _ in range(begin, shot_len, step)] + chunk_data = np.array(frames) c.blob.CopyFrom(array2blob(chunk_data)) else: diff --git a/gnes/service/base.py b/gnes/service/base.py index 18333445..c440951b 100644 --- a/gnes/service/base.py +++ b/gnes/service/base.py @@ -129,6 +129,12 @@ def build_socket(ctx: 'zmq.Context', host: str, port: int, socket_type: 'SocketT sock.setsockopt(zmq.SUBSCRIBE, identity.encode('ascii') if identity else b'') # sock.setsockopt(zmq.SUBSCRIBE, b'') + sock.setsockopt(zmq.RCVHWM, 100) + sock.setsockopt(zmq.RCVBUF, 512 * 1024 * 1024) # network buffer 512M + + sock.setsockopt(zmq.SNDHWM, 100) + sock.setsockopt(zmq.SNDBUF, 512 * 1024 * 1024) + return sock, sock.getsockopt_string(zmq.LAST_ENDPOINT) @@ -423,6 +429,8 @@ def _run(self, ctx): self.logger.info('break from the event loop') except ComponentNotLoad: self.logger.error('component can not be correctly loaded, terminated') + except Exception as e: + self.logger.error("exception occured: %s" % str(e), exc_info=True) finally: self.is_ready.set() self.is_event_loop.clear() diff --git a/gnes/service/frontend.py b/gnes/service/frontend.py index 3f861726..d9316031 100644 --- a/gnes/service/frontend.py +++ b/gnes/service/frontend.py @@ -138,22 +138,25 @@ def Search(self, request, context): def StreamCall(self, request_iterator, context): with self.zmq_context as zmq_client: num_request = 0 + # network traffic control + max_outstanding = 1000 for request in request_iterator: - zmq_client.send_message(self.add_envelope(request, zmq_client), -1) - num_request += 1 + timeout = 25 + if self.args.timeout > 0: + timeout = min(0.5 * self.args.timeout, 50) while num_request > 10: try: - # fetch response in real time to reduce network overload - timeout = 50 - if self.args.timeout > 0: - timeout = min(0.5 * self.args.timeout, 100) - msg = zmq_client.recv_message(timeout) yield self.remove_envelope(msg) num_request -= 1 except TimeoutError: + if num_request > max_outstanding: + self.logger.warning("the network traffic exceed max outstanding (%d > %d)" % (num_request, max_outstanding)) + continue break + zmq_client.send_message(self.add_envelope(request, zmq_client), -1) + num_request += 1 for _ in range(num_request): msg = zmq_client.recv_message(self.args.timeout)