From 2326fe97d17a5a545d50fd8390f61b8a3c728a90 Mon Sep 17 00:00:00 2001 From: raccoonliukai <903896015@qq.com> Date: Tue, 24 Sep 2019 21:15:28 +0800 Subject: [PATCH] feat(preprocessor): add preprocessor for mp4 and gif decode --- gnes/preprocessor/video/video_decode.py | 54 ++++++++++++++++++++++++ tests/test_video_decode_preprocessor.py | 53 +++++++++++++++++++++++ tests/yaml/preprocessor-video_decode.yml | 7 +++ 3 files changed, 114 insertions(+) create mode 100644 gnes/preprocessor/video/video_decode.py create mode 100644 tests/test_video_decode_preprocessor.py create mode 100644 tests/yaml/preprocessor-video_decode.yml diff --git a/gnes/preprocessor/video/video_decode.py b/gnes/preprocessor/video/video_decode.py new file mode 100644 index 00000000..c3d3855f --- /dev/null +++ b/gnes/preprocessor/video/video_decode.py @@ -0,0 +1,54 @@ +# Tencent is pleased to support the open source community by making GNES available. +# +# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from typing import List + +from gnes.preprocessor.base import BaseVideoPreprocessor +from gnes.proto import gnes_pb2, array2blob +from gnes.preprocessor.io_utils import video + + +class VideoDecodePreprocessor(BaseVideoPreprocessor): + store_args_kwargs = True + + def __init__(self, + frame_rate: int = 10, + vframes: int = -1, + scale: str = None, + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.frame_rate = frame_rate + self.vframes = vframes + self.scale = scale + + def apply(self, doc: 'gnes_pb2.Document') -> None: + super().apply(doc) + + if doc.raw_bytes: + all_frames = video.capture_frames( + input_data=doc.raw_bytes, + scale=self.scale, + fps=self.frame_rate, + vframes=self.vframes) + + c = doc.chunks.add() + c.doc_id = doc.doc_id + c.blob.CopyFrom(array2blob(all_frames)) + c.offset = 0 + c.weight = 1.0 + else: + self.logger.error('bad document: "raw_bytes" is empty!') diff --git a/tests/test_video_decode_preprocessor.py b/tests/test_video_decode_preprocessor.py new file mode 100644 index 00000000..e4a78770 --- /dev/null +++ b/tests/test_video_decode_preprocessor.py @@ -0,0 +1,53 @@ +# Tencent is pleased to support the open source community by making GNES available. +# +# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +from gnes.cli.parser import set_preprocessor_parser, _set_client_parser +from gnes.client.base import ZmqClient +from gnes.proto import gnes_pb2, RequestGenerator, blob2array +from gnes.service.base import ServiceManager +from gnes.service.preprocessor import PreprocessorService + + +class TestVideoDecode(unittest.TestCase): + + def setUp(self): + self.dirname = os.path.dirname(__file__) + self.yml_path = os.path.join(self.dirname, 'yaml', 'preprocessor-video_decode.yml') + self.video_path = os.path.join(self.dirname, 'videos') + + def test_video_decode_preprocessor(self): + args = set_preprocessor_parser().parse_args(['--yaml_path', self.yml_path]) + c_args = _set_client_parser().parse_args([ + '--port_in', str(args.port_out), + '--port_out', str(args.port_in)]) + video_bytes = [ + open(os.path.join(self.video_path, _), 'rb').read() + for _ in os.listdir(self.video_path) + ] + + with ServiceManager(PreprocessorService, args), ZmqClient(c_args) as client: + for req in RequestGenerator.index(video_bytes): + msg = gnes_pb2.Message() + msg.request.index.CopyFrom(req.index) + client.send_message(msg) + r = client.recv_message() + for d in r.request.index.docs: + self.assertGreater(len(d.chunks), 0) + for _ in range(len(d.chunks)): + shape = blob2array(d.chunks[_].blob).shape + self.assertEqual(shape[1:], (299, 299, 3)) diff --git a/tests/yaml/preprocessor-video_decode.yml b/tests/yaml/preprocessor-video_decode.yml new file mode 100644 index 00000000..57828000 --- /dev/null +++ b/tests/yaml/preprocessor-video_decode.yml @@ -0,0 +1,7 @@ +!VideoDecodePreprocessor +parameters: + frame_rate: 5 + vframes: 1 + scale: '299:299' +gnes_config: + is_trained: true