Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
Merge branch 'master' into fix_vlad
Browse files Browse the repository at this point in the history
  • Loading branch information
jemmyshin authored Sep 9, 2019
2 parents 63b56f2 + d1eb573 commit 0eccfdc
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 23 deletions.
4 changes: 3 additions & 1 deletion gnes/router/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
'BaseMapRouter': 'base',
'BaseReduceRouter': 'base',
'BaseTopkReduceRouter': 'base',
'BaseEmbedReduceRouter': 'base',
'DocTopkReducer': 'reduce',
'ChunkTopkReducer': 'reduce',
'DocFillReducer': 'reduce',
'PublishRouter': 'map',
'DocBatchRouter': 'map',
'ConcatEmbedRouter': 'reduce'
'ConcatEmbedRouter': 'reduce',
'AvgEmbedRouter': 'reduce'
}

register_all_class(_cls2file_map, 'router')
29 changes: 28 additions & 1 deletion gnes/router/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from gnes.score_fn.base import CombinedScoreFn
from ..base import TrainableBase, CompositionalTrainableBase
from ..proto import gnes_pb2, merge_routes
from ..proto import gnes_pb2, merge_routes, array2blob


class BaseRouter(TrainableBase):
Expand Down Expand Up @@ -94,6 +94,33 @@ def apply(self, msg: 'gnes_pb2.Message', accum_msgs: List['gnes_pb2.Message'], *
super().apply(msg, accum_msgs)


class BaseEmbedReduceRouter(BaseReduceRouter):
def reduce_embedding(self, accum_msgs: List['gnes_pb2.Message'], msg_type: str, chunk_idx: int, doc_idx: int):
raise NotImplementedError

def apply(self, msg: 'gnes_pb2.Message', accum_msgs: List['gnes_pb2.Message'], *args, **kwargs) -> None:
"""
reduce embeddings from encoders (means, concat ....)
:param msg: the current message
:param accum_msgs: accumulated messages
"""
body = getattr(msg, msg.WhichOneof('body'))
msg_type = type(getattr(body, body.WhichOneof('body')))
if msg_type == gnes_pb2.Request.QueryRequest:
for i in range(len(msg.request.search.query.chunks)):
reduced_embedding = array2blob(self.reduce_embedding(accum_msgs, 'query', chunk_idx=i, doc_idx=-1))
msg.request.search.query.chunks[i].embedding.CopyFrom(reduced_embedding)
elif msg_type == gnes_pb2.Request.IndexRequest:
for i in range(len(msg.request.index.docs)):
for j in range(len(msg.request.index.docs[i].chunks)):
reduced_embedding = array2blob(self.reduce_embedding(accum_msgs, 'index', chunk_idx=j, doc_idx=i))
msg.request.index.docs[i].chunks[j].embedding.CopyFrom(reduced_embedding)
else:
self.logger.error('dont know how to handle %s' % msg_type)

super().apply(msg, accum_msgs)


class PipelineRouter(CompositionalTrainableBase):
def apply(self, *args, **kwargs) -> None:
if not self.components:
Expand Down
49 changes: 28 additions & 21 deletions gnes/router/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from typing import List
import numpy as np

from .base import BaseReduceRouter, BaseTopkReduceRouter
from ..proto import gnes_pb2, blob2array, array2blob
from .base import BaseReduceRouter, BaseTopkReduceRouter, BaseEmbedReduceRouter
from ..proto import blob2array


class DocFillReducer(BaseReduceRouter):
Expand Down Expand Up @@ -78,30 +78,37 @@ def set_key(self, x: 'gnes_pb2.Response.QueryResponse.ScoredResult', k: str):
x.chunk.doc_id, x.chunk.offset = map(int, k.split('-'))


class ConcatEmbedRouter(BaseReduceRouter):
class ConcatEmbedRouter(BaseEmbedReduceRouter):
"""
Gather all embeddings from multiple encoders and concat them on a specific axis.
In default, concat will happen on the last axis.
chunk_idx, doc_idx denote index in for loop used in BaseEmbedReduceRouter
"""

def apply(self, msg: 'gnes_pb2.Message', accum_msgs: List['gnes_pb2.Message'], *args, **kwargs):
body = getattr(msg, msg.WhichOneof('body'))
msg_type = type(getattr(body, body.WhichOneof('body')))
if msg_type == gnes_pb2.Request.QueryRequest:
for i in range(len(msg.request.search.query.chunks)):
concat_embedding = array2blob(
np.concatenate([blob2array(m.request.search.query.chunks[i].embedding) for m in accum_msgs],
axis=1))
msg.request.search.query.chunks[i].embedding.CopyFrom(concat_embedding)

elif msg_type == gnes_pb2.Request.IndexRequest:
for i in range(len(msg.request.index.docs)):
for j in range(len(msg.request.index.docs[i].chunks)):
concat_embedding = array2blob(
np.concatenate(
[blob2array(m.request.index.docs[i].chunks[j].embedding) for m in accum_msgs], axis=1))
msg.request.index.docs[i].chunks[j].embedding.CopyFrom(concat_embedding)
def reduce_embedding(self, accum_msgs: List['gnes_pb2.Message'], msg_type: str, chunk_idx: int, doc_idx: int):
if msg_type == 'query':
return np.concatenate([blob2array(m.request.search.query.chunks[chunk_idx].embedding)
for m in accum_msgs], axis=1)
elif msg_type == 'index':
return np.concatenate([blob2array(m.request.index.docs[doc_idx].chunks[chunk_idx].embedding)
for m in accum_msgs], axis=1)
else:
self.logger.error('dont know how to handle %s' % msg_type)

super().apply(msg, accum_msgs)

class AvgEmbedRouter(BaseEmbedReduceRouter):
"""
Gather all embeddings from multiple encoders and do average on a specific axis.
In default, average will happen on the first axis.
chunk_idx, doc_idx denote index in for loop used in BaseEmbedReduceRouter
"""

def reduce_embedding(self, accum_msgs: List['gnes_pb2.Message'], msg_type: str, chunk_idx: int, doc_idx: int):
if msg_type == 'query':
return np.mean([blob2array(m.request.search.query.chunks[chunk_idx].embedding)
for m in accum_msgs], axis=0)
elif msg_type == 'index':
return np.mean([blob2array(m.request.index.docs[doc_idx].chunks[chunk_idx].embedding)
for m in accum_msgs], axis=0)
else:
self.logger.error('dont know how to handle %s' % msg_type)
42 changes: 42 additions & 0 deletions tests/test_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def setUp(self):
self.doc_router_yaml = 'DocFillReducer'
self.doc_sum_yaml = 'DocSumRouter'
self.concat_router_yaml = 'ConcatEmbedRouter'
self.avg_router_yaml = 'AvgEmbedRouter'

def test_service_empty(self):
args = set_router_parser().parse_args(['--yaml_path', 'BaseRouter'])
Expand Down Expand Up @@ -375,7 +376,48 @@ def test_concat_router(self):
for j in range(1, 4):
for i in range(10):
self.assertEqual(r.request.index.docs[j - 1].chunks[i].embedding.shape, [5, 6])

def test_avg_router(self):
args = set_router_parser().parse_args([
'--yaml_path', self.avg_router_yaml,
'--socket_out', str(SocketType.PUSH_BIND)
])
c_args = _set_client_parser().parse_args([
'--port_in', str(args.port_out),
'--port_out', str(args.port_in),
'--socket_in', str(SocketType.PULL_CONNECT)
])
# 10 chunks in each doc, dimension of chunk embedding is (5, 2)
with RouterService(args), ZmqClient(c_args) as c1:
msg = gnes_pb2.Message()
for i in range(10):
c = msg.request.search.query.chunks.add()
c.embedding.CopyFrom(array2blob(np.random.random([5, 2])))
msg.envelope.num_part.extend([1, 3])
c1.send_message(msg)
c1.send_message(msg)
c1.send_message(msg)
r = c1.recv_message()
self.assertSequenceEqual(r.envelope.num_part, [1])
print(r.envelope.routes)
for i in range(10):
self.assertEqual(r.request.search.query.chunks[i].embedding.shape, [5, 2])

for j in range(1, 4):
d = msg.request.index.docs.add()
for k in range(10):
c = d.chunks.add()
c.embedding.CopyFrom(array2blob(np.random.random([5, 2])))

c1.send_message(msg)
c1.send_message(msg)
c1.send_message(msg)
r = c1.recv_message()
self.assertSequenceEqual(r.envelope.num_part, [1])
for j in range(1, 4):
for i in range(10):
self.assertEqual(r.request.index.docs[j - 1].chunks[i].embedding.shape, [5, 2])

def test_multimap_multireduce(self):
# p1 ->
# p21 ->
Expand Down

0 comments on commit 0eccfdc

Please sign in to comment.