diff --git a/core/.gitignore b/core/.gitignore index 74e41dba6b..8db8df41db 100644 --- a/core/.gitignore +++ b/core/.gitignore @@ -9,3 +9,5 @@ output.info output_new.info server.info *.pyc +src/grpc/python_gen.h +src/grpc/python/ diff --git a/shards/mishards/connections.py b/shards/mishards/connections.py index 618690a099..50e214ec9a 100644 --- a/shards/mishards/connections.py +++ b/shards/mishards/connections.py @@ -2,6 +2,7 @@ import logging import threading from functools import wraps from milvus import Milvus +from milvus.client.hooks import BaseaSearchHook from mishards import (settings, exceptions) from utils import singleton @@ -9,6 +10,12 @@ from utils import singleton logger = logging.getLogger(__name__) +class Searchook(BaseaSearchHook): + + def on_response(self, *args, **kwargs): + return True + + class Connection: def __init__(self, name, uri, max_retry=1, error_handlers=None, **kwargs): self.name = name @@ -18,6 +25,9 @@ class Connection: self.conn = Milvus() self.error_handlers = [] if not error_handlers else error_handlers self.on_retry_func = kwargs.get('on_retry_func', None) + + # define search hook + self.conn._set_hook(search_in_file=Searchook()) # self._connect() def __str__(self): diff --git a/shards/mishards/service_handler.py b/shards/mishards/service_handler.py index 2f19152ae6..fc0ee0fa2b 100644 --- a/shards/mishards/service_handler.py +++ b/shards/mishards/service_handler.py @@ -29,39 +29,71 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): self.router = router self.max_workers = max_workers + def _reduce(self, source_ids, ids, source_diss, diss, k, reverse): + if source_diss[k - 1] <= diss[0]: + return source_ids, source_diss + if diss[k - 1] <= source_diss[0]: + return ids, diss + + source_diss.extend(diss) + diss_t = enumerate(source_diss) + diss_m_rst = sorted(diss_t, key=lambda x: x[1])[:k] + diss_m_out = [id_ for _, id_ in diss_m_rst] + + source_ids.extend(ids) + id_m_out = [source_ids[i] for i, _ in diss_m_rst] + + return id_m_out, diss_m_out + def _do_merge(self, files_n_topk_results, topk, reverse=False, **kwargs): status = status_pb2.Status(error_code=status_pb2.SUCCESS, reason="Success") if not files_n_topk_results: return status, [] - request_results = defaultdict(list) + merge_id_results = [] + merge_dis_results = [] calc_time = time.time() for files_collection in files_n_topk_results: if isinstance(files_collection, tuple): status, _ = files_collection return status, [] - for request_pos, each_request_results in enumerate( - files_collection.topk_query_result): - request_results[request_pos].extend( - each_request_results.query_result_arrays) - request_results[request_pos] = sorted( - request_results[request_pos], - key=lambda x: x.distance, - reverse=reverse)[:topk] + + row_num = files_collection.row_num + ids = files_collection.ids + diss = files_collection.distances # distance collections + # TODO: batch_len is equal to topk, may need to compare with topk + batch_len = len(ids) // row_num + + for row_index in range(row_num): + id_batch = ids[row_index * batch_len: (row_index + 1) * batch_len] + dis_batch = diss[row_index * batch_len: (row_index + 1) * batch_len] + + if len(merge_id_results) < row_index: + raise ValueError("merge error") + elif len(merge_id_results) == row_index: + # TODO: may bug here + merge_id_results.append(id_batch) + merge_dis_results.append(dis_batch) + else: + merge_id_results[row_index], merge_dis_results[row_index] = \ + self._reduce(merge_id_results[row_index], id_batch, + merge_dis_results[row_index], dis_batch, + batch_len, + reverse) calc_time = time.time() - calc_time logger.info('Merge takes {}'.format(calc_time)) - results = sorted(request_results.items()) - topk_query_result = [] + id_mrege_list = [] + dis_mrege_list = [] - for result in results: - query_result = TopKQueryResult(query_result_arrays=result[1]) - topk_query_result.append(query_result) + for id_results, dis_results in zip(merge_id_results, merge_dis_results): + id_mrege_list.extend(id_results) + dis_mrege_list.extend(dis_results) - return status, topk_query_result + return status, id_mrege_list, dis_mrege_list def _do_query(self, context, @@ -109,8 +141,8 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): file_ids=query_params['file_ids'], query_records=vectors, top_k=topk, - nprobe=nprobe, - lazy_=True) + nprobe=nprobe + ) end = time.time() logger.info('search_vectors_in_files takes: {}'.format(end - start)) @@ -241,7 +273,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): logger.info('Search {}: topk={} nprobe={}'.format( table_name, topk, nprobe)) - metadata = {'resp_class': milvus_pb2.TopKQueryResultList} + metadata = {'resp_class': milvus_pb2.TopKQueryResult} if nprobe > self.MAX_NPROBE or nprobe <= 0: raise exceptions.InvalidArgumentError( @@ -275,22 +307,24 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): query_range_array.append( Range(query_range.start_value, query_range.end_value)) - status, results = self._do_query(context, - table_name, - table_meta, - query_record_array, - topk, - nprobe, - query_range_array, - metadata=metadata) + status, id_results, dis_results = self._do_query(context, + table_name, + table_meta, + query_record_array, + topk, + nprobe, + query_range_array, + metadata=metadata) now = time.time() logger.info('SearchVector takes: {}'.format(now - start)) - topk_result_list = milvus_pb2.TopKQueryResultList( + topk_result_list = milvus_pb2.TopKQueryResult( status=status_pb2.Status(error_code=status.error_code, reason=status.reason), - topk_query_result=results) + row_num=len(query_record_array), + ids=id_results, + distances=dis_results) return topk_result_list @mark_grpc_method diff --git a/shards/requirements.txt b/shards/requirements.txt index 14bdde2a06..426ee0b704 100644 --- a/shards/requirements.txt +++ b/shards/requirements.txt @@ -14,8 +14,7 @@ py==1.8.0 pyasn1==0.4.7 pyasn1-modules==0.2.6 pylint==2.3.1 -pymilvus-test==0.2.28 -#pymilvus==0.2.0 +pymilvus==0.2.5 pyparsing==2.4.0 pytest==4.6.3 pytest-level==0.1.1