diff --git a/mishards/connections.py b/mishards/connections.py index c6323f66f8..365dc60125 100644 --- a/mishards/connections.py +++ b/mishards/connections.py @@ -24,14 +24,14 @@ class Connection: def __str__(self): return 'Connection:name=\"{}\";uri=\"{}\"'.format(self.name, self.uri) - def _connect(self): + def _connect(self, metadata=None): try: self.conn.connect(uri=self.uri) except Exception as e: if not self.error_handlers: - raise exceptions.ConnectionConnectError(e) + raise exceptions.ConnectionConnectError(message=str(e), metadata=metadata) for handler in self.error_handlers: - handler(e) + handler(e, metadata=metadata) @property def can_retry(self): @@ -47,14 +47,15 @@ class Connection: else: logger.warn('{} is retrying {}'.format(self, self.retried)) - def on_connect(self): + def on_connect(self, metadata=None): while not self.connected and self.can_retry: self.retried += 1 self.on_retry() - self._connect() + self._connect(metadata=metadata) if not self.can_retry and not self.connected: - raise exceptions.ConnectionConnectError(message='Max retry {} reached!'.format(self.max_retry)) + raise exceptions.ConnectionConnectError(message='Max retry {} reached!'.format(self.max_retry, + metadata=metadata)) self.retried = 0 @@ -81,14 +82,15 @@ class ConnectionMgr: def conn_names(self): return set(self.metas.keys()) - set(['WOSERVER']) - def conn(self, name, throw=False): + def conn(self, name, metadata, throw=False): c = self.conns.get(name, None) if not c: url = self.metas.get(name, None) if not url: if not throw: return None - raise exceptions.ConnectionNotFoundError('Connection {} not found'.format(name)) + raise exceptions.ConnectionNotFoundError(message='Connection {} not found'.format(name), + metadata=metadata) this_conn = Connection(name=name, uri=url, max_retry=settings.MAX_RETRY) threaded = { threading.get_ident() : this_conn @@ -103,7 +105,8 @@ class ConnectionMgr: if not url: if not throw: return None - raise exceptions.ConnectionNotFoundError('Connection {} not found'.format(name)) + raise exceptions.ConnectionNotFoundError('Connection {} not found'.format(name), + metadata=metadata) this_conn = Connection(name=name, uri=url, max_retry=settings.MAX_RETRY) c[tid] = this_conn return this_conn diff --git a/mishards/exception_handlers.py b/mishards/exception_handlers.py new file mode 100644 index 0000000000..3de0918be4 --- /dev/null +++ b/mishards/exception_handlers.py @@ -0,0 +1,35 @@ +import logging +from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2 +from mishards import server, exceptions + +logger = logging.getLogger(__name__) + +def resp_handler(err, error_code): + if not isinstance(err, exceptions.BaseException): + return status_pb2.Status(error_code=error_code, reason=str(err)) + + status = status_pb2.Status(error_code=error_code, reason=err.message) + + if err.metadata is None: + return status + + resp_class = err.metadata.get('resp_class', None) + if not resp_class: + return status + + if resp_class == milvus_pb2.BoolReply: + return resp_class(status=status, bool_reply=False) + + if resp_class == milvus_pb2.VectorIds: + return resp_class(status=status, vector_id_array=[]) + + if resp_class == milvus_pb2.TopKQueryResultList: + return resp_class(status=status, topk_query_result=[]) + + status.error_code = status_pb2.UNEXPECTED_ERROR + return status + +@server.error_handler(exceptions.TableNotFoundError) +def TableNotFoundErrorHandler(err): + logger.error(err) + return resp_handler(err, status_pb2.TABLE_NOT_EXISTS) diff --git a/mishards/exceptions.py b/mishards/exceptions.py index 0f89ecb52d..1579fefcf4 100644 --- a/mishards/exceptions.py +++ b/mishards/exceptions.py @@ -3,8 +3,9 @@ import mishards.exception_codes as codes class BaseException(Exception): code = codes.INVALID_CODE message = 'BaseException' - def __init__(self, message=''): + def __init__(self, message='', metadata=None): self.message = self.__class__.__name__ if not message else message + self.metadata = metadata class ConnectionConnectError(BaseException): code = codes.CONNECT_ERROR_CODE diff --git a/mishards/server.py b/mishards/server.py index 9966360d47..b000016e29 100644 --- a/mishards/server.py +++ b/mishards/server.py @@ -2,6 +2,7 @@ import logging import grpc import time import socket +import inspect from urllib.parse import urlparse from functools import wraps from concurrent import futures @@ -16,6 +17,7 @@ logger = logging.getLogger(__name__) class Server: def __init__(self, conn_mgr, port=19530, max_workers=10, **kwargs): self.pre_run_handlers = set() + self.error_handler = {} self.exit_flag = False self.port = int(port) self.conn_mgr = conn_mgr @@ -40,6 +42,14 @@ class Server: self.pre_run_handlers.add(func) return func + def errorhandler(self, exception): + if inspect.isclass(exception) and issubclass(exception, Exception): + def wrapper(func): + self.error_handlers[exception] = func + return func + return wrapper + return exception + def on_pre_run(self): for handler in self.pre_run_handlers: handler() diff --git a/mishards/service_handler.py b/mishards/service_handler.py index f88655d2d6..5346be91d8 100644 --- a/mishards/service_handler.py +++ b/mishards/service_handler.py @@ -25,18 +25,17 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): self.conn_mgr = conn_mgr self.table_meta = {} - @property - def connection(self): + def connection(self, metadata=None): conn = self.conn_mgr.conn('WOSERVER') if conn: - conn.on_connect() + conn.on_connect(metadata=metadata) return conn.conn - def query_conn(self, name): - conn = self.conn_mgr.conn(name) + def query_conn(self, name, metadata=None): + conn = self.conn_mgr.conn(name, metadata=metadata) if not conn: - raise exceptions.ConnectionNotFoundError(name) - conn.on_connect() + raise exceptions.ConnectionNotFoundError(name, metadata=metadata) + conn.on_connect(metadata=metadata) return conn.conn def _format_date(self, start, end): @@ -55,14 +54,14 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): return self._format_date(start, end) - def _get_routing_file_ids(self, table_id, range_array): + def _get_routing_file_ids(self, table_id, range_array, metadata=None): table = db.Session.query(Tables).filter(and_( Tables.table_id==table_id, Tables.state!=Tables.TO_DELETE )).first() if not table: - raise exceptions.TableNotFoundError(table_id) + raise exceptions.TableNotFoundError(table_id, metadata=metadata) files = table.files_to_search(range_array) servers = self.conn_mgr.conn_names @@ -84,7 +83,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): return routing - def _do_merge(self, files_n_topk_results, topk, reverse=False): + def _do_merge(self, files_n_topk_results, topk, reverse=False, **kwargs): if not files_n_topk_results: return [] @@ -111,9 +110,11 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): def _do_query(self, table_id, table_meta, vectors, topk, nprobe, range_array=None, **kwargs): range_array = [self._range_to_date(r) for r in range_array] if range_array else None - routing = self._get_routing_file_ids(table_id, range_array) + routing = self._get_routing_file_ids(table_id, range_array, metadata=metadata) logger.info('Routing: {}'.format(routing)) + metadata = kwargs.get('metadata', None) + rs = [] all_topk_results = [] @@ -124,7 +125,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): addr, query_params, len(vectors), topk, nprobe )) - conn = self.query_conn(addr) + conn = self.query_conn(addr, metadata=metadata) start = time.time() ret = conn.search_vectors_in_files(table_name=query_params['table_id'], file_ids=query_params['file_ids'], @@ -146,7 +147,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): res.result() reverse = table_meta.metric_type == types.MetricType.IP - return self._do_merge(all_topk_results, topk, reverse=reverse) + return self._do_merge(all_topk_results, topk, reverse=reverse, metadata=metadata) def CreateTable(self, request, context): _status, _table_schema = Parser.parse_proto_TableSchema(request) @@ -156,7 +157,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): logger.info('CreateTable {}'.format(_table_schema['table_name'])) - _status = self.connection.create_table(_table_schema) + _status = self.connection().create_table(_table_schema) return status_pb2.Status(error_code=_status.code, reason=_status.message) @@ -171,7 +172,9 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): logger.info('HasTable {}'.format(_table_name)) - _bool = self.connection.has_table(_table_name) + _bool = self.connection(metadata={ + 'resp_class': milvus_pb2.BoolReply + }).has_table(_table_name) return milvus_pb2.BoolReply( status=status_pb2.Status(error_code=status_pb2.SUCCESS, reason="OK"), @@ -186,7 +189,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): logger.info('DropTable {}'.format(_table_name)) - _status = self.connection.delete_table(_table_name) + _status = self.connection().delete_table(_table_name) return status_pb2.Status(error_code=_status.code, reason=_status.message) @@ -201,14 +204,16 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): logger.info('CreateIndex {}'.format(_table_name)) # TODO: interface create_table incompleted - _status = self.connection.create_index(_table_name, _index) + _status = self.connection().create_index(_table_name, _index) return status_pb2.Status(error_code=_status.code, reason=_status.message) def Insert(self, request, context): logger.info('Insert') # TODO: Ths SDK interface add_vectors() could update, add a key 'row_id_array' - _status, _ids = self.connection.add_vectors(None, None, insert_param=request) + _status, _ids = self.connection(metadata={ + 'resp_class': milvus_pb2.VectorIds + }).add_vectors(None, None, insert_param=request) return milvus_pb2.VectorIds( status=status_pb2.Status(error_code=_status.code, reason=_status.message), vector_id_array=_ids @@ -227,10 +232,14 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): raise exceptions.GRPCInvlidArgument('Invalid nprobe: {}'.format(nprobe)) table_meta = self.table_meta.get(table_name, None) + + metadata = { + 'resp_class': milvus_pb2.TopKQueryResultList + } if not table_meta: - status, info = self.connection.describe_table(table_name) + status, info = self.connection(metadata=metadata).describe_table(table_name) if not status.OK(): - raise exceptions.TableNotFoundError(table_name) + raise exceptions.TableNotFoundError(table_name, metadata=metadata) self.table_meta[table_name] = info table_meta = info @@ -248,7 +257,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): Range(query_range.start_value, query_range.end_value)) results = self._do_query(table_name, table_meta, query_record_array, topk, - nprobe, query_range_array) + nprobe, query_range_array, metadata=metadata) now = time.time() logger.info('SearchVector takes: {}'.format(now - start))