optimize exception handlers

This commit is contained in:
peng.xu 2019-09-21 09:56:19 +08:00
parent 09d3e78449
commit eb9174f2d9
7 changed files with 72 additions and 11 deletions

View File

@ -17,3 +17,5 @@ discover = ServiceFounder(namespace=settings.SD_NAMESPACE,
from mishards.server import Server
grpc_server = Server(conn_mgr=connect_mgr)
from mishards import exception_handlers

View File

@ -4,3 +4,4 @@ CONNECT_ERROR_CODE = 10001
CONNECTTION_NOT_FOUND_CODE = 10002
TABLE_NOT_FOUND_CODE = 20001
INVALID_ARGUMENT = 20002

View File

@ -1,6 +1,6 @@
import logging
from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2
from mishards import server, exceptions
from mishards import grpc_server as server, exceptions
logger = logging.getLogger(__name__)
@ -26,10 +26,18 @@ def resp_handler(err, error_code):
if resp_class == milvus_pb2.TopKQueryResultList:
return resp_class(status=status, topk_query_result=[])
if resp_class == milvus_pb2.TableRowCount:
return resp_class(status=status, table_row_count=-1)
status.error_code = status_pb2.UNEXPECTED_ERROR
return status
@server.error_handler(exceptions.TableNotFoundError)
@server.errorhandler(exceptions.TableNotFoundError)
def TableNotFoundErrorHandler(err):
logger.error(err)
return resp_handler(err, status_pb2.TABLE_NOT_EXISTS)
@server.errorhandler(exceptions.InvalidArgumentError)
def InvalidArgumentErrorHandler(err):
logger.error(err)
return resp_handler(err, status_pb2.ILLEGAL_ARGUMENT)

View File

@ -15,3 +15,6 @@ class ConnectionNotFoundError(BaseException):
class TableNotFoundError(BaseException):
code = codes.TABLE_NOT_FOUND_CODE
class InvalidArgumentError(BaseException):
code = codes.INVALID_ARGUMENT

View File

@ -0,0 +1,3 @@
def mark_grpc_method(func):
setattr(func, 'grpc_method', True)
return func

View File

@ -7,6 +7,7 @@ from urllib.parse import urlparse
from functools import wraps
from concurrent import futures
from grpc._cython import cygrpc
from grpc._channel import _Rendezvous, _UnaryUnaryMultiCallable
from milvus.grpc_gen.milvus_pb2_grpc import add_MilvusServiceServicer_to_server
from mishards.service_handler import ServiceHandler
from mishards import settings, discover
@ -17,7 +18,8 @@ logger = logging.getLogger(__name__)
class Server:
def __init__(self, conn_mgr, port=19530, max_workers=10, **kwargs):
self.pre_run_handlers = set()
self.error_handler = {}
self.grpc_methods = set()
self.error_handlers = {}
self.exit_flag = False
self.port = int(port)
self.conn_mgr = conn_mgr
@ -42,6 +44,18 @@ class Server:
self.pre_run_handlers.add(func)
return func
def wrap_method_with_errorhandler(self, func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
if e.__class__ in self.error_handlers:
return self.error_handlers[e.__class__](e)
raise
return wrapper
def errorhandler(self, exception):
if inspect.isclass(exception) and issubclass(exception, Exception):
def wrapper(func):
@ -56,7 +70,8 @@ class Server:
discover.start()
def start(self, port=None):
add_MilvusServiceServicer_to_server(ServiceHandler(conn_mgr=self.conn_mgr), self.server_impl)
handler_class = self.add_error_handlers(ServiceHandler)
add_MilvusServiceServicer_to_server(handler_class(conn_mgr=self.conn_mgr), self.server_impl)
self.server_impl.add_insecure_port("[::]:{}".format(str(port or self._port)))
self.server_impl.start()
@ -80,3 +95,10 @@ class Server:
self.exit_flag = True
self.server_impl.stop(0)
logger.info('Server is closed')
def add_error_handlers(self, target):
for key, attr in target.__dict__.items():
is_grpc_method = getattr(attr, 'grpc_method', False)
if is_grpc_method:
setattr(target, key, self.wrap_method_with_errorhandler(attr))
return target

View File

@ -12,6 +12,7 @@ from milvus.grpc_gen.milvus_pb2 import TopKQueryResult
from milvus.client import types
from mishards import (db, settings, exceptions)
from mishards.grpc_utils import mark_grpc_method
from mishards.grpc_utils.grpc_args_parser import GrpcArgsParser as Parser
from mishards.models import Tables, TableFiles
from mishards.hash_ring import HashRing
@ -24,9 +25,10 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
def __init__(self, conn_mgr, *args, **kwargs):
self.conn_mgr = conn_mgr
self.table_meta = {}
self.error_handlers = {}
def connection(self, metadata=None):
conn = self.conn_mgr.conn('WOSERVER')
conn = self.conn_mgr.conn('WOSERVER', metadata=metadata)
if conn:
conn.on_connect(metadata=metadata)
return conn.conn
@ -149,6 +151,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
reverse = table_meta.metric_type == types.MetricType.IP
return self._do_merge(all_topk_results, topk, reverse=reverse, metadata=metadata)
@mark_grpc_method
def CreateTable(self, request, context):
_status, _table_schema = Parser.parse_proto_TableSchema(request)
@ -161,6 +164,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
return status_pb2.Status(error_code=_status.code, reason=_status.message)
@mark_grpc_method
def HasTable(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
@ -181,6 +185,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
bool_reply=_bool
)
@mark_grpc_method
def DropTable(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
@ -193,6 +198,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
return status_pb2.Status(error_code=_status.code, reason=_status.message)
@mark_grpc_method
def CreateIndex(self, request, context):
_status, unpacks = Parser.parse_proto_IndexParam(request)
@ -208,6 +214,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
return status_pb2.Status(error_code=_status.code, reason=_status.message)
@mark_grpc_method
def Insert(self, request, context):
logger.info('Insert')
# TODO: Ths SDK interface add_vectors() could update, add a key 'row_id_array'
@ -219,6 +226,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
vector_id_array=_ids
)
@mark_grpc_method
def Search(self, request, context):
table_name = request.table_name
@ -228,14 +236,16 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
logger.info('Search {}: topk={} nprobe={}'.format(table_name, topk, nprobe))
if nprobe > self.MAX_NPROBE or nprobe <= 0:
raise exceptions.GRPCInvlidArgument('Invalid nprobe: {}'.format(nprobe))
table_meta = self.table_meta.get(table_name, None)
metadata = {
'resp_class': milvus_pb2.TopKQueryResultList
}
if nprobe > self.MAX_NPROBE or nprobe <= 0:
raise exceptions.InvalidArgumentError(message='Invalid nprobe: {}'.format(nprobe),
metadata=metadata)
table_meta = self.table_meta.get(table_name, None)
if not table_meta:
status, info = self.connection(metadata=metadata).describe_table(table_name)
if not status.OK():
@ -268,9 +278,11 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
)
return topk_result_list
@mark_grpc_method
def SearchInFiles(self, request, context):
raise NotImplemented()
@mark_grpc_method
def DescribeTable(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
@ -304,6 +316,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
)
)
@mark_grpc_method
def CountTable(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
@ -316,12 +329,16 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
logger.info('CountTable {}'.format(_table_name))
_status, _count = self.connection.get_table_row_count(_table_name)
metadata = {
'resp_class': milvus_pb2.TableRowCount
}
_status, _count = self.connection(metadata=metadata).get_table_row_count(_table_name)
return milvus_pb2.TableRowCount(
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
table_row_count=_count if isinstance(_count, int) else -1)
@mark_grpc_method
def Cmd(self, request, context):
_status, _cmd = Parser.parse_proto_Command(request)
logger.info('Cmd: {}'.format(_cmd))
@ -341,6 +358,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
string_reply=_reply
)
@mark_grpc_method
def ShowTables(self, request, context):
logger.info('ShowTables')
_status, _results = self.connection.show_tables()
@ -354,6 +372,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
table_name=_result
)
@mark_grpc_method
def DeleteByRange(self, request, context):
_status, unpacks = \
Parser.parse_proto_DeleteByRangeParam(request)
@ -367,6 +386,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
_status = self.connection.delete_vectors_by_range(_table_name, _start_date, _end_date)
return status_pb2.Status(error_code=_status.code, reason=_status.message)
@mark_grpc_method
def PreloadTable(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
@ -377,6 +397,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
_status = self.connection.preload_table(_table_name)
return status_pb2.Status(error_code=_status.code, reason=_status.message)
@mark_grpc_method
def DescribeIndex(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)
@ -397,6 +418,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
return milvus_pb2.IndexParam(table_name=_tablename, index=_index)
@mark_grpc_method
def DropIndex(self, request, context):
_status, _table_name = Parser.parse_proto_TableName(request)