add grpc server

This commit is contained in:
peng.xu 2019-09-17 20:48:08 +08:00
parent 052d79a58d
commit 4fc6f0a520
15 changed files with 502 additions and 13 deletions

View File

@ -1 +0,0 @@
import settings

6
mishards/__init__.py Normal file
View File

@ -0,0 +1,6 @@
import settings
from connections import ConnectionMgr
connect_mgr = ConnectionMgr()
from server import Server
grpc_server = Server(conn_mgr=connect_mgr)

View File

@ -89,7 +89,7 @@ class ConnectionMgr:
threaded = {
threading.get_ident() : this_conn
}
c[name] = threaded
self.conns[name] = threaded
return this_conn
tid = threading.get_ident()

View File

View File

@ -0,0 +1,101 @@
from milvus import Status
from functools import wraps
def error_status(func):
@wraps(func)
def inner(*args, **kwargs):
try:
results = func(*args, **kwargs)
except Exception as e:
return Status(code=Status.UNEXPECTED_ERROR, message=str(e)), None
return Status(code=0, message="Success"), results
return inner
class GrpcArgsParser(object):
@classmethod
@error_status
def parse_proto_TableSchema(cls, param):
_table_schema = {
'table_name': param.table_name.table_name,
'dimension': param.dimension,
'index_file_size': param.index_file_size,
'metric_type': param.metric_type
}
return _table_schema
@classmethod
@error_status
def parse_proto_TableName(cls, param):
return param.table_name
@classmethod
@error_status
def parse_proto_Index(cls, param):
_index = {
'index_type': param.index_type,
'nlist': param.nlist
}
return _index
@classmethod
@error_status
def parse_proto_IndexParam(cls, param):
_table_name = param.table_name.table_name
_status, _index = cls.parse_proto_Index(param.index)
if not _status.OK():
raise Exception("Argument parse error")
return _table_name, _index
@classmethod
@error_status
def parse_proto_Command(cls, param):
_cmd = param.cmd
return _cmd
@classmethod
@error_status
def parse_proto_Range(cls, param):
_start_value = param.start_value
_end_value = param.end_value
return _start_value, _end_value
@classmethod
@error_status
def parse_proto_RowRecord(cls, param):
return list(param.vector_data)
@classmethod
@error_status
def parse_proto_SearchParam(cls, param):
_table_name = param.table_name
_topk = param.topk
_nprobe = param.nprobe
_status, _range = cls.parse_proto_Range(param.query_range_array)
if not _status.OK():
raise Exception("Argument parse error")
_row_record = param.query_record_array
return _table_name, _row_record, _range, _topk
@classmethod
@error_status
def parse_proto_DeleteByRangeParam(cls, param):
_table_name = param.table_name
_range = param.range
_start_value = _range.start_value
_end_value = _range.end_value
return _table_name, _start_value, _end_value

View File

@ -0,0 +1,4 @@
# class GrpcArgsWrapper(object):
# @classmethod
# def proto_TableName(cls):

14
mishards/main.py Normal file
View File

@ -0,0 +1,14 @@
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import settings
from mishards import connect_mgr, grpc_server as server
def main():
connect_mgr.register('WOSERVER', settings.WOSERVER)
server.run(port=settings.SERVER_PORT)
return 0
if __name__ == '__main__':
sys.exit(main())

47
mishards/server.py Normal file
View File

@ -0,0 +1,47 @@
import logging
import grpc
import time
from concurrent import futures
from grpc._cython import cygrpc
from milvus.grpc_gen.milvus_pb2_grpc import add_MilvusServiceServicer_to_server
from service_handler import ServiceHandler
import settings
logger = logging.getLogger(__name__)
class Server:
def __init__(self, conn_mgr, port=19530, max_workers=10, **kwargs):
self.exit_flag = False
self.port = int(port)
self.conn_mgr = conn_mgr
self.server_impl = grpc.server(
thread_pool=futures.ThreadPoolExecutor(max_workers=max_workers),
options=[(cygrpc.ChannelArgKey.max_send_message_length, -1),
(cygrpc.ChannelArgKey.max_receive_message_length, -1)]
)
def start(self, port=None):
add_MilvusServiceServicer_to_server(ServiceHandler(conn_mgr=self.conn_mgr), self.server_impl)
self.server_impl.add_insecure_port("[::]:{}".format(str(port or self._port)))
self.server_impl.start()
def run(self, port):
logger.info('Milvus server start ......')
port = port or self.port
self.start(port)
logger.info('Successfully')
logger.info('Listening on port {}'.format(port))
try:
while not self.exit_flag:
time.sleep(5)
except KeyboardInterrupt:
self.stop()
def stop(self):
logger.info('Server is shuting down ......')
self.exit_flag = True
self.server.stop(0)
logger.info('Server is closed')

327
mishards/service_handler.py Normal file
View File

