diff --git a/gnes/preprocessor/helper.py b/gnes/preprocessor/helper.py index 2badba2f..0cec1116 100644 --- a/gnes/preprocessor/helper.py +++ b/gnes/preprocessor/helper.py @@ -215,21 +215,17 @@ def hsv_histogram(image: 'np.ndarray') -> 'np.ndarray': def canny_edge(image: 'np.ndarray', **kwargs) -> 'np.ndarray': import cv2 - arg_dict = { - 'sigma': 0.5, - 'gauss_kernel': (9, 9), - 'l2_gradient': True - } - arg_dict.update(kwargs) + sigma = kwargs.get('sigma', 0.5) + gauss_kernel = kwargs.get('gauss_kernel', (9, 9)) + l2_gradient = kwargs.get('l2_gradient', True) image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) # apply automatic Canny edge detection using the computed median v = np.median(image) - sigma = arg_dict['sigma'] low_threshold = ((1.0 - sigma) * v).astype("float32") high_threshold = ((1.0 + sigma) * v).astype("float32") - tmp_image = cv2.GaussianBlur(image, arg_dict['gauss_kernel'], 1.2) - edge_image = cv2.Canny(tmp_image, low_threshold, high_threshold, L2gradient=arg_dict['l2_gradient']) + tmp_image = cv2.GaussianBlur(image, gauss_kernel, 1.2) + edge_image = cv2.Canny(tmp_image, low_threshold, high_threshold, L2gradient=l2_gradient) return edge_image @@ -256,10 +252,12 @@ def compute_descriptor(image: 'np.ndarray', return funcs[method](image) -def compare_ecr(descriptors: List['np.ndarray'], dilate_rate: int = 5, neigh_avg: int = 2) -> List[float]: +def compare_ecr(descriptors: List['np.ndarray'], **kwargs) -> List[float]: import cv2 """ Apply the Edge Change Ratio Algorithm""" + dilate_rate = kwargs.get('dilate_rate', 5) + neigh_avg = kwargs.get('neigh_avg', 2) divd = lambda x, y: 0 if y == 0 else x / y dicts = [] @@ -319,7 +317,7 @@ def kmeans_algo(distances: List[float], **kwargs) -> List[int]: shots.append(0) for i in range(0, len(clt.labels_)): if big_center == clt.labels_[i]: - shots.append((i + 2)) + shots.append((i + 1)) if shots[-1] < num_frames: shots.append(num_frames) else: @@ -348,27 +346,26 @@ def thre_algo(distances: List[float], **kwargs) -> List[int]: def motion_algo(distances: List[float], **kwargs) -> List[int]: import peakutils - arg_dict = { - 'threshold': 0.6, - 'min_dist': 10, - 'motion_step': 15 - } - arg_dict.update(kwargs) + + threshold = kwargs.get('threshold', 0.6) + min_dist = kwargs.get('min_dist', 10) + motion_step = kwargs.get('motion_step', 15) + neigh_avg = kwargs.get('neigh_avg', 2) shots = [] - num_frames = len(distances) + 1 - p = peakutils.indexes(np.array(distances).astype('float32'), thres=arg_dict['threshold'], min_dist=arg_dict['min_dist']) if len(distances) else [] + num_frames = len(distances) + 2 * neigh_avg + 1 + p = peakutils.indexes(np.array(distances).astype('float32'), thres=threshold, min_dist=min_dist) if len(distances) else [] if len(p) == 0: return [0, num_frames] shots.append(0) - shots.append(p[0] + 2) + shots.append(p[0] + neigh_avg + 1) for i in range(1, len(p)): # We check that the peak is not due to a motion in the image - valid_dist = arg_dict['motion_step'] or not check_motion(distances[p[i]-arg_dict['motion_step']:p[i]], distances[p[i]]) + valid_dist = not motion_step or not check_motion(distances[p[i]-motion_step:p[i]], distances[p[i]]) if valid_dist: - shots.append(p[i] + 2) - if shots[-1] < num_frames - arg_dict['min_dist']: + shots.append(p[i] + neigh_avg + 1) + if shots[-1] < num_frames - min_dist: shots.append(num_frames) elif shots[-1] > num_frames: shots[-1] = num_frames diff --git a/gnes/preprocessor/video/shotdetect.py b/gnes/preprocessor/video/shotdetect.py index 1fd6bec4..01b30e41 100644 --- a/gnes/preprocessor/video/shotdetect.py +++ b/gnes/preprocessor/video/shotdetect.py @@ -57,14 +57,15 @@ def detect_shots(self, frames: 'np.ndarray') -> List[List['np.ndarray']]: # compute distances between frames if self.distance_metric == 'edge_change_ration': - dists = compare_ecr(descriptors) + dists = compare_ecr(descriptors, **self._detector_kwargs) else: dists = [ compare_descriptor(pair[0], pair[1], self.distance_metric) for pair in zip(descriptors[:-1], descriptors[1:]) ] + self._detector_kwargs['neigh_avg'] = 0 - shot_bounds = detect_peak_boundary(dists, self.detect_method) + shot_bounds = detect_peak_boundary(dists, self.detect_method, **self._detector_kwargs) shots = [] for ci in range(0, len(shot_bounds) - 1):