From 0ae4e5265f36b57fead6e5340f3d020e06bbad45 Mon Sep 17 00:00:00 2001
From: liliu-z <105927039+liliu-z@users.noreply.github.com>
Date: Mon, 14 Nov 2022 21:05:07 +0800
Subject: [PATCH] Optimize some low efficient code (#20529)
Signed-off-by: Li Liu
Signed-off-by: Li Liu
Co-authored-by: Li Liu
---
internal/core/src/segcore/Reduce.cpp | 57 ++++++++++++++++------------
1 file changed, 33 insertions(+), 24 deletions(-)
diff --git a/internal/core/src/segcore/Reduce.cpp b/internal/core/src/segcore/Reduce.cpp
index e030632de8..ecc52b9dce 100644
--- a/internal/core/src/segcore/Reduce.cpp
+++ b/internal/core/src/segcore/Reduce.cpp
@@ -76,23 +76,24 @@ ReduceHelper::FilterInvalidSearchResult(SearchResult* search_result) {
AssertInfo(search_result->distances_.size() == nq * topK,
"wrong distances size, size = " + std::to_string(search_result->distances_.size()) +
", expected size = " + std::to_string(nq * topK));
- std::vector real_topks(nq);
- std::vector distances;
- std::vector seg_offsets;
- for (auto i = 0; i < nq; i++) {
- real_topks[i] = 0;
- for (auto j = 0; j < topK; j++) {
- auto offset = i * topK + j;
- if (search_result->seg_offsets_[offset] != INVALID_SEG_OFFSET) {
+ std::vector real_topks(nq, 0);
+ uint32_t valid_index = 0;
+ auto& offsets = search_result->seg_offsets_;
+ auto& distances = search_result->distances_;
+ for (auto i = 0; i < nq; ++i) {
+ for (auto j = 0; j < topK; ++j) {
+ auto index = i * topK + j;
+ if (offsets[index] != INVALID_SEG_OFFSET) {
real_topks[i]++;
- seg_offsets.push_back(search_result->seg_offsets_[offset]);
- distances.push_back(search_result->distances_[offset]);
+ offsets[valid_index] = offsets[index];
+ distances[valid_index] = distances[index];
+ valid_index++;
}
}
}
+ offsets.resize(valid_index);
+ distances.resize(valid_index);
- search_result->distances_.swap(distances);
- search_result->seg_offsets_.swap(seg_offsets);
search_result->topk_per_nq_prefix_sum_.resize(nq + 1);
std::partial_sum(real_topks.begin(), real_topks.end(), search_result->topk_per_nq_prefix_sum_.begin() + 1);
}
@@ -101,15 +102,16 @@ void
ReduceHelper::FillPrimaryKey() {
std::vector valid_search_results;
// get primary keys for duplicates removal
- for (auto search_result : search_results_) {
+ uint32_t valid_index = 0;
+ for (auto& search_result : search_results_) {
FilterInvalidSearchResult(search_result);
if (search_result->get_total_result_count() > 0) {
auto segment = static_cast(search_result->segment_);
segment->FillPrimaryKeys(plan_, *search_result);
- valid_search_results.emplace_back(search_result);
+ search_results_[valid_index++] = search_result;
}
}
- search_results_.swap(valid_search_results);
+ search_results_.resize(valid_index);
num_segments_ = search_results_.size();
}
@@ -119,20 +121,27 @@ ReduceHelper::RefreshSearchResult() {
std::vector real_topks(total_nq_, 0);
auto search_result = search_results_[i];
if (search_result->result_offsets_.size() != 0) {
- std::vector primary_keys;
- std::vector distances;
- std::vector seg_offsets;
+ uint32_t size = 0;
+ for (int j = 0; j < total_nq_; j++) {
+ size += final_search_records_[i][j].size();
+ }
+ std::vector primary_keys(size);
+ std::vector distances(size);
+ std::vector seg_offsets(size);
+
+ uint32_t index = 0;
for (int j = 0; j < total_nq_; j++) {
for (auto offset : final_search_records_[i][j]) {
- primary_keys.push_back(search_result->primary_keys_[offset]);
- distances.push_back(search_result->distances_[offset]);
- seg_offsets.push_back(search_result->seg_offsets_[offset]);
+ primary_keys[index] = search_result->primary_keys_[offset];
+ distances[index] = search_result->distances_[offset];
+ seg_offsets[index] = search_result->seg_offsets_[offset];
+ index++;
real_topks[j]++;
}
}
- search_result->primary_keys_ = std::move(primary_keys);
- search_result->distances_ = std::move(distances);
- search_result->seg_offsets_ = std::move(seg_offsets);
+ search_result->primary_keys_.swap(primary_keys);
+ search_result->distances_.swap(distances);
+ search_result->seg_offsets_.swap(seg_offsets);
}
std::partial_sum(real_topks.begin(), real_topks.end(), search_result->topk_per_nq_prefix_sum_.begin() + 1);
}