diff --git a/gnes/preprocessor/image/simple.py b/gnes/preprocessor/image/simple.py index 6b2c3bd3..83776ecb 100644 --- a/gnes/preprocessor/image/simple.py +++ b/gnes/preprocessor/image/simple.py @@ -34,6 +34,7 @@ def __init__(self, window_size: int = 64, self.stride_wide = stride_wide def apply(self, doc: 'gnes_pb2.Document'): + super().apply(doc) if doc.raw_bytes: img = np.array(Image.open(io.BytesIO(doc.raw_bytes))) image_set = self._get_all_sliding_window(img) @@ -87,7 +88,18 @@ def _get_all_chunks_weight(self, image_set) -> List[float]: class WeightedSlidingPreprocessor(BaseSlidingPreprocessor): def _get_all_chunks_weight(self, image_set) -> List[float]: - raise NotImplementedError + weight = np.zeros([len(image_set)]) + # n_channel is usually 3 for RGB images + n_channel = image_set[0].shape[-1] + for i in range(len(image_set)): + # calcualte the variance of histgram of pixels + weight[i] = sum([np.histogram(image_set[i][:, :, _])[0].var() + for _ in range(n_channel)]) + weight = weight / weight.sum() + + # normalized result + weight = np.exp(- weight * 10) + return weight / weight.sum() class SegmentPreprocessor(BaseImagePreprocessor): diff --git a/gnes/service/encoder.py b/gnes/service/encoder.py index 56ed8d64..9afb4d68 100644 --- a/gnes/service/encoder.py +++ b/gnes/service/encoder.py @@ -50,7 +50,7 @@ def _handler_train(self, msg: 'gnes_pb2.Message'): chunks = self.get_chunks_from_docs(msg.request.train.docs) self.train_data.extend(chunks) msg.response.train.status = gnes_pb2.Response.PENDING - raise BlockMessage + # raise BlockMessage if msg.request.train.flush: self._model.train(self.train_data) self.logger.info('%d samples is flushed for training' % len(self.train_data)) diff --git a/gnes/service/indexer.py b/gnes/service/indexer.py index 8829195d..7322ae44 100644 --- a/gnes/service/indexer.py +++ b/gnes/service/indexer.py @@ -77,9 +77,9 @@ def _handler_chunk_search(self, msg: 'gnes_pb2.Message'): @handler.register(gnes_pb2.Response.QueryResponse) def _handler_doc_search(self, msg: 'gnes_pb2.Message'): if msg.response.search.level == gnes_pb2.Response.QueryResponse.DOCUMENT_NOT_FILLED: - doc_ids = [r.doc.doc_id for r in msg.response.topk_results] + doc_ids = [r.doc.doc_id for r in msg.response.search.topk_results] docs = self._model.query(doc_ids) - for r, d in zip(msg.response.topk_results, docs): + for r, d in zip(msg.response.search.topk_results, docs): if d is not None: # fill in the doc if this shard returns non-empty r.doc.CopyFrom(d)