diff --git a/gnes/client/cli.py b/gnes/client/cli.py index e2d04c5e..b44d4d22 100644 --- a/gnes/client/cli.py +++ b/gnes/client/cli.py @@ -17,8 +17,7 @@ import sys import time import zipfile -from math import ceil -from typing import List +from typing import List, Generator from termcolor import colored @@ -29,25 +28,25 @@ class CLIClient(GrpcClient): def __init__(self, args): super().__init__(args) - getattr(self, self.args.mode)(self.read_all()) + getattr(self, self.args.mode)() self.close() - def train(self, all_bytes: List[bytes]): - with ProgressBar(all_bytes, self.args.batch_size, task_name=self.args.mode) as p_bar: - for _ in self._stub.StreamCall(RequestGenerator.train(all_bytes, + def train(self): + with ProgressBar(task_name=self.args.mode) as p_bar: + for _ in self._stub.StreamCall(RequestGenerator.train(self.bytes_generator, doc_id_start=self.args.start_doc_id, batch_size=self.args.batch_size)): p_bar.update() - def index(self, all_bytes: List[bytes]): - with ProgressBar(all_bytes, self.args.batch_size, task_name=self.args.mode) as p_bar: - for _ in self._stub.StreamCall(RequestGenerator.index(all_bytes, + def index(self): + with ProgressBar(task_name=self.args.mode) as p_bar: + for _ in self._stub.StreamCall(RequestGenerator.index(self.bytes_generator, doc_id_start=self.args.start_doc_id, batch_size=self.args.batch_size)): p_bar.update() - def query(self, all_bytes: List[bytes]): - for idx, q in enumerate(all_bytes): + def query(self): + for idx, q in enumerate(self.bytes_generator): for req in RequestGenerator.query(q, request_id_start=idx, top_k=self.args.top_k): resp = self._stub.Call(req) self.query_callback(req, resp) @@ -77,45 +76,51 @@ def read_all(self) -> List[bytes]: return all_bytes + @property + def bytes_generator(self) -> Generator[bytes]: + if self.args.txt_file: + all_bytes = (v.encode() for v in self.args.txt_file) + elif self.args.image_zip_file: + zipfile_ = zipfile.ZipFile(self.args.image_zip_file) + all_bytes = (zipfile_.open(v).read() for v in zipfile_.namelist()) + elif self.args.video_zip_file: + zipfile_ = zipfile.ZipFile(self.args.video_zip_file) + all_bytes = (zipfile_.open(v).read() for v in zipfile_.namelist()) + else: + raise AttributeError('--txt_file, --image_zip_file, --video_zip_file one must be given') + + return all_bytes + class ProgressBar: - def __init__(self, all_bytes: List[bytes], batch_size: int, bar_len: int = 20, task_name: str = ''): - self.all_bytes_len = [len(v) for v in all_bytes] - self.batch_size = batch_size - self.total_batch = ceil(len(self.all_bytes_len) / self.batch_size) + def __init__(self, bar_len: int = 20, task_name: str = ''): self.bar_len = bar_len self.task_name = task_name def update(self): - if self.num_batch > self.total_batch - 1: - return sys.stdout.write('\r') elapsed = time.perf_counter() - self.start_time elapsed_str = colored('elapsed', 'yellow') speed_str = colored('speed', 'yellow') - estleft_str = colored('left', 'yellow') - self.num_batch += 1 - percent = self.num_batch / self.total_batch - num_bytes = sum(self.all_bytes_len[((self.num_batch - 1) * self.batch_size):(self.num_batch * self.batch_size)]) + self.num_bars += 1 + if self.num_bars > self.bar_len: + self.num_bars -= self.bar_len + sys.stdout.write('\n') sys.stdout.write( - '{:>10} [{:<{}}] {:3.0f}% {:>8}: {:3.1f}s {:>8}: {:3.1f} bytes/s {:3.1f} batch/s {:>8}: {:3.1f}s'.format( + '{:>10} [{:<{}}] {:3.0f}% {:>8}: {:3.1f}s {:>8}: {:3.1f} batch/s'.format( colored(self.task_name, 'cyan'), - colored('=' * int(self.bar_len * percent), 'green'), + colored('=' * self.num_bars, 'green'), self.bar_len + 9, - percent * 100, elapsed_str, elapsed, speed_str, - num_bytes / elapsed, - self.num_batch / elapsed, - estleft_str, - (self.total_batch - self.num_batch) / ((self.num_batch + 0.0001) / elapsed) + self.num_bars / elapsed, )) sys.stdout.flush() def __enter__(self): self.start_time = time.perf_counter() - self.num_batch = -1 + self.num_bars = -1 sys.stdout.write('\n') self.update() return self diff --git a/gnes/proto/__init__.py b/gnes/proto/__init__.py index 9b4eff20..9e87ddb5 100644 --- a/gnes/proto/__init__.py +++ b/gnes/proto/__init__.py @@ -15,7 +15,7 @@ import ctypes import random -from typing import List +from typing import List, Iterator from typing import Optional import numpy as np @@ -30,7 +30,7 @@ class RequestGenerator: @staticmethod - def index(data: List[bytes], batch_size: int = 0, doc_type: int = gnes_pb2.Document.TEXT, + def index(data: Iterator[bytes], batch_size: int = 0, doc_type: int = gnes_pb2.Document.TEXT, doc_id_start: int = 0, request_id_start: int = 0, random_doc_id: bool = False, *args, **kwargs): @@ -49,7 +49,7 @@ def index(data: List[bytes], batch_size: int = 0, doc_type: int = gnes_pb2.Docum request_id_start += 1 @staticmethod - def train(data: List[bytes], batch_size: int = 0, doc_type: int = gnes_pb2.Document.TEXT, + def train(data: Iterator[bytes], batch_size: int = 0, doc_type: int = gnes_pb2.Document.TEXT, doc_id_start: int = 0, request_id_start: int = 0, random_doc_id: bool = False, *args, **kwargs):