Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
Merge pull request #123 from gnes-ai/feat-contrib-2
Browse files Browse the repository at this point in the history
feat(contrib): no need to give module name in advance
  • Loading branch information
mergify[bot] authored Aug 19, 2019
2 parents 8e829b7 + b5b1a1e commit 0c586fa
Show file tree
Hide file tree
Showing 79 changed files with 269 additions and 192 deletions.
14 changes: 7 additions & 7 deletions .drone-cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ steps:
from_secret: BOT_URL
commands:
- export MSG_LINK=$DRONE_REPO_LINK
- export MSG_TITLE="✅🎁 All images is successfully updated!"
- export MSG_TITLE="✅🎁 All images are successfully delivered!"
- export MSG_CONTENT=""
- ./shell/push-wechatwork.sh

Expand Down Expand Up @@ -90,9 +90,9 @@ steps:
BOT_URL:
from_secret: BOT_URL
commands:
- export MSG_LINK=${DRONE_BUILD_LINK}
- export MSG_TITLE="⌛🗳 Start to build docker image for new release \`$DRONE_SOURCE_BRANCH\`(\`${DRONE_BUILD_NUMBER}\`)"
- export MSG_CONTENT="click the link below to see the status"
- export MSG_LINK=$DRONE_BUILD_LINK
- export MSG_TITLE="⌛🗳 Start to build docker image for new release \`$DRONE_SOURCE_BRANCH\`(\`${DRONE_TAG}\`)"
- export MSG_CONTENT="[tag link](https://github.com/gnes-ai/gnes/tree/${DRONE_TAG}), click the link below to see the status"
- ./shell/push-wechatwork.sh

- name: build and push docker images
Expand Down Expand Up @@ -123,8 +123,8 @@ steps:
BOT_URL:
from_secret: BOT_URL
commands:
- export MSG_LINK=$DRONE_REPO_LINK
- export MSG_TITLE="✅🎁 All images is successfully updated!"
- export MSG_LINK="https://github.com/gnes-ai/gnes/tree/${DRONE_TAG}"
- export MSG_TITLE="✅🎁 All images are successfully delivered!"
- export MSG_CONTENT=""
- ./shell/push-wechatwork.sh

Expand All @@ -136,7 +136,7 @@ steps:
commands:
- export MSG_LINK=$DRONE_BUILD_LINK
- export MSG_TITLE="❌🎁 fail to build docker image!"
- export MSG_CONTENT="please inform [$DRONE_COMMIT_AUTHOR]($DRONE_COMMIT_AUTHOR_EMAIL) to modify and fix [\`$DRONE_SOURCE_BRANCH\`]($DRONE_COMMIT_LINK). click the link below to see the details."
- export MSG_CONTENT="[tag link](https://github.com/gnes-ai/gnes/tree/${DRONE_TAG}) please inform [$DRONE_COMMIT_AUTHOR]($DRONE_COMMIT_AUTHOR_EMAIL) to modify and fix [\`$DRONE_SOURCE_BRANCH\`]($DRONE_COMMIT_LINK). click the link below to see the details."
- ./shell/push-wechatwork.sh
when:
status:
Expand Down
8 changes: 4 additions & 4 deletions .drone.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ steps:
commands:
- export MSG_LINK=${DRONE_BUILD_LINK}
- export MSG_TITLE="⌛🏗 Start a CI pipeline \`$DRONE_SOURCE_BRANCH\`(\`${DRONE_BUILD_NUMBER}\`)"
- export MSG_CONTENT="submit by [$DRONE_COMMIT_AUTHOR]($DRONE_COMMIT_AUTHOR_EMAIL), click the link below to see the status"
- export MSG_CONTENT="submit by [$DRONE_COMMIT_AUTHOR]($DRONE_COMMIT_AUTHOR_EMAIL) [PR link](https://github.com/gnes-ai/gnes/pull/${DRONE_PULL_REQUEST}), click the link below to see the status"
- ./shell/push-wechatwork.sh

