From 7031fe20e9583720fb2f4b1b930f029e50f55303 Mon Sep 17 00:00:00 2001 From: Larry Yan Date: Wed, 7 Aug 2019 13:45:36 +0800 Subject: [PATCH] fix(preprocessor): add random sampling to ffmpeg --- gnes/preprocessor/video/ffmpeg.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/gnes/preprocessor/video/ffmpeg.py b/gnes/preprocessor/video/ffmpeg.py index 66b19e59..8288adfb 100644 --- a/gnes/preprocessor/video/ffmpeg.py +++ b/gnes/preprocessor/video/ffmpeg.py @@ -16,7 +16,7 @@ from typing import List import numpy as np - +import random from .base import BaseVideoPreprocessor from ..helper import get_video_frames, phash_descriptor from ...proto import gnes_pb2, array2blob @@ -107,18 +107,24 @@ def __init__(self, segment_method: str = 'cut_by_frame', segment_interval: int = -1, segment_num: int = 3, + max_frames_per_doc: int = -1, *args, **kwargs): super().__init__(*args, **kwargs) self.segment_method = segment_method self.segment_interval = segment_interval self.segment_num = segment_num + self.max_frames_per_doc = max_frames_per_doc self._ffmpeg_kwargs = kwargs def apply(self, doc: 'gnes_pb2.Document') -> None: super().apply(doc) if doc.raw_bytes: frames = get_video_frames(doc.raw_bytes, **self._ffmpeg_kwargs) + if self.max_frames_per_doc > 0: + random_id = random.sample(range(len(frames)), + k=min(self.max_frames_per_doc, len(frames))) + frames = [frames[i] for i in sorted(random_id)] sub_videos = [] if len(frames) >= 1: