Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

feat(contrib): no need to give module name in advance #123

Merged
merged 6 commits into from
Aug 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions .drone-cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ steps:
from_secret: BOT_URL
commands:
- export MSG_LINK=$DRONE_REPO_LINK
- export MSG_TITLE="✅🎁 All images is successfully updated!"
- export MSG_TITLE="✅🎁 All images are successfully delivered!"
- export MSG_CONTENT=""
- ./shell/push-wechatwork.sh

Expand Down Expand Up @@ -90,9 +90,9 @@ steps:
BOT_URL:
from_secret: BOT_URL
commands:
- export MSG_LINK=${DRONE_BUILD_LINK}
- export MSG_TITLE="⌛🗳 Start to build docker image for new release \`$DRONE_SOURCE_BRANCH\`(\`${DRONE_BUILD_NUMBER}\`)"
- export MSG_CONTENT="click the link below to see the status"
- export MSG_LINK=$DRONE_BUILD_LINK
- export MSG_TITLE="⌛🗳 Start to build docker image for new release \`$DRONE_SOURCE_BRANCH\`(\`${DRONE_TAG}\`)"
- export MSG_CONTENT="[tag link](https://github.com/gnes-ai/gnes/tree/${DRONE_TAG}), click the link below to see the status"
- ./shell/push-wechatwork.sh

- name: build and push docker images
Expand Down Expand Up @@ -123,8 +123,8 @@ steps:
BOT_URL:
from_secret: BOT_URL
commands:
- export MSG_LINK=$DRONE_REPO_LINK
- export MSG_TITLE="✅🎁 All images is successfully updated!"
- export MSG_LINK="https://github.com/gnes-ai/gnes/tree/${DRONE_TAG}"
- export MSG_TITLE="✅🎁 All images are successfully delivered!"
- export MSG_CONTENT=""
- ./shell/push-wechatwork.sh

Expand All @@ -136,7 +136,7 @@ steps:
commands:
- export MSG_LINK=$DRONE_BUILD_LINK
- export MSG_TITLE="❌🎁 fail to build docker image!"
- export MSG_CONTENT="please inform [$DRONE_COMMIT_AUTHOR]($DRONE_COMMIT_AUTHOR_EMAIL) to modify and fix [\`$DRONE_SOURCE_BRANCH\`]($DRONE_COMMIT_LINK). click the link below to see the details."
- export MSG_CONTENT="[tag link](https://github.com/gnes-ai/gnes/tree/${DRONE_TAG}) please inform [$DRONE_COMMIT_AUTHOR]($DRONE_COMMIT_AUTHOR_EMAIL) to modify and fix [\`$DRONE_SOURCE_BRANCH\`]($DRONE_COMMIT_LINK). click the link below to see the details."
- ./shell/push-wechatwork.sh
when:
status:
Expand Down
8 changes: 4 additions & 4 deletions .drone.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ steps:
commands:
- export MSG_LINK=${DRONE_BUILD_LINK}
- export MSG_TITLE="⌛🏗 Start a CI pipeline \`$DRONE_SOURCE_BRANCH\`(\`${DRONE_BUILD_NUMBER}\`)"
- export MSG_CONTENT="submit by [$DRONE_COMMIT_AUTHOR]($DRONE_COMMIT_AUTHOR_EMAIL), click the link below to see the status"
- export MSG_CONTENT="submit by [$DRONE_COMMIT_AUTHOR]($DRONE_COMMIT_AUTHOR_EMAIL) [PR link](https://github.com/gnes-ai/gnes/pull/${DRONE_PULL_REQUEST}), click the link below to see the status"
- ./shell/push-wechatwork.sh

