diff --git a/gnes/preprocessor/__init__.py b/gnes/preprocessor/__init__.py index 8f07ce88..38f48cb8 100644 --- a/gnes/preprocessor/__init__.py +++ b/gnes/preprocessor/__init__.py @@ -37,7 +37,8 @@ 'BaseAudioPreprocessor': 'base', 'RawChunkPreprocessor': 'base', 'GifChunkPreprocessor': 'video.ffmpeg', - 'VggishPreprocessor': 'audio.vggish_example' + 'VggishPreprocessor': 'audio.vggish_example', + 'VideoDecodePreprocessor': 'video.video_decode' } register_all_class(_cls2file_map, 'preprocessor') diff --git a/gnes/preprocessor/video/video_decode.py b/gnes/preprocessor/video/video_decode.py new file mode 100644 index 00000000..94e8d325 --- /dev/null +++ b/gnes/preprocessor/video/video_decode.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making GNES available. +# +# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from gnes.preprocessor.base import BaseVideoPreprocessor +from gnes.proto import gnes_pb2, array2blob +from gnes.preprocessor.io_utils import video + + +class VideoDecodePreprocessor(BaseVideoPreprocessor): + store_args_kwargs = True + + def __init__(self, + frame_rate: int = 10, + vframes: int = -1, + scale: str = None, + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.frame_rate = frame_rate + self.vframes = vframes + self.scale = scale + + def apply(self, doc: 'gnes_pb2.Document') -> None: + super().apply(doc) + + all_frames = [] + if doc.WhichOneof('raw_data'): + raw_type = type(getattr(doc, doc.WhichOneof('raw_data'))) + if doc.raw_bytes: + all_frames = video.capture_frames( + input_data=doc.raw_bytes, + scale=self.scale, + fps=self.frame_rate, + vframes=self.vframes) + elif raw_type == gnes_pb2.NdArray: + all_frames = blob2array(doc.raw_video) + if self.vframes > 0: + all_frames = video_frames[0:self.vframes, :].copy() + + num_frames = len(all_frames) + if num_frames > 0: + c = doc.chunks.add() + c.doc_id = doc.doc_id + c.blob.CopyFrom(array2blob(all_frames)) + c.offset = 0 + c.weight = 1.0 + else: + self.logger.error('bad document: "raw_bytes" or "raw_video" is empty!') + else: + self.logger.error('bad document: "raw_bytes" is empty!') diff --git a/tests/test_video_decode_preprocessor.py b/tests/test_video_decode_preprocessor.py new file mode 100644 index 00000000..e4a78770 --- /dev/null +++ b/tests/test_video_decode_preprocessor.py @@ -0,0 +1,53 @@ +# Tencent is pleased to support the open source community by making GNES available. +# +# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +from gnes.cli.parser import set_preprocessor_parser, _set_client_parser +from gnes.client.base import ZmqClient +from gnes.proto import gnes_pb2, RequestGenerator, blob2array +from gnes.service.base import ServiceManager +from gnes.service.preprocessor import PreprocessorService + + +class TestVideoDecode(unittest.TestCase): + + def setUp(self): + self.dirname = os.path.dirname(__file__) + self.yml_path = os.path.join(self.dirname, 'yaml', 'preprocessor-video_decode.yml') + self.video_path = os.path.join(self.dirname, 'videos') + + def test_video_decode_preprocessor(self): + args = set_preprocessor_parser().parse_args(['--yaml_path', self.yml_path]) + c_args = _set_client_parser().parse_args([ + '--port_in', str(args.port_out), + '--port_out', str(args.port_in)]) + video_bytes = [ + open(os.path.join(self.video_path, _), 'rb').read() + for _ in os.listdir(self.video_path) + ] + + with ServiceManager(PreprocessorService, args), ZmqClient(c_args) as client: + for req in RequestGenerator.index(video_bytes): + msg = gnes_pb2.Message() + msg.request.index.CopyFrom(req.index) + client.send_message(msg) + r = client.recv_message() + for d in r.request.index.docs: + self.assertGreater(len(d.chunks), 0) + for _ in range(len(d.chunks)): + shape = blob2array(d.chunks[_].blob).shape + self.assertEqual(shape[1:], (299, 299, 3)) diff --git a/tests/yaml/preprocessor-video_decode.yml b/tests/yaml/preprocessor-video_decode.yml new file mode 100644 index 00000000..57828000 --- /dev/null +++ b/tests/yaml/preprocessor-video_decode.yml @@ -0,0 +1,7 @@ +!VideoDecodePreprocessor +parameters: + frame_rate: 5 + vframes: 1 + scale: '299:299' +gnes_config: + is_trained: true