Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
Merge pull request #67 from gnes-ai/fix-composer-preprocessor
Browse files Browse the repository at this point in the history
feat(composer): more interaction for gnes board
  • Loading branch information
Han Xiao authored Aug 2, 2019
2 parents 323879a + 823bded commit 888aac8
Show file tree
Hide file tree
Showing 28 changed files with 320 additions and 139 deletions.
74 changes: 70 additions & 4 deletions gnes/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
import tempfile
import uuid
from functools import wraps
from typing import Dict, Any, Union, TextIO, TypeVar, Type
from typing import Dict, Any, Union, TextIO, TypeVar, Type, List, Callable

import ruamel.yaml.constructor

from ..helper import set_logger, profiling, yaml, parse_arg, load_contrib_module

__all__ = ['TrainableBase']
__all__ = ['TrainableBase', 'CompositionalTrainableBase']

T = TypeVar('T', bound='TrainableBase')

Expand Down Expand Up @@ -295,7 +295,7 @@ def _get_instance_from_yaml(cls, constructor, node, stop_on_import_error=False):
if stop_on_import_error:
raise RuntimeError('Cannot import module, pip install may required') from ex

if node.tag in {'!PipelineEncoder', '!CompositionalEncoder'}:
if node.tag in {'!PipelineEncoder', '!CompositionalTrainableBase'}:
os.environ['GNES_WARN_UNNAMED_COMPONENT'] = '0'

