diff --git a/LICENSE b/LICENSE index f0c4b1df..50b561b5 100644 --- a/LICENSE +++ b/LICENSE @@ -98,6 +98,9 @@ Copyright (c) 2014-2019 Anthon van der Neut, Ruamel bvba 6. jieba 0.39 Copyright (c) 2013 Sun Junyi +7. opencv-python 4.0.0 +Copyright (c) 2016-2018 Olli-Pekka Heinisuo and contributors + diff --git a/gnes/preprocessor/__init__.py b/gnes/preprocessor/__init__.py index 01734028..7dafc751 100644 --- a/gnes/preprocessor/__init__.py +++ b/gnes/preprocessor/__init__.py @@ -29,7 +29,7 @@ 'BaseSingletonPreprocessor': 'base', 'BaseVideoPreprocessor': 'video.base', 'FFmpegPreprocessor': 'video.ffmpeg', - + 'ShotDetectPreprocessor': 'video.shotdetect', } register_all_class(_cls2file_map) diff --git a/gnes/preprocessor/helper.py b/gnes/preprocessor/helper.py new file mode 100644 index 00000000..aa0a747d --- /dev/null +++ b/gnes/preprocessor/helper.py @@ -0,0 +1,159 @@ +# Tencent is pleased to support the open source community by making GNES available. +# +# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=low-comment-ratio + +import io +import subprocess as sp +from typing import List, Callable + +import cv2 +import numpy as np +from PIL import Image +import imagehash + + +def get_video_frames(buffer_data: bytes, image_format: str = "cv2", + **kwargs) -> List["np.ndarray"]: + ffmpeg_cmd = ['ffmpeg', '-i', '-', '-f', 'image2pipe'] + + # example k,v pair: + # (-s, 420*360) + # (-vsync, vfr) + # (-vf, select=eq(pict_type\,I)) + for k, v in kwargs.items(): + ffmpeg_cmd.append('-' + k) + ffmpeg_cmd.append(v) + + # (-c:v, png) output bytes in png format + # (-an, -sn) disable audio processing + # (-) output to stdout pipeline + ffmpeg_cmd += ['-c:v', 'png', '-an', '-sn', '-'] + + with sp.Popen( + ffmpeg_cmd, stdin=sp.PIPE, stdout=sp.PIPE, bufsize=-1, + shell=False) as pipe: + stream, _ = pipe.communicate(buffer_data) + + # raw bytes for multiple PNGs. + # split by PNG EOF b'\x89PNG' + stream = stream.split(b'\x89PNG') + + if len(stream) <= 1: + return [] + + # reformulate the full pngs for feature processings. + if image_format == 'pil': + frames = [ + Image.open(io.BytesIO(b'\x89PNG' + _)) for _ in stream[1:] + ] + elif image_format == 'cv2': + frames = [ + cv2.imdecode(np.frombuffer(b'\x89PNG' + _, np.uint8), 1) + for _ in stream[1:] + ] + else: + raise NotImplementedError + + return frames + + +def block_descriptor(image: "np.ndarray", + descriptor_fn: Callable, + num_blocks: int = 3) -> "np.ndarray": + h, w, _ = image.shape # find shape of image and channel + block_h = int(np.ceil(h / num_blocks)) + block_w = int(np.ceil(w / num_blocks)) + + descriptors = [] + for i in range(0, h, block_h): + for j in range(0, w, block_w): + block = image[i:i + block_h, j:j + block_w] + descriptors.extend(descriptor_fn(block)) + + return np.array(descriptors) + + +def pyramid_descriptor(image: "np.ndarray", + descriptor_fn: Callable, + max_level: int = 2) -> "np.ndarray": + descriptors = [] + for level in range(max_level + 1): + num_blocks = 2**level + descriptors.extend(block_descriptor(image, descriptor_fn, num_blocks)) + + return np.array(descriptors) + + +def rgb_histogram(image: "np.ndarray") -> "np.ndarray": + _, _, c = image.shape + hist = [ + cv2.calcHist([image], [i], None, [256], [0, 256]) for i in range(c) + ] + # normalize hist + hist = np.array([h / np.sum(h) for h in hist]).flatten() + return hist + + +def hsv_histogram(image: "np.ndarray") -> "np.ndarray": + _, _, c = image.shape + hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) + + # sizes = [180, 256, 256] + # ranges = [(0, 180), (0, 256), (0, 256)] + + # hist = [ + # cv2.calcHist([hsv], [i], None, [sizes[i]], ranges[i]) for i in range(c) + # ] + + hist = [cv2.calcHist([hsv], [i], None, [256], [0, 256]) for i in range(c)] + # normalize hist + hist = np.array([h / np.sum(h) for h in hist]).flatten() + return hist + + +def phash_descriptor(image: "np.ndarray") -> "imagehash.ImageHash": + image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + return imagehash.phash(image) + + +def compute_descriptor(image: "np.ndarray", + method: str = "rgb_histogram", + **kwargs) -> "np.array": + funcs = { + 'rgb_histogram': rgb_histogram, + 'hsv_histogram': hsv_histogram, + 'block_rgb_histogram': lambda image: block_descriptor(image, rgb_histogram, kwargs.get("num_blocks", 3)), + 'block_hsv_histogram': lambda image: block_descriptor(image, hsv_histogram, kwargs.get("num_blocks", 3)), + 'pyramid_rgb_histogram': lambda image: pyramid_descriptor(image, rgb_histogram, kwargs.get("max_level", 2)), + 'pyramid_hsv_histogram': lambda image: pyramid_descriptor(image, hsv_histogram, kwargs.get("max_level", 2)), + } + return funcs[method](image) + + +def compare_descriptor(descriptor1: "np.ndarray", + descriptor2: "np.ndarray", + metric: str = "chisqr") -> float: + dist_metric = { + "correlation": cv2.HISTCMP_CORREL, + "chisqr": cv2.HISTCMP_CHISQR, + "chisqr_alt": cv2.HISTCMP_CHISQR_ALT, + "intersection": cv2.HISTCMP_INTERSECT, + "bhattacharya": cv2.HISTCMP_BHATTACHARYYA, + "hellinguer": cv2.HISTCMP_HELLINGER, + "kl_div": cv2.HISTCMP_KL_DIV + } + + return cv2.compareHist(descriptor1, descriptor2, dist_metric[metric]) diff --git a/gnes/preprocessor/video/ffmpeg.py b/gnes/preprocessor/video/ffmpeg.py index 8acb4f9f..33d6c488 100644 --- a/gnes/preprocessor/video/ffmpeg.py +++ b/gnes/preprocessor/video/ffmpeg.py @@ -12,45 +12,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import io -import subprocess as sp -from typing import List +from typing import List import numpy as np -from PIL import Image from .base import BaseVideoPreprocessor from ...proto import gnes_pb2, array2blob +from ..helper import get_video_frames, phash_descriptor class FFmpegPreprocessor(BaseVideoPreprocessor): def __init__(self, + frame_size: str = "192*168", duplicate_rm: bool = True, use_phash_weight: bool = False, phash_thresh: int = 5, - *args, **kwargs): + *args, + **kwargs): super().__init__(*args, **kwargs) + self.frame_size = frame_size self.phash_thresh = phash_thresh self.duplicate_rm = duplicate_rm self.use_phash_weight = use_phash_weight - # (-i, -) input from stdin pipeline - # (-f, image2pipe) output format is image pipeline - self.cmd = ['ffmpeg', - '-i', '-', - '-f', 'image2pipe'] - - # example k,v pair: - # (-s, 420*360) - # (-vsync, vfr) - # (-vf, select=eq(pict_type\,I)) - for k, v in kwargs.items(): - self.cmd.append('-' + k) - self.cmd.append(v) - - # (-c:v, png) output bytes in png format - # (-) output to stdout pipeline - self.cmd += ['-c:v', 'png', '-'] + + self._ffmpeg_kwargs = kwargs def apply(self, doc: 'gnes_pb2.Document') -> None: super().apply(doc) @@ -58,63 +44,51 @@ def apply(self, doc: 'gnes_pb2.Document') -> None: # video could't be processed from ndarray! # only bytes can be passed into ffmpeg pipeline if doc.raw_bytes: - pipe = sp.Popen(self.cmd, stdin=sp.PIPE, stdout=sp.PIPE, bufsize=-1) - stream, _ = pipe.communicate(doc.raw_bytes) - - # raw bytes for multiple PNGs. - # split by PNG EOF b'\x89PNG' - stream = stream.split(b'\x89PNG') - if len(stream) <= 1: - self.logger.info('no image extracted from video!') + frames = get_video_frames( + doc.raw_bytes, + s=self.frame_size, + vsync=self._ffmpeg_kwargs.get("vsync", "vfr"), + vf=self._ffmpeg_kwargs.get("vf", "select=eq(pict_type\\,I)")) + + # remove dupliated key frames by phash value + if self.duplicate_rm: + frames = self.duplicate_rm_hash(frames) + + if self.use_phash_weight: + weight = FFmpegPreprocessor.pic_weight(frames) else: - # reformulate the full pngs for feature processings. - stream = [b'\x89PNG' + _ for _ in stream[1:]] - - # remove dupliated key frames by phash value - if self.duplicate_rm: - stream = self.duplicate_rm_hash(stream) - - stream = [np.array(Image.open(io.BytesIO(chunk)), dtype=np.uint8) - for chunk in stream] - - if self.use_phash_weight: - weight = FFmpegPreprocessor.pic_weight(stream) - else: - weight = [1 / len(stream)] * len(stream) - - for ci, chunk in enumerate(stream): - c = doc.chunks.add() - c.doc_id = doc.doc_id - c.blob.CopyFrom(array2blob(chunk)) - c.offset_1d = ci - c.weight = weight[ci] - # close the stdout stream - pipe.stdout.close() + weight = [1 / len(frames)] * len(frames) + + for ci, chunk in enumerate(frames): + c = doc.chunks.add() + c.doc_id = doc.doc_id + c.blob.CopyFrom(array2blob(chunk)) + c.offset_1d = ci + c.weight = weight[ci] + else: self.logger.error('bad document: "raw_bytes" is empty!') @staticmethod - def phash(image_bytes: bytes): - import imagehash - return imagehash.phash(Image.open(io.BytesIO(image_bytes))) - - @staticmethod - def pic_weight(image_array: List[np.ndarray]) -> List[float]: - weight = np.zeros([len(image_array)]) + def pic_weight(images: List['np.ndarray']) -> List[float]: + import cv2 + weight = np.zeros([len(images)]) # n_channel is usually 3 for RGB images - n_channel = image_array[0].shape[-1] - for i in range(len(image_array)): - # calcualte the variance of histgram of pixels - weight[i] = sum([np.histogram(image_array[i][:, :, _])[0].var() - for _ in range(n_channel)]) + n_channel = images[0].shape[-1] + for i, image in enumerate(images): + weight[i] = sum([ + cv2.calcHist([image], [_], None, [256], [0, 256]).var() + for _ in range(n_channel) + ]) weight = weight / weight.sum() # normalized result - weight = np.exp(- weight * 10) + weight = np.exp(-weight * 10) return weight / weight.sum() - def duplicate_rm_hash(self, image_list: List[bytes]) -> List[bytes]: - hash_list = [FFmpegPreprocessor.phash(_) for _ in image_list] + def duplicate_rm_hash(self, + images: List['np.ndarray']) -> List['np.ndarray']: + hash_list = [phash_descriptor(_) for _ in images] ret = [] for i, h in enumerate(hash_list): flag = 1 @@ -129,4 +103,4 @@ def duplicate_rm_hash(self, image_list: List[bytes]) -> List[bytes]: if flag: ret.append((i, h)) - return [image_list[_[0]] for _ in ret] + return [images[_[0]] for _ in ret] diff --git a/gnes/preprocessor/video/shotdetect.py b/gnes/preprocessor/video/shotdetect.py new file mode 100644 index 00000000..29a86f58 --- /dev/null +++ b/gnes/preprocessor/video/shotdetect.py @@ -0,0 +1,89 @@ +# Tencent is pleased to support the open source community by making GNES available. +# +# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=low-comment-ratio + +import numpy as np +from .base import BaseVideoPreprocessor +from ...proto import gnes_pb2, array2blob +from ..helper import get_video_frames, compute_descriptor, compare_descriptor + + +class ShotDetectPreprocessor(BaseVideoPreprocessor): + store_args_kwargs = True + + def __init__(self, + frame_size: str = "192*168", + descriptor: str = "block_hsv_histogram", + distance_metric: str = "bhattacharya", + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.frame_size = frame_size + self.descriptor = descriptor + self.distance_metric = distance_metric + self._detector_kwargs = kwargs + + def apply(self, doc: 'gnes_pb2.Document') -> None: + super().apply(doc) + from sklearn.cluster import KMeans + + if doc.raw_bytes: + # stream_data = io.BytesIO(doc.raw_bytes) + # vidcap = cv2.VideoCapture(stream_data) + frames = get_video_frames( + doc.raw_bytes, + s=self.frame_size, + vsync="vfr", + vf='select=eq(pict_type\\,I)') + + descriptors = [] + shots = [] + for frame in frames: + descriptor = compute_descriptor( + frame, method=self.descriptor, **self._detector_kwargs) + descriptors.append(descriptor) + + # compute distances between frames + dists = [ + compare_descriptor(pair[0], pair[1], self.distance_metric) + for pair in zip(descriptors[:-1], descriptors[1:]) + ] + + dists = np.array(dists).reshape([-1, 1]) + clt = KMeans(n_clusters=2) + clt.fit(dists) + + #select which cluster includes shot frames + big_center = np.argmax(clt.cluster_centers_) + + shots = [] + prev_shot = 0 + for i in range(0, len(clt.labels_)): + if big_center == clt.labels_[i]: + shots.append((prev_shot, i + 2)) + prev_shot = i + 2 + + for ci, (start, end) in enumerate(shots): + c = doc.chunks.add() + c.doc_id = doc.doc_id + chunk_pos = start + (end - start) // 2 + chunk = frames[chunk_pos] + c.blob.CopyFrom(array2blob(chunk)) + c.offset_1d = ci + c.weight = (end - start) / len(frames) + + else: + self.logger.error('bad document: "raw_bytes" is empty!') diff --git a/setup.py b/setup.py index f0e6f147..6df1718a 100644 --- a/setup.py +++ b/setup.py @@ -49,7 +49,7 @@ annoy_dep = ['annoy==1.15.2'] chinese_dep = ['jieba'] cn_nlp_dep = list(set(chinese_dep + nlp_dep)) -vision_dep = ['torchvision==0.3.0', 'imagehash>=4.0'] +vision_dep = ['opencv-python>=4.0.0', 'torchvision==0.3.0', 'imagehash>=4.0'] leveldb_dep = ['plyvel>=1.0.5'] test_dep = ['pylint', 'memory_profiler>=0.55.0', 'psutil>=5.6.1', 'gputil>=1.4.0'] all_dep = list(set(base_dep + cn_nlp_dep + vision_dep + leveldb_dep + test_dep + annoy_dep)) diff --git a/tests/test_video_shotdetect_preprocessor.py b/tests/test_video_shotdetect_preprocessor.py new file mode 100644 index 00000000..04fc7476 --- /dev/null +++ b/tests/test_video_shotdetect_preprocessor.py @@ -0,0 +1,45 @@ +import os +import unittest + +from gnes.cli.parser import set_preprocessor_service_parser, _set_client_parser +from gnes.proto import gnes_pb2, RequestGenerator, blob2array +from gnes.service.grpc import ZmqClient +from gnes.service.preprocessor import PreprocessorService + + +class TestShotDetector(unittest.TestCase): + + def setUp(self): + self.dirname = os.path.dirname(__file__) + self.yml_path = os.path.join(self.dirname, 'yaml', 'preprocessor-shotdetect.yml') + self.video_path = os.path.join(self.dirname, 'videos') + + def test_video_preprocessor_service_empty(self): + args = set_preprocessor_service_parser().parse_args([ + '--yaml_path', self.yml_path + ]) + with PreprocessorService(args): + pass + + def test_video_preprocessor_service_realdata(self): + args = set_preprocessor_service_parser().parse_args([ + '--yaml_path', self.yml_path + ]) + c_args = _set_client_parser().parse_args([ + '--port_in', str(args.port_out), + '--port_out', str(args.port_in) + ]) + video_bytes = [open(os.path.join(self.video_path, _), 'rb').read() + for _ in os.listdir(self.video_path)] + + with PreprocessorService(args), ZmqClient(c_args) as client: + for req in RequestGenerator.index(video_bytes): + msg = gnes_pb2.Message() + msg.request.index.CopyFrom(req.index) + client.send_message(msg) + r = client.recv_message() + for d in r.request.index.docs: + self.assertGreater(len(d.chunks), 0) + for _ in range(len(d.chunks)): + shape = blob2array(d.chunks[_].blob).shape + self.assertEqual(shape, (168, 192, 3)) diff --git a/tests/yaml/preprocessor-ffmpeg.yml b/tests/yaml/preprocessor-ffmpeg.yml index 1ccce77b..610820ea 100644 --- a/tests/yaml/preprocessor-ffmpeg.yml +++ b/tests/yaml/preprocessor-ffmpeg.yml @@ -1,10 +1,12 @@ !FFmpegPreprocessor parameter: + frame_size: "192*168" duplicate_rm: True use_phash_weight: False phash_thresh: 5 - s: "192*168" - vsync: vfr - vf: select=eq(pict_type\,I) + + kwargs: + vsync: vfr + vf: select=eq(pict_type\,I) property: is_trained: true \ No newline at end of file diff --git a/tests/yaml/preprocessor-shotdetect.yml b/tests/yaml/preprocessor-shotdetect.yml new file mode 100644 index 00000000..3079cfce --- /dev/null +++ b/tests/yaml/preprocessor-shotdetect.yml @@ -0,0 +1,9 @@ +!ShotDetectPreprocessor +parameter: + descriptor: "block_hsv_histogram" + distance_metric: "bhattacharya" + frame_size: "192*168" + kwargs: + num_blocks: 3 +property: + is_trained: true \ No newline at end of file