diff --git a/gnes/preprocessor/video/shotdetect.py b/gnes/preprocessor/video/shotdetect.py index 2563d254..7d1bbfd4 100644 --- a/gnes/preprocessor/video/shotdetect.py +++ b/gnes/preprocessor/video/shotdetect.py @@ -13,24 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. - import numpy as np from typing import List from ..base import BaseVideoPreprocessor -from ..helper import get_video_frames, compute_descriptor, compare_descriptor, detect_peak_boundary, compare_ecr from ...proto import gnes_pb2, array2blob +from ..io_utils import video as video_util +from ..helper import compute_descriptor, compare_descriptor, detect_peak_boundary, compare_ecr class ShotDetectPreprocessor(BaseVideoPreprocessor): store_args_kwargs = True def __init__(self, - frame_size: str = '192*168', + frame_size: str = '192:168', descriptor: str = 'block_hsv_histogram', distance_metric: str = 'bhattacharya', detect_method: str = 'threshold', - frame_rate: str = '10', + frame_rate: int = 10, *args, **kwargs): super().__init__(*args, **kwargs) @@ -41,13 +41,8 @@ def __init__(self, self.frame_rate = frame_rate self._detector_kwargs = kwargs - def detect_from_bytes(self, raw_bytes: bytes) -> (List[List['np.ndarray']], int): - frames = get_video_frames( - raw_bytes, - s=self.frame_size, - vsync='vfr', - r=self.frame_rate) - + def detect_shots(self, + frames: List['np.ndarray']) -> List[List['np.ndarray']]: descriptors = [] for frame in frames: descriptor = compute_descriptor( @@ -63,25 +58,30 @@ def detect_from_bytes(self, raw_bytes: bytes) -> (List[List['np.ndarray']], int) for pair in zip(descriptors[:-1], descriptors[1:]) ] - shots = detect_peak_boundary(dists, self.detect_method) + shot_bounds = detect_peak_boundary(dists, self.detect_method) - shot_frames = [] - for ci in range(0, len(shots) - 1): - shot_frames.append(frames[shots[ci]:shots[ci+1]]) + shots = [] + for ci in range(0, len(shot_bounds) - 1): + shots.append(frames[shot_bounds[ci]:shot_bounds[ci + 1]]) - return shot_frames, len(frames) + return shots def apply(self, doc: 'gnes_pb2.Document') -> None: super().apply(doc) - #from sklearn.cluster import KMeans if doc.raw_bytes: - shot_frames, num_frames = self.detect_from_bytes(doc.raw_bytes) + all_frames = video_util.capture_frames( + video_data=doc.raw_bytes, + scale=self.frame_size, + fps=self.frame_rate) + num_frames = len(all_frames) + assert num_frames > 0 + shots = self.detect_shots(all_frames) - for ci, value in enumerate(shot_frames): + for ci, frames in enumerate(shots): c = doc.chunks.add() c.doc_id = doc.doc_id - chunk = np.array(value).astype('uint8') + chunk_data = np.concatenate(frames, axis=0) c.blob.CopyFrom(array2blob(chunk)) c.offset_1d = ci c.weight = len(value) / num_frames