- name: check commit style
Expand Down Expand Up @@ -65,7 +65,7 @@ steps:
BOT_URL:
from_secret: BOT_URL
commands:
- export MSG_LINK=$DRONE_REPO_LINK
- export MSG_LINK="https://github.com/gnes-ai/gnes/pull/${DRONE_PULL_REQUEST}"
- export MSG_TITLE="✅😃 All tests passed, good job! \`$DRONE_SOURCE_BRANCH\`(\`${DRONE_BUILD_NUMBER}\`)"
- export MSG_CONTENT="the branch \`$DRONE_SOURCE_BRANCH\` submit by [$DRONE_COMMIT_AUTHOR]($DRONE_COMMIT_AUTHOR_EMAIL) is ready to merge to master"
- ./shell/push-wechatwork.sh
Expand All @@ -77,9 +77,9 @@ steps:
BOT_URL:
from_secret: BOT_URL
commands:
- export MSG_LINK=$DRONE_BUILD_LINK
- export MSG_LINK=${DRONE_BUILD_LINK}
- export MSG_TITLE="❌😥 CI pipeline \`$DRONE_SOURCE_BRANCH\`(\`${DRONE_BUILD_NUMBER}\`) is failed!"
- export MSG_CONTENT="please inform [$DRONE_COMMIT_AUTHOR]($DRONE_COMMIT_AUTHOR_EMAIL) to modify and fix [\`$DRONE_SOURCE_BRANCH\`]($DRONE_COMMIT_LINK). click the link below to see the details."
- export MSG_CONTENT="[PR link](https://github.com/gnes-ai/gnes/pull/${DRONE_PULL_REQUEST}) please inform [$DRONE_COMMIT_AUTHOR]($DRONE_COMMIT_AUTHOR_EMAIL) to modify and fix [\`$DRONE_SOURCE_BRANCH\`]($DRONE_COMMIT_LINK). click the link below to see the details."
- ./shell/push-wechatwork.sh
when:
status:
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ Now let's see what the YAML config says. First impression, it is pretty intuitiv

