mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
add exception handler
This commit is contained in:
parent
5249b80b0d
commit
09d3e78449
@ -24,14 +24,14 @@ class Connection:
|
||||
def __str__(self):
|
||||
return 'Connection:name=\"{}\";uri=\"{}\"'.format(self.name, self.uri)
|
||||
|
||||
def _connect(self):
|
||||
def _connect(self, metadata=None):
|
||||
try:
|
||||
self.conn.connect(uri=self.uri)
|
||||
except Exception as e:
|
||||
if not self.error_handlers:
|
||||
raise exceptions.ConnectionConnectError(e)
|
||||
raise exceptions.ConnectionConnectError(message=str(e), metadata=metadata)
|
||||
for handler in self.error_handlers:
|
||||
handler(e)
|
||||
handler(e, metadata=metadata)
|
||||
|
||||
@property
|
||||
def can_retry(self):
|
||||
@ -47,14 +47,15 @@ class Connection:
|
||||
else:
|
||||
logger.warn('{} is retrying {}'.format(self, self.retried))
|
||||
|
||||
def on_connect(self):
|
||||
def on_connect(self, metadata=None):
|
||||
while not self.connected and self.can_retry:
|
||||
self.retried += 1
|
||||
self.on_retry()
|
||||
self._connect()
|
||||
self._connect(metadata=metadata)
|
||||
|
||||
if not self.can_retry and not self.connected:
|
||||
raise exceptions.ConnectionConnectError(message='Max retry {} reached!'.format(self.max_retry))
|
||||
raise exceptions.ConnectionConnectError(message='Max retry {} reached!'.format(self.max_retry,
|
||||
metadata=metadata))
|
||||
|
||||
self.retried = 0
|
||||
|
||||
@ -81,14 +82,15 @@ class ConnectionMgr:
|
||||
def conn_names(self):
|
||||
return set(self.metas.keys()) - set(['WOSERVER'])
|
||||
|
||||
def conn(self, name, throw=False):
|
||||
def conn(self, name, metadata, throw=False):
|
||||
c = self.conns.get(name, None)
|
||||
if not c:
|
||||
url = self.metas.get(name, None)
|
||||
if not url:
|
||||
if not throw:
|
||||
return None
|
||||
raise exceptions.ConnectionNotFoundError('Connection {} not found'.format(name))
|
||||
raise exceptions.ConnectionNotFoundError(message='Connection {} not found'.format(name),
|
||||
metadata=metadata)
|
||||
this_conn = Connection(name=name, uri=url, max_retry=settings.MAX_RETRY)
|
||||
threaded = {
|
||||
threading.get_ident() : this_conn
|
||||
@ -103,7 +105,8 @@ class ConnectionMgr:
|
||||
if not url:
|
||||
if not throw:
|
||||
return None
|
||||
raise exceptions.ConnectionNotFoundError('Connection {} not found'.format(name))
|
||||
raise exceptions.ConnectionNotFoundError('Connection {} not found'.format(name),
|
||||
metadata=metadata)
|
||||
this_conn = Connection(name=name, uri=url, max_retry=settings.MAX_RETRY)
|
||||
c[tid] = this_conn
|
||||
return this_conn
|
||||
|
||||
35
mishards/exception_handlers.py
Normal file
35
mishards/exception_handlers.py
Normal file
@ -0,0 +1,35 @@
|
||||
import logging
|
||||
from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2
|
||||
from mishards import server, exceptions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def resp_handler(err, error_code):
|
||||
if not isinstance(err, exceptions.BaseException):
|
||||
return status_pb2.Status(error_code=error_code, reason=str(err))
|
||||
|
||||
status = status_pb2.Status(error_code=error_code, reason=err.message)
|
||||
|
||||
if err.metadata is None:
|
||||
return status
|
||||
|
||||
resp_class = err.metadata.get('resp_class', None)
|
||||
if not resp_class:
|
||||
return status
|
||||
|
||||
if resp_class == milvus_pb2.BoolReply:
|
||||
return resp_class(status=status, bool_reply=False)
|
||||
|
||||
if resp_class == milvus_pb2.VectorIds:
|
||||
return resp_class(status=status, vector_id_array=[])
|
||||
|
||||
if resp_class == milvus_pb2.TopKQueryResultList:
|
||||
return resp_class(status=status, topk_query_result=[])
|
||||
|
||||
status.error_code = status_pb2.UNEXPECTED_ERROR
|
||||
return status
|
||||
|
||||
@server.error_handler(exceptions.TableNotFoundError)
|
||||
def TableNotFoundErrorHandler(err):
|
||||
logger.error(err)
|
||||
return resp_handler(err, status_pb2.TABLE_NOT_EXISTS)
|
||||
@ -3,8 +3,9 @@ import mishards.exception_codes as codes
|
||||
class BaseException(Exception):
|
||||
code = codes.INVALID_CODE
|
||||
message = 'BaseException'
|
||||
def __init__(self, message=''):
|
||||
def __init__(self, message='', metadata=None):
|
||||
self.message = self.__class__.__name__ if not message else message
|
||||
self.metadata = metadata
|
||||
|
||||
class ConnectionConnectError(BaseException):
|
||||
code = codes.CONNECT_ERROR_CODE
|
||||
|
||||
@ -2,6 +2,7 @@ import logging
|
||||
import grpc
|
||||
import time
|
||||
import socket
|
||||
import inspect
|
||||
from urllib.parse import urlparse
|
||||
from functools import wraps
|
||||
from concurrent import futures
|
||||
@ -16,6 +17,7 @@ logger = logging.getLogger(__name__)
|
||||
class Server:
|
||||
def __init__(self, conn_mgr, port=19530, max_workers=10, **kwargs):
|
||||
self.pre_run_handlers = set()
|
||||
self.error_handler = {}
|
||||
self.exit_flag = False
|
||||
self.port = int(port)
|
||||
self.conn_mgr = conn_mgr
|
||||
@ -40,6 +42,14 @@ class Server:
|
||||
self.pre_run_handlers.add(func)
|
||||
return func
|
||||
|
||||
def errorhandler(self, exception):
|
||||
if inspect.isclass(exception) and issubclass(exception, Exception):
|
||||
def wrapper(func):
|
||||
self.error_handlers[exception] = func
|
||||
return func
|
||||
return wrapper
|
||||
return exception
|
||||
|
||||
def on_pre_run(self):
|
||||
for handler in self.pre_run_handlers:
|
||||
handler()
|
||||
|
||||
@ -25,18 +25,17 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
self.conn_mgr = conn_mgr
|
||||
self.table_meta = {}
|
||||
|
||||
@property
|
||||
def connection(self):
|
||||
def connection(self, metadata=None):
|
||||
conn = self.conn_mgr.conn('WOSERVER')
|
||||
if conn:
|
||||
conn.on_connect()
|
||||
conn.on_connect(metadata=metadata)
|
||||
return conn.conn
|
||||
|
||||
def query_conn(self, name):
|
||||
conn = self.conn_mgr.conn(name)
|
||||
def query_conn(self, name, metadata=None):
|
||||
conn = self.conn_mgr.conn(name, metadata=metadata)
|
||||
if not conn:
|
||||
raise exceptions.ConnectionNotFoundError(name)
|
||||
conn.on_connect()
|
||||
raise exceptions.ConnectionNotFoundError(name, metadata=metadata)
|
||||
conn.on_connect(metadata=metadata)
|
||||
return conn.conn
|
||||
|
||||
def _format_date(self, start, end):
|
||||
@ -55,14 +54,14 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
|
||||
return self._format_date(start, end)
|
||||
|
||||
def _get_routing_file_ids(self, table_id, range_array):
|
||||
def _get_routing_file_ids(self, table_id, range_array, metadata=None):
|
||||
table = db.Session.query(Tables).filter(and_(
|
||||
Tables.table_id==table_id,
|
||||
Tables.state!=Tables.TO_DELETE
|
||||
)).first()
|
||||
|
||||
if not table:
|
||||
raise exceptions.TableNotFoundError(table_id)
|
||||
raise exceptions.TableNotFoundError(table_id, metadata=metadata)
|
||||
files = table.files_to_search(range_array)
|
||||
|
||||
servers = self.conn_mgr.conn_names
|
||||
@ -84,7 +83,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
|
||||
return routing
|
||||
|
||||
def _do_merge(self, files_n_topk_results, topk, reverse=False):
|
||||
def _do_merge(self, files_n_topk_results, topk, reverse=False, **kwargs):
|
||||
if not files_n_topk_results:
|
||||
return []
|
||||
|
||||
@ -111,9 +110,11 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
|
||||
def _do_query(self, table_id, table_meta, vectors, topk, nprobe, range_array=None, **kwargs):
|
||||
range_array = [self._range_to_date(r) for r in range_array] if range_array else None
|
||||
routing = self._get_routing_file_ids(table_id, range_array)
|
||||
routing = self._get_routing_file_ids(table_id, range_array, metadata=metadata)
|
||||
logger.info('Routing: {}'.format(routing))
|
||||
|
||||
metadata = kwargs.get('metadata', None)
|
||||
|
||||
rs = []
|
||||
all_topk_results = []
|
||||
|
||||
@ -124,7 +125,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
addr, query_params, len(vectors), topk, nprobe
|
||||
))
|
||||
|
||||
conn = self.query_conn(addr)
|
||||
conn = self.query_conn(addr, metadata=metadata)
|
||||
start = time.time()
|
||||
ret = conn.search_vectors_in_files(table_name=query_params['table_id'],
|
||||
file_ids=query_params['file_ids'],
|
||||
@ -146,7 +147,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
res.result()
|
||||
|
||||
reverse = table_meta.metric_type == types.MetricType.IP
|
||||
return self._do_merge(all_topk_results, topk, reverse=reverse)
|
||||
return self._do_merge(all_topk_results, topk, reverse=reverse, metadata=metadata)
|
||||
|
||||
def CreateTable(self, request, context):
|
||||
_status, _table_schema = Parser.parse_proto_TableSchema(request)
|
||||
@ -156,7 +157,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
|
||||
logger.info('CreateTable {}'.format(_table_schema['table_name']))
|
||||
|
||||
_status = self.connection.create_table(_table_schema)
|
||||
_status = self.connection().create_table(_table_schema)
|
||||
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
|
||||
@ -171,7 +172,9 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
|
||||
logger.info('HasTable {}'.format(_table_name))
|
||||
|
||||
_bool = self.connection.has_table(_table_name)
|
||||
_bool = self.connection(metadata={
|
||||
'resp_class': milvus_pb2.BoolReply
|
||||
}).has_table(_table_name)
|
||||
|
||||
return milvus_pb2.BoolReply(
|
||||
status=status_pb2.Status(error_code=status_pb2.SUCCESS, reason="OK"),
|
||||
@ -186,7 +189,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
|
||||
logger.info('DropTable {}'.format(_table_name))
|
||||
|
||||
_status = self.connection.delete_table(_table_name)
|
||||
_status = self.connection().delete_table(_table_name)
|
||||
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
|
||||
@ -201,14 +204,16 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
logger.info('CreateIndex {}'.format(_table_name))
|
||||
|
||||
# TODO: interface create_table incompleted
|
||||
_status = self.connection.create_index(_table_name, _index)
|
||||
_status = self.connection().create_index(_table_name, _index)
|
||||
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
|
||||
def Insert(self, request, context):
|
||||
logger.info('Insert')
|
||||
# TODO: Ths SDK interface add_vectors() could update, add a key 'row_id_array'
|
||||
_status, _ids = self.connection.add_vectors(None, None, insert_param=request)
|
||||
_status, _ids = self.connection(metadata={
|
||||
'resp_class': milvus_pb2.VectorIds
|
||||
}).add_vectors(None, None, insert_param=request)
|
||||
return milvus_pb2.VectorIds(
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
vector_id_array=_ids
|
||||
@ -227,10 +232,14 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
raise exceptions.GRPCInvlidArgument('Invalid nprobe: {}'.format(nprobe))
|
||||
|
||||
table_meta = self.table_meta.get(table_name, None)
|
||||
|
||||
metadata = {
|
||||
'resp_class': milvus_pb2.TopKQueryResultList
|
||||
}
|
||||
if not table_meta:
|
||||
status, info = self.connection.describe_table(table_name)
|
||||
status, info = self.connection(metadata=metadata).describe_table(table_name)
|
||||
if not status.OK():
|
||||
raise exceptions.TableNotFoundError(table_name)
|
||||
raise exceptions.TableNotFoundError(table_name, metadata=metadata)
|
||||
|
||||
self.table_meta[table_name] = info
|
||||
table_meta = info
|
||||
@ -248,7 +257,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
Range(query_range.start_value, query_range.end_value))
|
||||
|
||||
results = self._do_query(table_name, table_meta, query_record_array, topk,
|
||||
nprobe, query_range_array)
|
||||
nprobe, query_range_array, metadata=metadata)
|
||||
|
||||
now = time.time()
|
||||
logger.info('SearchVector takes: {}'.format(now - start))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user