Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
Merge pull request #6 from gnes-ai/video-shot-detect
Browse files Browse the repository at this point in the history
feat(preprocessor): add video shot boundary detector
  • Loading branch information
numb3r3 authored Jul 12, 2019
2 parents d1680e8 + 38fff78 commit 72a8bd9
Show file tree
Hide file tree
Showing 9 changed files with 355 additions and 74 deletions.
3 changes: 3 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ Copyright (c) 2014-2019 Anthon van der Neut, Ruamel bvba
6. jieba 0.39
Copyright (c) 2013 Sun Junyi

7. opencv-python 4.0.0
Copyright (c) 2016-2018 Olli-Pekka Heinisuo and contributors




Expand Down
2 changes: 1 addition & 1 deletion gnes/preprocessor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
'BaseSingletonPreprocessor': 'base',
'BaseVideoPreprocessor': 'video.base',
'FFmpegPreprocessor': 'video.ffmpeg',

'ShotDetectPreprocessor': 'video.shotdetect',
}

register_all_class(_cls2file_map)
159 changes: 159 additions & 0 deletions gnes/preprocessor/helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Tencent is pleased to support the open source community by making GNES available.
#
# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=low-comment-ratio

import io
import subprocess as sp
from typing import List, Callable

import cv2
import numpy as np
from PIL import Image
import imagehash


def get_video_frames(buffer_data: bytes, image_format: str = "cv2",
**kwargs) -> List["np.ndarray"]:
ffmpeg_cmd = ['ffmpeg', '-i', '-', '-f', 'image2pipe']

# example k,v pair:
# (-s, 420*360)
# (-vsync, vfr)
# (-vf, select=eq(pict_type\,I))
for k, v in kwargs.items():
ffmpeg_cmd.append('-' + k)
ffmpeg_cmd.append(v)

# (-c:v, png) output bytes in png format
# (-an, -sn) disable audio processing
# (-) output to stdout pipeline
ffmpeg_cmd += ['-c:v', 'png', '-an', '-sn', '-']

with sp.Popen(
ffmpeg_cmd, stdin=sp.PIPE, stdout=sp.PIPE, bufsize=-1,
shell=False) as pipe:
stream, _ = pipe.communicate(buffer_data)

# raw bytes for multiple PNGs.
# split by PNG EOF b'\x89PNG'
stream = stream.split(b'\x89PNG')

if len(stream) <= 1:
return []

# reformulate the full pngs for feature processings.
if image_format == 'pil':
frames = [
Image.open(io.BytesIO(b'\x89PNG' + _)) for _ in stream[1:]
]
elif image_format == 'cv2':
frames = [
cv2.imdecode(np.frombuffer(b'\x89PNG' + _, np.uint8), 1)
for _ in stream[1:]
]
else:
raise NotImplementedError

return frames


def block_descriptor(image: "np.ndarray",
descriptor_fn: Callable,
num_blocks: int = 3) -> "np.ndarray":
h, w, _ = image.shape # find shape of image and channel
block_h = int(np.ceil(h / num_blocks))
block_w = int(np.ceil(w / num_blocks))

descriptors = []
for i in range(0, h, block_h):
for j in range(0, w, block_w):
block = image[i:i + block_h, j:j + block_w]
descriptors.extend(descriptor_fn(block))

return np.array(descriptors)


def pyramid_descriptor(image: "np.ndarray",
descriptor_fn: Callable,
max_level: int = 2) -> "np.ndarray":
descriptors = []
for level in range(max_level + 1):
num_blocks = 2**level
descriptors.extend(block_descriptor(image, descriptor_fn, num_blocks))

return np.array(descriptors)


def rgb_histogram(image: "np.ndarray") -> "np.ndarray":
_, _, c = image.shape
hist = [
cv2.calcHist([image], [i], None, [256], [0, 256]) for i in range(c)
]
# normalize hist
hist = np.array([h / np.sum(h) for h in hist]).flatten()
return hist


def hsv_histogram(image: "np.ndarray") -> "np.ndarray":
_, _, c = image.shape
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)

# sizes = [180, 256, 256]
# ranges = [(0, 180), (0, 256), (0, 256)]

# hist = [
# cv2.calcHist([hsv], [i], None, [sizes[i]], ranges[i]) for i in range(c)
# ]

hist = [cv2.calcHist([hsv], [i], None, [256], [0, 256]) for i in range(c)]
# normalize hist
hist = np.array([h / np.sum(h) for h in hist]).flatten()
return hist


def phash_descriptor(image: "np.ndarray") -> "imagehash.ImageHash":
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
return imagehash.phash(image)


def compute_descriptor(image: "np.ndarray",
method: str = "rgb_histogram",
**kwargs) -> "np.array":
funcs = {
'rgb_histogram': rgb_histogram,
'hsv_histogram': hsv_histogram,
'block_rgb_histogram': lambda image: block_descriptor(image, rgb_histogram, kwargs.get("num_blocks", 3)),
'block_hsv_histogram': lambda image: block_descriptor(image, hsv_histogram, kwargs.get("num_blocks", 3)),
'pyramid_rgb_histogram': lambda image: pyramid_descriptor(image, rgb_histogram, kwargs.get("max_level", 2)),
'pyramid_hsv_histogram': lambda image: pyramid_descriptor(image, hsv_histogram, kwargs.get("max_level", 2)),
}
return funcs[method](image)


