Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
Merge pull request #56 from gnes-ai/fix-composer
Browse files Browse the repository at this point in the history
fix(base): fix gnes_config mixed in kwargs
  • Loading branch information
jemmyshin authored Jul 25, 2019
2 parents 50269c7 + c52c2cc commit af50969
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 15 deletions.
25 changes: 16 additions & 9 deletions gnes/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
T = TypeVar('T', bound='TrainableBase')



def register_all_class(cls2file_map: Dict, module_name: str):
import importlib
for k, v in cls2file_map.items():
Expand Down Expand Up @@ -78,12 +77,17 @@ def __call__(cls, *args, **kwargs):
# do _preload_package
getattr(cls, '_pre_init', lambda *x: None)()

if 'gnes_config' in kwargs:
gnes_config = kwargs.pop('gnes_config')
else:
gnes_config = {}

obj = type.__call__(cls, *args, **kwargs)

# set attribute
for k, v in TrainableType.default_gnes_config.items():
if k in kwargs:
v = kwargs[k]
if k in gnes_config:
v = gnes_config[k]
if not hasattr(obj, k):
setattr(obj, k, v)

Expand Down Expand Up @@ -164,7 +168,7 @@ def __init__(self, *args, **kwargs):
self._post_init_vars = set()

def _post_init_wrapper(self):
if not getattr(self, 'name', None):
if not getattr(self, 'name', None) and os.environ.get('GNES_WARN_UNNAMED_COMPONENT', '1') == '1':
_id = str(uuid.uuid4()).split('-')[0]
_name = '%s-%s' % (self.__class__.__name__, _id)
self.logger.warning(
Expand Down Expand Up @@ -290,6 +294,9 @@ def _get_instance_from_yaml(cls, constructor, node, stop_on_import_error=False):
if stop_on_import_error:
raise RuntimeError('Cannot import module, pip install may required') from ex

if node.tag in {'!PipelineEncoder', '!CompositionalEncoder'}:
os.environ['GNES_WARN_UNNAMED_COMPONENT'] = '0'

data = ruamel.yaml.constructor.SafeConstructor.construct_mapping(
constructor, node, deep=True)

Expand All @@ -309,14 +316,17 @@ def _get_instance_from_yaml(cls, constructor, node, stop_on_import_error=False):
# maybe there are some hanging kwargs in "parameter"
tmp_a = (cls._convert_env_var(v) for v in a)
tmp_p = {kk: cls._convert_env_var(vv) for kk, vv in {**k, **p}.items()}
obj = cls(*tmp_a, **tmp_p, **data.get('gnes_config', {}))
obj = cls(*tmp_a, **tmp_p, gnes_config=data.get('gnes_config', {}))
else:
tmp_p = {kk: cls._convert_env_var(vv) for kk, vv in data.get('parameter', {}).items()}
obj = cls(**tmp_p, **data.get('gnes_config', {}))
obj = cls(**tmp_p, gnes_config=data.get('gnes_config', {}))

obj.logger.info('initialize %s from a yaml config' % cls.__name__)
cls.init_from_yaml = False

if node.tag in {'!PipelineEncoder', '!CompositionalEncoder'}:
os.environ['GNES_WARN_UNNAMED_COMPONENT'] = '1'

return obj, data, load_from_dump

@staticmethod
Expand Down Expand Up @@ -344,6 +354,3 @@ def _dump_instance_to_yaml(data):
if p:
r['gnes_config'] = p
return r



39 changes: 33 additions & 6 deletions gnes/composer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,9 @@ def rule3():
'socket_out': str(SocketType.PUSH_BIND),
'port_in': last_layer.components[0]['port_out'],
'port_out': self._get_random_port()})
layer.components[0]['socket_in'] = str(SocketType.PULL_CONNECT)
layer.components[0]['port_in'] = r['port_out']
for c in layer.components:
c['socket_in'] = str(SocketType.PULL_CONNECT)
c['port_in'] = r['port_out']
router_layer.append(r)
router_layers.append(router_layer)

Expand Down Expand Up @@ -458,7 +459,6 @@ def rule10():
c['port_in'] = r0['port_out']

def rule8():
last_layer.components[0]['socket_out'] = str(SocketType.PUSH_CONNECT)
router_layer = YamlComposer.Layer(layer_id=self._num_layer)
self._num_layer += 1
r = CommentedMap({'name': 'Router',
Expand Down Expand Up @@ -503,11 +503,33 @@ def rule9():
last_layer.components[0]['socket_out'] = str(SocketType.PUSH_CONNECT)
layer.components[0]['socket_in'] = str(SocketType.PULL_BIND)

def rule11():
# a shortcut fn: (N)-2-(N) with push pull connection
router_layer = YamlComposer.Layer(layer_id=self._num_layer)
self._num_layer += 1
r = CommentedMap({'name': 'Router',
'yaml_path': None,
'socket_in': str(SocketType.PULL_BIND),
'socket_out': str(SocketType.PUSH_BIND),
'port_in': self._get_random_port(),
'port_out': self._get_random_port()})

for c in last_layer.components:
c['socket_out'] = str(SocketType.PUSH_CONNECT)
c['port_out'] = r['port_in']
for c in layer.components:
c['socket_in'] = str(SocketType.PULL_CONNECT)
c['port_in'] = r['port_out']
router_layer.append(r)
router_layers.append(router_layer)

router_layers = [] # type: List['self.Layer']
# bind the last out to current in
last_layer.components[0]['port_out'] = self._get_random_port()
layer.components[0]['port_in'] = last_layer.components[0]['port_out']

if last_layer.is_single_component:
last_layer.components[0]['port_out'] = self._get_random_port()
for c in layer.components:
c['port_in'] = last_layer.components[0]['port_out']
# 1-to-?
if layer.is_single_component:
# 1-to-(1)
Expand All @@ -530,8 +552,13 @@ def rule9():
rule6()
elif last_layer.is_homo_multi_component:
# (N)-to-?
last_layer.components[0]['port_out'] = self._get_random_port()

last_income = self.Layer.get_value(last_layer.components[0], 'income')

for c in layer.components:
c['port_in'] = last_layer.components[0]['port_out']

if layer.is_single_component:
if last_income == 'pull':
# (N)-to-1
Expand Down Expand Up @@ -560,7 +587,7 @@ def rule9():
else:
raise NotImplementedError('replica type: %s is not recognized!' % last_income)
elif last_layer.is_heto_single_component:
rule3()
rule8()
else:
rule8()
return [last_layer, *router_layers]
14 changes: 14 additions & 0 deletions tests/test_load_dump_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@ def test_base(self):
b = BaseEncoder.load_yaml(self.yaml_path)
self.assertTrue(b.is_trained)

def test_name_warning(self):
d1 = DummyTFEncoder()
d2 = DummyTFEncoder()
d1.name = ''
d2.name = ''
d3 = PipelineEncoder()
d3.component = lambda: [d1, d2]
d3.name = 'dummy-pipeline'
d3.work_dir = './'
d3.dump()
d3.dump_yaml()
print('there should not be any warning after this line')
BaseEncoder.load_yaml(d3.yaml_full_path)

def test_dummytf(self):
d1 = DummyTFEncoder()
self.assertEqual(d1.encode(1), 2)
Expand Down

0 comments on commit af50969

Please sign in to comment.