Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
feat(video preprocessor): add edge detect for shotdetect
Browse files Browse the repository at this point in the history
  • Loading branch information
raccoonliukai committed Aug 21, 2019
1 parent 4448992 commit d420f34
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 49 deletions.
156 changes: 149 additions & 7 deletions gnes/preprocessor/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def get_video_frames(buffer_data: bytes, image_format: str = 'cv2',
# example k,v pair:
# (-s, 420*360)
# (-vsync, vfr)
# (-vf, select=eq(pict_type\,I))
# (-r, 10)
for k, v in kwargs.items():
if isinstance(v, (float, int)):
v = str(v)
Expand Down Expand Up @@ -240,7 +240,7 @@ def pyramid_descriptor(image: 'np.ndarray',
return np.array(descriptors)


def rgb_histogram(image: 'np.ndarray') -> 'np.ndarray':
def bgr_histogram(image: 'np.ndarray') -> 'np.ndarray':
_, _, c = image.shape
hist = [
cv2.calcHist([image], [i], None, [256], [0, 256]) for i in range(c)
Expand All @@ -252,7 +252,7 @@ def rgb_histogram(image: 'np.ndarray') -> 'np.ndarray':

def hsv_histogram(image: 'np.ndarray') -> 'np.ndarray':
_, _, c = image.shape
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)

# sizes = [180, 256, 256]
# ranges = [(0, 180), (0, 256), (0, 256)]
Expand All @@ -267,26 +267,85 @@ def hsv_histogram(image: 'np.ndarray') -> 'np.ndarray':
return hist


def edge_detect(image: 'np.ndarray', **kwargs) -> 'np.ndarray':
arg_dict = {
'compute_threshold': True,
'low_threshold': 0,
'high_threshold': 200,
'sigma': 0.5,
'gauss_kernel': (9, 9),
'L2': True
}
for k, v in kwargs.items():
if k in arg_dict.keys():
arg_dict[k] = v

image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# add threshold for canny
low_threshold = arg_dict['low_threshold']
high_threshold = arg_dict['high_threshold']
if arg_dict['compute_threshold']:
# apply automatic Canny edge detection using the computed median
v = np.median(image)
sigma = arg_dict['sigma']
low_threshold = ((1.0 - sigma) * v).astype("float32")
high_threshold = ((1.0 + sigma) * v).astype("float32")
tmp_image = cv2.GaussianBlur(image, arg_dict['gauss_kernel'], 1.2)
edge_image = cv2.Canny(tmp_image, low_threshold, high_threshold, L2gradient=arg_dict['L2'])
return edge_image


def phash_descriptor(image: 'np.ndarray'):
image = Image.fromarray(image)
import imagehash
return imagehash.phash(image)


def compute_descriptor(image: 'np.ndarray',
method: str = 'rgb_histogram',
method: str = 'bgr_histogram',
**kwargs) -> 'np.array':
funcs = {
'rgb_histogram': rgb_histogram,
'bgr_histogram': bgr_histogram,
'hsv_histogram': hsv_histogram,
'block_rgb_histogram': lambda image: block_descriptor(image, rgb_histogram, kwargs.get('num_blocks', 3)),
'edge_detect': lambda image: edge_detect(image, **kwargs),
'block_bgr_histogram': lambda image: block_descriptor(image, bgr_histogram, kwargs.get('num_blocks', 3)),
'block_hsv_histogram': lambda image: block_descriptor(image, hsv_histogram, kwargs.get('num_blocks', 3)),
'pyramid_rgb_histogram': lambda image: pyramid_descriptor(image, rgb_histogram, kwargs.get('max_level', 2)),
'pyramid_bgr_histogram': lambda image: pyramid_descriptor(image, bgr_histogram, kwargs.get('max_level', 2)),
'pyramid_hsv_histogram': lambda image: pyramid_descriptor(image, hsv_histogram, kwargs.get('max_level', 2)),
}
return funcs[method](image)


def compare_ecr(descriptors: List['np.ndarray'], dilate_rate: int = 5, neigh_avg: int = 2) -> List[float]:
""" Apply the Edge Change Ratio Algorithm"""
divd = lambda x, y: 0 if y == 0 else x / y

dicts = []
inv_dilate = []
sum_disc = []
for descriptor in descriptors:
sum_disc.append(np.sum(descriptor))
inv_dilate.append(255 - cv2.dilate(descriptor, np.ones((dilate_rate, dilate_rate))))

for i in range(1, len(descriptors)):
dict_0 = divd(float(np.sum(descriptors[i - 1] & inv_dilate[i])), float(sum_disc[i - 1]))
dict_1 = divd(float(np.sum(descriptors[i] & inv_dilate[i - 1])), float(sum_disc[i]))
tmp_dict = max(dict_0, dict_1)
if i > 10:
dict_0 = divd(float(np.sum(descriptors[i - 10] & inv_dilate[i])), float(sum_disc[i - 10]))
dict_1 = divd(float(np.sum(descriptors[i] & inv_dilate[i - 10])), float(sum_disc[i]))
tmp_dict *= (1 + max(dict_0, dict_1))
dicts.append(tmp_dict)

for _ in range(neigh_avg):
tmp_dict = []
for i in range(1, len(dicts) - 1):
tmp_dict.append(max(dicts[i - 1], dicts[i], dicts[i + 1]))
dicts = tmp_dict.copy()

return dicts


def compare_descriptor(descriptor1: 'np.ndarray',
descriptor2: 'np.ndarray',
metric: str = 'chisqr') -> float:
Expand All @@ -303,6 +362,89 @@ def compare_descriptor(descriptor1: 'np.ndarray',
return cv2.compareHist(descriptor1, descriptor2, dist_metric[metric])


def kmeans_algo(distances: List[float], **kwargs) -> List[int]:
from sklearn.cluster import KMeans
clt = KMeans(n_clusters=2)
clt.fit(distances)

num_frames = len(distances) + 1
# select which cluster includes shot frames
big_center = np.argmax(clt.cluster_centers_)

shots = []
shots.append(0)
for i in range(0, len(clt.labels_)):
if big_center == clt.labels_[i]:
shots.append((i + 2))
if shots[-1] < num_frames:
shots.append(num_frames)
else:
shots[-1] = num_frames
return shots


def check_motion(prev_dists: List[float], cur_dist: float, motion_threshold: float = 0.75):
""" Returns a boolean value to decide if the peak is due to a motion"""
close_peaks = 0
# We observe the a defined number of frames before the peak
for dist in prev_dists:
if dist > cur_dist * motion_threshold:
close_peaks += 1
if close_peaks >= len(prev_dists) / 2:
return True
else:
return False


def thre_algo(distances: List[float], **kwargs) -> List[int]:
# now threshold algo not support motion
kwargs['motion_step'] = 0
return motion_algo(distances, **kwargs)


def motion_algo(distances: List[float], **kwargs) -> List[int]:
import peakutils
""" Returns the list of peaks in the ECR serie"""
arg_dict = {
'threshold': 0.6,
'min_dist': 10,
'motion_step': 15
}
shots = []
for k, v in kwargs.items():
if k in arg_dict.keys():
arg_dict[k] = v
num_frames = len(distances) + 1
p = peakutils.indexes(np.array(distances).astype('float32'), thres=arg_dict['threshold'], min_dist=arg_dict['min_dist'])
shots.append(0)
shots.append(p[0] + 2)
for i in range(1, len(p)):
# We check that the peak is not due to a motion in the image
valid_dist = arg_dict['motion_step'] or not check_motion(distances[p[i]-arg_dict['motion_step']:p[i]], distances[p[i]])
if valid_dist:
shots.append(p[i] + 2)
if shots[-1] < num_frames - arg_dict['min_dist']:
shots.append(num_frames)
elif shots[-1] > num_frames:
shots[-1] = num_frames
return shots


def detect_video_shot(distances: List[float],
method: str = 'kmeans',
**kwargs) -> List[int]:
detect_method = {
'kmeans': kmeans_algo,
'threshold': thre_algo,
'motion': motion_algo
}

if method in detect_method.keys():
return detect_method[method](distances, **kwargs)
else:
logger.error("detect video shot by [%s] not implemented! Please use threshold, kmeans or motion!" % method)


def torch_transform(img):
import torchvision.transforms as transforms
return transforms.Compose([transforms.ToTensor(),
Expand Down
72 changes: 36 additions & 36 deletions gnes/preprocessor/video/shotdetect.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@


import numpy as np
from typing import List

from ..base import BaseVideoPreprocessor
from ..helper import get_video_frames, compute_descriptor, compare_descriptor
from ..helper import get_video_frames, compute_descriptor, compare_descriptor, detect_video_shot, compare_ecr
from ...proto import gnes_pb2, array2blob


Expand All @@ -28,62 +29,61 @@ def __init__(self,
frame_size: str = '192*168',
descriptor: str = 'block_hsv_histogram',
distance_metric: str = 'bhattacharya',
detect_method: str = 'threshold',
frame_rate: str = '10',
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.frame_size = frame_size
self.descriptor = descriptor
self.distance_metric = distance_metric
self.detect_method = detect_method
self.frame_rate = frame_rate
self._detector_kwargs = kwargs

def apply(self, doc: 'gnes_pb2.Document') -> None:
super().apply(doc)
from sklearn.cluster import KMeans

if doc.raw_bytes:
# stream_data = io.BytesIO(doc.raw_bytes)
# vidcap = cv2.VideoCapture(stream_data)
frames = get_video_frames(
doc.raw_bytes,
s=self.frame_size,
vsync='vfr',
vf='select=eq(pict_type\\,I)')
def detect_from_bytes(self, raw_bytes: bytes) -> (List[List['np.ndarray']], int):
frames = get_video_frames(
raw_bytes,
s=self.frame_size,
vsync='vfr',
r=self.frame_rate)

descriptors = []
shots = []
for frame in frames:
descriptor = compute_descriptor(
frame, method=self.descriptor, **self._detector_kwargs)
descriptors.append(descriptor)
descriptors = []
for frame in frames:
descriptor = compute_descriptor(
frame, method=self.descriptor, **self._detector_kwargs)
descriptors.append(descriptor)

# compute distances between frames
# compute distances between frames
if self.distance_metric == 'edgechangeration':
dists = compare_ecr(descriptors)
else:
dists = [
compare_descriptor(pair[0], pair[1], self.distance_metric)
for pair in zip(descriptors[:-1], descriptors[1:])
]

dists = np.array(dists).reshape([-1, 1])
clt = KMeans(n_clusters=2)
clt.fit(dists)
shots = detect_video_shot(dists, self.detect_method)

# select which cluster includes shot frames
big_center = np.argmax(clt.cluster_centers_)
shot_frames = []
for ci in range(0, len(shots) - 1):
shot_frames.append(frames[shots[ci]:shots[ci+1]])

shots = []
prev_shot = 0
for i in range(0, len(clt.labels_)):
if big_center == clt.labels_[i]:
shots.append((prev_shot, i + 2))
prev_shot = i + 2
return shot_frames, len(frames)

for ci, (start, end) in enumerate(shots):
def apply(self, doc: 'gnes_pb2.Document') -> None:
super().apply(doc)
#from sklearn.cluster import KMeans

if doc.raw_bytes:
shot_frames, num_frames = self.detect_from_bytes(doc.raw_bytes)

for ci in range(0, len(shot_frames)):
c = doc.chunks.add()
c.doc_id = doc.doc_id
chunk_pos = start + (end - start) // 2
chunk = frames[chunk_pos]
chunk = np.array(shot_frames[ci]).astype('uint8')
c.blob.CopyFrom(array2blob(chunk))
c.offset_1d = ci
c.weight = (end - start) / len(frames)

c.weight = len(shot_frames[ci]) / num_frames
else:
self.logger.error('bad document: "raw_bytes" is empty!')
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
'flair': ['flair>=0.4.1'],
'annoy': ['annoy==1.15.2'],
'chinese': ['jieba'],
'vision': ['opencv-python>=4.0.0', 'imagehash>=4.0', 'image'],
'vision': ['opencv-python>=4.0.0', 'imagehash>=4.0', 'image', 'peakutils'],
'leveldb': ['plyvel>=1.0.5'],
'test': ['pylint', 'memory_profiler>=0.55.0', 'psutil>=5.6.1', 'gputil>=1.4.0'],
'transformers': ['pytorch-transformers'],
Expand Down
34 changes: 29 additions & 5 deletions tests/test_video_shotdetect_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,20 @@ class TestShotDetector(unittest.TestCase):

def setUp(self):
self.dirname = os.path.dirname(__file__)
self.yml_path = os.path.join(self.dirname, 'yaml', 'preprocessor-shotdetect.yml')
self.histogram_yml_path = os.path.join(self.dirname, 'yaml', 'preprocessor-shotdetect_histogram.yml')
self.edge_yml_path = os.path.join(self.dirname, 'yaml', 'preprocessor-shotdetect_edge.yml')
self.video_path = os.path.join(self.dirname, 'videos')

def test_video_preprocessor_service_empty(self):
args = set_preprocessor_parser().parse_args([
'--yaml_path', self.yml_path
'--yaml_path', self.histogram_yml_path
])
with PreprocessorService(args):
pass

def test_video_preprocessor_service_realdata(self):
def test_video_preprocessor_service_realdata_histogram(self):
args = set_preprocessor_parser().parse_args([
'--yaml_path', self.yml_path
'--yaml_path', self.histogram_yml_path
])
c_args = _set_client_parser().parse_args([
'--port_in', str(args.port_out),
Expand All @@ -42,4 +43,27 @@ def test_video_preprocessor_service_realdata(self):
self.assertGreater(len(d.chunks), 0)
for _ in range(len(d.chunks)):
shape = blob2array(d.chunks[_].blob).shape
self.assertEqual(shape, (168, 192, 3))
self.assertEqual(shape[1:], (168, 192, 3))

def test_video_preprocessor_service_realdata_edge(self):
args = set_preprocessor_parser().parse_args([
'--yaml_path', self.edge_yml_path
])
c_args = _set_client_parser().parse_args([
'--port_in', str(args.port_out),
'--port_out', str(args.port_in)
])
video_bytes = [open(os.path.join(self.video_path, _), 'rb').read()
for _ in os.listdir(self.video_path)]

with PreprocessorService(args), ZmqClient(c_args) as client:
for req in RequestGenerator.index(video_bytes):
msg = gnes_pb2.Message()
msg.request.index.CopyFrom(req.index)
client.send_message(msg)
r = client.recv_message()
for d in r.request.index.docs:
self.assertGreater(len(d.chunks), 0)
for _ in range(len(d.chunks)):
shape = blob2array(d.chunks[_].blob).shape
self.assertEqual(shape[1:], (168, 192, 3))
8 changes: 8 additions & 0 deletions tests/yaml/preprocessor-shotdetect_edge.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
!ShotDetectPreprocessor
parameters:
descriptor: "edge_detect"
distance_metric: "edgechangeration"
frame_size: "192*168"
frame_rate: "30"
gnes_config:
is_trained: true

0 comments on commit d420f34

Please sign in to comment.