diff --git a/mishards/exception_codes.py b/mishards/exception_codes.py index ecb2469562..bdd4572dd5 100644 --- a/mishards/exception_codes.py +++ b/mishards/exception_codes.py @@ -7,3 +7,4 @@ DB_ERROR_CODE = 10003 TABLE_NOT_FOUND_CODE = 20001 INVALID_ARGUMENT_CODE = 20002 INVALID_DATE_RANGE_CODE = 20003 +INVALID_TOPK_CODE = 20004 diff --git a/mishards/exception_handlers.py b/mishards/exception_handlers.py index 1e5ffb3529..c79a6db5a3 100644 --- a/mishards/exception_handlers.py +++ b/mishards/exception_handlers.py @@ -58,6 +58,12 @@ def TableNotFoundErrorHandler(err): return resp_handler(err, status_pb2.TABLE_NOT_EXISTS) +@server.errorhandler(exceptions.InvalidTopKError) +def InvalidTopKErrorHandler(err): + logger.error(err) + return resp_handler(err, status_pb2.ILLEGAL_TOPK) + + @server.errorhandler(exceptions.InvalidArgumentError) def InvalidArgumentErrorHandler(err): logger.error(err) diff --git a/mishards/exceptions.py b/mishards/exceptions.py index acd9372d6a..72839f88d2 100644 --- a/mishards/exceptions.py +++ b/mishards/exceptions.py @@ -26,6 +26,10 @@ class TableNotFoundError(BaseException): code = codes.TABLE_NOT_FOUND_CODE +class InvalidTopKError(BaseException): + code = codes.INVALID_TOPK_CODE + + class InvalidArgumentError(BaseException): code = codes.INVALID_ARGUMENT_CODE diff --git a/mishards/service_handler.py b/mishards/service_handler.py index 44e1d8cf7b..5e91c14f14 100644 --- a/mishards/service_handler.py +++ b/mishards/service_handler.py @@ -20,6 +20,7 @@ logger = logging.getLogger(__name__) class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): MAX_NPROBE = 2048 + MAX_TOPK = 2048 def __init__(self, tracer, router, max_workers=multiprocessing.cpu_count(), **kwargs): self.table_meta = {} @@ -246,6 +247,10 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): raise exceptions.InvalidArgumentError( message='Invalid nprobe: {}'.format(nprobe), metadata=metadata) + if topk > self.MAX_TOPK or topk <= 0: + raise exceptions.InvalidTopKError( + message='Invalid topk: {}'.format(topk), metadata=metadata) + table_meta = self.table_meta.get(table_name, None) if not table_meta: diff --git a/mishards/utilities.py b/mishards/utilities.py index c08d0d42df..42e982b5f1 100644 --- a/mishards/utilities.py +++ b/mishards/utilities.py @@ -2,12 +2,12 @@ import datetime from mishards import exceptions -def format_date(self, start, end): +def format_date(start, end): return ((start.year - 1900) * 10000 + (start.month - 1) * 100 + start.day, (end.year - 1900) * 10000 + (end.month - 1) * 100 + end.day) -def range_to_date(self, range_obj, metadata=None): +def range_to_date(range_obj, metadata=None): try: start = datetime.datetime.strptime(range_obj.start_date, '%Y-%m-%d') end = datetime.datetime.strptime(range_obj.end_date, '%Y-%m-%d') @@ -17,4 +17,4 @@ def range_to_date(self, range_obj, metadata=None): range_obj.start_date, range_obj.end_date), metadata=metadata) - return self.format_date(start, end) + return format_date(start, end)