From e5fefcee9ea003c7d244bd58c889606a03e12936 Mon Sep 17 00:00:00 2001 From: hanhxiao Date: Tue, 6 Aug 2019 18:42:24 +0800 Subject: [PATCH] fix(encoder): fix batching in encoder --- gnes/cli/parser.py | 3 +++ gnes/composer/base.py | 36 ++++++++--------------------- gnes/encoder/image/base.py | 4 ++-- gnes/encoder/image/inception.py | 2 +- gnes/encoder/video/incep_mixture.py | 4 ++-- tests/test_router.py | 4 ++-- 6 files changed, 19 insertions(+), 34 deletions(-) diff --git a/gnes/cli/parser.py b/gnes/cli/parser.py index 95d5cec4..8b69d988 100644 --- a/gnes/cli/parser.py +++ b/gnes/cli/parser.py @@ -31,6 +31,9 @@ def resolve_yaml_path(path): elif path.isidentifier(): # possible class name return io.StringIO('!%s {}' % path) + elif path.startswith('!'): + # possible YAML content + return io.StringIO(path) else: raise argparse.ArgumentTypeError('%s can not be resolved, it should be a readable stream,' ' or a valid file path, or a supported class name.' % path) diff --git a/gnes/composer/base.py b/gnes/composer/base.py index 8a51e234..1fa44ef5 100644 --- a/gnes/composer/base.py +++ b/gnes/composer/base.py @@ -384,7 +384,7 @@ def rule3(): self._num_layer += 1 last_layer.components[0]['socket_out'] = str(SocketType.PUSH_CONNECT) r = CommentedMap({'name': 'Router', - 'yaml_path': None, + 'yaml_path': 'BaseRouter', 'socket_in': str(SocketType.PULL_BIND), 'socket_out': str(SocketType.PUSH_BIND), 'port_in': last_layer.components[0]['port_out'], @@ -405,6 +405,8 @@ def rule5(): # a shortcut fn: based on c3(): (N)-2-(N) with pub sub connection rule3() router_layers[0].components[0]['socket_out'] = str(SocketType.PUB_BIND) + router_layers[0].components[0]['yaml_path'] = '"!PublishRouter {parameter: {num_part: %d}}"' \ + % len(layer.components) for c in layer.components: c['socket_in'] = str(SocketType.SUB_CONNECT) @@ -415,7 +417,7 @@ def rule6(): for c in layer.components: income = self.Layer.get_value(c, 'income') r = CommentedMap({'name': 'Router', - 'yaml_path': None, + 'yaml_path': 'BaseReduceRouter', 'socket_in': str(SocketType.SUB_CONNECT), 'socket_out': str(SocketType.PUSH_BIND) if income == 'pull' else str( SocketType.PUB_BIND), @@ -432,7 +434,7 @@ def rule7(): router_layer = YamlComposer.Layer(layer_id=self._num_layer) self._num_layer += 1 r0 = CommentedMap({'name': 'Router', - 'yaml_path': None, + 'yaml_path': '"!PublishRouter {parameter: {num_part: %d}}"' % len(layer.components), 'socket_in': str(SocketType.PULL_BIND), 'socket_out': str(SocketType.PUB_BIND), 'port_in': self._get_random_port(), @@ -445,7 +447,7 @@ def rule7(): self._num_layer += 1 for c in layer.components: r = CommentedMap({'name': 'Router', - 'yaml_path': None, + 'yaml_path': 'BaseRouter', 'socket_in': str(SocketType.SUB_CONNECT), 'socket_out': str(SocketType.PUSH_BIND), 'port_in': r0['port_out'], @@ -461,7 +463,7 @@ def rule10(): router_layer = YamlComposer.Layer(layer_id=self._num_layer) self._num_layer += 1 r0 = CommentedMap({'name': 'Router', - 'yaml_path': None, + 'yaml_path': '"!PublishRouter {parameter: {num_part: %d}}"' % len(layer.components), 'socket_in': str(SocketType.PULL_BIND), 'socket_out': str(SocketType.PUB_BIND), 'port_in': self._get_random_port(), @@ -478,7 +480,7 @@ def rule8(): router_layer = YamlComposer.Layer(layer_id=self._num_layer) self._num_layer += 1 r = CommentedMap({'name': 'Router', - 'yaml_path': None, + 'yaml_path': 'BaseReduceRouter', 'socket_in': str(SocketType.PULL_BIND), 'socket_out': str(SocketType.PUSH_BIND), 'port_in': self._get_random_port(), @@ -489,7 +491,7 @@ def rule8(): if last_income == 'sub': c['socket_out'] = str(SocketType.PUSH_CONNECT) r_c = CommentedMap({'name': 'Router', - 'yaml_path': None, + 'yaml_path': 'BaseReduceRouter', 'socket_in': str(SocketType.PULL_BIND), 'socket_out': str(SocketType.PUSH_CONNECT), 'port_in': self._get_random_port(), @@ -519,26 +521,6 @@ def rule9(): last_layer.components[0]['socket_out'] = str(SocketType.PUSH_CONNECT) layer.components[0]['socket_in'] = str(SocketType.PULL_BIND) - def rule11(): - # a shortcut fn: (N)-2-(N) with push pull connection - router_layer = YamlComposer.Layer(layer_id=self._num_layer) - self._num_layer += 1 - r = CommentedMap({'name': 'Router', - 'yaml_path': None, - 'socket_in': str(SocketType.PULL_BIND), - 'socket_out': str(SocketType.PUSH_BIND), - 'port_in': self._get_random_port(), - 'port_out': self._get_random_port()}) - - for c in last_layer.components: - c['socket_out'] = str(SocketType.PUSH_CONNECT) - c['port_out'] = r['port_in'] - for c in layer.components: - c['socket_in'] = str(SocketType.PULL_CONNECT) - c['port_in'] = r['port_out'] - router_layer.append(r) - router_layers.append(router_layer) - router_layers = [] # type: List['self.Layer'] # bind the last out to current in diff --git a/gnes/encoder/image/base.py b/gnes/encoder/image/base.py index 2409f607..5d1d05a1 100644 --- a/gnes/encoder/image/base.py +++ b/gnes/encoder/image/base.py @@ -118,9 +118,9 @@ def _encode(_, img: List['np.ndarray']): # for video if len(img[0].shape) == 4: padding_image, max_lenth = _padding(img) - output = _encode(None, padding_image) + output = _encode(self, padding_image) # for image else: - output = _encode(None, img) + output = _encode(self, img) return output diff --git a/gnes/encoder/image/inception.py b/gnes/encoder/image/inception.py index 9919fba7..222c13fe 100644 --- a/gnes/encoder/image/inception.py +++ b/gnes/encoder/image/inception.py @@ -74,4 +74,4 @@ def _encode(_, data): feed_dict={self.inputs: data}) return end_points_[self.select_layer] - return _encode(None, img).astype(np.float32) + return _encode(self, img).astype(np.float32) diff --git a/gnes/encoder/video/incep_mixture.py b/gnes/encoder/video/incep_mixture.py index b3991ca1..c7b13057 100644 --- a/gnes/encoder/video/incep_mixture.py +++ b/gnes/encoder/video/incep_mixture.py @@ -118,7 +118,7 @@ def _encode1(_, data): feed_dict={self.inputs: data}) return end_points_[self.select_layer] - v = [_ for vi in _encode1(None, img) for _ in vi] + v = [_ for vi in _encode1(self, img) for _ in vi] v_input = [v[s:e] for s, e in zip(pos_start, pos_end)] v_input = [(vi + [[0.0] * self.input_size] * (max_len - len(vi)))[:max_len] for vi in v_input] @@ -129,4 +129,4 @@ def _encode2(_, data): return self.sess2.run(self.mix_model.repre, feed_dict={self.mix_model.feeds: data}) - return _encode2(None, v_input).astype(np.float32) + return _encode2(self, v_input).astype(np.float32) diff --git a/tests/test_router.py b/tests/test_router.py index f3e40f41..4e751d36 100644 --- a/tests/test_router.py +++ b/tests/test_router.py @@ -14,8 +14,8 @@ class TestProto(unittest.TestCase): def setUp(self): self.dirname = os.path.dirname(__file__) - self.publish_router_yaml = os.path.join(self.dirname, 'yaml', 'router-publish.yml') - self.batch_router_yaml = os.path.join(self.dirname, 'yaml', 'router-batch.yml') + self.publish_router_yaml = '!PublishRouter {parameter: {num_part: 2}}' + self.batch_router_yaml = '!DocBatchRouter {gnes_config: {batch_size: 2}}' self.reduce_router_yaml = 'BaseReduceRouter' self.chunk_router_yaml = 'ChunkReduceRouter' self.doc_router_yaml = 'DocReduceRouter'