From a2520246a944549f634afcdeee57120f5524b7fc Mon Sep 17 00:00:00 2001 From: Larry Yan Date: Mon, 2 Sep 2019 19:53:39 +0800 Subject: [PATCH] fix (service): fix bug in encoder service --- gnes/encoder/video/incep_mixture.py | 5 +++- gnes/service/encoder.py | 37 ++++++++++++++++------------- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/gnes/encoder/video/incep_mixture.py b/gnes/encoder/video/incep_mixture.py index 0a0be11b..58867cc2 100644 --- a/gnes/encoder/video/incep_mixture.py +++ b/gnes/encoder/video/incep_mixture.py @@ -116,7 +116,10 @@ def _encode1(self, data): feed_dict={self.inputs: data}) return end_points_[self.select_layer] - v = [_ for vi in _encode1(self, img) for _ in vi] + if len(img) <= self.batch_size: + v = [_ for _ in _encode1(self, img)] + else: + v = [_ for vi in _encode1(self, img) for _ in vi] v_input = [v[s:e] for s, e in zip(pos_start, pos_end)] v_input = [(vi + [[0.0] * self.input_size] * (max_len - len(vi)))[:max_len] for vi in v_input] diff --git a/gnes/service/encoder.py b/gnes/service/encoder.py index 0206fca1..43421e92 100644 --- a/gnes/service/encoder.py +++ b/gnes/service/encoder.py @@ -35,37 +35,39 @@ def embed_chunks_in_docs(self, docs: Union[List['gnes_pb2.Document'], 'gnes_pb2. docs = [docs] contents = [] - chunks = [] + ids = [] embeds = None for d in docs: - if not d.chunks: - self.logger.warning('document (doc_id=%s) contains no chunks!' % d.doc_id) - continue - + ids.append(len(d.chunks)) for c in d.chunks: if d.doc_type == gnes_pb2.Document.TEXT: contents.append(c.text) - elif getattr(c, c.WhichOneof('content')) == 'blob': - contents.append(blob2array(c.blob)) else: - self.logger.warning( - 'chunk content is in type: %s, dont kow how to handle that, ignored' % c.WhichOneof('content')) - chunks.append(c) + contents.append(blob2array(c.blob)) - if do_encoding and contents: + if do_encoding: embeds = self._model.encode(contents) - if len(chunks) != embeds.shape[0]: + if sum(ids) != embeds.shape[0]: raise ServiceError( 'mismatched %d chunks and a %s shape embedding, ' - 'the first dimension must be the same' % (len(chunks), embeds.shape)) - for idx, c in enumerate(chunks): - c.embedding.CopyFrom(array2blob(embeds[idx])) + 'the first dimension must be the same' % (sum(ids), embeds.shape)) + idx = 0 + for d in docs: + for c in d.chunks: + c.embedding.CopyFrom(array2blob(embeds[idx])) + idx += 1 + return contents, embeds @handler.register(gnes_pb2.Request.IndexRequest) def _handler_index(self, msg: 'gnes_pb2.Message'): - self.embed_chunks_in_docs(msg.request.index.docs) + _, embeds = self.embed_chunks_in_docs(msg.request.index.docs) + idx = 0 + for d in msg.request.index.docs: + for c in d.chunks: + c.embedding.CopyFrom(array2blob(embeds[idx])) + idx += 1 @handler.register(gnes_pb2.Request.TrainRequest) def _handler_train(self, msg: 'gnes_pb2.Message'): @@ -83,4 +85,5 @@ def _handler_train(self, msg: 'gnes_pb2.Message'): @handler.register(gnes_pb2.Request.QueryRequest) def _handler_search(self, msg: 'gnes_pb2.Message'): - self.embed_chunks_in_docs(msg.request.search.query, is_input_list=False) + _, embeds = self.embed_chunks_in_docs(msg.request.search.query) + msg.request.search.query.chunk_embeddings.CopyFrom(array2blob(embeds))