From de5b336888fa1f6a90fe63c854f44987c3bbb977 Mon Sep 17 00:00:00 2001 From: felix Date: Mon, 16 Sep 2019 10:45:42 +0800 Subject: [PATCH] refactor(shot-detector): merge code from hub --- gnes/preprocessor/video/shotdetect.py | 68 +++++++++++++++++---------- 1 file changed, 42 insertions(+), 26 deletions(-) diff --git a/gnes/preprocessor/video/shotdetect.py b/gnes/preprocessor/video/shotdetect.py index f734afa1..9177b938 100644 --- a/gnes/preprocessor/video/shotdetect.py +++ b/gnes/preprocessor/video/shotdetect.py @@ -13,35 +13,36 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List - import numpy as np +from typing import List -from ..base import BaseVideoPreprocessor -from ..helper import compute_descriptor, compare_descriptor, detect_peak_boundary, compare_ecr -from ..io_utils import video as video_util -from ...proto import gnes_pb2, array2blob +from gnes.preprocessor.base import BaseVideoPreprocessor +from gnes.proto import gnes_pb2, array2blob, blob2array +from gnes.preprocessor.io_utils import video +from gnes.preprocessor.helper import compute_descriptor, compare_descriptor, detect_peak_boundary, compare_ecr class ShotDetectPreprocessor(BaseVideoPreprocessor): store_args_kwargs = True def __init__(self, - frame_size: str = '192:168', + scale: str = None, descriptor: str = 'block_hsv_histogram', distance_metric: str = 'bhattacharya', detect_method: str = 'threshold', frame_rate: int = 10, frame_num: int = -1, + drop_raw_data: bool = False, *args, **kwargs): super().__init__(*args, **kwargs) - self.frame_size = frame_size + self.scale = scale self.descriptor = descriptor self.distance_metric = distance_metric self.detect_method = detect_method self.frame_rate = frame_rate self.frame_num = frame_num + self.drop_raw_data = drop_raw_data self._detector_kwargs = kwargs def detect_shots(self, frames: 'np.ndarray') -> List[List['np.ndarray']]: @@ -71,23 +72,38 @@ def detect_shots(self, frames: 'np.ndarray') -> List[List['np.ndarray']]: def apply(self, doc: 'gnes_pb2.Document') -> None: super().apply(doc) - if doc.raw_bytes: - all_frames = video_util.capture_frames( - input_data=doc.raw_bytes, - scale=self.frame_size, - fps=self.frame_rate, - vframes=self.frame_num) - num_frames = len(all_frames) - assert num_frames > 0 - shots = self.detect_shots(all_frames) + video_frames = [] - for ci, frames in enumerate(shots): - c = doc.chunks.add() - c.doc_id = doc.doc_id - # chunk_data = np.concatenate(frames, axis=0) - chunk_data = np.array(frames) - c.blob.CopyFrom(array2blob(chunk_data)) - c.offset = ci - c.weight = len(frames) / num_frames + if doc.WhichOneof('raw_data'): + raw_type = type(getattr(doc, doc.WhichOneof('raw_data'))) + if doc.raw_bytes: + video_frames = video.capture_frames( + input_data=doc.raw_bytes, + scale=self.scale, + fps=self.frame_rate, + vframes=self.frame_num) + elif raw_type == gnes_pb2.NdArray: + video_frames = blob2array(doc.raw_video) + if self.frame_num > 0: + stepwise = len(video_frames) / self.frame_num + video_frames = video_frames[0::stepwise, :] + + num_frames = len(video_frames) + if num_frames > 0: + shots = self.detect_shots(video_frames) + for ci, frames in enumerate(shots): + c = doc.chunks.add() + c.doc_id = doc.doc_id + chunk_data = np.array(frames) + c.blob.CopyFrom(array2blob(chunk_data)) + c.offset = ci + c.weight = len(frames) / num_frames + else: + self.logger.error( + 'bad document: "raw_bytes" or "raw_video" is empty!') else: - self.logger.error('bad document: "raw_bytes" is empty!') + self.logger.error('bad document: "raw_data" is empty!') + + if self.drop_raw_data: + self.logger.info("document raw data will be cleaned!") + doc.ClearField('raw_data')