Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
refactor(shotdetector): use updated ffmpeg api to capture frames from…
Browse files Browse the repository at this point in the history
… videos
  • Loading branch information
felix committed Aug 23, 2019
1 parent e2fa073 commit 4497d76
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions gnes/preprocessor/video/shotdetect.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import numpy as np
from typing import List

from ..base import BaseVideoPreprocessor
from ..helper import get_video_frames, compute_descriptor, compare_descriptor, detect_peak_boundary, compare_ecr
from ...proto import gnes_pb2, array2blob
from ..io_utils import video as video_util
from ..helper import compute_descriptor, compare_descriptor, detect_peak_boundary, compare_ecr


class ShotDetectPreprocessor(BaseVideoPreprocessor):
store_args_kwargs = True

def __init__(self,
frame_size: str = '192*168',
frame_size: str = '192:168',
descriptor: str = 'block_hsv_histogram',
distance_metric: str = 'bhattacharya',
detect_method: str = 'threshold',
frame_rate: str = '10',
frame_rate: int = 10,
*args,
**kwargs):
super().__init__(*args, **kwargs)
Expand All @@ -41,13 +41,8 @@ def __init__(self,
self.frame_rate = frame_rate
self._detector_kwargs = kwargs

def detect_from_bytes(self, raw_bytes: bytes) -> (List[List['np.ndarray']], int):
frames = get_video_frames(
raw_bytes,
s=self.frame_size,
vsync='vfr',
r=self.frame_rate)

def detect_shots(self,
frames: List['np.ndarray']) -> List[List['np.ndarray']]:
descriptors = []
for frame in frames:
descriptor = compute_descriptor(
Expand All @@ -63,25 +58,30 @@ def detect_from_bytes(self, raw_bytes: bytes) -> (List[List['np.ndarray']], int)
for pair in zip(descriptors[:-1], descriptors[1:])
]

shots = detect_peak_boundary(dists, self.detect_method)
shot_bounds = detect_peak_boundary(dists, self.detect_method)

shot_frames = []
for ci in range(0, len(shots) - 1):
shot_frames.append(frames[shots[ci]:shots[ci+1]])
shots = []
for ci in range(0, len(shot_bounds) - 1):
shots.append(frames[shot_bounds[ci]:shot_bounds[ci + 1]])

return shot_frames, len(frames)
return shots

def apply(self, doc: 'gnes_pb2.Document') -> None:
super().apply(doc)
#from sklearn.cluster import KMeans

if doc.raw_bytes:
shot_frames, num_frames = self.detect_from_bytes(doc.raw_bytes)
all_frames = video_util.capture_frames(
video_data=doc.raw_bytes,
scale=self.frame_size,
fps=self.frame_rate)
num_frames = len(all_frames)
assert num_frames > 0
shots = self.detect_shots(all_frames)

for ci, value in enumerate(shot_frames):
for ci, frames in enumerate(shots):
c = doc.chunks.add()
c.doc_id = doc.doc_id
chunk = np.array(value).astype('uint8')
chunk_data = np.concatenate(frames, axis=0)
c.blob.CopyFrom(array2blob(chunk))
c.offset_1d = ci
c.weight = len(value) / num_frames
Expand Down

0 comments on commit 4497d76

Please sign in to comment.