data = ruamel.yaml.constructor.SafeConstructor.construct_mapping(
Expand Down Expand Up @@ -325,7 +325,7 @@ def _get_instance_from_yaml(cls, constructor, node, stop_on_import_error=False):
obj.logger.info('initialize %s from a yaml config' % cls.__name__)
cls.init_from_yaml = False

if node.tag in {'!PipelineEncoder', '!CompositionalEncoder'}:
if node.tag in {'!PipelineEncoder', '!CompositionalTrainableBase'}:
os.environ['GNES_WARN_UNNAMED_COMPONENT'] = '1'

return obj, data, load_from_dump
Expand Down Expand Up @@ -355,3 +355,69 @@ def _dump_instance_to_yaml(data):
if p:
r['gnes_config'] = p
return r

def _copy_from(self, x: 'TrainableBase') -> None:
pass


class CompositionalTrainableBase(TrainableBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._component = None # type: List[T]

@property
def component(self) -> Union[List[T], Dict[str, T]]:
return self._component

@property
def is_pipeline(self):
return isinstance(self.component, list)

@component.setter
def component(self, comps: Callable[[], Union[list, dict]]):
if not callable(comps):
raise TypeError('component must be a callable function that returns '
'a List[BaseEncoder]')
if not getattr(self, 'init_from_yaml', False):
self._component = comps()
else:
self.logger.info('component is omitted from construction, '
'as it is initialized from yaml config')

def close(self):
super().close()
# pipeline
if isinstance(self.component, list):
for be in self.component:
be.close()
# no typology
elif isinstance(self.component, dict):
for be in self.component.values():
be.close()
elif self.component is None:
pass
else:
raise TypeError('component must be dict or list, received %s' % type(self.component))

def _copy_from(self, x: T):
if isinstance(self.component, list):
for be1, be2 in zip(self.component, x.component):
be1._copy_from(be2)
elif isinstance(self.component, dict):
for k, v in self.component.items():
v._copy_from(x.component[k])
else:
raise TypeError('component must be dict or list, received %s' % type(self.component))

@classmethod
def to_yaml(cls, representer, data):
tmp = super()._dump_instance_to_yaml(data)
tmp['component'] = data.component
return representer.represent_mapping('!' + cls.__name__, tmp)

@classmethod
def from_yaml(cls, constructor, node):
obj, data, from_dump = super()._get_instance_from_yaml(constructor, node)
if not from_dump and 'component' in data:
obj.component = lambda: data['component']
return obj
6 changes: 3 additions & 3 deletions gnes/cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def set_composer_parser(parser=None):
'gnes', '/'.join(('resources', 'config', 'compose', 'default.yml'))),
help='yaml config of the service')
parser.add_argument('--html_path', type=argparse.FileType('w', encoding='utf8'),
default='./gnes-board.html',
help='output path of the HTML file, will contain all possible generations')
parser.add_argument('--shell_path', type=argparse.FileType('w', encoding='utf8'),
help='output path of the shell-based starting script')
Expand Down Expand Up @@ -214,10 +213,11 @@ def set_grpc_frontend_parser(parser=None):
from ..service.base import SocketType
if not parser:
parser = set_base_parser()
_set_client_parser(parser)
set_service_parser(parser)
_set_grpc_parser(parser)
parser.set_defaults(socket_in=SocketType.PULL_BIND,
socket_out=SocketType.PUSH_BIND)
socket_out=SocketType.PUSH_BIND,
read_only=True)
parser.add_argument('--max_concurrency', type=int, default=10,
help='maximum concurrent client allowed')
parser.add_argument('--max_send_size', type=int, default=100,
Expand Down
30 changes: 22 additions & 8 deletions gnes/composer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,24 @@ def build_layers(self) -> List['YamlComposer.Layer']:
last_layer = self._layers[idx - 1]
for l in self._add_router(last_layer, layer):
all_layers.append(copy.deepcopy(l))
# # add frontend
# for l in self._add_router(all_layers[-1], all_layers[0]):
# all_layers.append(copy.deepcopy(l))
all_layers[0] = copy.deepcopy(self._layers[0])

# gRPCfrontend should always on the bind role
assert all_layers[0].is_single_component
assert all_layers[0].components[0]['name'] == 'gRPCFrontend'

if all_layers[0].components[0]['socket_in'] == str(SocketType.SUB_CONNECT):
# change to sub bind
all_layers[0].components[0]['socket_in'] = str(SocketType.SUB_BIND)
for c in all_layers[-1].components:
c['socket_out'] = str(SocketType.PUB_CONNECT)

if all_layers[0].components[0]['socket_in'] == str(SocketType.PULL_CONNECT):
# change to sub bind
all_layers[0].components[0]['socket_in'] = str(SocketType.PULL_BIND)
for c in all_layers[-1].components:
c['socket_out'] = str(SocketType.PUSH_CONNECT)

return all_layers

@staticmethod
Expand Down Expand Up @@ -292,11 +306,11 @@ def build_mermaid(all_layers: List['YamlComposer.Layer'], mermaid_leftright: boo
# if len(last_layer.components) > 1:
# self.mermaid_graph.append('\tend')

style = ['classDef gRPCFrontendCLS fill:#FFAA04,stroke:#277CE8,stroke-width:1px;',
'classDef EncoderCLS fill:#27E1E8,stroke:#277CE8,stroke-width:1px;',
'classDef IndexerCLS fill:#27E1E8,stroke:#277CE8,stroke-width:1px;',
'classDef RouterCLS fill:#2BFFCB,stroke:#277CE8,stroke-width:1px;',
'classDef PreprocessorCLS fill:#27E1E8,stroke:#277CE8,stroke-width:1px;']
style = ['classDef gRPCFrontendCLS fill:#FFE0E0,stroke:#FFE0E0,stroke-width:1px;',
'classDef EncoderCLS fill:#FFDAAF,stroke:#FFDAAF,stroke-width:1px;',
'classDef IndexerCLS fill:#FFFBC1,stroke:#FFFBC1,stroke-width:1px;',
'classDef RouterCLS fill:#C9E8D2,stroke:#C9E8D2,stroke-width:1px;',
'classDef PreprocessorCLS fill:#CEEEEF,stroke:#CEEEEF,stroke-width:1px;']
class_def = ['class %s %s;' % (','.join(v), k) for k, v in cls_dict.items()]
mermaid_str = '\n'.join(
['graph %s' % ('LR' if mermaid_leftright else 'TD')] + mermaid_graph + style + class_def)
Expand Down
19 changes: 10 additions & 9 deletions gnes/composer/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tempfile
import io

from .base import YamlComposer
from ..cli.parser import set_composer_parser
Expand All @@ -37,23 +36,25 @@ def _create_flask_app(self):

# support up to 10 concurrent HTTP requests
app = Flask(__name__)
args = set_composer_parser().parse_args([])
default_html = YamlComposer(args).build_all()['html']

@app.route('/', methods=['GET'])
def _get_homepage():
return YamlComposer(set_composer_parser().parse_args([])).build_all()['html']
return default_html

@app.route('/generate', methods=['POST'])
def _regenerate():
data = request.form if request.form else request.json
if not data or 'yaml-config' not in data:
return '<h1>Bad POST request</h1> your POST request does not contain "yaml-config" field!', 406
f = tempfile.NamedTemporaryFile('w', delete=False).name
with open(f, 'w', encoding='utf8') as fp:
fp.write(data['yaml-config'])
try:
return YamlComposer(set_composer_parser().parse_args([
'--yaml_path', f
])).build_all()['html']
args.yaml_path = io.StringIO(data['yaml-config'])
if data.get('mermaid_direction', 'top-down').lower() == 'left-right':
args.mermaid_leftright = True
if 'docker-image' in data:
args.docker_img = data['docker-image']
return YamlComposer(args).build_all()['html']
except Exception as e:
self.logger.error(e)
return '<h1>Bad YAML input</h1> please kindly check the format, indent and content of your YAML file!', 400
Expand Down
1 change: 0 additions & 1 deletion gnes/encoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
'BaseTextEncoder': 'base',
'BaseVideoEncoder': 'base',
'BaseNumericEncoder': 'base',
'CompositionalEncoder': 'base',
'PipelineEncoder': 'base',
'HashEncoder': 'numeric.hash',
'BasePytorchEncoder': 'image.base',
Expand Down
69 changes: 3 additions & 66 deletions gnes/encoder/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
# pylint: disable=low-comment-ratio


from typing import List, Any, Union, Dict, Callable
from typing import List, Any

import numpy as np

from ..base import TrainableBase
from ..base import TrainableBase, CompositionalTrainableBase


class BaseEncoder(TrainableBase):
Expand Down Expand Up @@ -64,70 +64,7 @@ def encode(self, data: np.ndarray, *args, **kwargs) -> bytes:
return data.tobytes()


class CompositionalEncoder(BaseEncoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._component = None # type: List['BaseEncoder']

@property
def component(self) -> Union[List['BaseEncoder'], Dict[str, 'BaseEncoder']]:
return self._component

@property
def is_pipeline(self):
return isinstance(self.component, list)

@component.setter
def component(self, comps: Callable[[], Union[list, dict]]):
if not callable(comps):
raise TypeError('component must be a callable function that returns '
'a List[BaseEncoder]')
if not getattr(self, 'init_from_yaml', False):
self._component = comps()
else:
self.logger.info('component is omitted from construction, '
'as it is initialized from yaml config')

def close(self):
super().close()
# pipeline
if isinstance(self.component, list):
for be in self.component:
be.close()
# no typology
elif isinstance(self.component, dict):
for be in self.component.values():
be.close()
elif self.component is None:
pass
else:
raise TypeError('component must be dict or list, received %s' % type(self.component))

def _copy_from(self, x: 'CompositionalEncoder'):
if isinstance(self.component, list):
for be1, be2 in zip(self.component, x.component):
be1._copy_from(be2)
elif isinstance(self.component, dict):
for k, v in self.component.items():
v._copy_from(x.component[k])
else:
raise TypeError('component must be dict or list, received %s' % type(self.component))

@classmethod
def to_yaml(cls, representer, data):
tmp = super()._dump_instance_to_yaml(data)
tmp['component'] = data.component
return representer.represent_mapping('!' + cls.__name__, tmp)

@classmethod
def from_yaml(cls, constructor, node):
obj, data, from_dump = super()._get_instance_from_yaml(constructor, node)
if not from_dump and 'component' in data:
obj.component = lambda: data['component']
return obj


class PipelineEncoder(CompositionalEncoder):
class PipelineEncoder(CompositionalTrainableBase):
def encode(self, data: Any, *args, **kwargs) -> Any:
if not self.component:
raise NotImplementedError
Expand Down
4 changes: 2 additions & 2 deletions gnes/encoder/text/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import numpy as np

from ..base import CompositionalEncoder, BaseTextEncoder
from ..base import CompositionalTrainableBase, BaseTextEncoder
from ...helper import batching


Expand All @@ -45,7 +45,7 @@ def close(self):
self.bc_encoder.close()


class BertEncoderWithServer(CompositionalEncoder):
class BertEncoderWithServer(CompositionalTrainableBase):
def encode(self, text: List[str], *args, **kwargs) -> np.ndarray:
return self.component['bert_client'].encode(text, *args, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion gnes/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ class ColoredFormatter(Formatter):
'WARNING': dict(color='red', on_color='on_yellow'), # yellow
'ERROR': dict(color='white', on_color='on_red'), # 31 for red
'CRITICAL': dict(color='red', on_color='on_white'), # white on red bg
'SUCCESS': dict(color='white', on_color='on_green'), # green
}

PREFIX = '\033['
Expand Down Expand Up @@ -535,4 +536,3 @@ def load_contrib_module():
profiling = time_profile

yaml = _get_yaml()

5 changes: 2 additions & 3 deletions gnes/indexer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@

import numpy as np

from ..base import TrainableBase
from ..encoder.base import CompositionalEncoder
from ..base import TrainableBase, CompositionalTrainableBase


class BaseIndexer(TrainableBase):
Expand Down Expand Up @@ -71,7 +70,7 @@ def normalize_score(self, *args, **kwargs):
pass


class JointIndexer(CompositionalEncoder):
class JointIndexer(CompositionalTrainableBase):

@property
def component(self):
Expand Down
2 changes: 2 additions & 0 deletions gnes/preprocessor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

_cls2file_map = {
'BasePreprocessor': 'base',
'PipelinePreprocessor': 'base',
'TextPreprocessor': 'text.simple',
'BaseImagePreprocessor': 'image.base',
'BaseTextPreprocessor': 'text.base',
Expand All @@ -28,6 +29,7 @@
'WeightedSlidingPreprocessor': 'image.sliding_window',
'SegmentPreprocessor': 'image.segmentation',
'BaseUnaryPreprocessor': 'base',
'ResizeChunkPreprocessor': 'image.resize',
'BaseVideoPreprocessor': 'video.base',
'FFmpegPreprocessor': 'video.ffmpeg',
'FFmpegVideoSegmentor': 'video.ffmpeg',
Expand Down
Loading

0 comments on commit 888aac8

Please sign in to comment.