Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
refactor(shot-detector): merge code from hub
Browse files Browse the repository at this point in the history
  • Loading branch information
felix committed Sep 16, 2019
1 parent 981085a commit de5b336
Showing 1 changed file with 42 additions and 26 deletions.
68 changes: 42 additions & 26 deletions gnes/preprocessor/video/shotdetect.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,35 +13,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

import numpy as np
from typing import List

from ..base import BaseVideoPreprocessor
from ..helper import compute_descriptor, compare_descriptor, detect_peak_boundary, compare_ecr
from ..io_utils import video as video_util
from ...proto import gnes_pb2, array2blob
from gnes.preprocessor.base import BaseVideoPreprocessor
from gnes.proto import gnes_pb2, array2blob, blob2array
from gnes.preprocessor.io_utils import video
from gnes.preprocessor.helper import compute_descriptor, compare_descriptor, detect_peak_boundary, compare_ecr


class ShotDetectPreprocessor(BaseVideoPreprocessor):
store_args_kwargs = True

def __init__(self,
frame_size: str = '192:168',
scale: str = None,
descriptor: str = 'block_hsv_histogram',
distance_metric: str = 'bhattacharya',
detect_method: str = 'threshold',
frame_rate: int = 10,
frame_num: int = -1,
drop_raw_data: bool = False,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.frame_size = frame_size
self.scale = scale
self.descriptor = descriptor
self.distance_metric = distance_metric
self.detect_method = detect_method
self.frame_rate = frame_rate
self.frame_num = frame_num
self.drop_raw_data = drop_raw_data
self._detector_kwargs = kwargs

def detect_shots(self, frames: 'np.ndarray') -> List[List['np.ndarray']]:
Expand Down Expand Up @@ -71,23 +72,38 @@ def detect_shots(self, frames: 'np.ndarray') -> List[List['np.ndarray']]:
def apply(self, doc: 'gnes_pb2.Document') -> None:
super().apply(doc)

if doc.raw_bytes:
all_frames = video_util.capture_frames(
input_data=doc.raw_bytes,
scale=self.frame_size,
fps=self.frame_rate,
vframes=self.frame_num)
num_frames = len(all_frames)
assert num_frames > 0
shots = self.detect_shots(all_frames)
video_frames = []

for ci, frames in enumerate(shots):
c = doc.chunks.add()
c.doc_id = doc.doc_id
# chunk_data = np.concatenate(frames, axis=0)
chunk_data = np.array(frames)
c.blob.CopyFrom(array2blob(chunk_data))
c.offset = ci
c.weight = len(frames) / num_frames
if doc.WhichOneof('raw_data'):
raw_type = type(getattr(doc, doc.WhichOneof('raw_data')))
if doc.raw_bytes:
video_frames = video.capture_frames(
input_data=doc.raw_bytes,
scale=self.scale,
fps=self.frame_rate,
vframes=self.frame_num)
elif raw_type == gnes_pb2.NdArray:
video_frames = blob2array(doc.raw_video)
if self.frame_num > 0:
stepwise = len(video_frames) / self.frame_num
video_frames = video_frames[0::stepwise, :]

num_frames = len(video_frames)
if num_frames > 0:
shots = self.detect_shots(video_frames)
for ci, frames in enumerate(shots):
c = doc.chunks.add()
c.doc_id = doc.doc_id
chunk_data = np.array(frames)
c.blob.CopyFrom(array2blob(chunk_data))
c.offset = ci
c.weight = len(frames) / num_frames
else:
self.logger.error(
'bad document: "raw_bytes" or "raw_video" is empty!')
else:
self.logger.error('bad document: "raw_bytes" is empty!')
self.logger.error('bad document: "raw_data" is empty!')

if self.drop_raw_data:
self.logger.info("document raw data will be cleaned!")
doc.ClearField('raw_data')

0 comments on commit de5b336

Please sign in to comment.