```yaml
!TextPreprocessor
parameter:
parameters:
start_doc_id: 0
random_doc_id: True
deliminator: "[.!?]+"
Expand All @@ -277,19 +277,19 @@ gnes_config:
!PipelineEncoder
components:
- !GPT2Encoder
parameter:
parameters:
model_dir: $GPT2_CI_MODEL
pooling_stragy: REDUCE_MEAN
gnes_config:
is_trained: true
- !PCALocalEncoder
parameter:
parameters:
output_dim: 32
num_locals: 8
gnes_config:
batch_size: 2048
- !PQEncoder
parameter:
parameters:
cluster_per_byte: 8
num_bytes: 8
gnes_config:
Expand All @@ -304,7 +304,7 @@ gnes_config:
```yaml
!BIndexer
parameter:
parameters:
num_bytes: 8
data_path: /out_data/idx.binary
gnes_config:
Expand Down
10 changes: 5 additions & 5 deletions docs/chapter/yaml-config.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ Together they define the behavior of a GNES system. Roughly speaking,
!PipelineEncoder
components:
- !Word2VecEncoder
parameter:
parameters:
model_dir: /ext_data/sgns.wiki.bigram-char.refine
property:
is_trained: true
- !PCALocalEncoder
parameter:
parameters:
output_dim: 200
num_locals: 10
property:
Expand All @@ -42,18 +42,18 @@ One can also append extra component to this pipeline, e.g. adding quantization.
!PipelineEncoder
components:
- !Word2VecEncoder
parameter:
parameters:
model_dir: /ext_data/sgns.wiki.bigram-char.refine
property:
is_trained: true
- !PCALocalEncoder
parameter:
parameters:
output_dim: 200
num_locals: 10
property:
batch_size: 2048
- !PQEncoder
parameter:
parameters:
cluster_per_byte: 20
num_bytes: 10
```
Expand Down
35 changes: 17 additions & 18 deletions gnes/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __call__(cls, *args, **kwargs):
for k, v in TrainableType.default_gnes_config.items():
if k in gnes_config:
v = gnes_config[k]
v = _expand_env_var(v)
if not hasattr(obj, k):
setattr(obj, k, v)

Expand Down Expand Up @@ -172,9 +173,9 @@ def _post_init_wrapper(self):
_name = '%s-%s' % (self.__class__.__name__, _id)
self.logger.warning(
'this object is not named ("- gnes_config: - name" is not found in YAML config), '
'i will call it as "%s". '
'naming the object is important especially when you need to '
'serialize/deserialize/store/load the object.' % _name)
'i will call it "%s". '
'naming the object is important as it provides an unique identifier when '
'serializing/deserializing this object.' % _name)
setattr(self, 'name', _name)

_before = set(list(self.__dict__.keys()))
Expand All @@ -199,8 +200,6 @@ def yaml_full_path(self):
def __getstate__(self):
d = dict(self.__dict__)
del d['logger']
if '_file_lock' in d:
del d['_file_lock']
for k in self._post_init_vars:
del d[k]
return d
Expand Down Expand Up @@ -309,15 +308,15 @@ def _get_instance_from_yaml(cls, constructor, node, stop_on_import_error=False):
cls.init_from_yaml = True

if cls.store_args_kwargs:
p = data.get('parameter', {}) # type: Dict[str, Any]
p = data.get('parameters', {}) # type: Dict[str, Any]
a = p.pop('args') if 'args' in p else ()
k = p.pop('kwargs') if 'kwargs' in p else {}
# maybe there are some hanging kwargs in "parameter"
tmp_a = (cls._convert_env_var(v) for v in a)
tmp_p = {kk: cls._convert_env_var(vv) for kk, vv in {**k, **p}.items()}
# maybe there are some hanging kwargs in "parameters"
tmp_a = (_expand_env_var(v) for v in a)
tmp_p = {kk: _expand_env_var(vv) for kk, vv in {**k, **p}.items()}
obj = cls(*tmp_a, **tmp_p, gnes_config=data.get('gnes_config', {}))
else:
tmp_p = {kk: cls._convert_env_var(vv) for kk, vv in data.get('parameter', {}).items()}
tmp_p = {kk: _expand_env_var(vv) for kk, vv in data.get('parameters', {}).items()}
obj = cls(**tmp_p, gnes_config=data.get('gnes_config', {}))

obj.logger.info('initialize %s from a yaml config' % cls.__name__)
Expand All @@ -335,21 +334,14 @@ def _get_dump_path_from_config(gnes_config: Dict):
if os.path.exists(dump_path):
return dump_path

@staticmethod
def _convert_env_var(v):
if isinstance(v, str):
return parse_arg(os.path.expandvars(v))
else:
return v

@staticmethod
def _dump_instance_to_yaml(data):
# note: we only dump non-default property for the sake of clarity
p = {k: getattr(data, k) for k, v in TrainableType.default_gnes_config.items() if getattr(data, k) != v}
a = {k: v for k, v in data._init_kwargs_dict.items() if k not in TrainableType.default_gnes_config}
r = {}
if a:
r['parameter'] = a
r['parameters'] = a
if p:
r['gnes_config'] = p
return r
Expand Down Expand Up @@ -419,3 +411,10 @@ def from_yaml(cls, constructor, node):
if not from_dump and 'components' in data:
obj.components = lambda: data['components']
return obj


def _expand_env_var(v: str) -> str:
if isinstance(v, str):
return parse_arg(os.path.expandvars(v))
else:
return v
6 changes: 3 additions & 3 deletions gnes/composer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def rule5():
# a shortcut fn: based on c3(): (N)-2-(N) with pub sub connection
rule3()
router_layers[0].components[0]['socket_out'] = str(SocketType.PUB_BIND)
router_layers[0].components[0]['yaml_path'] = '"!PublishRouter {parameter: {num_part: %d}}"' \
router_layers[0].components[0]['yaml_path'] = '"!PublishRouter {parameters: {num_part: %d}}"' \
% len(layer.components)
for c in layer.components:
c['socket_in'] = str(SocketType.SUB_CONNECT)
Expand Down Expand Up @@ -439,7 +439,7 @@ def rule7():
router_layer = YamlComposer.Layer(layer_id=self._num_layer)
self._num_layer += 1
r0 = CommentedMap({'name': 'Router',
'yaml_path': '"!PublishRouter {parameter: {num_part: %d}}"' % len(layer.components),
'yaml_path': '"!PublishRouter {parameters: {num_part: %d}}"' % len(layer.components),
'socket_in': str(SocketType.PULL_BIND),
'socket_out': str(SocketType.PUB_BIND),
'port_in': self._get_random_port(),
Expand Down Expand Up @@ -468,7 +468,7 @@ def rule10():
router_layer = YamlComposer.Layer(layer_id=self._num_layer)
self._num_layer += 1
r0 = CommentedMap({'name': 'Router',
'yaml_path': '"!PublishRouter {parameter: {num_part: %d}}"' % len(layer.components),
'yaml_path': '"!PublishRouter {parameters: {num_part: %d}}"' % len(layer.components),
'socket_in': str(SocketType.PULL_BIND),
'socket_out': str(SocketType.PUB_BIND),
'port_in': self._get_random_port(),
Expand Down
13 changes: 5 additions & 8 deletions gnes/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,13 +504,11 @@ def load_contrib_module():

if contrib:
default_logger.info(
'find a value in $GNES_CONTRIB_MODULE=%s, will try to load these modules from external' % contrib)
for c in contrib.split(','):
if ':' in c:
_name, _path = c.split(':')
m = PathImporter.add_modules(_path)
modules.append(m)
default_logger.info('successfully register %s class, you can now use it via yaml.' % m)
'find a value in $GNES_CONTRIB_MODULE=%s, will load them as external modules' % contrib)
for p in contrib.split(','):
m = PathImporter.add_modules(p)
modules.append(m)
default_logger.info('successfully registered %s class, you can now use it via yaml.' % m)
return modules


Expand All @@ -528,7 +526,6 @@ def add_modules(*paths):
if not os.path.exists(p):
raise FileNotFoundError('cannot import module from %s, file not exist')
module, spec = PathImporter._path_import(p)
sys.modules[spec.name] = module
return module

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion gnes/service/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def __init__(self, service_cls, args):
else:
_head_router.socket_out = SocketType.PUB_BIND
_head_router.yaml_path = resolve_yaml_path(
'!PublishRouter {parameter: {num_part: %d}}' % args.num_parallel)
'!PublishRouter {parameters: {num_part: %d}}' % args.num_parallel)

if args.parallel_type.is_block:
_tail_router.yaml_path = resolve_yaml_path('BaseReduceRouter')
Expand Down
4 changes: 2 additions & 2 deletions gnes/service/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ def post_init(self):
options=[('grpc.max_send_message_length', self.args.max_message_size * 1024 * 1024),
('grpc.max_receive_message_length', self.args.max_message_size * 1024 * 1024)])

foo = PathImporter.add_modules(self.args.pb2_path, self.args.pb2_grpc_path)
m = PathImporter.add_modules(self.args.pb2_path, self.args.pb2_grpc_path)

# build stub
self.stub = getattr(foo, self.args.stub_name)(self.channel)
self.stub = getattr(m, self.args.stub_name)(self.channel)

def close(self):
self.channel.close()
Expand Down
9 changes: 9 additions & 0 deletions release.sh
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,20 @@ function make_release_note {
}

BRANCH=$(git rev-parse --abbrev-ref HEAD)

if [[ "$BRANCH" != "master" ]]; then
printf "You are not at master branch, exit\n";
exit 1;
fi

LAST_UPDATE=`git show --no-notes --format=format:"%H" $BRANCH | head -n 1`
LAST_COMMIT=`git show --no-notes --format=format:"%H" origin/$BRANCH | head -n 1`

if [ $LAST_COMMIT != $LAST_UPDATE ]; then
printf "Your local $BRANCH is behind the remote master, exit\n"
exit 1;
fi

if [[ -z "${BOT_URL}" ]]; then
printf "BOT_URL is not set! Need to export BOT_URL=xxx"
exit 1;
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/dummy.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
!FooContribEncoder
parameter:
parameters:
bar: 531
15 changes: 15 additions & 0 deletions tests/contrib/transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import List

import numpy as np

from gnes.encoder.base import BaseTextEncoder


class PyTorchTransformers(BaseTextEncoder):

def __init__(self, model_name: str = 'bert-base-uncased', *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_name = model_name

def encode(self, text: List[str], *args, **kwargs):
return np.random.random([5, 128])
7 changes: 7 additions & 0 deletions tests/contrib/transformer.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
!PyTorchTransformers
parameter:
model_name: bert-base-uncased
gnes_config:
name: my_transformer # a customized name
is_trained: true # indicate the model has been trained
work_dir: ./ # path for serialize/deserialize
7 changes: 3 additions & 4 deletions tests/test_contrib_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

dirname = os.path.dirname(__file__)
module_path = os.path.join(dirname, 'contrib', 'dummy_contrib.py')
cls_name = 'FooContribEncoder'


@unittest.SkipTest
Expand All @@ -20,7 +19,7 @@ def tearDown(self):
os.remove(self.dump_yaml_path)

def test_load_contrib(self):
os.environ['GNES_CONTRIB_MODULE'] = '%s:%s' % (cls_name, module_path)
os.environ['GNES_CONTRIB_MODULE'] = module_path
from gnes.encoder.base import BaseEncoder, BaseTextEncoder
a = BaseEncoder.load_yaml(self.yaml_path)
self.assertIsInstance(a, BaseTextEncoder)
Expand All @@ -32,14 +31,14 @@ def test_load_contrib(self):
self.assertEqual(b.encode([]), 'hello 531')

def test_bad_name(self):
os.environ['GNES_CONTRIB_MODULE'] = '%s:%s' % ('blah', module_path)
os.environ['GNES_CONTRIB_MODULE'] = module_path
try:
from gnes.encoder.base import BaseEncoder
except AttributeError:
pass

def test_bad_path(self):
os.environ['GNES_CONTRIB_MODULE'] = '%s:%s' % (cls_name, 'blah')
os.environ['GNES_CONTRIB_MODULE'] = 'blah'
try:
from gnes.encoder.base import BaseEncoder
except AttributeError:
Expand Down
Loading

0 comments on commit 0c586fa

Please sign in to comment.