Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
fix: bugs for integrated test
Browse files Browse the repository at this point in the history
  • Loading branch information
Jem committed Jul 12, 2019
1 parent 72a8bd9 commit 8780a4d
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 4 deletions.
14 changes: 13 additions & 1 deletion gnes/preprocessor/image/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self, window_size: int = 64,
self.stride_wide = stride_wide

def apply(self, doc: 'gnes_pb2.Document'):
super().apply(doc)
if doc.raw_bytes:
img = np.array(Image.open(io.BytesIO(doc.raw_bytes)))
image_set = self._get_all_sliding_window(img)
Expand Down Expand Up @@ -87,7 +88,18 @@ def _get_all_chunks_weight(self, image_set) -> List[float]:
class WeightedSlidingPreprocessor(BaseSlidingPreprocessor):

def _get_all_chunks_weight(self, image_set) -> List[float]:
raise NotImplementedError
weight = np.zeros([len(image_set)])
# n_channel is usually 3 for RGB images
n_channel = image_set[0].shape[-1]
for i in range(len(image_set)):
# calcualte the variance of histgram of pixels
weight[i] = sum([np.histogram(image_set[i][:, :, _])[0].var()
for _ in range(n_channel)])
weight = weight / weight.sum()

# normalized result
weight = np.exp(- weight * 10)
return weight / weight.sum()


class SegmentPreprocessor(BaseImagePreprocessor):
Expand Down
2 changes: 1 addition & 1 deletion gnes/service/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _handler_train(self, msg: 'gnes_pb2.Message'):
chunks = self.get_chunks_from_docs(msg.request.train.docs)
self.train_data.extend(chunks)
msg.response.train.status = gnes_pb2.Response.PENDING
raise BlockMessage
# raise BlockMessage
if msg.request.train.flush:
self._model.train(self.train_data)
self.logger.info('%d samples is flushed for training' % len(self.train_data))
Expand Down
4 changes: 2 additions & 2 deletions gnes/service/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ def _handler_chunk_search(self, msg: 'gnes_pb2.Message'):
@handler.register(gnes_pb2.Response.QueryResponse)
def _handler_doc_search(self, msg: 'gnes_pb2.Message'):
if msg.response.search.level == gnes_pb2.Response.QueryResponse.DOCUMENT_NOT_FILLED:
doc_ids = [r.doc.doc_id for r in msg.response.topk_results]
doc_ids = [r.doc.doc_id for r in msg.response.search.topk_results]
docs = self._model.query(doc_ids)
for r, d in zip(msg.response.topk_results, docs):
for r, d in zip(msg.response.search.topk_results, docs):
if d is not None:
# fill in the doc if this shard returns non-empty
r.doc.CopyFrom(d)
Expand Down

0 comments on commit 8780a4d

Please sign in to comment.