From 44b830639a51bf404b8d5c378fe612b7591736b1 Mon Sep 17 00:00:00 2001 From: "yudong.cai" Date: Mon, 14 Oct 2019 18:20:15 +0800 Subject: [PATCH] MS-606 fix result reduce bug Former-commit-id: 86bf350874b7a5ec72e77eea2be1f8a5a8548e19 --- core/src/scheduler/task/SearchTask.cpp | 9 +- core/unittest/db/test_search.cpp | 125 ++++++++++++++----------- 2 files changed, 76 insertions(+), 58 deletions(-) diff --git a/core/src/scheduler/task/SearchTask.cpp b/core/src/scheduler/task/SearchTask.cpp index b5f1599eba..b7a1e211d2 100644 --- a/core/src/scheduler/task/SearchTask.cpp +++ b/core/src/scheduler/task/SearchTask.cpp @@ -315,10 +315,10 @@ XSearchTask::MergeTopkArray(std::vector& tar_ids, std::vector& t return; } - std::vector id_buf(nq * topk, -1); - std::vector dist_buf(nq * topk, 0.0); - uint64_t output_k = std::min(topk, tar_input_k + src_input_k); + std::vector id_buf(nq * output_k, -1); + std::vector dist_buf(nq * output_k, 0.0); + uint64_t buf_k, src_k, tar_k; uint64_t src_idx, tar_idx, buf_idx; uint64_t src_input_k_multi_i, tar_input_k_multi_i, buf_k_multi_i; @@ -349,6 +349,7 @@ XSearchTask::MergeTopkArray(std::vector& tar_ids, std::vector& t if (src_k < src_input_k) { while (buf_k < output_k && src_k < src_input_k) { src_idx = src_input_k_multi_i + src_k; + buf_idx = buf_k_multi_i + buf_k; id_buf[buf_idx] = src_ids[src_idx]; dist_buf[buf_idx] = src_distance[src_idx]; src_k++; @@ -356,6 +357,8 @@ XSearchTask::MergeTopkArray(std::vector& tar_ids, std::vector& t } } else { while (buf_k < output_k && tar_k < tar_input_k) { + tar_idx = tar_input_k_multi_i + tar_k; + buf_idx = buf_k_multi_i + buf_k; id_buf[buf_idx] = tar_ids[tar_idx]; dist_buf[buf_idx] = tar_distance[tar_idx]; tar_k++; diff --git a/core/unittest/db/test_search.cpp b/core/unittest/db/test_search.cpp index b0ce9a28b6..dc393b7a26 100644 --- a/core/unittest/db/test_search.cpp +++ b/core/unittest/db/test_search.cpp @@ -110,87 +110,102 @@ CheckTopkResult(const std::vector& input_ids_1, } // namespace -TEST(DBSearchTest, TOPK_TEST) { - uint64_t NQ = 15; - uint64_t TOP_K = 64; - bool ascending; +void MergeTopkToResultSetTest(uint64_t topk_1, uint64_t topk_2, uint64_t nq, uint64_t topk, bool ascending) { std::vector ids1, ids2; std::vector dist1, dist2; ms::ResultSet result; + BuildResult(ids1, dist1, topk_1, nq, ascending); + BuildResult(ids2, dist2, topk_2, nq, ascending); + ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, topk_1, nq, topk, ascending, result); + ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, topk_2, nq, topk, ascending, result); + CheckTopkResult(ids1, dist1, ids2, dist2, topk, nq, ascending, result); +} + +TEST(DBSearchTest, MERGE_RESULT_SET_TEST) { + uint64_t NQ = 15; + uint64_t TOP_K = 64; /* test1, id1/dist1 valid, id2/dist2 empty */ - ascending = true; - BuildResult(ids1, dist1, TOP_K, NQ, ascending); - ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result); - CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result); + MergeTopkToResultSetTest(TOP_K, 0, NQ, TOP_K, true); + MergeTopkToResultSetTest(TOP_K, 0, NQ, TOP_K, false); /* test2, id1/dist1 valid, id2/dist2 valid */ - BuildResult(ids2, dist2, TOP_K, NQ, ascending); - ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); - CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result); + MergeTopkToResultSetTest(TOP_K, TOP_K, NQ, TOP_K, true); + MergeTopkToResultSetTest(TOP_K, TOP_K, NQ, TOP_K, false); /* test3, id1/dist1 small topk */ - ids1.clear(); - dist1.clear(); - result.clear(); - BuildResult(ids1, dist1, TOP_K/2, NQ, ascending); - ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); - ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); - CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result); + MergeTopkToResultSetTest(TOP_K/2, TOP_K, NQ, TOP_K, true); + MergeTopkToResultSetTest(TOP_K/2, TOP_K, NQ, TOP_K, false); /* test4, id1/dist1 small topk, id2/dist2 small topk */ - ids2.clear(); - dist2.clear(); - result.clear(); - BuildResult(ids2, dist2, TOP_K/3, NQ, ascending); - ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); - ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result); - CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result); + MergeTopkToResultSetTest(TOP_K/2, TOP_K/3, NQ, TOP_K, true); + MergeTopkToResultSetTest(TOP_K/2, TOP_K/3, NQ, TOP_K, false); +} -///////////////////////////////////////////////////////////////////////////////////////// - ascending = false; - ids1.clear(); - dist1.clear(); - ids2.clear(); - dist2.clear(); - result.clear(); +void MergeTopkArrayTest(uint64_t topk_1, uint64_t topk_2, uint64_t nq, uint64_t topk, bool ascending) { + std::vector ids1, ids2; + std::vector dist1, dist2; + ms::ResultSet result; + BuildResult(ids1, dist1, topk_1, nq, ascending); + BuildResult(ids2, dist2, topk_2, nq, ascending); + uint64_t result_topk = std::min(topk, topk_1 + topk_2); + ms::XSearchTask::MergeTopkArray(ids1, dist1, topk_1, ids2, dist2, topk_2, nq, topk, ascending); + if (ids1.size() != result_topk * nq) { + std::cout << ids1.size() << " " << result_topk * nq << std::endl; + } + ASSERT_TRUE(ids1.size() == result_topk * nq); + ASSERT_TRUE(dist1.size() == result_topk * nq); + for (uint64_t i = 0; i < nq; i++) { + for (uint64_t k = 1; k < result_topk; k++) { + if (ascending) { + if (dist1[i * result_topk + k] < dist1[i * result_topk + k - 1]) { + std::cout << dist1[i * result_topk + k - 1] << " " << dist1[i * result_topk + k] << std::endl; + } + ASSERT_TRUE(dist1[i * result_topk + k] >= dist1[i * result_topk + k - 1]); + } else { + if (dist1[i * result_topk + k] > dist1[i * result_topk + k - 1]) { + std::cout << dist1[i * result_topk + k - 1] << " " << dist1[i * result_topk + k] << std::endl; + } + ASSERT_TRUE(dist1[i * result_topk + k] <= dist1[i * result_topk + k - 1]); + } + } + } +} + +TEST(DBSearchTest, MERGE_ARRAY_TEST) { + uint64_t NQ = 15; + uint64_t TOP_K = 64; /* test1, id1/dist1 valid, id2/dist2 empty */ - BuildResult(ids1, dist1, TOP_K, NQ, ascending); - ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result); - CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result); + MergeTopkArrayTest(TOP_K, 0, NQ, TOP_K, true); + MergeTopkArrayTest(TOP_K, 0, NQ, TOP_K, false); + MergeTopkArrayTest(0, TOP_K, NQ, TOP_K, true); + MergeTopkArrayTest(0, TOP_K, NQ, TOP_K, false); /* test2, id1/dist1 valid, id2/dist2 valid */ - BuildResult(ids2, dist2, TOP_K, NQ, ascending); - ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); - CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result); + MergeTopkArrayTest(TOP_K, TOP_K, NQ, TOP_K, true); + MergeTopkArrayTest(TOP_K, TOP_K, NQ, TOP_K, false); /* test3, id1/dist1 small topk */ - ids1.clear(); - dist1.clear(); - result.clear(); - BuildResult(ids1, dist1, TOP_K/2, NQ, ascending); - ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); - ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); - CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result); + MergeTopkArrayTest(TOP_K/2, TOP_K, NQ, TOP_K, true); + MergeTopkArrayTest(TOP_K/2, TOP_K, NQ, TOP_K, false); + MergeTopkArrayTest(TOP_K, TOP_K/2, NQ, TOP_K, true); + MergeTopkArrayTest(TOP_K, TOP_K/2, NQ, TOP_K, false); /* test4, id1/dist1 small topk, id2/dist2 small topk */ - ids2.clear(); - dist2.clear(); - result.clear(); - BuildResult(ids2, dist2, TOP_K/3, NQ, ascending); - ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); - ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result); - CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result); + MergeTopkArrayTest(TOP_K/2, TOP_K/3, NQ, TOP_K, true); + MergeTopkArrayTest(TOP_K/2, TOP_K/3, NQ, TOP_K, false); + MergeTopkArrayTest(TOP_K/3, TOP_K/2, NQ, TOP_K, true); + MergeTopkArrayTest(TOP_K/3, TOP_K/2, NQ, TOP_K, false); } TEST(DBSearchTest, REDUCE_PERF_TEST) { int32_t index_file_num = 478; /* sift1B dataset, index files num */ bool ascending = true; - std::vector thread_vec = {4, 8, 11}; - std::vector nq_vec = {1, 10, 100, 1000}; - std::vector topk_vec = {1, 4, 16, 64, 256, 1024}; + std::vector thread_vec = {4, 8}; + std::vector nq_vec = {1, 10, 100}; + std::vector topk_vec = {1, 4, 16, 64}; int32_t NQ = nq_vec[nq_vec.size()-1]; int32_t TOPK = topk_vec[topk_vec.size()-1];