Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
fix(service): fix training logic in encoderservice
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Aug 29, 2019
1 parent 5828d20 commit c618396
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions gnes/service/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ def post_init(self):
self._model = self.load_model(BaseEncoder)
self.train_data = []

def embed_chunks_in_docs(self, docs: Union[List['gnes_pb2.Document'], 'gnes_pb2.Document']):
def embed_chunks_in_docs(self, docs: Union[List['gnes_pb2.Document'], 'gnes_pb2.Document'],
do_encoding: bool = True):
if not isinstance(docs, list):
docs = [docs]

contents = []
chunks = []
embeds = None

for d in docs:
for c in d.chunks:
Expand All @@ -46,14 +48,15 @@ def embed_chunks_in_docs(self, docs: Union[List['gnes_pb2.Document'], 'gnes_pb2.
raise ServiceError(
'chunk content is in type: %s, dont kow how to handle that' % c.WhichOneof('content'))

embeds = self._model.encode(contents)
if len(chunks) != embeds.shape[0]:
raise ServiceError(
'mismatched %d chunks and a %s shape embedding, '
'the first dimension must be the same' % (len(chunks), embeds.shape))
for idx, c in enumerate(chunks):
c.embedding.CopyFrom(array2blob(embeds[idx]))
return embeds
if do_encoding:
embeds = self._model.encode(contents)
if len(chunks) != embeds.shape[0]:
raise ServiceError(
'mismatched %d chunks and a %s shape embedding, '
'the first dimension must be the same' % (len(chunks), embeds.shape))
for idx, c in enumerate(chunks):
c.embedding.CopyFrom(array2blob(embeds[idx]))
return contents, embeds

@handler.register(gnes_pb2.Request.IndexRequest)
def _handler_index(self, msg: 'gnes_pb2.Message'):
Expand All @@ -62,7 +65,7 @@ def _handler_index(self, msg: 'gnes_pb2.Message'):
@handler.register(gnes_pb2.Request.TrainRequest)
def _handler_train(self, msg: 'gnes_pb2.Message'):
if msg.request.train.docs:
_, contents = self.embed_chunks_in_docs(msg.request.train.docs)
contents, _ = self.embed_chunks_in_docs(msg.request.train.docs, do_encoding=False)
self.train_data.extend(contents)
msg.response.train.status = gnes_pb2.Response.PENDING
# raise BlockMessage
Expand Down

0 comments on commit c618396

Please sign in to comment.