Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
Merge pull request #316 from gnes-ai/fix-dump-interval
Browse files Browse the repository at this point in the history
fix(indexer): fix empty chunk and dump_interval
  • Loading branch information
mergify[bot] authored Oct 10, 2019
2 parents 49150fb + bca5b5b commit 1878564
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 25 deletions.
6 changes: 6 additions & 0 deletions gnes/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,12 @@ def _get_instance_from_yaml(cls, constructor, node, stop_on_import_error=False):
data = ruamel.yaml.constructor.SafeConstructor.construct_mapping(
constructor, node, deep=True)

_gnes_config = data.get('gnes_config', {})
for k, v in _gnes_config.items():
_gnes_config[k] = _expand_env_var(v)
if _gnes_config:
data['gnes_config'] = _gnes_config

dump_path = cls._get_dump_path_from_config(data.get('gnes_config', {}))
load_from_dump = False
if dump_path:
Expand Down
9 changes: 8 additions & 1 deletion gnes/flow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import os
from collections import OrderedDict, defaultdict
from contextlib import ExitStack
from functools import wraps
Expand Down Expand Up @@ -175,6 +176,9 @@ def query(self, bytes_gen: Generator[bytes, None, None] = None, **kwargs):

@_build_level(BuildLevel.RUNTIME)
def _call_client(self, bytes_gen: Generator[bytes, None, None] = None, **kwargs):

os.unsetenv('http_proxy')
os.unsetenv('https_proxy')
args, p_args = self._get_parsed_args(self, set_client_cli_parser, kwargs)
p_args.grpc_port = self._service_nodes[self._frontend]['parsed_args'].grpc_port
p_args.grpc_host = self._service_nodes[self._frontend]['parsed_args'].grpc_host
Expand Down Expand Up @@ -356,19 +360,20 @@ def _build_graph(self, copy_flow: bool) -> 'Flow':
#
# when a socket is BIND, then host must NOT be set, aka default host 0.0.0.0
# host_in and host_out is only set when corresponding socket is CONNECT
e_pargs.port_in = s_pargs.port_out

if len(edges_with_same_start) > 1 and len(edges_with_same_end) == 1:
s_pargs.socket_out = SocketType.PUB_BIND
s_pargs.host_out = BaseService.default_host
e_pargs.socket_in = SocketType.SUB_CONNECT
e_pargs.host_in = start_node
e_pargs.port_in = s_pargs.port_out
op_flow._service_edges[k] = 'PUB-sub'
elif len(edges_with_same_end) > 1 and len(edges_with_same_start) == 1:
s_pargs.socket_out = SocketType.PUSH_CONNECT
s_pargs.host_out = end_node
e_pargs.socket_in = SocketType.PULL_BIND
e_pargs.host_in = BaseService.default_host
s_pargs.port_out = e_pargs.port_in
op_flow._service_edges[k] = 'push-PULL'
elif len(edges_with_same_start) == 1 and len(edges_with_same_end) == 1:
# in this case, either side can be BIND
Expand All @@ -386,10 +391,12 @@ def _build_graph(self, copy_flow: bool) -> 'Flow':
if s_pargs.socket_out.is_bind:
s_pargs.host_out = BaseService.default_host
e_pargs.host_in = start_node
e_pargs.port_in = s_pargs.port_out
op_flow._service_edges[k] = 'PUSH-pull'
elif e_pargs.socket_in.is_bind:
s_pargs.host_out = end_node
e_pargs.host_in = BaseService.default_host
s_pargs.port_out = e_pargs.port_in
op_flow._service_edges[k] = 'push-PULL'
else:
raise FlowTopologyError('edge %s -> %s is ambiguous, at least one socket should be BIND')
Expand Down
24 changes: 19 additions & 5 deletions gnes/service/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,13 @@ def build_socket(ctx: 'zmq.Context', host: str, port: int, socket_type: 'SocketT

class MessageHandler:
def __init__(self, mh: 'MessageHandler' = None):
self.routes = {k: v for k, v in mh.routes.items()} if mh else {}
self.hooks = {k: v for k, v in mh.hooks.items()} if mh else {'pre': [], 'post': []}
self.routes = {}
self.hooks = {'pre': [], 'post': []}

if mh:
self.routes = copy.deepcopy(mh.routes)
self.hooks = copy.deepcopy(mh.hooks)

self.logger = set_logger(self.__class__.__name__)
self.service_context = None

Expand Down Expand Up @@ -329,6 +334,14 @@ def __init__(self, args):
check_version=self.args.check_version,
timeout=self.args.timeout,
squeeze_pb=self.args.squeeze_pb)
# self._override_handler()

def _override_handler(self):
# replace the function name by the function itself
mh = MessageHandler()
mh.routes = {k: getattr(self, v) for k, v in self.handler.routes.items()}
mh.hooks = {k: [(getattr(self, vv[0]), vv[1]) for vv in v] for k, v in self.handler.hooks.items()}
self.handler = mh

def run(self):
try:
Expand All @@ -341,9 +354,9 @@ def dump(self, respect_dump_interval: bool = True):
and self.args.dump_interval > 0
and self._model
and self.is_model_changed.is_set()
and (respect_dump_interval
and (time.perf_counter() - self.last_dump_time) > self.args.dump_interval)
or not respect_dump_interval):
and ((respect_dump_interval
and (time.perf_counter() - self.last_dump_time) > self.args.dump_interval)
or not respect_dump_interval)):
self.is_model_changed.clear()
self.logger.info('dumping changes to the model, %3.0fs since last the dump'
% (time.perf_counter() - self.last_dump_time))
Expand Down Expand Up @@ -385,6 +398,7 @@ def _hook_update_route_timestamp(self, msg: 'gnes_pb2.Message', *args, **kwargs)
def _run(self, ctx):
ctx.setsockopt(zmq.LINGER, 0)
self.handler.service_context = self
# print('!!!! t_id: %d service_context: %r' % (threading.get_ident(), self.handler.service_context))
self.logger.info('bind sockets...')
in_sock, _ = build_socket(ctx, self.args.host_in, self.args.port_in, self.args.socket_in,
self.args.identity)
Expand Down
20 changes: 12 additions & 8 deletions gnes/service/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import numpy as np

