diff --git a/internal/core/unittest/test_reduce.cpp b/internal/core/unittest/test_reduce.cpp index 01fdd17b62..1dfef70e82 100644 --- a/internal/core/unittest/test_reduce.cpp +++ b/internal/core/unittest/test_reduce.cpp @@ -24,7 +24,7 @@ using SubSearchResultUniq = std::unique_ptr; std::default_random_engine e(42); -std::unique_ptr +SubSearchResultUniq GenSubSearchResult(const int64_t nq, const int64_t topk, const knowhere::MetricType &metric_type, @@ -34,8 +34,8 @@ GenSubSearchResult(const int64_t nq, SubSearchResultUniq sub_result = std::make_unique(nq, topk, metric_type, round_decimal); std::vector ids; std::vector distances; - for (int n = 0; n < nq; ++n) { - for (int k = 0; k < topk; ++k) { + for (auto n = 0; n < nq; ++n) { + for (auto k = 0; k < topk; ++k) { auto gen_x = e() % limit; ids.push_back(gen_x); distances.push_back(gen_x); @@ -57,7 +57,7 @@ template void CheckSubSearchResult(const int64_t nq, const int64_t topk, - SubSearchResult& search_result, + SubSearchResult& result, std::vector& result_ref) { ASSERT_EQ(result_ref.size(), nq); for (int n = 0; n < nq; ++n) { @@ -66,8 +66,8 @@ CheckSubSearchResult(const int64_t nq, auto ref_x = result_ref[n].top(); result_ref[n].pop(); auto index = n * topk + topk - 1 - k; - auto id = search_result.get_seg_offsets()[index]; - auto distance = search_result.get_distances()[index]; + auto id = result.get_seg_offsets()[index]; + auto distance = result.get_distances()[index]; ASSERT_EQ(id, ref_x); ASSERT_EQ(distance, ref_x); } @@ -76,19 +76,19 @@ CheckSubSearchResult(const int64_t nq, template void -TestSubSearchResultMerge(const knowhere::MetricType& metric_type) { - int64_t num_queries = 16; - int64_t topk = 10; - int64_t iteration = 10; - int64_t round_decimal = 3; +TestSubSearchResultMerge(const knowhere::MetricType& metric_type, + const int64_t iteration, + const int64_t nq, + const int64_t topk) { + const int64_t round_decimal = 3; - std::vector result_ref(num_queries); + std::vector result_ref(nq); - SubSearchResult final_result(num_queries, topk, metric_type, round_decimal); + SubSearchResult final_result(nq, topk, metric_type, round_decimal); for (int i = 0; i < iteration; ++i) { - SubSearchResultUniq sub_result = GenSubSearchResult(num_queries, topk, metric_type, round_decimal); + SubSearchResultUniq sub_result = GenSubSearchResult(nq, topk, metric_type, round_decimal); auto ids = sub_result->get_ids(); - for (int n = 0; n < num_queries; ++n) { + for (int n = 0; n < nq; ++n) { for (int k = 0; k < topk; ++k) { int64_t x = ids[n * topk + k]; result_ref[n].push(x); @@ -99,12 +99,28 @@ TestSubSearchResultMerge(const knowhere::MetricType& metric_type) { } final_result.merge(*sub_result); } - CheckSubSearchResult(num_queries, topk, final_result, result_ref); + CheckSubSearchResult(nq, topk, final_result, result_ref); } TEST(Reduce, SubSearchResult) { using queue_type_l2 = std::priority_queue, std::less>; using queue_type_ip = std::priority_queue, std::greater>; - TestSubSearchResultMerge(knowhere::metric::L2); - TestSubSearchResultMerge(knowhere::metric::IP); + + TestSubSearchResultMerge(knowhere::metric::L2, 1, 1, 1); + TestSubSearchResultMerge(knowhere::metric::L2, 1, 1, 10); + TestSubSearchResultMerge(knowhere::metric::L2, 1, 16, 1); + TestSubSearchResultMerge(knowhere::metric::L2, 1, 16, 10); + TestSubSearchResultMerge(knowhere::metric::L2, 4, 1, 1); + TestSubSearchResultMerge(knowhere::metric::L2, 4, 1, 10); + TestSubSearchResultMerge(knowhere::metric::L2, 4, 16, 1); + TestSubSearchResultMerge(knowhere::metric::L2, 4, 16, 10); + + TestSubSearchResultMerge(knowhere::metric::IP, 1, 1, 1); + TestSubSearchResultMerge(knowhere::metric::IP, 1, 1, 10); + TestSubSearchResultMerge(knowhere::metric::IP, 1, 16, 1); + TestSubSearchResultMerge(knowhere::metric::IP, 1, 16, 10); + TestSubSearchResultMerge(knowhere::metric::IP, 4, 1, 1); + TestSubSearchResultMerge(knowhere::metric::IP, 4, 1, 10); + TestSubSearchResultMerge(knowhere::metric::IP, 4, 16, 1); + TestSubSearchResultMerge(knowhere::metric::IP, 4, 16, 10); } \ No newline at end of file