From bf1f7ce66de85fdcfe713ff08005861f68b2d7f6 Mon Sep 17 00:00:00 2001 From: "yudong.cai" Date: Wed, 9 Oct 2019 19:35:43 +0800 Subject: [PATCH] MS-606 support result reduce parallel Former-commit-id: 8b3b1fdae57214bcfb30bb8365aabd9eabc957cd --- cpp/src/scheduler/task/SearchTask.cpp | 126 ++++++++++++++----- cpp/src/scheduler/task/SearchTask.h | 24 +++- cpp/unittest/db/test_search.cpp | 170 ++++++++++++++++++-------- 3 files changed, 235 insertions(+), 85 deletions(-) diff --git a/cpp/src/scheduler/task/SearchTask.cpp b/cpp/src/scheduler/task/SearchTask.cpp index 9925a8bcf8..b718fca7b6 100644 --- a/cpp/src/scheduler/task/SearchTask.cpp +++ b/cpp/src/scheduler/task/SearchTask.cpp @@ -33,8 +33,6 @@ namespace scheduler { static constexpr size_t PARALLEL_REDUCE_THRESHOLD = 10000; static constexpr size_t PARALLEL_REDUCE_BATCH = 1000; -std::mutex XSearchTask::merge_mutex_; - // TODO(wxyu): remove unused code // bool // NeedParallelReduce(uint64_t nq, uint64_t topk) { @@ -211,7 +209,7 @@ XSearchTask::Execute() { // step 3: pick up topk result auto spec_k = index_engine_->Count() < topk ? index_engine_->Count() : topk; - XSearchTask::TopkResult(output_ids, output_distance, spec_k, nq, topk, metric_l2, search_job->GetResult()); + XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, metric_l2, search_job->GetResult()); span = rc.RecordSection(hdr + ", reduce topk"); // search_job->AccumReduceCost(span); @@ -220,7 +218,7 @@ XSearchTask::Execute() { // search_job->IndexSearchDone(index_id_);//mark as done avoid dead lock, even search failed } - // step 5: notify to send result to client + // step 4: notify to send result to client search_job->SearchDone(index_id_); } @@ -230,37 +228,42 @@ XSearchTask::Execute() { index_engine_ = nullptr; } -Status -XSearchTask::TopkResult(const std::vector& input_ids, const std::vector& input_distance, - uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result) { - scheduler::ResultSet result_buf; - +void +XSearchTask::MergeTopkToResultSet(const std::vector& input_ids, + const std::vector& input_distance, + uint64_t input_k, + uint64_t nq, + uint64_t topk, + bool ascending, + scheduler::ResultSet& result) { if (result.empty()) { - result_buf.resize(nq, scheduler::Id2DistVec(input_k, scheduler::IdDistPair(-1, 0.0))); - for (auto i = 0; i < nq; ++i) { - auto& result_buf_i = result_buf[i]; + result.resize(nq); + } + + for (uint64_t i = 0; i < nq; i++) { + scheduler::Id2DistVec result_buf; + auto &result_i = result[i]; + + if (result[i].empty()) { + result_buf.resize(input_k, scheduler::IdDistPair(-1, 0.0)); uint64_t input_k_multi_i = input_k * i; for (auto k = 0; k < input_k; ++k) { uint64_t idx = input_k_multi_i + k; - auto& result_buf_item = result_buf_i[k]; + auto &result_buf_item = result_buf[k]; result_buf_item.first = input_ids[idx]; result_buf_item.second = input_distance[idx]; } - } - } else { - size_t tar_size = result[0].size(); - uint64_t output_k = std::min(topk, input_k + tar_size); - result_buf.resize(nq, scheduler::Id2DistVec(output_k, scheduler::IdDistPair(-1, 0.0))); - for (auto i = 0; i < nq; ++i) { + } else { + size_t tar_size = result_i.size(); + uint64_t output_k = std::min(topk, input_k + tar_size); + result_buf.resize(output_k, scheduler::IdDistPair(-1, 0.0)); size_t buf_k = 0, src_k = 0, tar_k = 0; uint64_t src_idx; - auto& result_i = result[i]; - auto& result_buf_i = result_buf[i]; uint64_t input_k_multi_i = input_k * i; while (buf_k < output_k && src_k < input_k && tar_k < tar_size) { src_idx = input_k_multi_i + src_k; - auto& result_buf_item = result_buf_i[buf_k]; - auto& result_item = result_i[tar_k]; + auto &result_buf_item = result_buf[buf_k]; + auto &result_item = result_i[tar_k]; if ((ascending && input_distance[src_idx] < result_item.second) || (!ascending && input_distance[src_idx] > result_item.second)) { result_buf_item.first = input_ids[src_idx]; @@ -273,11 +276,11 @@ XSearchTask::TopkResult(const std::vector& input_ids, const std::vector buf_k++; } - if (buf_k < topk) { + if (buf_k < output_k) { if (src_k < input_k) { while (buf_k < output_k && src_k < input_k) { src_idx = input_k_multi_i + src_k; - auto& result_buf_item = result_buf_i[buf_k]; + auto &result_buf_item = result_buf[buf_k]; result_buf_item.first = input_ids[src_idx]; result_buf_item.second = input_distance[src_idx]; src_k++; @@ -285,18 +288,83 @@ XSearchTask::TopkResult(const std::vector& input_ids, const std::vector } } else { while (buf_k < output_k && tar_k < tar_size) { - result_buf_i[buf_k] = result_i[tar_k]; + result_buf[buf_k] = result_i[tar_k]; tar_k++; buf_k++; } } } } + + result_i.swap(result_buf); + } +} + +void +XSearchTask::MergeTopkArray(std::vector& tar_ids, + std::vector& tar_distance, + uint64_t& tar_input_k, + const std::vector& src_ids, + const std::vector& src_distance, + uint64_t src_input_k, + uint64_t nq, + uint64_t topk, + bool ascending) { + if (src_ids.empty() || src_distance.empty()) return; + + std::vector id_buf(nq*topk, -1); + std::vector dist_buf(nq*topk, 0.0); + + uint64_t output_k = std::min(topk, tar_input_k + src_input_k); + uint64_t buf_k, src_k, tar_k; + uint64_t src_idx, tar_idx, buf_idx; + uint64_t src_input_k_multi_i, tar_input_k_multi_i, buf_k_multi_i; + + for (uint64_t i = 0; i < nq; i++) { + src_input_k_multi_i = src_input_k * i; + tar_input_k_multi_i = tar_input_k * i; + buf_k_multi_i = output_k * i; + buf_k = src_k = tar_k = 0; + while (buf_k < output_k && src_k < src_input_k && tar_k < tar_input_k) { + src_idx = src_input_k_multi_i + src_k; + tar_idx = tar_input_k_multi_i + tar_k; + buf_idx = buf_k_multi_i + buf_k; + if ((ascending && src_distance[src_idx] < tar_distance[tar_idx]) || + (!ascending && src_distance[src_idx] > tar_distance[tar_idx])) { + id_buf[buf_idx] = src_ids[src_idx]; + dist_buf[buf_idx] = src_distance[src_idx]; + src_k++; + } else { + id_buf[buf_idx] = tar_ids[tar_idx]; + dist_buf[buf_idx] = tar_distance[tar_idx]; + tar_k++; + } + buf_k++; + } + + if (buf_k < output_k) { + if (src_k < src_input_k) { + while (buf_k < output_k && src_k < src_input_k) { + src_idx = src_input_k_multi_i + src_k; + id_buf[buf_idx] = src_ids[src_idx]; + dist_buf[buf_idx] = src_distance[src_idx]; + src_k++; + buf_k++; + } + } else { + while (buf_k < output_k && tar_k < tar_input_k) { + id_buf[buf_idx] = tar_ids[tar_idx]; + dist_buf[buf_idx] = tar_distance[tar_idx]; + tar_k++; + buf_k++; + } + } + } } - result.swap(result_buf); - - return Status::OK(); + tar_ids.swap(id_buf); + tar_distance.swap(dist_buf); + tar_input_k = output_k; } } // namespace scheduler diff --git a/cpp/src/scheduler/task/SearchTask.h b/cpp/src/scheduler/task/SearchTask.h index fd5c8a0d1d..3dd832d1da 100644 --- a/cpp/src/scheduler/task/SearchTask.h +++ b/cpp/src/scheduler/task/SearchTask.h @@ -38,9 +38,25 @@ class XSearchTask : public Task { Execute() override; public: - static Status - TopkResult(const std::vector& input_ids, const std::vector& input_distance, uint64_t input_k, - uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result); + static void + MergeTopkToResultSet(const std::vector& input_ids, + const std::vector& input_distance, + uint64_t input_k, + uint64_t nq, + uint64_t topk, + bool ascending, + scheduler::ResultSet& result); + + static void + MergeTopkArray(std::vector& tar_ids, + std::vector& tar_distance, + uint64_t& tar_input_k, + const std::vector& src_ids, + const std::vector& src_distance, + uint64_t src_input_k, + uint64_t nq, + uint64_t topk, + bool ascending); public: TableFileSchemaPtr file_; @@ -49,8 +65,6 @@ class XSearchTask : public Task { int index_type_ = 0; ExecutionEnginePtr index_engine_ = nullptr; bool metric_l2 = true; - - static std::mutex merge_mutex_; }; } // namespace scheduler diff --git a/cpp/unittest/db/test_search.cpp b/cpp/unittest/db/test_search.cpp index e17e06ac16..e858fc4c4f 100644 --- a/cpp/unittest/db/test_search.cpp +++ b/cpp/unittest/db/test_search.cpp @@ -21,6 +21,7 @@ #include "scheduler/task/SearchTask.h" #include "utils/TimeRecorder.h" +#include "utils/ThreadPool.h" namespace { @@ -91,42 +92,35 @@ TEST(DBSearchTest, TOPK_TEST) { bool ascending; std::vector ids1, ids2; std::vector dist1, dist2; - milvus::scheduler::ResultSet result; - milvus::Status status; + ms::ResultSet result; /* test1, id1/dist1 valid, id2/dist2 empty */ ascending = true; BuildResult(NQ, TOP_K, ascending, ids1, dist1); - status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result); - ASSERT_TRUE(status.ok()); + ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); /* test2, id1/dist1 valid, id2/dist2 valid */ BuildResult(NQ, TOP_K, ascending, ids2, dist2); - status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); - ASSERT_TRUE(status.ok()); + ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); /* test3, id1/dist1 small topk */ ids1.clear(); dist1.clear(); result.clear(); - BuildResult(NQ, TOP_K / 2, ascending, ids1, dist1); - status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result); - ASSERT_TRUE(status.ok()); - status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); - ASSERT_TRUE(status.ok()); + BuildResult(NQ, TOP_K/2, ascending, ids1, dist1); + ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); + ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); /* test4, id1/dist1 small topk, id2/dist2 small topk */ ids2.clear(); dist2.clear(); result.clear(); - BuildResult(NQ, TOP_K / 3, ascending, ids2, dist2); - status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result); - ASSERT_TRUE(status.ok()); - status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K / 3, NQ, TOP_K, ascending, result); - ASSERT_TRUE(status.ok()); + BuildResult(NQ, TOP_K/3, ascending, ids2, dist2); + ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); + ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); ///////////////////////////////////////////////////////////////////////////////////////// @@ -139,36 +133,30 @@ TEST(DBSearchTest, TOPK_TEST) { /* test1, id1/dist1 valid, id2/dist2 empty */ BuildResult(NQ, TOP_K, ascending, ids1, dist1); - status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result); - ASSERT_TRUE(status.ok()); + ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); /* test2, id1/dist1 valid, id2/dist2 valid */ BuildResult(NQ, TOP_K, ascending, ids2, dist2); - status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); - ASSERT_TRUE(status.ok()); + ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); /* test3, id1/dist1 small topk */ ids1.clear(); dist1.clear(); result.clear(); - BuildResult(NQ, TOP_K / 2, ascending, ids1, dist1); - status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result); - ASSERT_TRUE(status.ok()); - status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); - ASSERT_TRUE(status.ok()); + BuildResult(NQ, TOP_K/2, ascending, ids1, dist1); + ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); + ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); /* test4, id1/dist1 small topk, id2/dist2 small topk */ ids2.clear(); dist2.clear(); result.clear(); - BuildResult(NQ, TOP_K / 3, ascending, ids2, dist2); - status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result); - ASSERT_TRUE(status.ok()); - status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K / 3, NQ, TOP_K, ascending, result); - ASSERT_TRUE(status.ok()); + BuildResult(NQ, TOP_K/3, ascending, ids2, dist2); + ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); + ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result); CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); } @@ -177,32 +165,112 @@ TEST(DBSearchTest, REDUCE_PERF_TEST) { int32_t top_k = 1000; int32_t index_file_num = 478; /* sift1B dataset, index files num */ bool ascending = true; + std::vector> id_vec; + std::vector> dist_vec; + std::vector k_vec; std::vector input_ids; std::vector input_distance; - milvus::scheduler::ResultSet final_result; - milvus::Status status; + ms::ResultSet final_result, final_result_2, final_result_3; - double span, reduce_cost = 0.0; + int32_t i, k, step; + double reduce_cost = 0.0; milvus::TimeRecorder rc(""); - for (int32_t i = 0; i < index_file_num; i++) { + for (i = 0; i < index_file_num; i++) { BuildResult(nq, top_k, ascending, input_ids, input_distance); - - rc.RecordSection("do search for context: " + std::to_string(i)); - - // pick up topk result - status = milvus::scheduler::XSearchTask::TopkResult(input_ids, - input_distance, - top_k, - nq, - top_k, - ascending, - final_result); - ASSERT_TRUE(status.ok()); - ASSERT_EQ(final_result.size(), nq); - - span = rc.RecordSection("reduce topk for context: " + std::to_string(i)); - reduce_cost += span; + id_vec.push_back(input_ids); + dist_vec.push_back(input_distance); + k_vec.push_back(top_k); + } + + rc.RecordSection("Method-1 result reduce start"); + + /* method-1 */ + for (i = 0; i < index_file_num; i++) { + ms::XSearchTask::MergeTopkToResultSet(id_vec[i], dist_vec[i], k_vec[i], nq, top_k, ascending, final_result); + ASSERT_EQ(final_result.size(), nq); + } + + reduce_cost = rc.RecordSection("Method-1 result reduce done"); + std::cout << "Method-1: total reduce time " << reduce_cost/1000 << " ms" << std::endl; + + /* method-2 */ + std::vector> id_vec_2(id_vec); + std::vector> dist_vec_2(dist_vec); + std::vector k_vec_2(k_vec); + + rc.RecordSection("Method-2 result reduce start"); + + for (step = 1; step < index_file_num; step *= 2) { + for (i = 0; i+step < index_file_num; i += step*2) { + ms::XSearchTask::MergeTopkArray(id_vec_2[i], dist_vec_2[i], k_vec_2[i], + id_vec_2[i+step], dist_vec_2[i+step], k_vec_2[i+step], + nq, top_k, ascending); + } + } + ms::XSearchTask::MergeTopkToResultSet(id_vec_2[0], dist_vec_2[0], k_vec_2[0], nq, top_k, ascending, final_result_2); + ASSERT_EQ(final_result_2.size(), nq); + + reduce_cost = rc.RecordSection("Method-2 result reduce done"); + std::cout << "Method-2: total reduce time " << reduce_cost/1000 << " ms" << std::endl; + + for (i = 0; i < nq; i++) { + ASSERT_EQ(final_result[i].size(), final_result_2[i].size()); + for (k = 0; k < final_result.size(); k++) { + ASSERT_EQ(final_result[i][k].first, final_result_2[i][k].first); + ASSERT_EQ(final_result[i][k].second, final_result_2[i][k].second); + } + } + + /* method-3 parallel */ + std::vector> id_vec_3(id_vec); + std::vector> dist_vec_3(dist_vec); + std::vector k_vec_3(k_vec); + + uint32_t max_thread_count = std::min(std::thread::hardware_concurrency() - 1, (uint32_t)MAX_THREADS_NUM); + milvus::ThreadPool threadPool(max_thread_count); + std::list> threads_list; + + rc.RecordSection("Method-3 parallel result reduce start"); + + for (step = 1; step < index_file_num; step *= 2) { + for (i = 0; i+step < index_file_num; i += step*2) { + threads_list.push_back( + threadPool.enqueue(ms::XSearchTask::MergeTopkArray, + std::ref(id_vec_3[i]), std::ref(dist_vec_3[i]), std::ref(k_vec_3[i]), + std::ref(id_vec_3[i+step]), std::ref(dist_vec_3[i+step]), std::ref(k_vec_3[i+step]), + nq, top_k, ascending)); + } + + while (threads_list.size() > 0) { + int nready = 0; + for (auto it = threads_list.begin(); it != threads_list.end(); it = it) { + auto &p = *it; + std::chrono::milliseconds span(0); + if (p.wait_for(span) == std::future_status::ready) { + threads_list.erase(it++); + ++nready; + } else { + ++it; + } + } + + if (nready == 0) { + std::this_thread::yield(); + } + } + } + ms::XSearchTask::MergeTopkToResultSet(id_vec_3[0], dist_vec_3[0], k_vec_3[0], nq, top_k, ascending, final_result_3); + ASSERT_EQ(final_result_3.size(), nq); + + reduce_cost = rc.RecordSection("Method-3 parallel result reduce done"); + std::cout << "Method-3 parallel: total reduce time " << reduce_cost/1000 << " ms" << std::endl; + + for (i = 0; i < nq; i++) { + ASSERT_EQ(final_result[i].size(), final_result_3[i].size()); + for (k = 0; k < final_result.size(); k++) { + ASSERT_EQ(final_result[i][k].first, final_result_3[i][k].first); + ASSERT_EQ(final_result[i][k].second, final_result_3[i][k].second); + } } - std::cout << "total reduce time: " << reduce_cost / 1000 << " ms" << std::endl; }