diff --git a/mishards/exception_codes.py b/mishards/exception_codes.py index c8cfd81dab..32b29bdfab 100644 --- a/mishards/exception_codes.py +++ b/mishards/exception_codes.py @@ -2,3 +2,5 @@ INVALID_CODE = -1 CONNECT_ERROR_CODE = 10001 CONNECTTION_NOT_FOUND_CODE = 10002 + +TABLE_NOT_FOUND_CODE = 20001 diff --git a/mishards/exceptions.py b/mishards/exceptions.py index a25fb2c4ae..1445d18769 100644 --- a/mishards/exceptions.py +++ b/mishards/exceptions.py @@ -11,3 +11,6 @@ class ConnectionConnectError(BaseException): class ConnectionNotFoundError(BaseException): code = codes.CONNECTTION_NOT_FOUND_CODE + +class TableNotFoundError(BaseException): + code = codes.TABLE_NOT_FOUND_CODE diff --git a/mishards/main.py b/mishards/main.py index 0185e6ac1d..2ba3f14697 100644 --- a/mishards/main.py +++ b/mishards/main.py @@ -7,6 +7,7 @@ from mishards import connect_mgr, grpc_server as server def main(): connect_mgr.register('WOSERVER', settings.WOSERVER) + connect_mgr.register('TEST', 'tcp://127.0.0.1:19530') server.run(port=settings.SERVER_PORT) return 0 diff --git a/mishards/service_handler.py b/mishards/service_handler.py index ead8d14d88..89ae2cd36c 100644 --- a/mishards/service_handler.py +++ b/mishards/service_handler.py @@ -1,13 +1,22 @@ import logging +import time +import datetime from contextlib import contextmanager -from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2 +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2 +from milvus.grpc_gen.milvus_pb2 import TopKQueryResult +from milvus.client import types + +import settings from grpc_utils.grpc_args_parser import GrpcArgsParser as Parser logger = logging.getLogger(__name__) class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): + MAX_NPROBE = 2048 def __init__(self, conn_mgr, *args, **kwargs): self.conn_mgr = conn_mgr self.table_meta = {} @@ -19,6 +28,99 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): conn.on_connect() return conn.conn + def query_conn(self, name): + conn = self.conn_mgr.conn(name) + conn and conn.on_connect() + return conn.conn + + def _format_date(self, start, end): + return ((start.year-1900)*10000 + (start.month-1)*100 + start.day + , (end.year-1900)*10000 + (end.month-1)*100 + end.day) + + def _range_to_date(self, range_obj): + try: + start = datetime.datetime.strptime(range_obj.start_date, '%Y-%m-%d') + end = datetime.datetime.strptime(range_obj.end_date, '%Y-%m-%d') + assert start >= end + except (ValueError, AssertionError): + raise exceptions.InvalidRangeError('Invalid time range: {} {}'.format( + range_obj.start_date, range_obj.end_date + )) + + return self._format_date(start, end) + + def _get_routing_file_ids(self, table_id, range_array): + return { + 'TEST': { + 'table_id': table_id, + 'file_ids': [123] + } + } + + def _do_merge(self, files_n_topk_results, topk, reverse=False): + if not files_n_topk_results: + return [] + + request_results = defaultdict(list) + + calc_time = time.time() + for files_collection in files_n_topk_results: + for request_pos, each_request_results in enumerate(files_collection.topk_query_result): + request_results[request_pos].extend(each_request_results.query_result_arrays) + request_results[request_pos] = sorted(request_results[request_pos], key=lambda x: x.distance, + reverse=reverse)[:topk] + + calc_time = time.time() - calc_time + logger.info('Merge takes {}'.format(calc_time)) + + results = sorted(request_results.items()) + topk_query_result = [] + + for result in results: + query_result = TopKQueryResult(query_result_arrays=result[1]) + topk_query_result.append(query_result) + + return topk_query_result + + def _do_query(self, table_id, table_meta, vectors, topk, nprobe, range_array=None, **kwargs): + range_array = [self._range_to_date(r) for r in range_array] if range_array else None + routing = self._get_routing_file_ids(table_id, range_array) + logger.debug(routing) + + rs = [] + all_topk_results = [] + + workers = settings.SEARCH_WORKER_SIZE + + def search(addr, query_params, vectors, topk, nprobe, **kwargs): + logger.info('Send Search Request: addr={};params={};nq={};topk={};nprobe={}'.format( + addr, query_params, len(vectors), topk, nprobe + )) + + conn = self.query_conn(addr) + start = time.time() + ret = conn.search_vectors_in_files(table_name=query_params['table_id'], + file_ids=query_params['file_ids'], + query_records=vectors, + top_k=topk, + nprobe=nprobe, + lazy=True) + end = time.time() + logger.info('search_vectors_in_files takes: {}'.format(end - start)) + + all_topk_results.append(ret) + + with ThreadPoolExecutor(max_workers=workers) as pool: + for addr, params in routing.items(): + res = pool.submit(search, addr, params, vectors, topk, nprobe) + rs.append(res) + + for res in rs: + res.result() + + reverse = table_meta.metric_type == types.MetricType.L2 + return self._do_merge(all_topk_results, topk, reverse=reverse) + def CreateTable(self, request, context): _status, _table_schema = Parser.parse_proto_TableSchema(request) @@ -87,64 +189,64 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): def Search(self, request, context): - try: - table_name = request.table_name + table_name = request.table_name - topk = request.topk - nprobe = request.nprobe + topk = request.topk + nprobe = request.nprobe - logger.info('Search {}: topk={} nprobe={}'.format(table_name, topk, nprobe)) + logger.info('Search {}: topk={} nprobe={}'.format(table_name, topk, nprobe)) - if nprobe > 2048 or nprobe <= 0: - raise exceptions.GRPCInvlidArgument('Invalid nprobe: {}'.format(nprobe)) + if nprobe > self.MAX_NPROBE or nprobe <= 0: + raise exceptions.GRPCInvlidArgument('Invalid nprobe: {}'.format(nprobe)) - table_meta = self.table_meta.get(table_name, None) - if not table_meta: - status, info = self.connection.describe_table(table_name) - if not status.OK(): - raise TableNotFoundException(table_name) + table_meta = self.table_meta.get(table_name, None) + if not table_meta: + status, info = self.connection.describe_table(table_name) + if not status.OK(): + raise exceptions.TableNotFoundError(table_name) - self.table_meta[table_name] = info - table_meta = info + self.table_meta[table_name] = info + table_meta = info - start = time.time() + start = time.time() - query_record_array = [] + query_record_array = [] - for query_record in request.query_record_array: - query_record_array.append(list(query_record.vector_data)) + for query_record in request.query_record_array: + query_record_array.append(list(query_record.vector_data)) - query_range_array = [] - for query_range in request.query_range_array: - query_range_array.append( - Range(query_range.start_value, query_range.end_value)) - except (TableNotFoundException, exceptions.GRPCInvlidArgument) as exc: - return milvus_pb2.TopKQueryResultList( - status=status_pb2.Status(error_code=exc.code, reason=exc.message) - ) - except Exception as e: - return milvus_pb2.TopKQueryResultList( - status=status_pb2.Status(error_code=status_pb2.UNEXPECTED_ERROR, reason=str(e)) - ) + query_range_array = [] + for query_range in request.query_range_array: + query_range_array.append( + Range(query_range.start_value, query_range.end_value)) + # except (TableNotFoundException, exceptions.GRPCInvlidArgument) as exc: + # return milvus_pb2.TopKQueryResultList( + # status=status_pb2.Status(error_code=exc.code, reason=exc.message) + # ) + # except Exception as e: + # return milvus_pb2.TopKQueryResultList( + # status=status_pb2.Status(error_code=status_pb2.UNEXPECTED_ERROR, reason=str(e)) + # ) - try: - results = workflow.query_vectors(table_name, table_meta, query_record_array, topk, - nprobe, query_range_array) - except (exceptions.GRPCQueryInvalidRangeException, TableNotFoundException) as exc: - return milvus_pb2.TopKQueryResultList( - status=status_pb2.Status(error_code=exc.code, reason=exc.message) - ) - except exceptions.ServiceNotFoundException as exc: - return milvus_pb2.TopKQueryResultList( - status=status_pb2.Status(error_code=status_pb2.UNEXPECTED_ERROR, reason=exc.message) - ) - except Exception as e: - logger.error(e) - results = workflow.query_vectors(table_name, table_meta, query_record_array, - topk, nprobe, query_range_array) + results = self._do_query(table_name, table_meta, query_record_array, topk, + nprobe, query_range_array) + # try: + # results = workflow.query_vectors(table_name, table_meta, query_record_array, topk, + # nprobe, query_range_array) + # except (exceptions.GRPCQueryInvalidRangeException, TableNotFoundException) as exc: + # return milvus_pb2.TopKQueryResultList( + # status=status_pb2.Status(error_code=exc.code, reason=exc.message) + # ) + # except exceptions.ServiceNotFoundException as exc: + # return milvus_pb2.TopKQueryResultList( + # status=status_pb2.Status(error_code=status_pb2.UNEXPECTED_ERROR, reason=exc.message) + # ) + # except Exception as e: + # logger.error(e) + # results = workflow.query_vectors(table_name, table_meta, query_record_array, + # topk, nprobe, query_range_array) now = time.time() - logger.info('SearchVector Ends @{}'.format(now)) logger.info('SearchVector takes: {}'.format(now - start)) topk_result_list = milvus_pb2.TopKQueryResultList( @@ -154,41 +256,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): return topk_result_list def SearchInFiles(self, request, context): - try: - file_id_array = list(request.file_id_array) - search_param = request.search_param - table_name = search_param.table_name - topk = search_param.topk - nprobe = search_param.nprobe - - query_record_array = [] - - for query_record in search_param.query_record_array: - query_record_array.append(list(query_record)) - - query_range_array = [] - for query_range in search_param.query_range_array: - query_range_array.append("") - except Exception as e: - milvus_pb2.TopKQueryResultList( - status=status_pb2.Status(error_code=status_pb2.UNEXPECTED_ERROR, reason=str(e)), - ) - - res = search_vector_in_files.delay(table_name=table_name, - file_id_array=file_id_array, - query_record_array=query_record_array, - query_range_array=query_range_array, - topk=topk, - nprobe=nprobe) - status, result = res.get(timeout=1) - - if not status.OK(): - raise ThriftException(code=status.code, reason=status.message) - res = TopKQueryResult() - for top_k_query_results in result: - res.query_result_arrays.append([QueryResult(id=qr.id, distance=qr.distance) - for qr in top_k_query_results]) - return res + raise NotImplemented() def DescribeTable(self, request, context): _status, _table_name = Parser.parse_proto_TableName(request) diff --git a/mishards/settings.py b/mishards/settings.py index 0566cf066f..4d87e69fe3 100644 --- a/mishards/settings.py +++ b/mishards/settings.py @@ -21,6 +21,7 @@ config(LOG_LEVEL, LOG_PATH, LOG_NAME, TIMEZONE) TIMEOUT = env.int('TIMEOUT', 60) MAX_RETRY = env.int('MAX_RETRY', 3) +SEARCH_WORKER_SIZE = env.int('SEARCH_WORKER_SIZE', 10) SERVER_PORT = env.int('SERVER_PORT', 19530) WOSERVER = env.str('WOSERVER')