diff --git a/mishards/exception_codes.py b/mishards/exception_codes.py index 37492f25d4..ecb2469562 100644 --- a/mishards/exception_codes.py +++ b/mishards/exception_codes.py @@ -2,6 +2,8 @@ INVALID_CODE = -1 CONNECT_ERROR_CODE = 10001 CONNECTTION_NOT_FOUND_CODE = 10002 +DB_ERROR_CODE = 10003 TABLE_NOT_FOUND_CODE = 20001 -INVALID_ARGUMENT = 20002 +INVALID_ARGUMENT_CODE = 20002 +INVALID_DATE_RANGE_CODE = 20003 diff --git a/mishards/exception_handlers.py b/mishards/exception_handlers.py index 6207f2088c..2518b64b3e 100644 --- a/mishards/exception_handlers.py +++ b/mishards/exception_handlers.py @@ -29,6 +29,9 @@ def resp_handler(err, error_code): if resp_class == milvus_pb2.TableRowCount: return resp_class(status=status, table_row_count=-1) + if resp_class == milvus_pb2.TableName: + return resp_class(status=status, table_name=[]) + status.error_code = status_pb2.UNEXPECTED_ERROR return status @@ -41,3 +44,13 @@ def TableNotFoundErrorHandler(err): def InvalidArgumentErrorHandler(err): logger.error(err) return resp_handler(err, status_pb2.ILLEGAL_ARGUMENT) + +@server.errorhandler(exceptions.DBError) +def DBErrorHandler(err): + logger.error(err) + return resp_handler(err, status_pb2.UNEXPECTED_ERROR) + +@server.errorhandler(exceptions.InvalidRangeError) +def InvalidArgumentErrorHandler(err): + logger.error(err) + return resp_handler(err, status_pb2.ILLEGAL_RANGE) diff --git a/mishards/exceptions.py b/mishards/exceptions.py index 4686cf674f..2aa2b39eb9 100644 --- a/mishards/exceptions.py +++ b/mishards/exceptions.py @@ -13,8 +13,14 @@ class ConnectionConnectError(BaseException): class ConnectionNotFoundError(BaseException): code = codes.CONNECTTION_NOT_FOUND_CODE +class DBError(BaseException): + code = codes.DB_ERROR_CODE + class TableNotFoundError(BaseException): code = codes.TABLE_NOT_FOUND_CODE class InvalidArgumentError(BaseException): - code = codes.INVALID_ARGUMENT + code = codes.INVALID_ARGUMENT_CODE + +class InvalidRangeError(BaseException): + code = codes.INVALID_DATE_RANGE_CODE diff --git a/mishards/service_handler.py b/mishards/service_handler.py index 128667d9b6..536a17c4e3 100644 --- a/mishards/service_handler.py +++ b/mishards/service_handler.py @@ -5,10 +5,12 @@ from contextlib import contextmanager from collections import defaultdict from sqlalchemy import and_ +from sqlalchemy import exc as sqlalchemy_exc from concurrent.futures import ThreadPoolExecutor from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2 from milvus.grpc_gen.milvus_pb2 import TopKQueryResult +from milvus.client.Abstract import Range from milvus.client import types from mishards import (db, settings, exceptions) @@ -44,7 +46,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): return ((start.year-1900)*10000 + (start.month-1)*100 + start.day , (end.year-1900)*10000 + (end.month-1)*100 + end.day) - def _range_to_date(self, range_obj): + def _range_to_date(self, range_obj, metadata=None): try: start = datetime.datetime.strptime(range_obj.start_date, '%Y-%m-%d') end = datetime.datetime.strptime(range_obj.end_date, '%Y-%m-%d') @@ -52,15 +54,19 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): except (ValueError, AssertionError): raise exceptions.InvalidRangeError('Invalid time range: {} {}'.format( range_obj.start_date, range_obj.end_date - )) + ), metadata=metadata) return self._format_date(start, end) def _get_routing_file_ids(self, table_id, range_array, metadata=None): - table = db.Session.query(Tables).filter(and_( - Tables.table_id==table_id, - Tables.state!=Tables.TO_DELETE - )).first() + # PXU TODO: Implement Thread-local Context + try: + table = db.Session.query(Tables).filter(and_( + Tables.table_id==table_id, + Tables.state!=Tables.TO_DELETE + )).first() + except sqlalchemy_exc.SQLAlchemyError as e: + raise exceptions.DBError(message=str(e), metadata=metadata) if not table: raise exceptions.TableNotFoundError(table_id, metadata=metadata) @@ -111,8 +117,8 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): return topk_query_result def _do_query(self, table_id, table_meta, vectors, topk, nprobe, range_array=None, **kwargs): - range_array = [self._range_to_date(r) for r in range_array] if range_array else None metadata = kwargs.get('metadata', None) + range_array = [self._range_to_date(r, metadata=metadata) for r in range_array] if range_array else None routing = self._get_routing_file_ids(table_id, range_array, metadata=metadata) logger.info('Routing: {}'.format(routing)) @@ -362,7 +368,10 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): @mark_grpc_method def ShowTables(self, request, context): logger.info('ShowTables') - _status, _results = self.connection.show_tables() + metadata = { + 'resp_class': milvus_pb2.TableName + } + _status, _results = self.connection(metadata=metadata).show_tables() if not _status.OK(): _results = []