Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
fix (service): fix bug in encoder service
Browse files Browse the repository at this point in the history
  • Loading branch information
Larryjianfeng committed Sep 2, 2019
1 parent 3a18111 commit a252024
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 18 deletions.
5 changes: 4 additions & 1 deletion gnes/encoder/video/incep_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ def _encode1(self, data):
feed_dict={self.inputs: data})
return end_points_[self.select_layer]

v = [_ for vi in _encode1(self, img) for _ in vi]
if len(img) <= self.batch_size:
v = [_ for _ in _encode1(self, img)]
else:
v = [_ for vi in _encode1(self, img) for _ in vi]

v_input = [v[s:e] for s, e in zip(pos_start, pos_end)]
v_input = [(vi + [[0.0] * self.input_size] * (max_len - len(vi)))[:max_len] for vi in v_input]
Expand Down
37 changes: 20 additions & 17 deletions gnes/service/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,37 +35,39 @@ def embed_chunks_in_docs(self, docs: Union[List['gnes_pb2.Document'], 'gnes_pb2.
docs = [docs]

contents = []
chunks = []
ids = []
embeds = None

for d in docs:
if not d.chunks:
self.logger.warning('document (doc_id=%s) contains no chunks!' % d.doc_id)
continue

ids.append(len(d.chunks))
for c in d.chunks:
if d.doc_type == gnes_pb2.Document.TEXT:
contents.append(c.text)
elif getattr(c, c.WhichOneof('content')) == 'blob':
contents.append(blob2array(c.blob))
else:
self.logger.warning(
'chunk content is in type: %s, dont kow how to handle that, ignored' % c.WhichOneof('content'))
chunks.append(c)
contents.append(blob2array(c.blob))

if do_encoding and contents:
if do_encoding:
embeds = self._model.encode(contents)
if len(chunks) != embeds.shape[0]:
if sum(ids) != embeds.shape[0]:
raise ServiceError(
'mismatched %d chunks and a %s shape embedding, '
'the first dimension must be the same' % (len(chunks), embeds.shape))
for idx, c in enumerate(chunks):
c.embedding.CopyFrom(array2blob(embeds[idx]))
'the first dimension must be the same' % (sum(ids), embeds.shape))
idx = 0
for d in docs:
for c in d.chunks:
c.embedding.CopyFrom(array2blob(embeds[idx]))
idx += 1

return contents, embeds

@handler.register(gnes_pb2.Request.IndexRequest)
def _handler_index(self, msg: 'gnes_pb2.Message'):
self.embed_chunks_in_docs(msg.request.index.docs)
_, embeds = self.embed_chunks_in_docs(msg.request.index.docs)
idx = 0
for d in msg.request.index.docs:
for c in d.chunks:
c.embedding.CopyFrom(array2blob(embeds[idx]))
idx += 1

@handler.register(gnes_pb2.Request.TrainRequest)
def _handler_train(self, msg: 'gnes_pb2.Message'):
Expand All @@ -83,4 +85,5 @@ def _handler_train(self, msg: 'gnes_pb2.Message'):

@handler.register(gnes_pb2.Request.QueryRequest)
def _handler_search(self, msg: 'gnes_pb2.Message'):
self.embed_chunks_in_docs(msg.request.search.query, is_input_list=False)
_, embeds = self.embed_chunks_in_docs(msg.request.search.query)
msg.request.search.query.chunk_embeddings.CopyFrom(array2blob(embeds))

0 comments on commit a252024

Please sign in to comment.