diff --git a/gnes/preprocessor/helper.py b/gnes/preprocessor/helper.py index 373d3ac7..71cbf1d8 100644 --- a/gnes/preprocessor/helper.py +++ b/gnes/preprocessor/helper.py @@ -354,20 +354,21 @@ def motion_algo(distances: List[float], **kwargs) -> List[int]: 'motion_step': 15 } arg_dict.update(kwargs) + neigh_avg = kwargs.get('neigh_avg', 2) shots = [] - num_frames = len(distances) + 2 * 2 + 1 + num_frames = len(distances) + 2 * neigh_avg + 1 p = peakutils.indexes(np.array(distances).astype('float32'), thres=arg_dict['threshold'], min_dist=arg_dict['min_dist']) if len(distances) else [] if len(p) == 0: return [0, num_frames] shots.append(0) - shots.append(p[0] + 2 + 1) + shots.append(p[0] + neigh_avg + 1) for i in range(1, len(p)): # We check that the peak is not due to a motion in the image valid_dist = not arg_dict['motion_step'] or not check_motion(distances[p[i]-arg_dict['motion_step']:p[i]], distances[p[i]]) if valid_dist: - shots.append(p[i] + 2 + 1) + shots.append(p[i] + neigh_avg + 1) if shots[-1] < num_frames - arg_dict['min_dist']: shots.append(num_frames) elif shots[-1] > num_frames: diff --git a/gnes/preprocessor/video/shotdetect.py b/gnes/preprocessor/video/shotdetect.py index 1fd6bec4..1da422f4 100644 --- a/gnes/preprocessor/video/shotdetect.py +++ b/gnes/preprocessor/video/shotdetect.py @@ -63,8 +63,9 @@ def detect_shots(self, frames: 'np.ndarray') -> List[List['np.ndarray']]: compare_descriptor(pair[0], pair[1], self.distance_metric) for pair in zip(descriptors[:-1], descriptors[1:]) ] + self._detector_kwargs['neigh_avg'] = 0 - shot_bounds = detect_peak_boundary(dists, self.detect_method) + shot_bounds = detect_peak_boundary(dists, self.detect_method, **self._detector_kwargs) shots = [] for ci in range(0, len(shot_bounds) - 1):