Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
fix(base): fix duplicate load and init from yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Jul 19, 2019
1 parent 69a486e commit 991e442
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 32 deletions.
73 changes: 47 additions & 26 deletions gnes/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _import(module_name, class_name):


class TrainableType(type):
default_property = {
default_gnes_config = {
'is_trained': False,
'batch_size': None,
'work_dir': os.environ.get('GNES_VOLUME', os.getcwd()),
Expand All @@ -77,12 +77,8 @@ def __call__(cls, *args, **kwargs):

obj = type.__call__(cls, *args, **kwargs)

# set attribute
for k, v in TrainableType.default_property.items():
if not hasattr(obj, k):
setattr(obj, k, v)
obj._set_gnes_config(**kwargs)

# do _post_init()
getattr(obj, '_post_init_wrapper', lambda *x: None)()
return obj

Expand Down Expand Up @@ -179,6 +175,23 @@ def post_init(self):
def pre_init(cls):
pass

def _set_gnes_config(self, **kwargs):
# set attribute
for k, v in TrainableType.default_gnes_config.items():
if k in kwargs:
v = kwargs[k]
setattr(self, k, v)

if not getattr(self, 'name', None):
_id = str(uuid.uuid4()).split('-')[0]
_name = '%s-%s' % (self.__class__.__name__, _id)
self.logger.warning(
'this object is not named ("- gnes_config: - name" is not found in YAML config), '
'i will call it as "%s". '
'However, naming the object is important especially when you need to '
'serialize/deserialize/store/load the object.' % _name)
setattr(self, 'name', _name)

@property
def dump_full_path(self):
return os.path.join(self.work_dir, '%s.bin' % self.name)
Expand Down Expand Up @@ -286,30 +299,38 @@ def _get_instance_from_yaml(cls, constructor, node, stop_on_import_error=False):

data = ruamel.yaml.constructor.SafeConstructor.construct_mapping(
constructor, node, deep=True)
cls.init_from_yaml = True

if cls.store_args_kwargs:
p = data.get('parameter', {}) # type: Dict[str, Any]
a = p.pop('args') if 'args' in p else ()
k = p.pop('kwargs') if 'kwargs' in p else {}
# maybe there are some hanging kwargs in "parameter"
tmp_a = (cls._convert_env_var(v) for v in a)
tmp_p = {kk: cls._convert_env_var(vv) for kk, vv in {**k, **p}.items()}
obj = cls(*tmp_a, **tmp_p)
else:
tmp_p = {kk: cls._convert_env_var(vv) for kk, vv in data.get('parameter', {}).items()}
obj = cls(**tmp_p)

for k, v in data.get('gnes_config', {}).items():
old = getattr(obj, k, None)
setattr(obj, k, v)
if old and old != v:
obj.logger.info('gnes_config: %r is replaced from %r to %r' % (k, old, v))
dump_path = cls._get_dump_path_from_config(data)
if dump_path:
obj = cls.load(dump_path)
obj.logger.info('restore %s from %s' % (cls.__name__, dump_path))
else:
cls.init_from_yaml = True

if cls.store_args_kwargs:
p = data.get('parameter', {}) # type: Dict[str, Any]
a = p.pop('args') if 'args' in p else ()
k = p.pop('kwargs') if 'kwargs' in p else {}
# maybe there are some hanging kwargs in "parameter"
tmp_a = (cls._convert_env_var(v) for v in a)
tmp_p = {kk: cls._convert_env_var(vv) for kk, vv in {**k, **p}.items()}
obj = cls(*tmp_a, **tmp_p, **data.get('gnes_config', {}))
else:
tmp_p = {kk: cls._convert_env_var(vv) for kk, vv in data.get('parameter', {}).items()}
obj = cls(**tmp_p, **data.get('gnes_config', {}))

cls.init_from_yaml = False
obj.logger.info('initialize %s from a yaml config' % cls.__name__)
cls.init_from_yaml = False

return obj, data

@staticmethod
def _get_dump_path_from_config(gnes_config: Dict):
if 'work_dir' in gnes_config and 'name' in gnes_config:
dump_path = os.path.join(gnes_config['work_dir'], '%s.bin' % gnes_config['name'])
if os.path.exists(dump_path):
return dump_path

@staticmethod
def _convert_env_var(v):
if isinstance(v, str):
Expand All @@ -320,7 +341,7 @@ def _convert_env_var(v):
@staticmethod
def _dump_instance_to_yaml(data):
# note: we only dump non-default property for the sake of clarity
p = {k: getattr(data, k) for k, v in TrainableType.default_property.items() if getattr(data, k) != v}
p = {k: getattr(data, k) for k, v in TrainableType.default_gnes_config.items() if getattr(data, k) != v}
a = {k: v for k, v in data._init_kwargs_dict.items()}
r = {}
if a:
Expand Down
7 changes: 1 addition & 6 deletions gnes/service/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,14 +304,9 @@ def post_init(self):

def load_model(self, base_class: Type[TrainableBase]) -> T:
try:
model = base_class.load_yaml(self.args.yaml_path)
return base_class.load_yaml(self.args.yaml_path)
except FileNotFoundError:
raise ComponentNotLoad
try:
model = model.__class__.load(model.dump_full_path)
except FileNotFoundError:
self.logger.warning('load an empty %s from %s' % (model.__class__.__name__, self.args.yaml_path))
return model

@handler.register(NotImplementedError)
def _handler_default(self, msg: 'gnes_pb2.Message'):
Expand Down

0 comments on commit 991e442

Please sign in to comment.