Source code for gnes.client.http

#  Tencent is pleased to support the open source community by making GNES available.
#
#  Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#  http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.


import asyncio
import json
from concurrent.futures import ThreadPoolExecutor

import grpc
from google.protobuf.json_format import MessageToJson

from ..helper import set_logger
from ..proto import gnes_pb2_grpc, RequestGenerator


[docs]class HttpClient: def __init__(self, args=None): self.args = args self.logger = set_logger(self.__class__.__name__, self.args.verbose)
[docs] def start(self): try: from aiohttp import web except ImportError: self.logger.error('can not import aiohttp, it is not installed correctly. please do ' '"pip install gnes[aiohttp]"') return loop = asyncio.get_event_loop() executor = ThreadPoolExecutor(max_workers=self.args.max_workers) async def general_handler(request, parser, *args, **kwargs): try: data = dict() # # Option 1: uploading streaming chunk data # data = b"" # async for chunk in request.content.iter_any(): # data += chunk # self.logger.info("received %d content" % len(data)) # Option 2: uploading via Multipart-Encoded File post_data = await request.post() if 'query' in post_data.keys(): _file = post_data.get('query') self.logger.info('query request from input file: %s' % _file.filename) data['query'] = _file.file.read() elif 'docs' in post_data.keys(): files = post_data.getall('docs') self.logger.info('index request from input files: %d files' % len(files)) data['docs'] = [_file.file.read() for _file in files] self.logger.info('data received') resp = await loop.run_in_executor( executor, stub_call, parser([d for d in data.get('docs')] if hasattr(data, 'docs') else data.get('query'), *args, **kwargs)) self.logger.info('send back to user') return web.Response(body=json.dumps({'result': resp, 'meta': None}, ensure_ascii=False), status=200, content_type='application/json') except Exception as ex: return web.Response(body=json.dumps({'message': str(ex), 'type': type(ex)}), status=400, content_type='application/json') async def init(loop): # persistant connection or non-persistant connection handler_args = {'tcp_keepalive': False, 'keepalive_timeout': 25} app = web.Application(loop=loop, client_max_size=10 ** 10, handler_args=handler_args) app.router.add_route('post', '/train', lambda x: general_handler(x, RequestGenerator.train, batch_size=self.args.batch_size)) app.router.add_route('post', '/index', lambda x: general_handler(x, RequestGenerator.index, batch_size=self.args.batch_size)) app.router.add_route('post', '/query', lambda x: general_handler(x, RequestGenerator.query, top_k=self.args.top_k)) srv = await loop.create_server(app.make_handler(), self.args.http_host, self.args.http_port) self.logger.info('http server listens at %s:%d' % (self.args.http_host, self.args.http_port)) return srv def stub_call(req): res_f = list(stub.StreamCall(req))[-1] return json.loads(MessageToJson(res_f)) with grpc.insecure_channel( '%s:%s' % (self.args.grpc_host, self.args.grpc_port), options=[('grpc.max_send_message_length', self.args.max_message_size), ('grpc.max_receive_message_length', self.args.max_message_size), ('grpc.keepalive_timeout_ms', 100 * 1000)]) as channel: stub = gnes_pb2_grpc.GnesRPCStub(channel) loop.run_until_complete(init(loop)) loop.run_forever()