diff --git a/gnes/preprocessor/video/ffmpeg.py b/gnes/preprocessor/video/ffmpeg.py index c54e13bf..66b19e59 100644 --- a/gnes/preprocessor/video/ffmpeg.py +++ b/gnes/preprocessor/video/ffmpeg.py @@ -112,6 +112,7 @@ def __init__(self, super().__init__(*args, **kwargs) self.segment_method = segment_method self.segment_interval = segment_interval + self.segment_num = segment_num self._ffmpeg_kwargs = kwargs def apply(self, doc: 'gnes_pb2.Document') -> None: diff --git a/tests/test_video_preprocessor.py b/tests/test_video_preprocessor.py index 9d92e986..22bf525c 100644 --- a/tests/test_video_preprocessor.py +++ b/tests/test_video_preprocessor.py @@ -15,6 +15,8 @@ def setUp(self): self.yml_path_2 = os.path.join(self.dirname, 'yaml', 'preprocessor-ffmpeg2.yml') self.yml_path_3 = os.path.join(self.dirname, 'yaml', 'preprocessor-ffmpeg3.yml') self.video_path = os.path.join(self.dirname, 'videos') + self.video_bytes = [open(os.path.join(self.video_path, _), 'rb').read() + for _ in os.listdir(self.video_path)] def test_video_preprocessor_service_empty(self): args = set_preprocessor_service_parser().parse_args([ @@ -28,23 +30,13 @@ def test_video_preprocessor_service_realdata(self): '--yaml_path', self.yml_path ]) - args_2 = set_preprocessor_service_parser().parse_args([ - '--yaml_path', self.yml_path_2 - ]) - - args_3 = set_preprocessor_service_parser().parse_args([ - '--yaml_path', self.yml_path_3 - ]) - c_args = _set_client_parser().parse_args([ '--port_in', str(args.port_out), '--port_out', str(args.port_in) ]) - video_bytes = [open(os.path.join(self.video_path, _), 'rb').read() - for _ in os.listdir(self.video_path)] with PreprocessorService(args), ZmqClient(c_args) as client: - for req in RequestGenerator.index(video_bytes): + for req in RequestGenerator.index(self.video_bytes): msg = gnes_pb2.Message() msg.request.index.CopyFrom(req.index) client.send_message(msg) @@ -55,8 +47,17 @@ def test_video_preprocessor_service_realdata(self): shape = blob2array(d.chunks[_].blob).shape self.assertEqual(shape, (168, 192, 3)) - with PreprocessorService(args_2), ZmqClient(c_args) as client: - for req in RequestGenerator.index(video_bytes): + def test_video_cut_by_frame(self): + args = set_preprocessor_service_parser().parse_args([ + '--yaml_path', self.yml_path_2, + ]) + c_args = _set_client_parser().parse_args([ + '--port_in', str(args.port_out), + '--port_out', str(args.port_in) + ]) + + with PreprocessorService(args), ZmqClient(c_args) as client: + for req in RequestGenerator.index(self.video_bytes): msg = gnes_pb2.Message() msg.request.index.CopyFrom(req.index) client.send_message(msg) @@ -69,8 +70,17 @@ def test_video_preprocessor_service_realdata(self): shape = blob2array(d.chunks[-1].blob).shape self.assertLessEqual(shape[0], 30) - with PreprocessorService(args_2), ZmqClient(c_args) as client: - for req in RequestGenerator.index(video_bytes): + def test_video_cut_by_num(self): + args = set_preprocessor_service_parser().parse_args([ + '--yaml_path', self.yml_path_3 + ]) + c_args = _set_client_parser().parse_args([ + '--port_in', str(args.port_out), + '--port_out', str(args.port_in) + ]) + + with PreprocessorService(args), ZmqClient(c_args) as client: + for req in RequestGenerator.index(self.video_bytes): msg = gnes_pb2.Message() msg.request.index.CopyFrom(req.index) client.send_message(msg)