mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
optimize exception handlers
This commit is contained in:
parent
09d3e78449
commit
eb9174f2d9
@ -17,3 +17,5 @@ discover = ServiceFounder(namespace=settings.SD_NAMESPACE,
|
||||
|
||||
from mishards.server import Server
|
||||
grpc_server = Server(conn_mgr=connect_mgr)
|
||||
|
||||
from mishards import exception_handlers
|
||||
|
||||
@ -4,3 +4,4 @@ CONNECT_ERROR_CODE = 10001
|
||||
CONNECTTION_NOT_FOUND_CODE = 10002
|
||||
|
||||
TABLE_NOT_FOUND_CODE = 20001
|
||||
INVALID_ARGUMENT = 20002
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2
|
||||
from mishards import server, exceptions
|
||||
from mishards import grpc_server as server, exceptions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -26,10 +26,18 @@ def resp_handler(err, error_code):
|
||||
if resp_class == milvus_pb2.TopKQueryResultList:
|
||||
return resp_class(status=status, topk_query_result=[])
|
||||
|
||||
if resp_class == milvus_pb2.TableRowCount:
|
||||
return resp_class(status=status, table_row_count=-1)
|
||||
|
||||
status.error_code = status_pb2.UNEXPECTED_ERROR
|
||||
return status
|
||||
|
||||
@server.error_handler(exceptions.TableNotFoundError)
|
||||
@server.errorhandler(exceptions.TableNotFoundError)
|
||||
def TableNotFoundErrorHandler(err):
|
||||
logger.error(err)
|
||||
return resp_handler(err, status_pb2.TABLE_NOT_EXISTS)
|
||||
|
||||
@server.errorhandler(exceptions.InvalidArgumentError)
|
||||
def InvalidArgumentErrorHandler(err):
|
||||
logger.error(err)
|
||||
return resp_handler(err, status_pb2.ILLEGAL_ARGUMENT)
|
||||
|
||||
@ -15,3 +15,6 @@ class ConnectionNotFoundError(BaseException):
|
||||
|
||||
class TableNotFoundError(BaseException):
|
||||
code = codes.TABLE_NOT_FOUND_CODE
|
||||
|
||||
class InvalidArgumentError(BaseException):
|
||||
code = codes.INVALID_ARGUMENT
|
||||
|
||||
@ -0,0 +1,3 @@
|
||||
def mark_grpc_method(func):
|
||||
setattr(func, 'grpc_method', True)
|
||||
return func
|
||||
@ -7,6 +7,7 @@ from urllib.parse import urlparse
|
||||
from functools import wraps
|
||||
from concurrent import futures
|
||||
from grpc._cython import cygrpc
|
||||
from grpc._channel import _Rendezvous, _UnaryUnaryMultiCallable
|
||||
from milvus.grpc_gen.milvus_pb2_grpc import add_MilvusServiceServicer_to_server
|
||||
from mishards.service_handler import ServiceHandler
|
||||
from mishards import settings, discover
|
||||
@ -17,7 +18,8 @@ logger = logging.getLogger(__name__)
|
||||
class Server:
|
||||
def __init__(self, conn_mgr, port=19530, max_workers=10, **kwargs):
|
||||
self.pre_run_handlers = set()
|
||||
self.error_handler = {}
|
||||
self.grpc_methods = set()
|
||||
self.error_handlers = {}
|
||||
self.exit_flag = False
|
||||
self.port = int(port)
|
||||
self.conn_mgr = conn_mgr
|
||||
@ -42,6 +44,18 @@ class Server:
|
||||
self.pre_run_handlers.add(func)
|
||||
return func
|
||||
|
||||
def wrap_method_with_errorhandler(self, func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
if e.__class__ in self.error_handlers:
|
||||
return self.error_handlers[e.__class__](e)
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
def errorhandler(self, exception):
|
||||
if inspect.isclass(exception) and issubclass(exception, Exception):
|
||||
def wrapper(func):
|
||||
@ -56,7 +70,8 @@ class Server:
|
||||
discover.start()
|
||||
|
||||
def start(self, port=None):
|
||||
add_MilvusServiceServicer_to_server(ServiceHandler(conn_mgr=self.conn_mgr), self.server_impl)
|
||||
handler_class = self.add_error_handlers(ServiceHandler)
|
||||
add_MilvusServiceServicer_to_server(handler_class(conn_mgr=self.conn_mgr), self.server_impl)
|
||||
self.server_impl.add_insecure_port("[::]:{}".format(str(port or self._port)))
|
||||
self.server_impl.start()
|
||||
|
||||
@ -80,3 +95,10 @@ class Server:
|
||||
self.exit_flag = True
|
||||
self.server_impl.stop(0)
|
||||
logger.info('Server is closed')
|
||||
|
||||
def add_error_handlers(self, target):
|
||||
for key, attr in target.__dict__.items():
|
||||
is_grpc_method = getattr(attr, 'grpc_method', False)
|
||||
if is_grpc_method:
|
||||
setattr(target, key, self.wrap_method_with_errorhandler(attr))
|
||||
return target
|
||||
|
||||
@ -12,6 +12,7 @@ from milvus.grpc_gen.milvus_pb2 import TopKQueryResult
|
||||
from milvus.client import types
|
||||
|
||||
from mishards import (db, settings, exceptions)
|
||||
from mishards.grpc_utils import mark_grpc_method
|
||||
from mishards.grpc_utils.grpc_args_parser import GrpcArgsParser as Parser
|
||||
from mishards.models import Tables, TableFiles
|
||||
from mishards.hash_ring import HashRing
|
||||
@ -24,9 +25,10 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
def __init__(self, conn_mgr, *args, **kwargs):
|
||||
self.conn_mgr = conn_mgr
|
||||
self.table_meta = {}
|
||||
self.error_handlers = {}
|
||||
|
||||
def connection(self, metadata=None):
|
||||
conn = self.conn_mgr.conn('WOSERVER')
|
||||
conn = self.conn_mgr.conn('WOSERVER', metadata=metadata)
|
||||
if conn:
|
||||
conn.on_connect(metadata=metadata)
|
||||
return conn.conn
|
||||
@ -149,6 +151,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
reverse = table_meta.metric_type == types.MetricType.IP
|
||||
return self._do_merge(all_topk_results, topk, reverse=reverse, metadata=metadata)
|
||||
|
||||
@mark_grpc_method
|
||||
def CreateTable(self, request, context):
|
||||
_status, _table_schema = Parser.parse_proto_TableSchema(request)
|
||||
|
||||
@ -161,6 +164,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
|
||||
@mark_grpc_method
|
||||
def HasTable(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
@ -181,6 +185,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
bool_reply=_bool
|
||||
)
|
||||
|
||||
@mark_grpc_method
|
||||
def DropTable(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
@ -193,6 +198,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
|
||||
@mark_grpc_method
|
||||
def CreateIndex(self, request, context):
|
||||
_status, unpacks = Parser.parse_proto_IndexParam(request)
|
||||
|
||||
@ -208,6 +214,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
|
||||
@mark_grpc_method
|
||||
def Insert(self, request, context):
|
||||
logger.info('Insert')
|
||||
# TODO: Ths SDK interface add_vectors() could update, add a key 'row_id_array'
|
||||
@ -219,6 +226,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
vector_id_array=_ids
|
||||
)
|
||||
|
||||
@mark_grpc_method
|
||||
def Search(self, request, context):
|
||||
|
||||
table_name = request.table_name
|
||||
@ -228,14 +236,16 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
|
||||
logger.info('Search {}: topk={} nprobe={}'.format(table_name, topk, nprobe))
|
||||
|
||||
if nprobe > self.MAX_NPROBE or nprobe <= 0:
|
||||
raise exceptions.GRPCInvlidArgument('Invalid nprobe: {}'.format(nprobe))
|
||||
|
||||
table_meta = self.table_meta.get(table_name, None)
|
||||
|
||||
metadata = {
|
||||
'resp_class': milvus_pb2.TopKQueryResultList
|
||||
}
|
||||
|
||||
if nprobe > self.MAX_NPROBE or nprobe <= 0:
|
||||
raise exceptions.InvalidArgumentError(message='Invalid nprobe: {}'.format(nprobe),
|
||||
metadata=metadata)
|
||||
|
||||
table_meta = self.table_meta.get(table_name, None)
|
||||
|
||||
if not table_meta:
|
||||
status, info = self.connection(metadata=metadata).describe_table(table_name)
|
||||
if not status.OK():
|
||||
@ -268,9 +278,11 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
)
|
||||
return topk_result_list
|
||||
|
||||
@mark_grpc_method
|
||||
def SearchInFiles(self, request, context):
|
||||
raise NotImplemented()
|
||||
|
||||
@mark_grpc_method
|
||||
def DescribeTable(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
@ -304,6 +316,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
)
|
||||
)
|
||||
|
||||
@mark_grpc_method
|
||||
def CountTable(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
@ -316,12 +329,16 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
|
||||
logger.info('CountTable {}'.format(_table_name))
|
||||
|
||||
_status, _count = self.connection.get_table_row_count(_table_name)
|
||||
metadata = {
|
||||
'resp_class': milvus_pb2.TableRowCount
|
||||
}
|
||||
_status, _count = self.connection(metadata=metadata).get_table_row_count(_table_name)
|
||||
|
||||
return milvus_pb2.TableRowCount(
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
table_row_count=_count if isinstance(_count, int) else -1)
|
||||
|
||||
@mark_grpc_method
|
||||
def Cmd(self, request, context):
|
||||
_status, _cmd = Parser.parse_proto_Command(request)
|
||||
logger.info('Cmd: {}'.format(_cmd))
|
||||
@ -341,6 +358,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
string_reply=_reply
|
||||
)
|
||||
|
||||
@mark_grpc_method
|
||||
def ShowTables(self, request, context):
|
||||
logger.info('ShowTables')
|
||||
_status, _results = self.connection.show_tables()
|
||||
@ -354,6 +372,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
table_name=_result
|
||||
)
|
||||
|
||||
@mark_grpc_method
|
||||
def DeleteByRange(self, request, context):
|
||||
_status, unpacks = \
|
||||
Parser.parse_proto_DeleteByRangeParam(request)
|
||||
@ -367,6 +386,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
_status = self.connection.delete_vectors_by_range(_table_name, _start_date, _end_date)
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
|
||||
@mark_grpc_method
|
||||
def PreloadTable(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
@ -377,6 +397,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
_status = self.connection.preload_table(_table_name)
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
|
||||
@mark_grpc_method
|
||||
def DescribeIndex(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
@ -397,6 +418,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
|
||||
return milvus_pb2.IndexParam(table_name=_tablename, index=_index)
|
||||
|
||||
@mark_grpc_method
|
||||
def DropIndex(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user