from .base import BaseService as BS, MessageHandler, ServiceError
Expand All @@ -25,10 +24,16 @@ class IndexerService(BS):

def post_init(self):
from ..indexer.base import BaseIndexer
# print('id: %s, before: %r' % (threading.get_ident(), self._model))
self._model = self.load_model(BaseIndexer)
# self._tmp_a = threading.get_ident()
# print('id: %s, after: %r, self._tmp_a: %r' % (threading.get_ident(), self._model, self._tmp_a))

@handler.register(gnes_pb2.Request.IndexRequest)
def _handler_index(self, msg: 'gnes_pb2.Message'):
# print('tid: %s, model: %r, self._tmp_a: %r' % (threading.get_ident(), self._model, self._tmp_a))
# if self._tmp_a != threading.get_ident():
# print('tid: %s, tmp_a: %r !!! %r' % (threading.get_ident(), self._tmp_a, self._handler_index))
from ..indexer.base import BaseChunkIndexer, BaseDocIndexer
if isinstance(self._model, BaseChunkIndexer):
self._handler_chunk_index(msg)
Expand All @@ -41,22 +46,21 @@ def _handler_index(self, msg: 'gnes_pb2.Message'):
self.is_model_changed.set()

def _handler_chunk_index(self, msg: 'gnes_pb2.Message'):
vecs, doc_ids, offsets, weights = [], [], [], []
embed_info = []

for d in msg.request.index.docs:
if not d.chunks:
self.logger.warning('document (doc_id=%s) contains no chunks!' % d.doc_id)
continue

vecs += [blob2array(c.embedding) for c in d.chunks]
doc_ids += [d.doc_id] * len(d.chunks)
offsets += [c.offset for c in d.chunks]
weights += [c.weight for c in d.chunks]
embed_info += [(blob2array(c.embedding), d.doc_id, c.offset, c.weight) for c in d.chunks if
c.embedding.data]

if vecs:
if embed_info:
vecs, doc_ids, offsets, weights = zip(*embed_info)
self._model.add(list(zip(doc_ids, offsets)), np.stack(vecs), weights)
else:
self.logger.warning('chunks contain no embedded vectors, %the indexer will do nothing')
self.logger.warning('chunks contain no embedded vectors, the indexer will do nothing')

def _handler_doc_index(self, msg: 'gnes_pb2.Message'):
self._model.add([d.doc_id for d in msg.request.index.docs],
Expand Down
22 changes: 11 additions & 11 deletions tests/test_gnes_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,29 +120,29 @@ def _test_index_flow(self):

flow = (Flow(check_version=False, route_table=False)
.add(gfs.Preprocessor, name='prep', yaml_path='SentSplitPreprocessor')
.add(gfs.Encoder, yaml_path='yaml/flow-transformer.yml')
.add(gfs.Indexer, name='vec_idx', yaml_path='yaml/flow-vecindex.yml')
.add(gfs.Indexer, name='doc_idx', yaml_path='yaml/flow-dictindex.yml',
.add(gfs.Encoder, yaml_path=os.path.join(self.dirname, 'yaml/flow-transformer.yml'))
.add(gfs.Indexer, name='vec_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-vecindex.yml'))
.add(gfs.Indexer, name='doc_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-dictindex.yml'),
service_in='prep')
.add(gfs.Router, name='sync_barrier', yaml_path='BaseReduceRouter',
num_part=2, service_in=['vec_idx', 'doc_idx']))

with flow.build(backend='thread') as f:
with flow.build(backend='process') as f:
f.index(txt_file=self.test_file, batch_size=20)

for k in [self.indexer1_bin, self.indexer2_bin, self.encoder_bin]:
for k in [self.indexer1_bin, self.indexer2_bin]:
self.assertTrue(os.path.exists(k))

def _test_query_flow(self):
flow = (Flow(check_version=False, route_table=False)
.add(gfs.Preprocessor, name='prep', yaml_path='SentSplitPreprocessor')
.add(gfs.Encoder, yaml_path='yaml/flow-transformer.yml')
.add(gfs.Indexer, name='vec_idx', yaml_path='yaml/flow-vecindex.yml')
.add(gfs.Router, name='scorer', yaml_path='yaml/flow-score.yml')
.add(gfs.Indexer, name='doc_idx', yaml_path='yaml/flow-dictindex.yml'))
.add(gfs.Encoder, yaml_path=os.path.join(self.dirname, 'yaml/flow-transformer.yml'))
.add(gfs.Indexer, name='vec_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-vecindex.yml'))
.add(gfs.Router, name='scorer', yaml_path=os.path.join(self.dirname, 'yaml/flow-score.yml'))
.add(gfs.Indexer, name='doc_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-dictindex.yml')))

with flow.build(backend='thread') as f:
f.query(txt_file=self.test_file)
with flow.build(backend='process') as f, open(self.test_file, encoding='utf8') as fp:
f.query(bytes_gen=[v.encode() for v in fp][:10])

@unittest.SkipTest
def test_index_query_flow(self):
Expand Down

0 comments on commit 1878564

Please sign in to comment.