- name: check commit style
Expand Down Expand Up @@ -65,7 +65,7 @@ steps:
BOT_URL:
from_secret: BOT_URL
commands:
- export MSG_LINK=$DRONE_REPO_LINK
- export MSG_LINK="https://github.com/gnes-ai/gnes/pull/${DRONE_PULL_REQUEST}"
- export MSG_TITLE="✅😃 All tests passed, good job! \`$DRONE_SOURCE_BRANCH\`(\`${DRONE_BUILD_NUMBER}\`)"
- export MSG_CONTENT="the branch \`$DRONE_SOURCE_BRANCH\` submit by [$DRONE_COMMIT_AUTHOR]($DRONE_COMMIT_AUTHOR_EMAIL) is ready to merge to master"
- ./shell/push-wechatwork.sh
Expand All @@ -77,9 +77,9 @@ steps:
BOT_URL:
from_secret: BOT_URL
commands:
- export MSG_LINK=$DRONE_BUILD_LINK
- export MSG_LINK=${DRONE_BUILD_LINK}
- export MSG_TITLE="❌😥 CI pipeline \`$DRONE_SOURCE_BRANCH\`(\`${DRONE_BUILD_NUMBER}\`) is failed!"
- export MSG_CONTENT="please inform [$DRONE_COMMIT_AUTHOR]($DRONE_COMMIT_AUTHOR_EMAIL) to modify and fix [\`$DRONE_SOURCE_BRANCH\`]($DRONE_COMMIT_LINK). click the link below to see the details."
- export MSG_CONTENT="[PR link](https://github.com/gnes-ai/gnes/pull/${DRONE_PULL_REQUEST}) please inform [$DRONE_COMMIT_AUTHOR]($DRONE_COMMIT_AUTHOR_EMAIL) to modify and fix [\`$DRONE_SOURCE_BRANCH\`]($DRONE_COMMIT_LINK). click the link below to see the details."
- ./shell/push-wechatwork.sh
when:
status:
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ Now let's see what the YAML config says. First impression, it is pretty intuitiv

