From 7c9b627f2baea3b403263da30573a7e54dd47673 Mon Sep 17 00:00:00 2001 From: groot Date: Wed, 20 Nov 2019 17:54:45 +0800 Subject: [PATCH] #433 C++ SDK query result is not easy to use --- CHANGELOG.md | 1 + .../sdk/examples/partition/src/ClientTest.cpp | 1 + .../sdk/examples/simple/src/ClientTest.cpp | 1 + core/src/sdk/examples/utils/Utils.cpp | 35 +++++++++---------- core/src/sdk/grpc/ClientProxy.cpp | 17 +++++---- core/src/sdk/include/MilvusApi.h | 4 +-- 6 files changed, 33 insertions(+), 26 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 62c82df007..c9a93774cc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,7 @@ Please mark all change in change log and use the ticket from JIRA. - \#358 - Add more information in build.sh and install.md - \#404 - Add virtual method Init() in Pass abstract class - \#409 - Add a Fallback pass in optimizer +- \#433 - C++ SDK query result is not easy to use ## Task diff --git a/core/src/sdk/examples/partition/src/ClientTest.cpp b/core/src/sdk/examples/partition/src/ClientTest.cpp index 775e1f6d60..a12a7ff50e 100644 --- a/core/src/sdk/examples/partition/src/ClientTest.cpp +++ b/core/src/sdk/examples/partition/src/ClientTest.cpp @@ -148,6 +148,7 @@ ClientTest::Test(const std::string& address, const std::string& port) { } { // wait unit build index finish + milvus_sdk::TimeRecorder rc("Create index"); std::cout << "Wait until create all index done" << std::endl; milvus::IndexParam index1 = BuildIndexParam(); milvus_sdk::Utils::PrintIndexParam(index1); diff --git a/core/src/sdk/examples/simple/src/ClientTest.cpp b/core/src/sdk/examples/simple/src/ClientTest.cpp index dfa5e2219e..016c9eceac 100644 --- a/core/src/sdk/examples/simple/src/ClientTest.cpp +++ b/core/src/sdk/examples/simple/src/ClientTest.cpp @@ -150,6 +150,7 @@ ClientTest::Test(const std::string& address, const std::string& port) { } { // wait unit build index finish + milvus_sdk::TimeRecorder rc("Create index"); std::cout << "Wait until create all index done" << std::endl; milvus::IndexParam index1 = BuildIndexParam(); milvus_sdk::Utils::PrintIndexParam(index1); diff --git a/core/src/sdk/examples/utils/Utils.cpp b/core/src/sdk/examples/utils/Utils.cpp index da5e854e9b..fa373cd498 100644 --- a/core/src/sdk/examples/utils/Utils.cpp +++ b/core/src/sdk/examples/utils/Utils.cpp @@ -157,18 +157,20 @@ void Utils::PrintSearchResult(const std::vector>& search_record_array, const milvus::TopKQueryResult& topk_query_result) { BLOCK_SPLITER - size_t nq = topk_query_result.row_num; - size_t topk = topk_query_result.ids.size() / nq; - std::cout << "Returned result count: " << nq * topk << std::endl; + std::cout << "Returned result count: " << topk_query_result.size() << std::endl; - int32_t index = 0; - for (size_t i = 0; i < nq; i++) { - auto search_id = search_record_array[index].first; - index++; - std::cout << "No." << index << " vector " << search_id << " top " << topk << " search result:" << std::endl; + if (topk_query_result.size() != search_record_array.size()) { + std::cout << "ERROR: Returned result count dones equal nq" << std::endl; + return; + } + + for (size_t i = 0; i < topk_query_result.size(); i++) { + const milvus::QueryResult& one_result = topk_query_result[i]; + size_t topk = one_result.ids.size(); + auto search_id = search_record_array[i].first; + std::cout << "No." << i << " vector " << search_id << " top " << topk << " search result:" << std::endl; for (size_t j = 0; j < topk; j++) { - size_t idx = i * topk + j; - std::cout << "\t" << topk_query_result.ids[idx] << "\t" << topk_query_result.distances[idx] << std::endl; + std::cout << "\t" << one_result.ids[j] << "\t" << one_result.distances[j] << std::endl; } } BLOCK_SPLITER @@ -178,12 +180,11 @@ void Utils::CheckSearchResult(const std::vector>& search_record_array, const milvus::TopKQueryResult& topk_query_result) { BLOCK_SPLITER - size_t nq = topk_query_result.row_num; - size_t result_k = topk_query_result.ids.size() / nq; - int64_t index = 0; + size_t nq = topk_query_result.size(); for (size_t i = 0; i < nq; i++) { - auto result_id = topk_query_result.ids[i * result_k]; - auto search_id = search_record_array[index++].first; + const milvus::QueryResult& one_result = topk_query_result[i]; + auto search_id = search_record_array[i].first; + int64_t result_id = one_result.ids[0]; if (result_id != search_id) { std::cout << "The top 1 result is wrong: " << result_id << " vs. " << search_id << std::endl; } else { @@ -198,9 +199,7 @@ Utils::DoSearch(std::shared_ptr conn, const std::string& tab const std::vector& partiton_tags, int64_t top_k, int64_t nprobe, const std::vector>& search_record_array, milvus::TopKQueryResult& topk_query_result) { - topk_query_result.distances.clear(); - topk_query_result.ids.clear(); - topk_query_result.row_num = 0; + topk_query_result.clear(); std::vector query_range_array; milvus::Range rg; diff --git a/core/src/sdk/grpc/ClientProxy.cpp b/core/src/sdk/grpc/ClientProxy.cpp index 4a9c319b4d..fd19281343 100644 --- a/core/src/sdk/grpc/ClientProxy.cpp +++ b/core/src/sdk/grpc/ClientProxy.cpp @@ -250,12 +250,17 @@ ClientProxy::Search(const std::string& table_name, const std::vectorSearch(result, search_param); // step 4: convert result array - topk_query_result.row_num = result.row_num(); - topk_query_result.ids.resize(result.ids().size()); - memcpy(topk_query_result.ids.data(), result.ids().data(), result.ids().size() * sizeof(int64_t)); - topk_query_result.distances.resize(result.distances().size()); - memcpy(topk_query_result.distances.data(), result.distances().data(), - result.distances().size() * sizeof(float)); + topk_query_result.reserve(result.row_num()); + int64_t nq = result.row_num(); + int64_t topk = result.ids().size() / nq; + for (int64_t i = 0; i < result.row_num(); i++) { + milvus::QueryResult one_result; + one_result.ids.resize(topk); + one_result.distances.resize(topk); + memcpy(one_result.ids.data(), result.ids().data() + topk * i, topk * sizeof(int64_t)); + memcpy(one_result.distances.data(), result.distances().data() + topk * i, topk * sizeof(float)); + topk_query_result.emplace_back(one_result); + } return status; } catch (std::exception& ex) { diff --git a/core/src/sdk/include/MilvusApi.h b/core/src/sdk/include/MilvusApi.h index 0ec37fa9a4..9fa98deb40 100644 --- a/core/src/sdk/include/MilvusApi.h +++ b/core/src/sdk/include/MilvusApi.h @@ -81,11 +81,11 @@ struct RowRecord { /** * @brief TopK query result */ -struct TopKQueryResult { - int64_t row_num; +struct QueryResult { std::vector ids; std::vector distances; }; +using TopKQueryResult = std::vector; /** * @brief index parameters