From 928608483c63571df870bc1023503c16067d62c8 Mon Sep 17 00:00:00 2001 From: Jem Date: Fri, 6 Sep 2019 11:34:18 +0800 Subject: [PATCH] feat(reducer): add concat reducer --- gnes/router/__init__.py | 1 + gnes/router/reduce.py | 31 +++++++++++++++++++++++++++++++ tests/test_router.py | 17 ++++++++++++----- 3 files changed, 44 insertions(+), 5 deletions(-) diff --git a/gnes/router/__init__.py b/gnes/router/__init__.py index 75de2cc1..cd3f9d8e 100644 --- a/gnes/router/__init__.py +++ b/gnes/router/__init__.py @@ -27,6 +27,7 @@ 'DocFillReducer': 'reduce', 'PublishRouter': 'map', 'DocBatchRouter': 'map', + 'ConcatEmbedRouter': 'reduce' } register_all_class(_cls2file_map, 'router') diff --git a/gnes/router/reduce.py b/gnes/router/reduce.py index 65eccc49..e21a5f11 100644 --- a/gnes/router/reduce.py +++ b/gnes/router/reduce.py @@ -14,8 +14,10 @@ # limitations under the License. from typing import List +import numpy as np from .base import BaseReduceRouter, BaseTopkReduceRouter +from ..proto import gnes_pb2, blob2array, array2blob class DocFillReducer(BaseReduceRouter): @@ -74,3 +76,32 @@ def get_key(self, x: 'gnes_pb2.Response.QueryResponse.ScoredResult') -> str: def set_key(self, x: 'gnes_pb2.Response.QueryResponse.ScoredResult', k: str): x.chunk.doc_id, x.chunk.offset = map(int, k.split('-')) + + +class ConcatEmbedRouter(BaseReduceRouter): + """ + Gather all embeddings from multiple encoders and concat them on a specific axis. + In default, concat will happen on the last axis. + """ + + def apply(self, msg: 'gnes_pb2.Message', accum_msgs: List['gnes_pb2.Message'], *args, **kwargs): + body = getattr(msg, msg.WhichOneof('body')) + msg_type = type(getattr(body, body.WhichOneof('body'))) + if msg_type == gnes_pb2.Request.QueryRequest: + for i in range(len(msg.request.search.query.chunks)): + concat_embedding = array2blob( + np.concatenate([blob2array(m.request.search.query.chunks[i].embedding) for m in accum_msgs], + axis=1)) + msg.request.search.query.chunks[i].embedding.CopyFrom(concat_embedding) + + elif msg_type == gnes_pb2.Request.IndexRequest: + for i in range(len(msg.request.index.docs)): + for j in range(len(msg.request.index.docs[i].chunks)): + concat_embedding = array2blob( + np.concatenate( + [blob2array(m.request.index.docs[i].chunks[j].embedding) for m in accum_msgs], axis=1)) + msg.request.index.docs[i].chunks[j].embedding.CopyFrom(concat_embedding) + else: + self.logger.error('dont know how to handle %s' % msg_type) + + super().apply(msg, accum_msgs) \ No newline at end of file diff --git a/tests/test_router.py b/tests/test_router.py index f3adfd1f..39ec8ef5 100644 --- a/tests/test_router.py +++ b/tests/test_router.py @@ -334,7 +334,7 @@ def test_doc_sum_reduce_router(self): self.assertGreaterEqual(r.response.search.topk_results[0].score.value, r.response.search.topk_results[-1].score.value) - @unittest.SkipTest + # @unittest.SkipTest def test_concat_router(self): args = set_router_parser().parse_args([ '--yaml_path', self.concat_router_yaml, @@ -345,9 +345,12 @@ def test_concat_router(self): '--port_out', str(args.port_in), '--socket_in', str(SocketType.PULL_CONNECT) ]) + # 10 chunks in each doc, dimension of chunk embedding is (5, 2) with RouterService(args), ZmqClient(c_args) as c1: msg = gnes_pb2.Message() - msg.request.search.query.chunk_embeddings.CopyFrom(array2blob(np.random.random([5, 2]))) + for i in range(10): + c = msg.request.search.query.chunks.add() + c.embedding.CopyFrom(array2blob(np.random.random([5, 2]))) msg.envelope.num_part.extend([1, 3]) c1.send_message(msg) c1.send_message(msg) @@ -355,11 +358,14 @@ def test_concat_router(self): r = c1.recv_message() self.assertSequenceEqual(r.envelope.num_part, [1]) print(r.envelope.routes) - self.assertEqual(r.request.search.query.chunk_embeddings.shape, [5, 6]) + for i in range(10): + self.assertEqual(r.request.search.query.chunks[i].embedding.shape, [5, 6]) for j in range(1, 4): d = msg.request.index.docs.add() - d.chunk_embeddings.CopyFrom(array2blob(np.random.random([5, 2 * j]))) + for k in range(10): + c = d.chunks.add() + c.embedding.CopyFrom(array2blob(np.random.random([5, 2]))) c1.send_message(msg) c1.send_message(msg) @@ -367,7 +373,8 @@ def test_concat_router(self): r = c1.recv_message() self.assertSequenceEqual(r.envelope.num_part, [1]) for j in range(1, 4): - self.assertEqual(r.request.index.docs[j - 1].chunk_embeddings.shape, [5, 6 * j]) + for i in range(10): + self.assertEqual(r.request.index.docs[j - 1].chunks[i].embedding.shape, [5, 6]) def test_multimap_multireduce(self): # p1 ->