```yaml
!TextPreprocessor
parameter:
parameters:
start_doc_id: 0
random_doc_id: True
deliminator: "[.!?]+"
Expand All @@ -277,19 +277,19 @@ gnes_config:
!PipelineEncoder
components:
- !GPT2Encoder
parameter:
parameters:
model_dir: $GPT2_CI_MODEL
pooling_stragy: REDUCE_MEAN
gnes_config:
is_trained: true
- !PCALocalEncoder
parameter:
parameters:
output_dim: 32
num_locals: 8
gnes_config:
batch_size: 2048
- !PQEncoder
parameter:
parameters:
cluster_per_byte: 8
num_bytes: 8
gnes_config:
Expand All @@ -304,7 +304,7 @@ gnes_config:

```yaml
!BIndexer
parameter:
parameters:
num_bytes: 8
data_path: /out_data/idx.binary
gnes_config:
Expand Down
10 changes: 5 additions & 5 deletions docs/chapter/yaml-config.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ Together they define the behavior of a GNES system. Roughly speaking,
!PipelineEncoder
components:
- !Word2VecEncoder
parameter:
parameters:
model_dir: /ext_data/sgns.wiki.bigram-char.refine
property:
is_trained: true
- !PCALocalEncoder
parameter:
parameters:
output_dim: 200
num_locals: 10
property:
Expand All @@ -42,18 +42,18 @@ One can also append extra component to this pipeline, e.g. adding quantization.
!PipelineEncoder
components:
- !Word2VecEncoder
parameter:
parameters:
model_dir: /ext_data/sgns.wiki.bigram-char.refine
property:
is_trained: true
- !PCALocalEncoder
parameter:
parameters:
output_dim: 200
num_locals: 10
property:
batch_size: 2048
- !PQEncoder
parameter:
parameters:
cluster_per_byte: 20
num_bytes: 10
```
Expand Down
35 changes: 17 additions & 18 deletions gnes/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __call__(cls, *args, **kwargs):
for k, v in TrainableType.default_gnes_config.items():
if k in gnes_config:
v = gnes_config[k]
v = _expand_env_var(v)
if not hasattr(obj, k):
setattr(obj, k, v)

Expand Down Expand Up @@ -172,9 +173,9 @@ def _post_init_wrapper(self):
_name = '%s-%s' % (self.__class__.__name__, _id)
self.logger.warning(
'this object is not named ("- gnes_config: - name" is not found in YAML config), '
'i will call it as "%s". '
'naming the object is important especially when you need to '
'serialize/deserialize/store/load the object.' % _name)
'i will call it "%s". '
'naming the object is important as it provides an unique identifier when '
'serializing/deserializing this object.' % _name)
setattr(self, 'name', _name)

_before = set(list(self.__dict__.keys()))
Expand All @@ -199,8 +200,6 @@ def yaml_full_path(self):
def __getstate__(self):
d = dict(self.__dict__)
del d['logger']
if '_file_lock' in d:
del d['_file_lock']
for k in self._post_init_vars:
del d[k]
return d
Expand Down Expand Up @@ -309,15 +308,15 @@ def _get_instance_from_yaml(cls, constructor, node, stop_on_import_error=False):
cls.init_from_yaml = True

if cls.store_args_kwargs:
p = data.get('parameter', {}) # type: Dict[str, Any]
p = data.get('parameters', {}) # type: Dict[str, Any]
a = p.pop('args') if 'args' in p else ()
k = p.pop('kwargs') if 'kwargs' in p else {}
# maybe there are some hanging kwargs in "parameter"
tmp_a = (cls._convert_env_var(v) for v in a)
tmp_p = {kk: cls._convert_env_var(vv) for kk, vv in {**k, **p}.items()}
# maybe there are some hanging kwargs in "parameters"
tmp_a = (_expand_env_var(v) for v in a)
tmp_p = {kk: _expand_env_var(vv) for kk, vv in {**k, **p}.items()}
obj = cls(*tmp_a, **tmp_p, gnes_config=data.get('gnes_config', {}))
else:
tmp_p = {kk: cls._convert_env_var(vv) for kk, vv in data.get('parameter', {}).items()}
tmp_p = {kk: _expand_env_var(vv) for kk, vv in data.get('parameters', {}).items()}
obj = cls(**tmp_p, gnes_config=data.get('gnes_config', {}))

obj.logger.info('initialize %s from a yaml config' % cls.__name__)
Expand All @@ -335,21 +334,14 @@ def _get_dump_path_from_config(gnes_config: Dict):
if os.path.exists(dump_path):
return dump_path

@staticmethod
def _convert_env_var(v):
if isinstance(v, str):
return parse_arg(os.path.expandvars(v))
else:
return v

@staticmethod
def _dump_instance_to_yaml(data):
# note: we only dump non-default property for the sake of clarity
p = {k: getattr(data, k) for k, v in TrainableType.default_gnes_config.items() if getattr(data, k) != v}
a = {k: v for k, v in data._init_kwargs_dict.items() if k not in TrainableType.default_gnes_config}
r = {}
if a:
r['parameter'] = a
r['parameters'] = a
if p:
r['gnes_config'] = p
return r
Expand Down Expand Up @@ -419,3 +411,10 @@ def from_yaml(cls, constructor, node):
if not from_dump and 'components' in data:
obj.components = lambda: data['components']
return obj


def _expand_env_var(v: str) -> str:
if isinstance(v, str):
return parse_arg(os.path.expandvars(v))
else:
return v
6 changes: 3 additions & 3 deletions gnes/composer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def rule5():
# a shortcut fn: based on c3(): (N)-2-(N) with pub sub connection
rule3()
router_layers[0].components[0]['socket_out'] = str(SocketType.PUB_BIND)
router_layers[0].components[0]['yaml_path'] = '"!PublishRouter {parameter: {num_part: %d}}"' \
router_layers[0].components[0]['yaml_path'] = '"!PublishRouter {parameters: {num_part: %d}}"' \
% len(layer.components)
for c in layer.components:
c['socket_in'] = str(SocketType.SUB_CONNECT)
Expand Down Expand Up @@ -439,7 +439,7 @@ def rule7():
router_layer = YamlComposer.Layer(layer_id=self._num_layer)
self._num_layer += 1
r0 = CommentedMap({'name': 'Router',
'yaml_path': '"!PublishRouter {parameter: {num_part: %d}}"' % len(layer.components),
'yaml_path': '"!PublishRouter {parameters: {num_part: %d}}"' % len(layer.components),
'socket_in': str(SocketType.PULL_BIND),
'socket_out': str(SocketType.PUB_BIND),
'port_in': self._get_random_port(),
Expand Down Expand Up @@ -468,7 +468,7 @@ def rule10():
router_layer = YamlComposer.Layer(layer_id=self._num_layer)
self._num_layer += 1
r0 = CommentedMap({'name': 'Router',
'yaml_path': '"!PublishRouter {parameter: {num_part: %d}}"' % len(layer.components),
'yaml_path': '"!PublishRouter {parameters: {num_part: %d}}"' % len(layer.components),
'socket_in': str(SocketType.PULL_BIND),
'socket_out': str(SocketType.PUB_BIND),
'port_in': self._get_random_port(),
Expand Down
13 changes: 5 additions & 8 deletions gnes/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,13 +504,11 @@ def load_contrib_module():

if contrib:
default_logger.info(
'find a value in $GNES_CONTRIB_MODULE=%s, will try to load these modules from external' % contrib)
for c in contrib.split(','):
if ':' in c:
_name, _path = c.split(':')
m = PathImporter.add_modules(_path)
modules.append(m)
default_logger.info('successfully register %s class, you can now use it via yaml.' % m)
'find a value in $GNES_CONTRIB_MODULE=%s, will load them as external modules' % contrib)
for p in contrib.split(','):
m = PathImporter.add_modules(p)
modules.append(m)
default_logger.info('successfully registered %s class, you can now use it via yaml.' % m)
return modules


Expand All @@ -528,7 +526,6 @@ def add_modules(*paths):
if not os.path.exists(p):
raise FileNotFoundError('cannot import module from %s, file not exist')
module, spec = PathImporter._path_import(p)
sys.modules[spec.name] = module
return module

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion gnes/service/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def __init__(self, service_cls, args):
else:
_head_router.socket_out = SocketType.PUB_BIND
_head_router.yaml_path = resolve_yaml_path(
'!PublishRouter {parameter: {num_part: %d}}' % args.num_parallel)
'!PublishRouter {parameters: {num_part: %d}}' % args.num_parallel)

if args.parallel_type.is_block:
_tail_router.yaml_path = resolve_yaml_path('BaseReduceRouter')
Expand Down
4 changes: 2 additions & 2 deletions gnes/service/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ def post_init(self):
options=[('grpc.max_send_message_length', self.args.max_message_size * 1024 * 1024),
('grpc.max_receive_message_length', self.args.max_message_size * 1024 * 1024)])

foo = PathImporter.add_modules(self.args.pb2_path, self.args.pb2_grpc_path)
m = PathImporter.add_modules(self.args.pb2_path, self.args.pb2_grpc_path)

# build stub
self.stub = getattr(foo, self.args.stub_name)(self.channel)
self.stub = getattr(m, self.args.stub_name)(self.channel)

def close(self):
self.channel.close()
Expand Down
9 changes: 9 additions & 0 deletions release.sh
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,20 @@ function make_release_note {
}

BRANCH=$(git rev-parse --abbrev-ref HEAD)

if [[ "$BRANCH" != "master" ]]; then
printf "You are not at master branch, exit\n";
exit 1;
fi

LAST_UPDATE=`git show --no-notes --format=format:"%H" $BRANCH | head -n 1`
LAST_COMMIT=`git show --no-notes --format=format:"%H" origin/$BRANCH | head -n 1`

if [ $LAST_COMMIT != $LAST_UPDATE ]; then
printf "Your local $BRANCH is behind the remote master, exit\n"
exit 1;
fi

if [[ -z "${BOT_URL}" ]]; then
printf "BOT_URL is not set! Need to export BOT_URL=xxx"
exit 1;
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/dummy.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
!FooContribEncoder
parameter:
parameters:
bar: 531
15 changes: 15 additions & 0 deletions tests/contrib/transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import List

import numpy as np

from gnes.encoder.base import BaseTextEncoder


class PyTorchTransformers(BaseTextEncoder):

def __init__(self, model_name: str = 'bert-base-uncased', *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_name = model_name

def encode(self, text: List[str], *args, **kwargs):
return np.random.random([5, 128])
7 changes: 7 additions & 0 deletions tests/contrib/transformer.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
!PyTorchTransformers
parameter:
model_name: bert-base-uncased
gnes_config:
name: my_transformer # a customized name
is_trained: true # indicate the model has been trained
work_dir: ./ # path for serialize/deserialize
7 changes: 3 additions & 4 deletions tests/test_contrib_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

dirname = os.path.dirname(__file__)
module_path = os.path.join(dirname, 'contrib', 'dummy_contrib.py')
cls_name = 'FooContribEncoder'


@unittest.SkipTest
Expand All @@ -20,7 +19,7 @@ def tearDown(self):
os.remove(self.dump_yaml_path)

def test_load_contrib(self):
os.environ['GNES_CONTRIB_MODULE'] = '%s:%s' % (cls_name, module_path)
os.environ['GNES_CONTRIB_MODULE'] = module_path
from gnes.encoder.base import BaseEncoder, BaseTextEncoder
a = BaseEncoder.load_yaml(self.yaml_path)
self.assertIsInstance(a, BaseTextEncoder)
Expand All @@ -32,14 +31,14 @@ def test_load_contrib(self):
self.assertEqual(b.encode([]), 'hello 531')

def test_bad_name(self):
os.environ['GNES_CONTRIB_MODULE'] = '%s:%s' % ('blah', module_path)
os.environ['GNES_CONTRIB_MODULE'] = module_path
try:
from gnes.encoder.base import BaseEncoder
except AttributeError:
pass

def test_bad_path(self):
os.environ['GNES_CONTRIB_MODULE'] = '%s:%s' % (cls_name, 'blah')
os.environ['GNES_CONTRIB_MODULE'] = 'blah'
try:
from gnes.encoder.base import BaseEncoder
except AttributeError:
Expand Down
Loading