add exception handler

This commit is contained in:
peng.xu 2019-09-19 19:41:20 +08:00
parent 5249b80b0d
commit 09d3e78449
5 changed files with 89 additions and 31 deletions

View File

@ -24,14 +24,14 @@ class Connection:
def __str__(self):
return 'Connection:name=\"{}\";uri=\"{}\"'.format(self.name, self.uri)
def _connect(self):
def _connect(self, metadata=None):
try:
self.conn.connect(uri=self.uri)
except Exception as e:
if not self.error_handlers:
raise exceptions.ConnectionConnectError(e)
raise exceptions.ConnectionConnectError(message=str(e), metadata=metadata)
for handler in self.error_handlers:
handler(e)
handler(e, metadata=metadata)
@property
def can_retry(self):
@ -47,14 +47,15 @@ class Connection:
else:
logger.warn('{} is retrying {}'.format(self, self.retried))
def on_connect(self):
def on_connect(self, metadata=None):
while not self.connected and self.can_retry:
self.retried += 1
self.on_retry()
self._connect()
self._connect(metadata=metadata)
if not self.can_retry and not self.connected:
raise exceptions.ConnectionConnectError(message='Max retry {} reached!'.format(self.max_retry))
raise exceptions.ConnectionConnectError(message='Max retry {} reached!'.format(self.max_retry,
metadata=metadata))
self.retried = 0
@ -81,14 +82,15 @@ class ConnectionMgr:
def conn_names(self):
return set(self.metas.keys()) - set(['WOSERVER'])
def conn(self, name, throw=False):
def conn(self, name, metadata, throw=False):
c = self.conns.get(name, None)
if not c:
url = self.metas.get(name, None)
if not url:
if not throw:
return None
raise exceptions.ConnectionNotFoundError('Connection {} not found'.format(name))
raise exceptions.ConnectionNotFoundError(message='Connection {} not found'.format(name),
metadata=metadata)
this_conn = Connection(name=name, uri=url, max_retry=settings.MAX_RETRY)
threaded = {
threading.get_ident() : this_conn
@ -103,7 +105,8 @@ class ConnectionMgr:
if not url:
if not throw:
return None
raise exceptions.ConnectionNotFoundError('Connection {} not found'.format(name))
raise exceptions.ConnectionNotFoundError('Connection {} not found'.format(name),
metadata=metadata)
this_conn = Connection(name=name, uri=url, max_retry=settings.MAX_RETRY)
c[tid] = this_conn
return this_conn

View File

@ -0,0 +1,35 @@
import logging
from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2
from mishards import server, exceptions
logger = logging.getLogger(__name__)
def resp_handler(err, error_code):
if not isinstance(err, exceptions.BaseException):
return status_pb2.Status(error_code=error_code, reason=str(err))
status = status_pb2.Status(error_code=error_code, reason=err.message)
if err.metadata is None:
return status
resp_class = err.metadata.get('resp_class', None)
if not resp_class:
return status
if resp_class == milvus_pb2.BoolReply:
return resp_class(status=status, bool_reply=False)
if resp_class == milvus_pb2.VectorIds:
return resp_class(status=status, vector_id_array=[])
if resp_class == milvus_pb2.TopKQueryResultList:
return resp_class(status=status, topk_query_result=[])
status.error_code = status_pb2.UNEXPECTED_ERROR
return status
@server.error_handler(exceptions.TableNotFoundError)
def TableNotFoundErrorHandler(err):
logger.error(err)
return resp_handler(err, status_pb2.TABLE_NOT_EXISTS)

View File

@ -3,8 +3,9 @@ import mishards.exception_codes as codes
class BaseException(Exception):
code = codes.INVALID_CODE
message = 'BaseException'
def __init__(self, message=''):
def __init__(self, message='', metadata=None):
self.message = self.__class__.__name__ if not message else message
self.metadata = metadata
class ConnectionConnectError(BaseException):
code = codes.CONNECT_ERROR_CODE

View File

@ -2,6 +2,7 @@ import logging
import grpc
import time
import socket
import inspect
from urllib.parse import urlparse
from functools import wraps
from concurrent import futures
@ -16,6 +17,7 @@ logger = logging.getLogger(__name__)
class Server:
def __init__(self, conn_mgr, port=19530, max_workers=10, **kwargs):
self.pre_run_handlers = set()
self.error_handler = {}
self.exit_flag = False
self.port = int(port)
self.conn_mgr = conn_mgr
@ -40,6 +42,14 @@ class Server:
self.pre_run_handlers.add(func)
return func
def errorhandler(self, exception):
if inspect.isclass(exception) and issubclass(exception, Exception):
def wrapper(func):
self.error_handlers[exception] = func
return func
return wrapper
return exception
def on_pre_run(self):
for handler in self.pre_run_handlers:
handler()

View File

@ -25,18 +25,17 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
self.conn_mgr = conn_mgr
self.table_meta = {}
@property
def connection(self):
def connection(self, metadata=None):
conn = self.conn_mgr.conn('WOSERVER')
if conn:
conn.on_connect()
conn.on_connect(metadata=metadata)
return conn.conn
def query_conn(self, name):
conn = self.conn_mgr.conn(name)
def query_conn(self, name, metadata=None):
conn = self.conn_mgr.conn(name, metadata=metadata)
if not conn:
raise exceptions.ConnectionNotFoundError(name)
conn.on_connect()
raise exceptions.ConnectionNotFoundError(name, metadata=metadata)
conn.on_connect(metadata=metadata)
return conn.conn
def _format_date(self, start, end):
@ -55,14 +54,14 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
return self._format_date(start, end)
def _get_routing_file_ids(self, table_id, range_array):
def _get_routing_file_ids(self, table_id, range_array, metadata=None):
table = db.Session.query(Tables).filter(and_(
Tables.table_id==table_id,
Tables.state!=Tables.TO_DELETE
)).first()
if not table:
raise exceptions.TableNotFoundError(table_id)
raise exceptions.TableNotFoundError(table_id, metadata=metadata)
files = table.files_to_search(range_array)
servers = self.conn_mgr.conn_names
@ -84,7 +83,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
return routing
def _do_merge(self, files_n_topk_results, topk, reverse=False):
def _do_merge(self, files_n_topk_results, topk, reverse=False, **kwargs):
if not files_n_topk_results:
return []
@ -111,9 +110,11 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
def _do_query(self, table_id, table_meta, vectors, topk, nprobe, range_array=None, **kwargs):
range_array = [self._range_to_date(r) for r in range_array] if range_array else None
routing = self._get_routing_file_ids(table_id, range_array)
routing = self._get_routing_file_ids(table_id, range_array, metadata=metadata)
logger.info('Routing: {}'.format(routing))
metadata = kwargs.get('metadata', None)
rs = []
all_topk_results = []
@ -124,7 +125,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
addr, query_params, len(vectors), topk, nprobe
))
conn = self.query_conn(addr)
conn = self.query_conn(addr, metadata=metadata)
start = time.time()
ret = conn.search_vectors_in_files(table_name=query_params['table_id'],
file_ids=query_params['file_ids'],
@ -146,7 +147,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
res.result()
reverse = table_meta.metric_type == types.MetricType.IP
return self._do_merge(all_topk_results, topk, reverse=reverse)
return self._do_merge(all_topk_results, topk, reverse=reverse, metadata=metadata)
def CreateTable(self, request, context):
_status, _table_schema = Parser.parse_proto_TableSchema(request)
@ -156,7 +157,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
logger.info('CreateTable {}'.format(_table_schema['table_name']))
_status = self.connection.create_table(_table_schema)
_status = self.connection().create_table(_table_schema)
return status_pb2.Status(error_code=_status.code, reason=_status.message)
@ -171,7 +172,9 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
logger.info('HasTable {}'.format(_table_name))
_bool = self.connection.has_table(_table_name)
_bool = self.connection(metadata={
'resp_class': milvus_pb2.BoolReply
}).has_table(_table_name)
return milvus_pb2.BoolReply(
status=status_pb2.Status(error_code=status_pb2.SUCCESS, reason="OK"),
@ -186,7 +189,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
logger.info('DropTable {}'.format(_table_name))
_status = self.connection.delete_table(_table_name)
_status = self.connection().delete_table(_table_name)
return status_pb2.Status(error_code=_status.code, reason=_status.message)
@ -201,14 +204,16 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
logger.info('CreateIndex {}'.format(_table_name))
# TODO: interface create_table incompleted
_status = self.connection.create_index(_table_name, _index)
_status = self.connection().create_index(_table_name, _index)
return status_pb2.Status(error_code=_status.code, reason=_status.message)
def Insert(self, request, context):
logger.info('Insert')
# TODO: Ths SDK interface add_vectors() could update, add a key 'row_id_array'
_status, _ids = self.connection.add_vectors(None, None, insert_param=request)
_status, _ids = self.connection(metadata={
'resp_class': milvus_pb2.VectorIds
}).add_vectors(None, None, insert_param=request)
return milvus_pb2.VectorIds(
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
vector_id_array=_ids
@ -227,10 +232,14 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
raise exceptions.GRPCInvlidArgument('Invalid nprobe: {}'.format(nprobe))
table_meta = self.table_meta.get(table_name, None)
metadata = {
'resp_class': milvus_pb2.TopKQueryResultList
}
if not table_meta:
status, info = self.connection.describe_table(table_name)
status, info = self.connection(metadata=metadata).describe_table(table_name)
if not status.OK():
raise exceptions.TableNotFoundError(table_name)
raise exceptions.TableNotFoundError(table_name, metadata=metadata)
self.table_meta[table_name] = info
table_meta = info
@ -248,7 +257,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
Range(query_range.start_value, query_range.end_value))
results = self._do_query(table_name, table_meta, query_record_array, topk,
nprobe, query_range_array)
nprobe, query_range_array, metadata=metadata)
now = time.time()
logger.info('SearchVector takes: {}'.format(now - start))