@ -0,0 +1,327 @@
import logging
from contextlib import contextmanager
from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2
from grpc_utils.grpc_args_parser import GrpcArgsParser as Parser
logger = logging.getLogger(__name__)
class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
def __init__(self, conn_mgr, *args, **kwargs):
self.conn_mgr = conn_mgr
self.table_meta = {}
@property
def connection(self):
conn = self.conn_mgr.conn('WOSERVER')
if conn:
conn.on_connect()
return conn.conn
def CreateTable(self, request, context):
_status, _table_schema = Parser.parse_proto_TableSchema(request)
if not _status.OK():
return status_pb2.Status(error_code=_status.code, reason=_status.message)
logger.info('CreateTable {}'.format(_table_schema['table_name']))
_status = self.connection.create_table(_table_schema)
return status_pb2.Status(error_code=_status.code, reason=_status.message)
def HasTable(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
if not _status.OK():
return milvus_pb2.BoolReply(
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
bool_reply=False
)
logger.info('HasTable {}'.format(_table_name))
_bool = self.connection.has_table(_table_name)
return milvus_pb2.BoolReply(
status=status_pb2.Status(error_code=status_pb2.SUCCESS, reason="OK"),
bool_reply=_bool
)
def DropTable(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
if not _status.OK():
return status_pb2.Status(error_code=_status.code, reason=_status.message)
logger.info('DropTable {}'.format(_table_name))
_status = self.connection.delete_table(_table_name)
return status_pb2.Status(error_code=_status.code, reason=_status.message)
def CreateIndex(self, request, context):
_status, unpacks = Parser.parse_proto_IndexParam(request)
if not _status.OK():
return status_pb2.Status(error_code=_status.code, reason=_status.message)
_table_name, _index = unpacks
logger.info('CreateIndex {}'.format(_table_name))
# TODO: interface create_table incompleted
_status = self.connection.create_index(_table_name, _index)
return status_pb2.Status(error_code=_status.code, reason=_status.message)
def Insert(self, request, context):
logger.info('Insert')
# TODO: Ths SDK interface add_vectors() could update, add a key 'row_id_array'
_status, _ids = self.connection.add_vectors(None, None, insert_param=request)
return milvus_pb2.VectorIds(
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
vector_id_array=_ids
)
def Search(self, request, context):
try:
table_name = request.table_name
topk = request.topk
nprobe = request.nprobe
logger.info('Search {}: topk={} nprobe={}'.format(table_name, topk, nprobe))
if nprobe > 2048 or nprobe <= 0:
raise exceptions.GRPCInvlidArgument('Invalid nprobe: {}'.format(nprobe))
table_meta = self.table_meta.get(table_name, None)
if not table_meta:
status, info = self.connection.describe_table(table_name)
if not status.OK():
raise TableNotFoundException(table_name)
self.table_meta[table_name] = info
table_meta = info
start = time.time()
query_record_array = []
for query_record in request.query_record_array:
query_record_array.append(list(query_record.vector_data))
query_range_array = []
for query_range in request.query_range_array:
query_range_array.append(
Range(query_range.start_value, query_range.end_value))
except (TableNotFoundException, exceptions.GRPCInvlidArgument) as exc:
return milvus_pb2.TopKQueryResultList(
status=status_pb2.Status(error_code=exc.code, reason=exc.message)
)
except Exception as e:
return milvus_pb2.TopKQueryResultList(
status=status_pb2.Status(error_code=status_pb2.UNEXPECTED_ERROR, reason=str(e))
)
try:
results = workflow.query_vectors(table_name, table_meta, query_record_array, topk,
nprobe, query_range_array)
except (exceptions.GRPCQueryInvalidRangeException, TableNotFoundException) as exc:
return milvus_pb2.TopKQueryResultList(
status=status_pb2.Status(error_code=exc.code, reason=exc.message)
)
except exceptions.ServiceNotFoundException as exc:
return milvus_pb2.TopKQueryResultList(
status=status_pb2.Status(error_code=status_pb2.UNEXPECTED_ERROR, reason=exc.message)
)
except Exception as e:
logger.error(e)
results = workflow.query_vectors(table_name, table_meta, query_record_array,
topk, nprobe, query_range_array)
now = time.time()
logger.info('SearchVector Ends @{}'.format(now))
logger.info('SearchVector takes: {}'.format(now - start))
topk_result_list = milvus_pb2.TopKQueryResultList(
status=status_pb2.Status(error_code=status_pb2.SUCCESS, reason="Success"),
topk_query_result=results
)
return topk_result_list
def SearchInFiles(self, request, context):
try:
file_id_array = list(request.file_id_array)
search_param = request.search_param
table_name = search_param.table_name
topk = search_param.topk
nprobe = search_param.nprobe
query_record_array = []
for query_record in search_param.query_record_array:
query_record_array.append(list(query_record))
query_range_array = []
for query_range in search_param.query_range_array:
query_range_array.append("")
except Exception as e:
milvus_pb2.TopKQueryResultList(
status=status_pb2.Status(error_code=status_pb2.UNEXPECTED_ERROR, reason=str(e)),
)
res = search_vector_in_files.delay(table_name=table_name,
file_id_array=file_id_array,
query_record_array=query_record_array,
query_range_array=query_range_array,
topk=topk,
nprobe=nprobe)
status, result = res.get(timeout=1)
if not status.OK():
raise ThriftException(code=status.code, reason=status.message)
res = TopKQueryResult()
for top_k_query_results in result:
res.query_result_arrays.append([QueryResult(id=qr.id, distance=qr.distance)
for qr in top_k_query_results])
return res
def DescribeTable(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
if not _status.OK():
table_name = milvus_pb2.TableName(
status=status_pb2.Status(error_code=_status.code, reason=_status.message)
)
return milvus_pb2.TableSchema(
table_name=table_name
)
logger.info('DescribeTable {}'.format(_table_name))
_status, _table = self.connection.describe_table(_table_name)
if _status.OK():
_grpc_table_name = milvus_pb2.TableName(
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
table_name=_table.table_name
)
return milvus_pb2.TableSchema(
table_name=_grpc_table_name,
index_file_size=_table.index_file_size,
dimension=_table.dimension,
metric_type=_table.metric_type
)
return milvus_pb2.TableSchema(
table_name=milvus_pb2.TableName(
status=status_pb2.Status(error_code=_status.code, reason=_status.message)
)
)
def CountTable(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
if not _status.OK():
status = status_pb2.Status(error_code=_status.code, reason=_status.message)
return milvus_pb2.TableRowCount(
status=status
)
logger.info('CountTable {}'.format(_table_name))
_status, _count = self.connection.get_table_row_count(_table_name)
return milvus_pb2.TableRowCount(
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
table_row_count=_count if isinstance(_count, int) else -1)
def Cmd(self, request, context):
_status, _cmd = Parser.parse_proto_Command(request)
logger.info('Cmd: {}'.format(_cmd))
if not _status.OK():
return milvus_pb2.StringReply(
status_pb2.Status(error_code=_status.code, reason=_status.message)
)
if _cmd == 'version':
_status, _reply = self.connection.server_version()
else:
_status, _reply = self.connection.server_status()
return milvus_pb2.StringReply(
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
string_reply=_reply
)
def ShowTables(self, request, context):
logger.info('ShowTables')
_status, _results = self.connection.show_tables()
if not _status.OK():
_results = []
for _result in _results:
yield milvus_pb2.TableName(
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
table_name=_result
)
def DeleteByRange(self, request, context):
_status, unpacks = \
Parser.parse_proto_DeleteByRangeParam(request)
if not _status.OK():
return status_pb2.Status(error_code=_status.code, reason=_status.message)
_table_name, _start_date, _end_date = unpacks
logger.info('DeleteByRange {}: {} {}'.format(_table_name, _start_date, _end_date))
_status = self.connection.delete_vectors_by_range(_table_name, _start_date, _end_date)
return status_pb2.Status(error_code=_status.code, reason=_status.message)
def PreloadTable(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
if not _status.OK():
return status_pb2.Status(error_code=_status.code, reason=_status.message)
logger.info('PreloadTable {}'.format(_table_name))
_status = self.connection.preload_table(_table_name)
return status_pb2.Status(error_code=_status.code, reason=_status.message)
def DescribeIndex(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
if not _status.OK():
return milvus_pb2.IndexParam(
table_name=milvus_pb2.TableName(
status=status_pb2.Status(error_code=_status.code, reason=_status.message)
)
)
logger.info('DescribeIndex {}'.format(_table_name))
_status, _index_param = self.connection.describe_index(_table_name)
_index = milvus_pb2.Index(index_type=_index_param._index_type, nlist=_index_param._nlist)
_tablename = milvus_pb2.TableName(
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
table_name=_table_name)
return milvus_pb2.IndexParam(table_name=_tablename, index=_index)
def DropIndex(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
if not _status.OK():
return status_pb2.Status(error_code=_status.code, reason=_status.message)
logger.info('DropIndex {}'.format(_table_name))
_status = self.connection.drop_index(_table_name)
return status_pb2.Status(error_code=_status.code, reason=_status.message)

View File

@ -22,6 +22,8 @@ config(LOG_LEVEL, LOG_PATH, LOG_NAME, TIMEZONE)
TIMEOUT = env.int('TIMEOUT', 60)
MAX_RETRY = env.int('MAX_RETRY', 3)
SERVER_PORT = env.int('SERVER_PORT', 19530)
WOSERVER = env.str('WOSERVER')
if __name__ == '__main__':
import logging

View File

@ -1,11 +0,0 @@
import logging
import grpco
from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2
logger = logging.getLogger(__name__)
class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
def __init__(self, connections, *args, **kwargs):
self.connections = self.connections