From eb9174f2d91355c218c4e256a7361d68e776b79e Mon Sep 17 00:00:00 2001 From: "peng.xu" Date: Sat, 21 Sep 2019 09:56:19 +0800 Subject: [PATCH] optimize exception handlers --- mishards/__init__.py | 2 ++ mishards/exception_codes.py | 1 + mishards/exception_handlers.py | 12 +++++++++-- mishards/exceptions.py | 3 +++ mishards/grpc_utils/__init__.py | 3 +++ mishards/server.py | 26 ++++++++++++++++++++++-- mishards/service_handler.py | 36 ++++++++++++++++++++++++++------- 7 files changed, 72 insertions(+), 11 deletions(-) diff --git a/mishards/__init__.py b/mishards/__init__.py index a792cd5ce9..8105e7edc8 100644 --- a/mishards/__init__.py +++ b/mishards/__init__.py @@ -17,3 +17,5 @@ discover = ServiceFounder(namespace=settings.SD_NAMESPACE, from mishards.server import Server grpc_server = Server(conn_mgr=connect_mgr) + +from mishards import exception_handlers diff --git a/mishards/exception_codes.py b/mishards/exception_codes.py index 32b29bdfab..37492f25d4 100644 --- a/mishards/exception_codes.py +++ b/mishards/exception_codes.py @@ -4,3 +4,4 @@ CONNECT_ERROR_CODE = 10001 CONNECTTION_NOT_FOUND_CODE = 10002 TABLE_NOT_FOUND_CODE = 20001 +INVALID_ARGUMENT = 20002 diff --git a/mishards/exception_handlers.py b/mishards/exception_handlers.py index 3de0918be4..6207f2088c 100644 --- a/mishards/exception_handlers.py +++ b/mishards/exception_handlers.py @@ -1,6 +1,6 @@ import logging from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2 -from mishards import server, exceptions +from mishards import grpc_server as server, exceptions logger = logging.getLogger(__name__) @@ -26,10 +26,18 @@ def resp_handler(err, error_code): if resp_class == milvus_pb2.TopKQueryResultList: return resp_class(status=status, topk_query_result=[]) + if resp_class == milvus_pb2.TableRowCount: + return resp_class(status=status, table_row_count=-1) + status.error_code = status_pb2.UNEXPECTED_ERROR return status -@server.error_handler(exceptions.TableNotFoundError) +@server.errorhandler(exceptions.TableNotFoundError) def TableNotFoundErrorHandler(err): logger.error(err) return resp_handler(err, status_pb2.TABLE_NOT_EXISTS) + +@server.errorhandler(exceptions.InvalidArgumentError) +def InvalidArgumentErrorHandler(err): + logger.error(err) + return resp_handler(err, status_pb2.ILLEGAL_ARGUMENT) diff --git a/mishards/exceptions.py b/mishards/exceptions.py index 1579fefcf4..4686cf674f 100644 --- a/mishards/exceptions.py +++ b/mishards/exceptions.py @@ -15,3 +15,6 @@ class ConnectionNotFoundError(BaseException): class TableNotFoundError(BaseException): code = codes.TABLE_NOT_FOUND_CODE + +class InvalidArgumentError(BaseException): + code = codes.INVALID_ARGUMENT diff --git a/mishards/grpc_utils/__init__.py b/mishards/grpc_utils/__init__.py index e69de29bb2..959d5549c7 100644 --- a/mishards/grpc_utils/__init__.py +++ b/mishards/grpc_utils/__init__.py @@ -0,0 +1,3 @@ +def mark_grpc_method(func): + setattr(func, 'grpc_method', True) + return func diff --git a/mishards/server.py b/mishards/server.py index b000016e29..9cca096b6b 100644 --- a/mishards/server.py +++ b/mishards/server.py @@ -7,6 +7,7 @@ from urllib.parse import urlparse from functools import wraps from concurrent import futures from grpc._cython import cygrpc +from grpc._channel import _Rendezvous, _UnaryUnaryMultiCallable from milvus.grpc_gen.milvus_pb2_grpc import add_MilvusServiceServicer_to_server from mishards.service_handler import ServiceHandler from mishards import settings, discover @@ -17,7 +18,8 @@ logger = logging.getLogger(__name__) class Server: def __init__(self, conn_mgr, port=19530, max_workers=10, **kwargs): self.pre_run_handlers = set() - self.error_handler = {} + self.grpc_methods = set() + self.error_handlers = {} self.exit_flag = False self.port = int(port) self.conn_mgr = conn_mgr @@ -42,6 +44,18 @@ class Server: self.pre_run_handlers.add(func) return func + def wrap_method_with_errorhandler(self, func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + if e.__class__ in self.error_handlers: + return self.error_handlers[e.__class__](e) + raise + + return wrapper + def errorhandler(self, exception): if inspect.isclass(exception) and issubclass(exception, Exception): def wrapper(func): @@ -56,7 +70,8 @@ class Server: discover.start() def start(self, port=None): - add_MilvusServiceServicer_to_server(ServiceHandler(conn_mgr=self.conn_mgr), self.server_impl) + handler_class = self.add_error_handlers(ServiceHandler) + add_MilvusServiceServicer_to_server(handler_class(conn_mgr=self.conn_mgr), self.server_impl) self.server_impl.add_insecure_port("[::]:{}".format(str(port or self._port))) self.server_impl.start() @@ -80,3 +95,10 @@ class Server: self.exit_flag = True self.server_impl.stop(0) logger.info('Server is closed') + + def add_error_handlers(self, target): + for key, attr in target.__dict__.items(): + is_grpc_method = getattr(attr, 'grpc_method', False) + if is_grpc_method: + setattr(target, key, self.wrap_method_with_errorhandler(attr)) + return target diff --git a/mishards/service_handler.py b/mishards/service_handler.py index 5346be91d8..acc04c5eee 100644 --- a/mishards/service_handler.py +++ b/mishards/service_handler.py @@ -12,6 +12,7 @@ from milvus.grpc_gen.milvus_pb2 import TopKQueryResult from milvus.client import types from mishards import (db, settings, exceptions) +from mishards.grpc_utils import mark_grpc_method from mishards.grpc_utils.grpc_args_parser import GrpcArgsParser as Parser from mishards.models import Tables, TableFiles from mishards.hash_ring import HashRing @@ -24,9 +25,10 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): def __init__(self, conn_mgr, *args, **kwargs): self.conn_mgr = conn_mgr self.table_meta = {} + self.error_handlers = {} def connection(self, metadata=None): - conn = self.conn_mgr.conn('WOSERVER') + conn = self.conn_mgr.conn('WOSERVER', metadata=metadata) if conn: conn.on_connect(metadata=metadata) return conn.conn @@ -149,6 +151,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): reverse = table_meta.metric_type == types.MetricType.IP return self._do_merge(all_topk_results, topk, reverse=reverse, metadata=metadata) + @mark_grpc_method def CreateTable(self, request, context): _status, _table_schema = Parser.parse_proto_TableSchema(request) @@ -161,6 +164,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): return status_pb2.Status(error_code=_status.code, reason=_status.message) + @mark_grpc_method def HasTable(self, request, context): _status, _table_name = Parser.parse_proto_TableName(request) @@ -181,6 +185,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): bool_reply=_bool ) + @mark_grpc_method def DropTable(self, request, context): _status, _table_name = Parser.parse_proto_TableName(request) @@ -193,6 +198,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): return status_pb2.Status(error_code=_status.code, reason=_status.message) + @mark_grpc_method def CreateIndex(self, request, context): _status, unpacks = Parser.parse_proto_IndexParam(request) @@ -208,6 +214,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): return status_pb2.Status(error_code=_status.code, reason=_status.message) + @mark_grpc_method def Insert(self, request, context): logger.info('Insert') # TODO: Ths SDK interface add_vectors() could update, add a key 'row_id_array' @@ -219,6 +226,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): vector_id_array=_ids ) + @mark_grpc_method def Search(self, request, context): table_name = request.table_name @@ -228,14 +236,16 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): logger.info('Search {}: topk={} nprobe={}'.format(table_name, topk, nprobe)) - if nprobe > self.MAX_NPROBE or nprobe <= 0: - raise exceptions.GRPCInvlidArgument('Invalid nprobe: {}'.format(nprobe)) - - table_meta = self.table_meta.get(table_name, None) - metadata = { 'resp_class': milvus_pb2.TopKQueryResultList } + + if nprobe > self.MAX_NPROBE or nprobe <= 0: + raise exceptions.InvalidArgumentError(message='Invalid nprobe: {}'.format(nprobe), + metadata=metadata) + + table_meta = self.table_meta.get(table_name, None) + if not table_meta: status, info = self.connection(metadata=metadata).describe_table(table_name) if not status.OK(): @@ -268,9 +278,11 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): ) return topk_result_list + @mark_grpc_method def SearchInFiles(self, request, context): raise NotImplemented() + @mark_grpc_method def DescribeTable(self, request, context): _status, _table_name = Parser.parse_proto_TableName(request) @@ -304,6 +316,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): ) ) + @mark_grpc_method def CountTable(self, request, context): _status, _table_name = Parser.parse_proto_TableName(request) @@ -316,12 +329,16 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): logger.info('CountTable {}'.format(_table_name)) - _status, _count = self.connection.get_table_row_count(_table_name) + metadata = { + 'resp_class': milvus_pb2.TableRowCount + } + _status, _count = self.connection(metadata=metadata).get_table_row_count(_table_name) return milvus_pb2.TableRowCount( status=status_pb2.Status(error_code=_status.code, reason=_status.message), table_row_count=_count if isinstance(_count, int) else -1) + @mark_grpc_method def Cmd(self, request, context): _status, _cmd = Parser.parse_proto_Command(request) logger.info('Cmd: {}'.format(_cmd)) @@ -341,6 +358,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): string_reply=_reply ) + @mark_grpc_method def ShowTables(self, request, context): logger.info('ShowTables') _status, _results = self.connection.show_tables() @@ -354,6 +372,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): table_name=_result ) + @mark_grpc_method def DeleteByRange(self, request, context): _status, unpacks = \ Parser.parse_proto_DeleteByRangeParam(request) @@ -367,6 +386,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): _status = self.connection.delete_vectors_by_range(_table_name, _start_date, _end_date) return status_pb2.Status(error_code=_status.code, reason=_status.message) + @mark_grpc_method def PreloadTable(self, request, context): _status, _table_name = Parser.parse_proto_TableName(request) @@ -377,6 +397,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): _status = self.connection.preload_table(_table_name) return status_pb2.Status(error_code=_status.code, reason=_status.message) + @mark_grpc_method def DescribeIndex(self, request, context): _status, _table_name = Parser.parse_proto_TableName(request) @@ -397,6 +418,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): return milvus_pb2.IndexParam(table_name=_tablename, index=_index) + @mark_grpc_method def DropIndex(self, request, context): _status, _table_name = Parser.parse_proto_TableName(request)