diff --git a/gnes/preprocessor/video/video_decode.py b/gnes/preprocessor/video/video_decode.py index d5f68890..7c3432f1 100644 --- a/gnes/preprocessor/video/video_decode.py +++ b/gnes/preprocessor/video/video_decode.py @@ -53,14 +53,16 @@ def apply(self, doc: 'gnes_pb2.Document') -> None: else: self.logger.error('the document "raw_bytes" is empty!') - if self.chunk_spliter == 'frame_split': + if self.chunk_spliter == 'base': for i, frame in enumerate(video_frames): c = doc.chunks.add() c.doc_id = doc.doc_id c.blob.CopyFrom(array2blob(frame)) c.offset = i c.weight = 1.0 - elif self.chunk_spliter == 'none': + elif self.chunk_spliter == 'shot': + raise NotImplementedError + else: chunk = doc.chunks.add() chunk.doc_id = doc.doc_id chunk.blob.CopyFrom(array2blob(video_frames)) diff --git a/tests/test_video_decode_preprocessor.py b/tests/test_video_decode_preprocessor.py index e4a78770..d64e28bf 100644 --- a/tests/test_video_decode_preprocessor.py +++ b/tests/test_video_decode_preprocessor.py @@ -50,4 +50,4 @@ def test_video_decode_preprocessor(self): self.assertGreater(len(d.chunks), 0) for _ in range(len(d.chunks)): shape = blob2array(d.chunks[_].blob).shape - self.assertEqual(shape[1:], (299, 299, 3)) + self.assertEqual(shape, (299, 299, 3)) diff --git a/tests/yaml/preprocessor-video_decode.yml b/tests/yaml/preprocessor-video_decode.yml index 60a5ab29..c8e36f75 100644 --- a/tests/yaml/preprocessor-video_decode.yml +++ b/tests/yaml/preprocessor-video_decode.yml @@ -3,6 +3,6 @@ parameters: frame_rate: 5 vframes: 1 frame_size: '299:299' - chunk_spliter: 'none' + chunk_spliter: 'base' gnes_config: is_trained: true