def compare_descriptor(descriptor1: "np.ndarray",
descriptor2: "np.ndarray",
metric: str = "chisqr") -> float:
dist_metric = {
"correlation": cv2.HISTCMP_CORREL,
"chisqr": cv2.HISTCMP_CHISQR,
"chisqr_alt": cv2.HISTCMP_CHISQR_ALT,
"intersection": cv2.HISTCMP_INTERSECT,
"bhattacharya": cv2.HISTCMP_BHATTACHARYYA,
"hellinguer": cv2.HISTCMP_HELLINGER,
"kl_div": cv2.HISTCMP_KL_DIV
}

return cv2.compareHist(descriptor1, descriptor2, dist_metric[metric])
112 changes: 43 additions & 69 deletions gnes/preprocessor/video/ffmpeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,109 +12,83 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import subprocess as sp
from typing import List

from typing import List
import numpy as np
from PIL import Image

from .base import BaseVideoPreprocessor
from ...proto import gnes_pb2, array2blob
from ..helper import get_video_frames, phash_descriptor


class FFmpegPreprocessor(BaseVideoPreprocessor):

def __init__(self,
frame_size: str = "192*168",
duplicate_rm: bool = True,
use_phash_weight: bool = False,
phash_thresh: int = 5,
*args, **kwargs):
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.frame_size = frame_size
self.phash_thresh = phash_thresh
self.duplicate_rm = duplicate_rm
self.use_phash_weight = use_phash_weight
# (-i, -) input from stdin pipeline
# (-f, image2pipe) output format is image pipeline
self.cmd = ['ffmpeg',
'-i', '-',
'-f', 'image2pipe']

# example k,v pair:
# (-s, 420*360)
# (-vsync, vfr)
# (-vf, select=eq(pict_type\,I))
for k, v in kwargs.items():
self.cmd.append('-' + k)
self.cmd.append(v)

# (-c:v, png) output bytes in png format
# (-) output to stdout pipeline
self.cmd += ['-c:v', 'png', '-']

self._ffmpeg_kwargs = kwargs

def apply(self, doc: 'gnes_pb2.Document') -> None:
super().apply(doc)

# video could't be processed from ndarray!
# only bytes can be passed into ffmpeg pipeline
if doc.raw_bytes:
pipe = sp.Popen(self.cmd, stdin=sp.PIPE, stdout=sp.PIPE, bufsize=-1)
stream, _ = pipe.communicate(doc.raw_bytes)

# raw bytes for multiple PNGs.
# split by PNG EOF b'\x89PNG'
stream = stream.split(b'\x89PNG')
if len(stream) <= 1:
self.logger.info('no image extracted from video!')
frames = get_video_frames(
doc.raw_bytes,
s=self.frame_size,
vsync=self._ffmpeg_kwargs.get("vsync", "vfr"),
vf=self._ffmpeg_kwargs.get("vf", "select=eq(pict_type\\,I)"))

# remove dupliated key frames by phash value
if self.duplicate_rm:
frames = self.duplicate_rm_hash(frames)

if self.use_phash_weight:
weight = FFmpegPreprocessor.pic_weight(frames)
else:
# reformulate the full pngs for feature processings.
stream = [b'\x89PNG' + _ for _ in stream[1:]]

# remove dupliated key frames by phash value
if self.duplicate_rm:
stream = self.duplicate_rm_hash(stream)

stream = [np.array(Image.open(io.BytesIO(chunk)), dtype=np.uint8)
for chunk in stream]

if self.use_phash_weight:
weight = FFmpegPreprocessor.pic_weight(stream)
else:
weight = [1 / len(stream)] * len(stream)

for ci, chunk in enumerate(stream):
c = doc.chunks.add()
c.doc_id = doc.doc_id
c.blob.CopyFrom(array2blob(chunk))
c.offset_1d = ci
c.weight = weight[ci]
# close the stdout stream
pipe.stdout.close()
weight = [1 / len(frames)] * len(frames)

for ci, chunk in enumerate(frames):
c = doc.chunks.add()
c.doc_id = doc.doc_id
c.blob.CopyFrom(array2blob(chunk))
c.offset_1d = ci
c.weight = weight[ci]

else:
self.logger.error('bad document: "raw_bytes" is empty!')

@staticmethod
def phash(image_bytes: bytes):
import imagehash
return imagehash.phash(Image.open(io.BytesIO(image_bytes)))

@staticmethod
def pic_weight(image_array: List[np.ndarray]) -> List[float]:
weight = np.zeros([len(image_array)])
def pic_weight(images: List['np.ndarray']) -> List[float]:
import cv2
weight = np.zeros([len(images)])
# n_channel is usually 3 for RGB images
n_channel = image_array[0].shape[-1]
for i in range(len(image_array)):
# calcualte the variance of histgram of pixels
weight[i] = sum([np.histogram(image_array[i][:, :, _])[0].var()
for _ in range(n_channel)])
n_channel = images[0].shape[-1]
for i, image in enumerate(images):
weight[i] = sum([
cv2.calcHist([image], [_], None, [256], [0, 256]).var()
for _ in range(n_channel)
])
weight = weight / weight.sum()

# normalized result
weight = np.exp(- weight * 10)
weight = np.exp(-weight * 10)
return weight / weight.sum()

def duplicate_rm_hash(self, image_list: List[bytes]) -> List[bytes]:
hash_list = [FFmpegPreprocessor.phash(_) for _ in image_list]
def duplicate_rm_hash(self,
images: List['np.ndarray']) -> List['np.ndarray']:
hash_list = [phash_descriptor(_) for _ in images]
ret = []
for i, h in enumerate(hash_list):
flag = 1
Expand All @@ -129,4 +103,4 @@ def duplicate_rm_hash(self, image_list: List[bytes]) -> List[bytes]:
if flag:
ret.append((i, h))

return [image_list[_[0]] for _ in ret]
return [images[_[0]] for _ in ret]
Loading

0 comments on commit 72a8bd9

Please sign in to comment.