From 0ae4e5265f36b57fead6e5340f3d020e06bbad45 Mon Sep 17 00:00:00 2001 From: liliu-z <105927039+liliu-z@users.noreply.github.com> Date: Mon, 14 Nov 2022 21:05:07 +0800 Subject: [PATCH] Optimize some low efficient code (#20529) Signed-off-by: Li Liu Signed-off-by: Li Liu Co-authored-by: Li Liu --- internal/core/src/segcore/Reduce.cpp | 57 ++++++++++++++++------------ 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/internal/core/src/segcore/Reduce.cpp b/internal/core/src/segcore/Reduce.cpp index e030632de8..ecc52b9dce 100644 --- a/internal/core/src/segcore/Reduce.cpp +++ b/internal/core/src/segcore/Reduce.cpp @@ -76,23 +76,24 @@ ReduceHelper::FilterInvalidSearchResult(SearchResult* search_result) { AssertInfo(search_result->distances_.size() == nq * topK, "wrong distances size, size = " + std::to_string(search_result->distances_.size()) + ", expected size = " + std::to_string(nq * topK)); - std::vector real_topks(nq); - std::vector distances; - std::vector seg_offsets; - for (auto i = 0; i < nq; i++) { - real_topks[i] = 0; - for (auto j = 0; j < topK; j++) { - auto offset = i * topK + j; - if (search_result->seg_offsets_[offset] != INVALID_SEG_OFFSET) { + std::vector real_topks(nq, 0); + uint32_t valid_index = 0; + auto& offsets = search_result->seg_offsets_; + auto& distances = search_result->distances_; + for (auto i = 0; i < nq; ++i) { + for (auto j = 0; j < topK; ++j) { + auto index = i * topK + j; + if (offsets[index] != INVALID_SEG_OFFSET) { real_topks[i]++; - seg_offsets.push_back(search_result->seg_offsets_[offset]); - distances.push_back(search_result->distances_[offset]); + offsets[valid_index] = offsets[index]; + distances[valid_index] = distances[index]; + valid_index++; } } } + offsets.resize(valid_index); + distances.resize(valid_index); - search_result->distances_.swap(distances); - search_result->seg_offsets_.swap(seg_offsets); search_result->topk_per_nq_prefix_sum_.resize(nq + 1); std::partial_sum(real_topks.begin(), real_topks.end(), search_result->topk_per_nq_prefix_sum_.begin() + 1); } @@ -101,15 +102,16 @@ void ReduceHelper::FillPrimaryKey() { std::vector valid_search_results; // get primary keys for duplicates removal - for (auto search_result : search_results_) { + uint32_t valid_index = 0; + for (auto& search_result : search_results_) { FilterInvalidSearchResult(search_result); if (search_result->get_total_result_count() > 0) { auto segment = static_cast(search_result->segment_); segment->FillPrimaryKeys(plan_, *search_result); - valid_search_results.emplace_back(search_result); + search_results_[valid_index++] = search_result; } } - search_results_.swap(valid_search_results); + search_results_.resize(valid_index); num_segments_ = search_results_.size(); } @@ -119,20 +121,27 @@ ReduceHelper::RefreshSearchResult() { std::vector real_topks(total_nq_, 0); auto search_result = search_results_[i]; if (search_result->result_offsets_.size() != 0) { - std::vector primary_keys; - std::vector distances; - std::vector seg_offsets; + uint32_t size = 0; + for (int j = 0; j < total_nq_; j++) { + size += final_search_records_[i][j].size(); + } + std::vector primary_keys(size); + std::vector distances(size); + std::vector seg_offsets(size); + + uint32_t index = 0; for (int j = 0; j < total_nq_; j++) { for (auto offset : final_search_records_[i][j]) { - primary_keys.push_back(search_result->primary_keys_[offset]); - distances.push_back(search_result->distances_[offset]); - seg_offsets.push_back(search_result->seg_offsets_[offset]); + primary_keys[index] = search_result->primary_keys_[offset]; + distances[index] = search_result->distances_[offset]; + seg_offsets[index] = search_result->seg_offsets_[offset]; + index++; real_topks[j]++; } } - search_result->primary_keys_ = std::move(primary_keys); - search_result->distances_ = std::move(distances); - search_result->seg_offsets_ = std::move(seg_offsets); + search_result->primary_keys_.swap(primary_keys); + search_result->distances_.swap(distances); + search_result->seg_offsets_.swap(seg_offsets); } std::partial_sum(real_topks.begin(), real_topks.end(), search_result->topk_per_nq_prefix_sum_.begin() + 1); }