# Tencent is pleased to support the open source community by making GNES available.
#
# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import time
import zipfile
from typing import Iterator, Tuple
from termcolor import colored
from .base import GrpcClient
from ..proto import RequestGenerator
[docs]class CLIClient(GrpcClient):
def __init__(self, args, start_at_init: bool = True):
super().__init__(args)
self._bytes_generator = self._get_bytes_generator_from_args(args)
if start_at_init:
self.start()
@staticmethod
def _get_bytes_generator_from_args(args):
if args.txt_file:
all_bytes = (v.encode() for v in args.txt_file)
elif args.image_zip_file:
zipfile_ = zipfile.ZipFile(args.image_zip_file)
all_bytes = (zipfile_.open(v).read() for v in zipfile_.namelist())
elif args.video_zip_file:
zipfile_ = zipfile.ZipFile(args.video_zip_file)
all_bytes = (zipfile_.open(v).read() for v in zipfile_.namelist())
else:
all_bytes = None
return all_bytes
[docs] def start(self):
try:
getattr(self, self.args.mode)()
except Exception as ex:
self.logger.error(ex)
finally:
self.close()
[docs] def train(self) -> None:
with ProgressBar(task_name=self.args.mode) as p_bar:
for _ in self._stub.StreamCall(RequestGenerator.train(self.bytes_generator,
doc_id_start=self.args.start_doc_id,
batch_size=self.args.batch_size)):
p_bar.update()
[docs] def index(self) -> None:
with ProgressBar(task_name=self.args.mode) as p_bar:
for _ in self._stub.StreamCall(RequestGenerator.index(self.bytes_generator,
doc_id_start=self.args.start_doc_id,
batch_size=self.args.batch_size)):
p_bar.update()
[docs] def query(self) -> Iterator[Tuple]:
for idx, q in enumerate(self.bytes_generator):
for req in RequestGenerator.query(q, request_id_start=idx, top_k=self.args.top_k):
resp = self._stub.Call(req)
yield (req, resp)
@property
def bytes_generator(self) -> Iterator[bytes]:
if self._bytes_generator:
return self._bytes_generator
else:
raise ValueError('bytes_generator is empty or not set')
@bytes_generator.setter
def bytes_generator(self, bytes_gen: Iterator[bytes]):
if self._bytes_generator:
self.logger.warning('bytes_generator is not empty, overrided')
self._bytes_generator = bytes_gen
[docs]class ProgressBar:
def __init__(self, bar_len: int = 20, task_name: str = ''):
self.bar_len = bar_len
self.task_name = task_name
[docs] def update(self):
self.num_bars += 1
sys.stdout.write('\r')
elapsed = time.perf_counter() - self.start_time
elapsed_str = colored('elapsed', 'yellow')
speed_str = colored('speed', 'yellow')
num_bars = self.num_bars % self.bar_len
num_bars = self.bar_len if not num_bars and self.num_bars else max(num_bars, 1)
sys.stdout.write(
'{:>10} [{:<{}}] {:>8}: {:3.1f}s {:>8}: {:3.1f} batch/s'.format(
colored(self.task_name, 'cyan'),
colored('=' * num_bars, 'green'),
self.bar_len + 9,
elapsed_str,
elapsed,
speed_str,
self.num_bars / elapsed,
))
if num_bars == self.bar_len:
sys.stdout.write('\n')
sys.stdout.flush()
def __enter__(self):
self.start_time = time.perf_counter()
self.num_bars = -1
self.update()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
sys.stdout.write('\t%s\n' % colored('done!', 'green'))