From 52f87c7fa2d54b25a6b075cf549ce960ed63b59d Mon Sep 17 00:00:00 2001 From: hanhxiao Date: Thu, 1 Aug 2019 14:05:04 +0800 Subject: [PATCH] refactor(base): make pipelineencoder more general and allow pipelinepreprocessor --- gnes/base/__init__.py | 68 +++++++++++++++++++++++++- gnes/encoder/__init__.py | 1 - gnes/encoder/base.py | 69 ++------------------------- gnes/preprocessor/__init__.py | 1 + gnes/preprocessor/base.py | 18 ++++++- gnes/preprocessor/video/ffmpeg.py | 6 +-- gnes/preprocessor/video/shotdetect.py | 8 ++-- tests/test_pipelinepreprocess.py | 45 +++++++++++++++++ 8 files changed, 140 insertions(+), 76 deletions(-) create mode 100644 tests/test_pipelinepreprocess.py diff --git a/gnes/base/__init__.py b/gnes/base/__init__.py index 872201b0..fcef4e89 100644 --- a/gnes/base/__init__.py +++ b/gnes/base/__init__.py @@ -22,7 +22,7 @@ import tempfile import uuid from functools import wraps -from typing import Dict, Any, Union, TextIO, TypeVar, Type +from typing import Dict, Any, Union, TextIO, TypeVar, Type, List, Callable import ruamel.yaml.constructor @@ -355,3 +355,69 @@ def _dump_instance_to_yaml(data): if p: r['gnes_config'] = p return r + + def _copy_from(self, x: 'TrainableBase') -> None: + pass + + +class CompositionalTrainableBase(TrainableBase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._component = None # type: List[T] + + @property + def component(self) -> Union[List[T], Dict[str, T]]: + return self._component + + @property + def is_pipeline(self): + return isinstance(self.component, list) + + @component.setter + def component(self, comps: Callable[[], Union[list, dict]]): + if not callable(comps): + raise TypeError('component must be a callable function that returns ' + 'a List[BaseEncoder]') + if not getattr(self, 'init_from_yaml', False): + self._component = comps() + else: + self.logger.info('component is omitted from construction, ' + 'as it is initialized from yaml config') + + def close(self): + super().close() + # pipeline + if isinstance(self.component, list): + for be in self.component: + be.close() + # no typology + elif isinstance(self.component, dict): + for be in self.component.values(): + be.close() + elif self.component is None: + pass + else: + raise TypeError('component must be dict or list, received %s' % type(self.component)) + + def _copy_from(self, x: T): + if isinstance(self.component, list): + for be1, be2 in zip(self.component, x.component): + be1._copy_from(be2) + elif isinstance(self.component, dict): + for k, v in self.component.items(): + v._copy_from(x.component[k]) + else: + raise TypeError('component must be dict or list, received %s' % type(self.component)) + + @classmethod + def to_yaml(cls, representer, data): + tmp = super()._dump_instance_to_yaml(data) + tmp['component'] = data.component + return representer.represent_mapping('!' + cls.__name__, tmp) + + @classmethod + def from_yaml(cls, constructor, node): + obj, data, from_dump = super()._get_instance_from_yaml(constructor, node) + if not from_dump and 'component' in data: + obj.component = lambda: data['component'] + return obj diff --git a/gnes/encoder/__init__.py b/gnes/encoder/__init__.py index e77f2411..691623f1 100644 --- a/gnes/encoder/__init__.py +++ b/gnes/encoder/__init__.py @@ -35,7 +35,6 @@ 'BaseBinaryEncoder': 'base', 'BaseTextEncoder': 'base', 'BaseNumericEncoder': 'base', - 'CompositionalEncoder': 'base', 'PipelineEncoder': 'base', 'HashEncoder': 'numeric.hash', 'BasePytorchEncoder': 'image.base', diff --git a/gnes/encoder/base.py b/gnes/encoder/base.py index 964e237e..469a3c1b 100644 --- a/gnes/encoder/base.py +++ b/gnes/encoder/base.py @@ -16,11 +16,11 @@ # pylint: disable=low-comment-ratio -from typing import List, Any, Union, Dict, Callable +from typing import List, Any import numpy as np -from ..base import TrainableBase +from ..base import TrainableBase, CompositionalTrainableBase class BaseEncoder(TrainableBase): @@ -58,70 +58,7 @@ def encode(self, data: np.ndarray, *args, **kwargs) -> bytes: return data.tobytes() -class CompositionalEncoder(BaseEncoder): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._component = None # type: List['BaseEncoder'] - - @property - def component(self) -> Union[List['BaseEncoder'], Dict[str, 'BaseEncoder']]: - return self._component - - @property - def is_pipeline(self): - return isinstance(self.component, list) - - @component.setter - def component(self, comps: Callable[[], Union[list, dict]]): - if not callable(comps): - raise TypeError('component must be a callable function that returns ' - 'a List[BaseEncoder]') - if not getattr(self, 'init_from_yaml', False): - self._component = comps() - else: - self.logger.info('component is omitted from construction, ' - 'as it is initialized from yaml config') - - def close(self): - super().close() - # pipeline - if isinstance(self.component, list): - for be in self.component: - be.close() - # no typology - elif isinstance(self.component, dict): - for be in self.component.values(): - be.close() - elif self.component is None: - pass - else: - raise TypeError('component must be dict or list, received %s' % type(self.component)) - - def _copy_from(self, x: 'CompositionalEncoder'): - if isinstance(self.component, list): - for be1, be2 in zip(self.component, x.component): - be1._copy_from(be2) - elif isinstance(self.component, dict): - for k, v in self.component.items(): - v._copy_from(x.component[k]) - else: - raise TypeError('component must be dict or list, received %s' % type(self.component)) - - @classmethod - def to_yaml(cls, representer, data): - tmp = super()._dump_instance_to_yaml(data) - tmp['component'] = data.component - return representer.represent_mapping('!' + cls.__name__, tmp) - - @classmethod - def from_yaml(cls, constructor, node): - obj, data, from_dump = super()._get_instance_from_yaml(constructor, node) - if not from_dump and 'component' in data: - obj.component = lambda: data['component'] - return obj - - -class PipelineEncoder(CompositionalEncoder): +class PipelineEncoder(CompositionalTrainableBase): def encode(self, data: Any, *args, **kwargs) -> Any: if not self.component: raise NotImplementedError diff --git a/gnes/preprocessor/__init__.py b/gnes/preprocessor/__init__.py index 6764e9b0..63222041 100644 --- a/gnes/preprocessor/__init__.py +++ b/gnes/preprocessor/__init__.py @@ -20,6 +20,7 @@ _cls2file_map = { 'BasePreprocessor': 'base', + 'PipelinePreprocessor': 'base', 'TextPreprocessor': 'text.simple', 'BaseImagePreprocessor': 'image.base', 'BaseTextPreprocessor': 'text.base', diff --git a/gnes/preprocessor/base.py b/gnes/preprocessor/base.py index f2d0e61d..356c728e 100644 --- a/gnes/preprocessor/base.py +++ b/gnes/preprocessor/base.py @@ -21,7 +21,7 @@ import numpy as np from PIL import Image -from ..base import TrainableBase +from ..base import TrainableBase, CompositionalTrainableBase from ..proto import gnes_pb2, array2blob @@ -38,6 +38,22 @@ def apply(self, doc: 'gnes_pb2.Document') -> None: doc.doc_type = self.doc_type +class PipelinePreprocessor(CompositionalTrainableBase): + def apply(self, doc: 'gnes_pb2.Document') -> None: + if not self.component: + raise NotImplementedError + for be in self.component: + be.apply(doc) + + def train(self, data, *args, **kwargs): + if not self.component: + raise NotImplementedError + for idx, be in enumerate(self.component): + be.train(data, *args, **kwargs) + if idx + 1 < len(self.component): + data = be.apply(data, *args, **kwargs) + + class BaseUnaryPreprocessor(BasePreprocessor): def __init__(self, doc_type: int, *args, **kwargs): diff --git a/gnes/preprocessor/video/ffmpeg.py b/gnes/preprocessor/video/ffmpeg.py index c21439b9..08e17c58 100644 --- a/gnes/preprocessor/video/ffmpeg.py +++ b/gnes/preprocessor/video/ffmpeg.py @@ -25,7 +25,7 @@ class FFmpegPreprocessor(BaseVideoPreprocessor): def __init__(self, - frame_size: str = "192*168", + frame_size: str = '192*168', duplicate_rm: bool = True, use_phash_weight: bool = False, phash_thresh: int = 5, @@ -48,8 +48,8 @@ def apply(self, doc: 'gnes_pb2.Document') -> None: frames = get_video_frames( doc.raw_bytes, s=self.frame_size, - vsync=self._ffmpeg_kwargs.get("vsync", "vfr"), - vf=self._ffmpeg_kwargs.get("vf", "select=eq(pict_type\\,I)")) + vsync=self._ffmpeg_kwargs.get('vsync', 'vfr'), + vf=self._ffmpeg_kwargs.get('vf', 'select=eq(pict_type\\,I)')) # remove dupliated key frames by phash value if self.duplicate_rm: diff --git a/gnes/preprocessor/video/shotdetect.py b/gnes/preprocessor/video/shotdetect.py index 377d8ee5..c1fdca80 100644 --- a/gnes/preprocessor/video/shotdetect.py +++ b/gnes/preprocessor/video/shotdetect.py @@ -26,9 +26,9 @@ class ShotDetectPreprocessor(BaseVideoPreprocessor): store_args_kwargs = True def __init__(self, - frame_size: str = "192*168", - descriptor: str = "block_hsv_histogram", - distance_metric: str = "bhattacharya", + frame_size: str = '192*168', + descriptor: str = 'block_hsv_histogram', + distance_metric: str = 'bhattacharya', *args, **kwargs): super().__init__(*args, **kwargs) @@ -47,7 +47,7 @@ def apply(self, doc: 'gnes_pb2.Document') -> None: frames = get_video_frames( doc.raw_bytes, s=self.frame_size, - vsync="vfr", + vsync='vfr', vf='select=eq(pict_type\\,I)') descriptors = [] diff --git a/tests/test_pipelinepreprocess.py b/tests/test_pipelinepreprocess.py new file mode 100644 index 00000000..b6a11ef3 --- /dev/null +++ b/tests/test_pipelinepreprocess.py @@ -0,0 +1,45 @@ +import os +import unittest + +from gnes.preprocessor.base import BasePreprocessor, PipelinePreprocessor +from gnes.proto import gnes_pb2 + + +class P1(BasePreprocessor): + def apply(self, doc: 'gnes_pb2.Document'): + doc.doc_id += 1 + + +class P2(BasePreprocessor): + def apply(self, doc: 'gnes_pb2.Document'): + doc.doc_id *= 3 + + +class TestPartition(unittest.TestCase): + def setUp(self): + self.dirname = os.path.dirname(__file__) + self.p3_name = 'pipe-p12' + self.yml_dump_path = os.path.join(self.dirname, '%s.yml' % self.p3_name) + self.bin_dump_path = os.path.join(self.dirname, '%s.bin' % self.p3_name) + + def tearDown(self): + if os.path.exists(self.yml_dump_path): + os.remove(self.yml_dump_path) + if os.path.exists(self.bin_dump_path): + os.remove(self.bin_dump_path) + + def test_pipelinepreproces(self): + p3 = PipelinePreprocessor() + p3.component = lambda: [P1(), P2()] + d = gnes_pb2.Document() + d.doc_id = 1 + p3.apply(d) + self.assertEqual(d.doc_id, 6) + + p3.name = self.p3_name + p3.dump_yaml() + p3.dump() + + p4 = BasePreprocessor.load_yaml(p3.yaml_full_path) + p4.apply(d) + self.assertEqual(d.doc_id, 21)