Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
Merge pull request #225 from gnes-ai/concat_reducer
Browse files Browse the repository at this point in the history
feat(reducer): add concat reducer
  • Loading branch information
mergify[bot] authored Sep 6, 2019
2 parents 3d2a74f + 9286084 commit 8021d18
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 5 deletions.
1 change: 1 addition & 0 deletions gnes/router/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
'DocFillReducer': 'reduce',
'PublishRouter': 'map',
'DocBatchRouter': 'map',
'ConcatEmbedRouter': 'reduce'
}

register_all_class(_cls2file_map, 'router')
31 changes: 31 additions & 0 deletions gnes/router/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
# limitations under the License.

from typing import List
import numpy as np

from .base import BaseReduceRouter, BaseTopkReduceRouter
from ..proto import gnes_pb2, blob2array, array2blob


class DocFillReducer(BaseReduceRouter):
Expand Down Expand Up @@ -74,3 +76,32 @@ def get_key(self, x: 'gnes_pb2.Response.QueryResponse.ScoredResult') -> str:

def set_key(self, x: 'gnes_pb2.Response.QueryResponse.ScoredResult', k: str):
x.chunk.doc_id, x.chunk.offset = map(int, k.split('-'))


class ConcatEmbedRouter(BaseReduceRouter):
"""
Gather all embeddings from multiple encoders and concat them on a specific axis.
In default, concat will happen on the last axis.
"""

def apply(self, msg: 'gnes_pb2.Message', accum_msgs: List['gnes_pb2.Message'], *args, **kwargs):
body = getattr(msg, msg.WhichOneof('body'))
msg_type = type(getattr(body, body.WhichOneof('body')))
if msg_type == gnes_pb2.Request.QueryRequest:
for i in range(len(msg.request.search.query.chunks)):
concat_embedding = array2blob(
np.concatenate([blob2array(m.request.search.query.chunks[i].embedding) for m in accum_msgs],
axis=1))
msg.request.search.query.chunks[i].embedding.CopyFrom(concat_embedding)

elif msg_type == gnes_pb2.Request.IndexRequest:
for i in range(len(msg.request.index.docs)):
for j in range(len(msg.request.index.docs[i].chunks)):
concat_embedding = array2blob(
np.concatenate(
[blob2array(m.request.index.docs[i].chunks[j].embedding) for m in accum_msgs], axis=1))
msg.request.index.docs[i].chunks[j].embedding.CopyFrom(concat_embedding)
else:
self.logger.error('dont know how to handle %s' % msg_type)

super().apply(msg, accum_msgs)
17 changes: 12 additions & 5 deletions tests/test_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def test_doc_sum_reduce_router(self):
self.assertGreaterEqual(r.response.search.topk_results[0].score.value,
r.response.search.topk_results[-1].score.value)

@unittest.SkipTest
# @unittest.SkipTest
def test_concat_router(self):
args = set_router_parser().parse_args([
'--yaml_path', self.concat_router_yaml,
Expand All @@ -345,29 +345,36 @@ def test_concat_router(self):
'--port_out', str(args.port_in),
'--socket_in', str(SocketType.PULL_CONNECT)
])
# 10 chunks in each doc, dimension of chunk embedding is (5, 2)
with RouterService(args), ZmqClient(c_args) as c1:
msg = gnes_pb2.Message()
msg.request.search.query.chunk_embeddings.CopyFrom(array2blob(np.random.random([5, 2])))
for i in range(10):
c = msg.request.search.query.chunks.add()
c.embedding.CopyFrom(array2blob(np.random.random([5, 2])))
msg.envelope.num_part.extend([1, 3])
c1.send_message(msg)
c1.send_message(msg)
c1.send_message(msg)
r = c1.recv_message()
self.assertSequenceEqual(r.envelope.num_part, [1])
print(r.envelope.routes)
self.assertEqual(r.request.search.query.chunk_embeddings.shape, [5, 6])
for i in range(10):
self.assertEqual(r.request.search.query.chunks[i].embedding.shape, [5, 6])

for j in range(1, 4):
d = msg.request.index.docs.add()
d.chunk_embeddings.CopyFrom(array2blob(np.random.random([5, 2 * j])))
for k in range(10):
c = d.chunks.add()
c.embedding.CopyFrom(array2blob(np.random.random([5, 2])))

c1.send_message(msg)
c1.send_message(msg)
c1.send_message(msg)
r = c1.recv_message()
self.assertSequenceEqual(r.envelope.num_part, [1])
for j in range(1, 4):
self.assertEqual(r.request.index.docs[j - 1].chunk_embeddings.shape, [5, 6 * j])
for i in range(10):
self.assertEqual(r.request.index.docs[j - 1].chunks[i].embedding.shape, [5, 6])

def test_multimap_multireduce(self):
# p1 ->
Expand Down

0 comments on commit 8021d18

Please sign in to comment.