diff --git a/cpp/src/scheduler/task/SearchTask.cpp b/cpp/src/scheduler/task/SearchTask.cpp index b7e27e4944..477a9959c3 100644 --- a/cpp/src/scheduler/task/SearchTask.cpp +++ b/cpp/src/scheduler/task/SearchTask.cpp @@ -34,8 +34,6 @@ namespace scheduler { static constexpr size_t PARALLEL_REDUCE_THRESHOLD = 10000; static constexpr size_t PARALLEL_REDUCE_BATCH = 1000; -std::mutex XSearchTask::merge_mutex_; - // TODO(wxyu): remove unused code // bool // NeedParallelReduce(uint64_t nq, uint64_t topk) { @@ -162,8 +160,8 @@ XSearchTask::Load(LoadType type, uint8_t device_id) { size_t file_size = index_engine_->PhysicalSize(); - std::string info = "Load file id:" + std::to_string(file_->id_) + - " file type:" + std::to_string(file_->file_type_) + " size:" + std::to_string(file_size) + + std::string info = "Load file id:" + std::to_string(file_->id_) + " file type:" + + std::to_string(file_->file_type_) + " size:" + std::to_string(file_size) + " bytes from location: " + file_->location_ + " totally cost"; double span = rc.ElapseFromBegin(info); // for (auto &context : search_contexts_) { @@ -221,7 +219,8 @@ XSearchTask::Execute() { // step 3: pick up topk result auto spec_k = index_engine_->Count() < topk ? index_engine_->Count() : topk; - XSearchTask::TopkResult(output_ids, output_distance, spec_k, nq, topk, metric_l2, search_job->GetResult()); + XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, metric_l2, + search_job->GetResult()); span = rc.RecordSection(hdr + ", reduce topk"); // search_job->AccumReduceCost(span); @@ -230,7 +229,7 @@ XSearchTask::Execute() { // search_job->IndexSearchDone(index_id_);//mark as done avoid dead lock, even search failed } - // step 5: notify to send result to client + // step 4: notify to send result to client search_job->SearchDone(index_id_); } @@ -240,36 +239,37 @@ XSearchTask::Execute() { index_engine_ = nullptr; } -Status -XSearchTask::TopkResult(const std::vector& input_ids, const std::vector& input_distance, - uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result) { - scheduler::ResultSet result_buf; - +void +XSearchTask::MergeTopkToResultSet(const std::vector& input_ids, const std::vector& input_distance, + uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, + scheduler::ResultSet& result) { if (result.empty()) { - result_buf.resize(nq, scheduler::Id2DistVec(input_k, scheduler::IdDistPair(-1, 0.0))); - for (auto i = 0; i < nq; ++i) { - auto& result_buf_i = result_buf[i]; + result.resize(nq); + } + + for (uint64_t i = 0; i < nq; i++) { + scheduler::Id2DistVec result_buf; + auto& result_i = result[i]; + + if (result[i].empty()) { + result_buf.resize(input_k, scheduler::IdDistPair(-1, 0.0)); uint64_t input_k_multi_i = input_k * i; for (auto k = 0; k < input_k; ++k) { uint64_t idx = input_k_multi_i + k; - auto& result_buf_item = result_buf_i[k]; + auto& result_buf_item = result_buf[k]; result_buf_item.first = input_ids[idx]; result_buf_item.second = input_distance[idx]; } - } - } else { - size_t tar_size = result[0].size(); - uint64_t output_k = std::min(topk, input_k + tar_size); - result_buf.resize(nq, scheduler::Id2DistVec(output_k, scheduler::IdDistPair(-1, 0.0))); - for (auto i = 0; i < nq; ++i) { + } else { + size_t tar_size = result_i.size(); + uint64_t output_k = std::min(topk, input_k + tar_size); + result_buf.resize(output_k, scheduler::IdDistPair(-1, 0.0)); size_t buf_k = 0, src_k = 0, tar_k = 0; uint64_t src_idx; - auto& result_i = result[i]; - auto& result_buf_i = result_buf[i]; uint64_t input_k_multi_i = input_k * i; while (buf_k < output_k && src_k < input_k && tar_k < tar_size) { src_idx = input_k_multi_i + src_k; - auto& result_buf_item = result_buf_i[buf_k]; + auto& result_buf_item = result_buf[buf_k]; auto& result_item = result_i[tar_k]; if ((ascending && input_distance[src_idx] < result_item.second) || (!ascending && input_distance[src_idx] > result_item.second)) { @@ -283,11 +283,11 @@ XSearchTask::TopkResult(const std::vector& input_ids, const std::vector buf_k++; } - if (buf_k < topk) { + if (buf_k < output_k) { if (src_k < input_k) { while (buf_k < output_k && src_k < input_k) { src_idx = input_k_multi_i + src_k; - auto& result_buf_item = result_buf_i[buf_k]; + auto& result_buf_item = result_buf[buf_k]; result_buf_item.first = input_ids[src_idx]; result_buf_item.second = input_distance[src_idx]; src_k++; @@ -295,18 +295,79 @@ XSearchTask::TopkResult(const std::vector& input_ids, const std::vector } } else { while (buf_k < output_k && tar_k < tar_size) { - result_buf_i[buf_k] = result_i[tar_k]; + result_buf[buf_k] = result_i[tar_k]; tar_k++; buf_k++; } } } } + + result_i.swap(result_buf); + } +} + +void +XSearchTask::MergeTopkArray(std::vector& tar_ids, std::vector& tar_distance, uint64_t& tar_input_k, + const std::vector& src_ids, const std::vector& src_distance, + uint64_t src_input_k, uint64_t nq, uint64_t topk, bool ascending) { + if (src_ids.empty() || src_distance.empty()) { + return; } - result.swap(result_buf); + std::vector id_buf(nq * topk, -1); + std::vector dist_buf(nq * topk, 0.0); - return Status::OK(); + uint64_t output_k = std::min(topk, tar_input_k + src_input_k); + uint64_t buf_k, src_k, tar_k; + uint64_t src_idx, tar_idx, buf_idx; + uint64_t src_input_k_multi_i, tar_input_k_multi_i, buf_k_multi_i; + + for (uint64_t i = 0; i < nq; i++) { + src_input_k_multi_i = src_input_k * i; + tar_input_k_multi_i = tar_input_k * i; + buf_k_multi_i = output_k * i; + buf_k = src_k = tar_k = 0; + while (buf_k < output_k && src_k < src_input_k && tar_k < tar_input_k) { + src_idx = src_input_k_multi_i + src_k; + tar_idx = tar_input_k_multi_i + tar_k; + buf_idx = buf_k_multi_i + buf_k; + if ((ascending && src_distance[src_idx] < tar_distance[tar_idx]) || + (!ascending && src_distance[src_idx] > tar_distance[tar_idx])) { + id_buf[buf_idx] = src_ids[src_idx]; + dist_buf[buf_idx] = src_distance[src_idx]; + src_k++; + } else { + id_buf[buf_idx] = tar_ids[tar_idx]; + dist_buf[buf_idx] = tar_distance[tar_idx]; + tar_k++; + } + buf_k++; + } + + if (buf_k < output_k) { + if (src_k < src_input_k) { + while (buf_k < output_k && src_k < src_input_k) { + src_idx = src_input_k_multi_i + src_k; + id_buf[buf_idx] = src_ids[src_idx]; + dist_buf[buf_idx] = src_distance[src_idx]; + src_k++; + buf_k++; + } + } else { + while (buf_k < output_k && tar_k < tar_input_k) { + id_buf[buf_idx] = tar_ids[tar_idx]; + dist_buf[buf_idx] = tar_distance[tar_idx]; + tar_k++; + buf_k++; + } + } + } + } + + tar_ids.swap(id_buf); + tar_distance.swap(dist_buf); + tar_input_k = output_k; } } // namespace scheduler diff --git a/cpp/src/scheduler/task/SearchTask.h b/cpp/src/scheduler/task/SearchTask.h index fd5c8a0d1d..6a7381e0e6 100644 --- a/cpp/src/scheduler/task/SearchTask.h +++ b/cpp/src/scheduler/task/SearchTask.h @@ -38,9 +38,14 @@ class XSearchTask : public Task { Execute() override; public: - static Status - TopkResult(const std::vector& input_ids, const std::vector& input_distance, uint64_t input_k, - uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result); + static void + MergeTopkToResultSet(const std::vector& input_ids, const std::vector& input_distance, + uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result); + + static void + MergeTopkArray(std::vector& tar_ids, std::vector& tar_distance, uint64_t& tar_input_k, + const std::vector& src_ids, const std::vector& src_distance, uint64_t src_input_k, + uint64_t nq, uint64_t topk, bool ascending); public: TableFileSchemaPtr file_; @@ -49,8 +54,6 @@ class XSearchTask : public Task { int index_type_ = 0; ExecutionEnginePtr index_engine_ = nullptr; bool metric_l2 = true; - - static std::mutex merge_mutex_; }; } // namespace scheduler diff --git a/cpp/unittest/db/test_search.cpp b/cpp/unittest/db/test_search.cpp index e17e06ac16..b0ce9a28b6 100644 --- a/cpp/unittest/db/test_search.cpp +++ b/cpp/unittest/db/test_search.cpp @@ -21,26 +21,51 @@ #include "scheduler/task/SearchTask.h" #include "utils/TimeRecorder.h" +#include "utils/ThreadPool.h" namespace { namespace ms = milvus::scheduler; void -BuildResult(uint64_t nq, +BuildResult(std::vector& output_ids, + std::vector& output_distance, uint64_t topk, - bool ascending, - std::vector& output_ids, - std::vector& output_distence) { + uint64_t nq, + bool ascending) { output_ids.clear(); output_ids.resize(nq * topk); - output_distence.clear(); - output_distence.resize(nq * topk); + output_distance.clear(); + output_distance.resize(nq * topk); for (uint64_t i = 0; i < nq; i++) { for (uint64_t j = 0; j < topk; j++) { output_ids[i * topk + j] = (int64_t)(drand48() * 100000); - output_distence[i * topk + j] = ascending ? (j + drand48()) : ((topk - j) + drand48()); + output_distance[i * topk + j] = ascending ? (j + drand48()) : ((topk - j) + drand48()); + } + } +} + +void +CopyResult(std::vector& output_ids, + std::vector& output_distance, + uint64_t output_topk, + std::vector& input_ids, + std::vector& input_distance, + uint64_t input_topk, + uint64_t nq) { + ASSERT_TRUE(input_ids.size() >= nq * input_topk); + ASSERT_TRUE(input_distance.size() >= nq * input_topk); + ASSERT_TRUE(output_topk <= input_topk); + output_ids.clear(); + output_ids.resize(nq * output_topk); + output_distance.clear(); + output_distance.resize(nq * output_topk); + + for (uint64_t i = 0; i < nq; i++) { + for (uint64_t j = 0; j < output_topk; j++) { + output_ids[i * output_topk + j] = input_ids[i * input_topk + j]; + output_distance[i * output_topk + j] = input_distance[i * input_topk + j]; } } } @@ -50,8 +75,8 @@ CheckTopkResult(const std::vector& input_ids_1, const std::vector& input_distance_1, const std::vector& input_ids_2, const std::vector& input_distance_2, - uint64_t nq, uint64_t topk, + uint64_t nq, bool ascending, const milvus::scheduler::ResultSet& result) { ASSERT_EQ(result.size(), nq); @@ -91,43 +116,36 @@ TEST(DBSearchTest, TOPK_TEST) { bool ascending; std::vector ids1, ids2; std::vector dist1, dist2; - milvus::scheduler::ResultSet result; - milvus::Status status; + ms::ResultSet result; /* test1, id1/dist1 valid, id2/dist2 empty */ ascending = true; - BuildResult(NQ, TOP_K, ascending, ids1, dist1); - status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result); - ASSERT_TRUE(status.ok()); - CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + BuildResult(ids1, dist1, TOP_K, NQ, ascending); + ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result); + CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result); /* test2, id1/dist1 valid, id2/dist2 valid */ - BuildResult(NQ, TOP_K, ascending, ids2, dist2); - status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); - ASSERT_TRUE(status.ok()); - CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + BuildResult(ids2, dist2, TOP_K, NQ, ascending); + ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); + CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result); /* test3, id1/dist1 small topk */ ids1.clear(); dist1.clear(); result.clear(); - BuildResult(NQ, TOP_K / 2, ascending, ids1, dist1); - status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result); - ASSERT_TRUE(status.ok()); - status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); - ASSERT_TRUE(status.ok()); - CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + BuildResult(ids1, dist1, TOP_K/2, NQ, ascending); + ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); + ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); + CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result); /* test4, id1/dist1 small topk, id2/dist2 small topk */ ids2.clear(); dist2.clear(); result.clear(); - BuildResult(NQ, TOP_K / 3, ascending, ids2, dist2); - status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result); - ASSERT_TRUE(status.ok()); - status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K / 3, NQ, TOP_K, ascending, result); - ASSERT_TRUE(status.ok()); - CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + BuildResult(ids2, dist2, TOP_K/3, NQ, ascending); + ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); + ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result); + CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result); ///////////////////////////////////////////////////////////////////////////////////////// ascending = false; @@ -138,71 +156,199 @@ TEST(DBSearchTest, TOPK_TEST) { result.clear(); /* test1, id1/dist1 valid, id2/dist2 empty */ - BuildResult(NQ, TOP_K, ascending, ids1, dist1); - status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result); - ASSERT_TRUE(status.ok()); - CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + BuildResult(ids1, dist1, TOP_K, NQ, ascending); + ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result); + CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result); /* test2, id1/dist1 valid, id2/dist2 valid */ - BuildResult(NQ, TOP_K, ascending, ids2, dist2); - status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); - ASSERT_TRUE(status.ok()); - CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + BuildResult(ids2, dist2, TOP_K, NQ, ascending); + ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); + CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result); /* test3, id1/dist1 small topk */ ids1.clear(); dist1.clear(); result.clear(); - BuildResult(NQ, TOP_K / 2, ascending, ids1, dist1); - status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result); - ASSERT_TRUE(status.ok()); - status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); - ASSERT_TRUE(status.ok()); - CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + BuildResult(ids1, dist1, TOP_K/2, NQ, ascending); + ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); + ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); + CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result); /* test4, id1/dist1 small topk, id2/dist2 small topk */ ids2.clear(); dist2.clear(); result.clear(); - BuildResult(NQ, TOP_K / 3, ascending, ids2, dist2); - status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result); - ASSERT_TRUE(status.ok()); - status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K / 3, NQ, TOP_K, ascending, result); - ASSERT_TRUE(status.ok()); - CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + BuildResult(ids2, dist2, TOP_K/3, NQ, ascending); + ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); + ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result); + CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result); } TEST(DBSearchTest, REDUCE_PERF_TEST) { - int32_t nq = 100; - int32_t top_k = 1000; int32_t index_file_num = 478; /* sift1B dataset, index files num */ bool ascending = true; + + std::vector thread_vec = {4, 8, 11}; + std::vector nq_vec = {1, 10, 100, 1000}; + std::vector topk_vec = {1, 4, 16, 64, 256, 1024}; + int32_t NQ = nq_vec[nq_vec.size()-1]; + int32_t TOPK = topk_vec[topk_vec.size()-1]; + + std::vector> id_vec; + std::vector> dist_vec; std::vector input_ids; std::vector input_distance; - milvus::scheduler::ResultSet final_result; - milvus::Status status; + int32_t i, k, step; - double span, reduce_cost = 0.0; - milvus::TimeRecorder rc(""); - - for (int32_t i = 0; i < index_file_num; i++) { - BuildResult(nq, top_k, ascending, input_ids, input_distance); - - rc.RecordSection("do search for context: " + std::to_string(i)); - - // pick up topk result - status = milvus::scheduler::XSearchTask::TopkResult(input_ids, - input_distance, - top_k, - nq, - top_k, - ascending, - final_result); - ASSERT_TRUE(status.ok()); - ASSERT_EQ(final_result.size(), nq); - - span = rc.RecordSection("reduce topk for context: " + std::to_string(i)); - reduce_cost += span; + /* generate testing data */ + for (i = 0; i < index_file_num; i++) { + BuildResult(input_ids, input_distance, TOPK, NQ, ascending); + id_vec.push_back(input_ids); + dist_vec.push_back(input_distance); + } + + for (int32_t max_thread_num : thread_vec) { + milvus::ThreadPool threadPool(max_thread_num); + std::list> threads_list; + + for (int32_t nq : nq_vec) { + for (int32_t top_k : topk_vec) { + ms::ResultSet final_result, final_result_2, final_result_3; + + std::vector> id_vec_1(index_file_num); + std::vector> dist_vec_1(index_file_num); + for (i = 0; i < index_file_num; i++) { + CopyResult(id_vec_1[i], dist_vec_1[i], top_k, id_vec[i], dist_vec[i], TOPK, nq); + } + + std::string str1 = "Method-1 " + std::to_string(max_thread_num) + " " + + std::to_string(nq) + " " + std::to_string(top_k); + milvus::TimeRecorder rc1(str1); + + /////////////////////////////////////////////////////////////////////////////////////// + /* method-1 */ + for (i = 0; i < index_file_num; i++) { + ms::XSearchTask::MergeTopkToResultSet(id_vec_1[i], + dist_vec_1[i], + top_k, + nq, + top_k, + ascending, + final_result); + ASSERT_EQ(final_result.size(), nq); + } + + rc1.RecordSection("reduce done"); + + /////////////////////////////////////////////////////////////////////////////////////// + /* method-2 */ + std::vector> id_vec_2(index_file_num); + std::vector> dist_vec_2(index_file_num); + std::vector k_vec_2(index_file_num); + for (i = 0; i < index_file_num; i++) { + CopyResult(id_vec_2[i], dist_vec_2[i], top_k, id_vec[i], dist_vec[i], TOPK, nq); + k_vec_2[i] = top_k; + } + + std::string str2 = "Method-2 " + std::to_string(max_thread_num) + " " + + std::to_string(nq) + " " + std::to_string(top_k); + milvus::TimeRecorder rc2(str2); + + for (step = 1; step < index_file_num; step *= 2) { + for (i = 0; i + step < index_file_num; i += step * 2) { + ms::XSearchTask::MergeTopkArray(id_vec_2[i], dist_vec_2[i], k_vec_2[i], + id_vec_2[i + step], dist_vec_2[i + step], k_vec_2[i + step], + nq, top_k, ascending); + } + } + ms::XSearchTask::MergeTopkToResultSet(id_vec_2[0], + dist_vec_2[0], + k_vec_2[0], + nq, + top_k, + ascending, + final_result_2); + ASSERT_EQ(final_result_2.size(), nq); + + rc2.RecordSection("reduce done"); + + for (i = 0; i < nq; i++) { + ASSERT_EQ(final_result[i].size(), final_result_2[i].size()); + for (k = 0; k < final_result[i].size(); k++) { + if (final_result[i][k].first != final_result_2[i][k].first) { + std::cout << i << " " << k << std::endl; + } + ASSERT_EQ(final_result[i][k].first, final_result_2[i][k].first); + ASSERT_EQ(final_result[i][k].second, final_result_2[i][k].second); + } + } + + /////////////////////////////////////////////////////////////////////////////////////// + /* method-3 parallel */ + std::vector> id_vec_3(index_file_num); + std::vector> dist_vec_3(index_file_num); + std::vector k_vec_3(index_file_num); + for (i = 0; i < index_file_num; i++) { + CopyResult(id_vec_3[i], dist_vec_3[i], top_k, id_vec[i], dist_vec[i], TOPK, nq); + k_vec_3[i] = top_k; + } + + std::string str3 = "Method-3 " + std::to_string(max_thread_num) + " " + + std::to_string(nq) + " " + std::to_string(top_k); + milvus::TimeRecorder rc3(str3); + + for (step = 1; step < index_file_num; step *= 2) { + for (i = 0; i + step < index_file_num; i += step * 2) { + threads_list.push_back( + threadPool.enqueue(ms::XSearchTask::MergeTopkArray, + std::ref(id_vec_3[i]), + std::ref(dist_vec_3[i]), + std::ref(k_vec_3[i]), + std::ref(id_vec_3[i + step]), + std::ref(dist_vec_3[i + step]), + std::ref(k_vec_3[i + step]), + nq, + top_k, + ascending)); + } + + while (threads_list.size() > 0) { + int nready = 0; + for (auto it = threads_list.begin(); it != threads_list.end(); it = it) { + auto &p = *it; + std::chrono::milliseconds span(0); + if (p.wait_for(span) == std::future_status::ready) { + threads_list.erase(it++); + ++nready; + } else { + ++it; + } + } + + if (nready == 0) { + std::this_thread::yield(); + } + } + } + ms::XSearchTask::MergeTopkToResultSet(id_vec_3[0], + dist_vec_3[0], + k_vec_3[0], + nq, + top_k, + ascending, + final_result_3); + ASSERT_EQ(final_result_3.size(), nq); + + rc3.RecordSection("reduce done"); + + for (i = 0; i < nq; i++) { + ASSERT_EQ(final_result[i].size(), final_result_3[i].size()); + for (k = 0; k < final_result[i].size(); k++) { + ASSERT_EQ(final_result[i][k].first, final_result_3[i][k].first); + ASSERT_EQ(final_result[i][k].second, final_result_3[i][k].second); + } + } + } + } } - std::cout << "total reduce time: " << reduce_cost / 1000 << " ms" << std::endl; }