From 1159880036a01b1b2750889592ad2826b31a470a Mon Sep 17 00:00:00 2001 From: yhz <413554850@qq.com> Date: Fri, 15 Nov 2019 17:00:53 +0800 Subject: [PATCH 1/4] ignore proto python --- core/.gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/.gitignore b/core/.gitignore index 74e41dba6b..8db8df41db 100644 --- a/core/.gitignore +++ b/core/.gitignore @@ -9,3 +9,5 @@ output.info output_new.info server.info *.pyc +src/grpc/python_gen.h +src/grpc/python/ From 15315158e790d2be24ba45696bd27b27705aff7f Mon Sep 17 00:00:00 2001 From: yhz <413554850@qq.com> Date: Sat, 16 Nov 2019 09:34:17 +0800 Subject: [PATCH 2/4] modify sdk --- .../sdk/examples/partition/src/ClientTest.cpp | 65 ++++++++++--------- core/src/sdk/examples/simple/main.cpp | 2 +- .../sdk/examples/simple/src/ClientTest.cpp | 3 +- core/src/sdk/grpc/ClientProxy.cpp | 30 ++++++--- core/src/sdk/grpc/ClientProxy.h | 2 +- 5 files changed, 58 insertions(+), 44 deletions(-) diff --git a/core/src/sdk/examples/partition/src/ClientTest.cpp b/core/src/sdk/examples/partition/src/ClientTest.cpp index 775e1f6d60..b67b1fe907 100644 --- a/core/src/sdk/examples/partition/src/ClientTest.cpp +++ b/core/src/sdk/examples/partition/src/ClientTest.cpp @@ -31,45 +31,45 @@ namespace { -const char* TABLE_NAME = milvus_sdk::Utils::GenTableName().c_str(); + const char *TABLE_NAME = milvus_sdk::Utils::GenTableName().c_str(); -constexpr int64_t TABLE_DIMENSION = 512; -constexpr int64_t TABLE_INDEX_FILE_SIZE = 1024; -constexpr milvus::MetricType TABLE_METRIC_TYPE = milvus::MetricType::L2; -constexpr int64_t BATCH_ROW_COUNT = 10000; -constexpr int64_t NQ = 5; -constexpr int64_t TOP_K = 10; -constexpr int64_t NPROBE = 32; -constexpr int64_t SEARCH_TARGET = 5000; // change this value, result is different -constexpr milvus::IndexType INDEX_TYPE = milvus::IndexType::IVFSQ8; -constexpr int32_t N_LIST = 15000; -constexpr int32_t PARTITION_COUNT = 5; -constexpr int32_t TARGET_PARTITION = 3; + constexpr int64_t TABLE_DIMENSION = 512; + constexpr int64_t TABLE_INDEX_FILE_SIZE = 1024; + constexpr milvus::MetricType TABLE_METRIC_TYPE = milvus::MetricType::L2; + constexpr int64_t BATCH_ROW_COUNT = 10000; + constexpr int64_t NQ = 5; + constexpr int64_t TOP_K = 10; + constexpr int64_t NPROBE = 32; + constexpr int64_t SEARCH_TARGET = 5000; // change this value, result is different + constexpr milvus::IndexType INDEX_TYPE = milvus::IndexType::IVFSQ8; + constexpr int32_t N_LIST = 15000; + constexpr int32_t PARTITION_COUNT = 5; + constexpr int32_t TARGET_PARTITION = 3; -milvus::TableSchema -BuildTableSchema() { - milvus::TableSchema tb_schema = {TABLE_NAME, TABLE_DIMENSION, TABLE_INDEX_FILE_SIZE, TABLE_METRIC_TYPE}; - return tb_schema; -} + milvus::TableSchema + BuildTableSchema() { + milvus::TableSchema tb_schema = {TABLE_NAME, TABLE_DIMENSION, TABLE_INDEX_FILE_SIZE, TABLE_METRIC_TYPE}; + return tb_schema; + } -milvus::PartitionParam -BuildPartitionParam(int32_t index) { - std::string tag = std::to_string(index); - std::string partition_name = std::string(TABLE_NAME) + "_" + tag; - milvus::PartitionParam partition_param = {TABLE_NAME, partition_name, tag}; - return partition_param; -} + milvus::PartitionParam + BuildPartitionParam(int32_t index) { + std::string tag = std::to_string(index); + std::string partition_name = std::string(TABLE_NAME) + "_" + tag; + milvus::PartitionParam partition_param = {TABLE_NAME, partition_name, tag}; + return partition_param; + } -milvus::IndexParam -BuildIndexParam() { - milvus::IndexParam index_param = {TABLE_NAME, INDEX_TYPE, N_LIST}; - return index_param; -} + milvus::IndexParam + BuildIndexParam() { + milvus::IndexParam index_param = {TABLE_NAME, INDEX_TYPE, N_LIST}; + return index_param; + } } // namespace void -ClientTest::Test(const std::string& address, const std::string& port) { +ClientTest::Test(const std::string &address, const std::string &port) { std::shared_ptr conn = milvus::Connection::Create(); milvus::Status stat; @@ -78,7 +78,7 @@ ClientTest::Test(const std::string& address, const std::string& port) { stat = conn->Connect(param); std::cout << "Connect function call status: " << stat.message() << std::endl; } - +#ifdef yhz { // create table milvus::TableSchema tb_schema = BuildTableSchema(); stat = conn->CreateTable(tb_schema); @@ -202,4 +202,5 @@ ClientTest::Test(const std::string& address, const std::string& port) { } milvus::Connection::Destroy(conn); +#endif } diff --git a/core/src/sdk/examples/simple/main.cpp b/core/src/sdk/examples/simple/main.cpp index c08741606c..d9b6194329 100644 --- a/core/src/sdk/examples/simple/main.cpp +++ b/core/src/sdk/examples/simple/main.cpp @@ -36,7 +36,7 @@ main(int argc, char* argv[]) { {nullptr, 0, nullptr, 0}}; int option_index = 0; - std::string address = "127.0.0.1", port = "19530"; + std::string address = "192.168.1.89", port = "19530"; app_name = argv[0]; int value; diff --git a/core/src/sdk/examples/simple/src/ClientTest.cpp b/core/src/sdk/examples/simple/src/ClientTest.cpp index dfa5e2219e..da43f2f2aa 100644 --- a/core/src/sdk/examples/simple/src/ClientTest.cpp +++ b/core/src/sdk/examples/simple/src/ClientTest.cpp @@ -68,7 +68,7 @@ ClientTest::Test(const std::string& address, const std::string& port) { stat = conn->Connect(param); std::cout << "Connect function call status: " << stat.message() << std::endl; } - +#ifdef yhz { // server version std::string version = conn->ServerVersion(); std::cout << "Server version: " << version << std::endl; @@ -206,4 +206,5 @@ ClientTest::Test(const std::string& address, const std::string& port) { std::string status = conn->ServerStatus(); std::cout << "Server status after disconnect: " << status << std::endl; } +#endif } diff --git a/core/src/sdk/grpc/ClientProxy.cpp b/core/src/sdk/grpc/ClientProxy.cpp index 4ec94cfa98..5e22904a08 100644 --- a/core/src/sdk/grpc/ClientProxy.cpp +++ b/core/src/sdk/grpc/ClientProxy.cpp @@ -43,16 +43,28 @@ Status ClientProxy::Connect(const ConnectParam& param) { std::string uri = param.ip_address + ":" + param.port; - channel_ = ::grpc::CreateChannel(uri, ::grpc::InsecureChannelCredentials()); - if (channel_ != nullptr) { - connected_ = true; - client_ptr_ = std::make_shared(channel_); - return Status::OK(); - } +// channel_ = ::grpc::CreateChannel(uri, ::grpc::InsecureChannelCredentials()); - std::string reason = "connect failed!"; - connected_ = false; - return Status(StatusCode::NotConnected, reason); +// channel_ = std::make_shared(grpc_insecure_channel_create(uri.c_str(), nullptr, nullptr)); +// channel_ = std::shared_ptr(grpc_insecure_channel_create(uri.c_str(), nullptr, nullptr)); + auto uri_str = uri.c_str(); + grpc_channel * channel = grpc_insecure_channel_create(uri_str, nullptr, nullptr); +// grpc_insecure_channel_create(uri.c_str(), nullptr, nullptr); + auto state = grpc_channel_check_connectivity_state(channel, true); + if (state == GRPC_CHANNEL_READY) { + std::cout << "Connect " << uri << " successfully"; + } else { + std::cout << "Connect " << uri << " failed."; + } +// if (channel_ != nullptr) { +// connected_ = true; +// client_ptr_ = std::make_shared(channel_); +// return Status::OK(); +// } + +// std::string reason = "connect failed!"; +// connected_ = false; +// return Status(StatusCode::NotConnected, reason); } Status diff --git a/core/src/sdk/grpc/ClientProxy.h b/core/src/sdk/grpc/ClientProxy.h index e332266acf..572b782769 100644 --- a/core/src/sdk/grpc/ClientProxy.h +++ b/core/src/sdk/grpc/ClientProxy.h @@ -105,7 +105,7 @@ class ClientProxy : public Connection { DropPartition(const PartitionParam& partition_param) override; private: - std::shared_ptr<::grpc::Channel> channel_; + std::shared_ptr channel_; private: std::shared_ptr client_ptr_; From 99deaf5c503b20b6af135549bc94a2b28c7d9da1 Mon Sep 17 00:00:00 2001 From: yhz <413554850@qq.com> Date: Tue, 19 Nov 2019 17:37:13 +0800 Subject: [PATCH 3/4] modify shards for v0.5.3 --- shards/mishards/connections.py | 10 +++ shards/mishards/service_handler.py | 107 +++++++++++++++++++++-------- 2 files changed, 89 insertions(+), 28 deletions(-) diff --git a/shards/mishards/connections.py b/shards/mishards/connections.py index 618690a099..50e214ec9a 100644 --- a/shards/mishards/connections.py +++ b/shards/mishards/connections.py @@ -2,6 +2,7 @@ import logging import threading from functools import wraps from milvus import Milvus +from milvus.client.hooks import BaseaSearchHook from mishards import (settings, exceptions) from utils import singleton @@ -9,6 +10,12 @@ from utils import singleton logger = logging.getLogger(__name__) +class Searchook(BaseaSearchHook): + + def on_response(self, *args, **kwargs): + return True + + class Connection: def __init__(self, name, uri, max_retry=1, error_handlers=None, **kwargs): self.name = name @@ -18,6 +25,9 @@ class Connection: self.conn = Milvus() self.error_handlers = [] if not error_handlers else error_handlers self.on_retry_func = kwargs.get('on_retry_func', None) + + # define search hook + self.conn._set_hook(search_in_file=Searchook()) # self._connect() def __str__(self): diff --git a/shards/mishards/service_handler.py b/shards/mishards/service_handler.py index 2f19152ae6..620f6213de 100644 --- a/shards/mishards/service_handler.py +++ b/shards/mishards/service_handler.py @@ -29,39 +29,88 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): self.router = router self.max_workers = max_workers + def _reduce(self, source_ids, ids, source_diss, diss, k, reverse): + if source_diss[k - 1] <= diss[0]: + return source_ids, source_diss + if diss[k - 1] <= source_diss[0]: + return ids, diss + + diss_t = enumerate(source_diss.extend(diss)) + diss_m_rst = sorted(diss_t, key=lambda x: x[1])[:k] + diss_m_out = [id_ for _, id_ in diss_m_rst] + + id_t = source_ids.extend(ids) + id_m_out = [id_t[i] for i, _ in diss_m_rst] + + return id_m_out, diss_m_out + def _do_merge(self, files_n_topk_results, topk, reverse=False, **kwargs): status = status_pb2.Status(error_code=status_pb2.SUCCESS, reason="Success") if not files_n_topk_results: return status, [] - request_results = defaultdict(list) + # request_results = defaultdict(list) + # row_num = files_n_topk_results[0].row_num + merge_id_results = [] + merge_dis_results = [] calc_time = time.time() for files_collection in files_n_topk_results: if isinstance(files_collection, tuple): status, _ = files_collection return status, [] - for request_pos, each_request_results in enumerate( - files_collection.topk_query_result): - request_results[request_pos].extend( - each_request_results.query_result_arrays) - request_results[request_pos] = sorted( - request_results[request_pos], - key=lambda x: x.distance, - reverse=reverse)[:topk] + + row_num = files_collection.row_num + ids = files_collection.ids + diss = files_collection.distances # distance collections + batch_len = len(ids) // row_num + + for row_index in range(row_num): + id_batch = ids[row_index * batch_len: (row_index + 1) * batch_len] + dis_batch = diss[row_index * batch_len: (row_index + 1) * batch_len] + + if len(merge_id_results) < row_index: + raise ValueError("merge error") + elif len(merge_id_results) == row_index: + # TODO: may bug here + merge_id_results.append(id_batch) + merge_dis_results.append(dis_batch) + else: + merge_id_results[row_index].extend(ids[row_index * batch_len, (row_index + 1) * batch_len]) + merge_dis_results[row_index].extend(diss[row_index * batch_len, (row_index + 1) * batch_len]) + # _reduce(_ids, _diss, k, reverse) + merge_id_results[row_index], merge_dis_results[row_index] = \ + self._reduce(merge_id_results[row_index], id_batch, + merge_dis_results[row_index], dis_batch, + batch_len, + reverse) + + # for request_pos, each_request_results in enumerate( + # files_collection.topk_query_result): + # request_results[request_pos].extend( + # each_request_results.query_result_arrays) + # request_results[request_pos] = sorted( + # request_results[request_pos], + # key=lambda x: x.distance, + # reverse=reverse)[:topk] calc_time = time.time() - calc_time logger.info('Merge takes {}'.format(calc_time)) - results = sorted(request_results.items()) - topk_query_result = [] + # results = sorted(request_results.items()) + id_mrege_list = [] + dis_mrege_list = [] - for result in results: - query_result = TopKQueryResult(query_result_arrays=result[1]) - topk_query_result.append(query_result) + for id_results, dis_results in zip(merge_id_results, merge_dis_results): + id_mrege_list.extend(id_results) + dis_mrege_list.extend(dis_results) - return status, topk_query_result + # for result in results: + # query_result = TopKQueryResult(query_result_arrays=result[1]) + # topk_query_result.append(query_result) + + return status, id_mrege_list, dis_mrege_list def _do_query(self, context, @@ -109,8 +158,8 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): file_ids=query_params['file_ids'], query_records=vectors, top_k=topk, - nprobe=nprobe, - lazy_=True) + nprobe=nprobe + ) end = time.time() logger.info('search_vectors_in_files takes: {}'.format(end - start)) @@ -241,7 +290,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): logger.info('Search {}: topk={} nprobe={}'.format( table_name, topk, nprobe)) - metadata = {'resp_class': milvus_pb2.TopKQueryResultList} + metadata = {'resp_class': milvus_pb2.TopKQueryResult} if nprobe > self.MAX_NPROBE or nprobe <= 0: raise exceptions.InvalidArgumentError( @@ -275,22 +324,24 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): query_range_array.append( Range(query_range.start_value, query_range.end_value)) - status, results = self._do_query(context, - table_name, - table_meta, - query_record_array, - topk, - nprobe, - query_range_array, - metadata=metadata) + status, id_results, dis_results = self._do_query(context, + table_name, + table_meta, + query_record_array, + topk, + nprobe, + query_range_array, + metadata=metadata) now = time.time() logger.info('SearchVector takes: {}'.format(now - start)) - topk_result_list = milvus_pb2.TopKQueryResultList( + topk_result_list = milvus_pb2.TopKQueryResult( status=status_pb2.Status(error_code=status.error_code, reason=status.reason), - topk_query_result=results) + row_num=len(query_record_array), + ids=id_results, + distances=dis_results) return topk_result_list @mark_grpc_method From 67605968b8977ddb9e29c7f0a2ea8d49a3e4f703 Mon Sep 17 00:00:00 2001 From: yhz <413554850@qq.com> Date: Tue, 19 Nov 2019 20:36:08 +0800 Subject: [PATCH 4/4] finish results reduce in mishards --- shards/mishards/service_handler.py | 28 ++++++---------------------- 1 file changed, 6 insertions(+), 22 deletions(-) diff --git a/shards/mishards/service_handler.py b/shards/mishards/service_handler.py index 620f6213de..640ae61ba8 100644 --- a/shards/mishards/service_handler.py +++ b/shards/mishards/service_handler.py @@ -34,13 +34,14 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): return source_ids, source_diss if diss[k - 1] <= source_diss[0]: return ids, diss - - diss_t = enumerate(source_diss.extend(diss)) + + source_diss.extend(diss) + diss_t = enumerate(source_diss) diss_m_rst = sorted(diss_t, key=lambda x: x[1])[:k] diss_m_out = [id_ for _, id_ in diss_m_rst] - id_t = source_ids.extend(ids) - id_m_out = [id_t[i] for i, _ in diss_m_rst] + source_ids.extend(ids) + id_m_out = [source_ids[i] for i, _ in diss_m_rst] return id_m_out, diss_m_out @@ -50,8 +51,6 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): if not files_n_topk_results: return status, [] - # request_results = defaultdict(list) - # row_num = files_n_topk_results[0].row_num merge_id_results = [] merge_dis_results = [] @@ -64,6 +63,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): row_num = files_collection.row_num ids = files_collection.ids diss = files_collection.distances # distance collections + # TODO: batch_len is equal to topk batch_len = len(ids) // row_num for row_index in range(row_num): @@ -77,28 +77,16 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): merge_id_results.append(id_batch) merge_dis_results.append(dis_batch) else: - merge_id_results[row_index].extend(ids[row_index * batch_len, (row_index + 1) * batch_len]) - merge_dis_results[row_index].extend(diss[row_index * batch_len, (row_index + 1) * batch_len]) - # _reduce(_ids, _diss, k, reverse) merge_id_results[row_index], merge_dis_results[row_index] = \ self._reduce(merge_id_results[row_index], id_batch, merge_dis_results[row_index], dis_batch, batch_len, reverse) - # for request_pos, each_request_results in enumerate( - # files_collection.topk_query_result): - # request_results[request_pos].extend( - # each_request_results.query_result_arrays) - # request_results[request_pos] = sorted( - # request_results[request_pos], - # key=lambda x: x.distance, - # reverse=reverse)[:topk] calc_time = time.time() - calc_time logger.info('Merge takes {}'.format(calc_time)) - # results = sorted(request_results.items()) id_mrege_list = [] dis_mrege_list = [] @@ -106,10 +94,6 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): id_mrege_list.extend(id_results) dis_mrege_list.extend(dis_results) - # for result in results: - # query_result = TopKQueryResult(query_result_arrays=result[1]) - # topk_query_result.append(query_result) - return status, id_mrege_list, dis_mrege_list def _do_query(self,