Optimize performance of reducing segments (#21722)

- Improve the performance of reducing from O(knlogn) to O(nlogk)

Signed-off-by: yah01 <yang.cen@zilliz.com>
This commit is contained in:
yah01 2023-01-16 15:25:42 +08:00 committed by GitHub
parent 5f3d3dc4fc
commit 836773f1a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 219 additions and 211 deletions

View File

@ -17,7 +17,6 @@
#include "Reduce.h"
#include "pkVisitor.h"
#include "SegmentInterface.h"
#include "ReduceStructure.h"
#include "Utils.h"
namespace milvus::segcore {
@ -157,7 +156,12 @@ ReduceHelper::FillEntryData() {
int64_t
ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi, int64_t topk, int64_t& offset) {
std::vector<SearchResultPair> result_pairs;
while (!heap_.empty()) {
heap_.pop();
}
pk_set_.clear();
pairs_.reserve(num_segments_);
for (int i = 0; i < num_segments_; i++) {
auto search_result = search_results_[i];
auto offset_beg = search_result->topk_per_nq_prefix_sum_[qi];
@ -167,36 +171,39 @@ ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi, int64_t topk, int64_t& offs
}
auto primary_key = search_result->primary_keys_[offset_beg];
auto distance = search_result->distances_[offset_beg];
result_pairs.emplace_back(primary_key, distance, search_result, i, offset_beg, offset_end);
pairs_.emplace_back(primary_key, distance, search_result, i, offset_beg, offset_end);
heap_.push(&pairs_.back());
}
// nq has no results for all segments
if (result_pairs.size() == 0) {
if (heap_.size() == 0) {
return 0;
}
int64_t dup_cnt = 0;
std::unordered_set<milvus::PkType> pk_set;
int64_t prev_offset = offset;
while (offset - prev_offset < topk) {
std::sort(result_pairs.begin(), result_pairs.end(), std::greater<>());
auto& pilot = result_pairs[0];
auto index = pilot.segment_index_;
auto pk = pilot.primary_key_;
auto start = offset;
while (offset - start < topk) {
auto pilot = heap_.top();
heap_.pop();
auto index = pilot->segment_index_;
auto pk = pilot->primary_key_;
// no valid search result for this nq, break to next
if (pk == INVALID_PK) {
break;
}
// remove duplicates
if (pk_set.count(pk) == 0) {
pilot.search_result_->result_offsets_.push_back(offset++);
final_search_records_[index][qi].push_back(pilot.offset_);
pk_set.insert(pk);
if (pk_set_.count(pk) == 0) {
pilot->search_result_->result_offsets_.push_back(offset++);
final_search_records_[index][qi].push_back(pilot->offset_);
pk_set_.insert(pk);
} else {
// skip entity with same primary key
dup_cnt++;
}
pilot.reset();
pilot->advance();
heap_.push(pilot);
}
return dup_cnt;
}
@ -218,9 +225,9 @@ ReduceHelper::ReduceResultData() {
auto nq_end = slice_nqs_prefix_sum_[slice_index + 1];
// reduce search results
int64_t result_offset = 0;
int64_t offset = 0;
for (int64_t qi = nq_begin; qi < nq_end; qi++) {
skip_dup_cnt += ReduceSearchResultForOneNQ(qi, slice_topKs_[slice_index], result_offset);
skip_dup_cnt += ReduceSearchResultForOneNQ(qi, slice_topKs_[slice_index], offset);
}
}
if (skip_dup_cnt > 0) {

View File

@ -15,11 +15,13 @@
#include <cstdint>
#include <memory>
#include <vector>
#include <queue>
#include "utils/Status.h"
#include "common/type_c.h"
#include "common/QueryResult.h"
#include "query/PlanImpl.h"
#include "ReduceStructure.h"
namespace milvus::segcore {
@ -95,6 +97,12 @@ class ReduceHelper {
// output
std::unique_ptr<SearchResultDataBlobs> search_result_data_blobs_;
// Used for merge results,
// define these here to avoid allocating them for each query
std::vector<SearchResultPair> pairs_;
std::priority_queue<SearchResultPair*, std::vector<SearchResultPair*>, SearchResultPairComparator> heap_;
std::unordered_set<milvus::PkType> pk_set_;
};
} // namespace milvus::segcore

View File

@ -9,12 +9,13 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <limits>
#include "common/Consts.h"
#include "common/Types.h"
#include "common/QueryResult.h"
#include "segcore/Reduce.h"
using milvus::SearchResult;
@ -38,19 +39,11 @@ struct SearchResultPair {
bool
operator>(const SearchResultPair& other) const {
if (this->primary_key_ == INVALID_PK) {
return false;
} else {
if (other.primary_key_ == INVALID_PK) {
return true;
} else {
return (distance_ > other.distance_);
}
}
return distance_ > other.distance_;
}
void
reset() {
advance() {
offset_++;
if (offset_ < offset_rb_) {
primary_key_ = search_result_->primary_keys_.at(offset_);
@ -61,3 +54,10 @@ struct SearchResultPair {
}
}
};
struct SearchResultPairComparator {
bool
operator()(const SearchResultPair* lhs, const SearchResultPair* rhs) const {
return *lhs > *rhs;
}
};

View File

@ -15,7 +15,6 @@
#include "common/QueryResult.h"
#include "exceptions/EasyAssert.h"
#include "query/Plan.h"
#include "segcore/ReduceStructure.h"
#include "segcore/reduce_c.h"
#include "segcore/Utils.h"

File diff suppressed because it is too large Load Diff

View File

@ -15,19 +15,11 @@
#include "segcore/ReduceStructure.h"
TEST(SearchResultPair, Greater) {
auto pair1 = SearchResultPair(0, 1.0, nullptr, 0, 0, 10);
auto pair1 = SearchResultPair(0, 1.0, nullptr, 0, 0, 1);
auto pair2 = SearchResultPair(1, 2.0, nullptr, 1, 0, 10);
ASSERT_EQ(pair1 > pair2, false);
pair1.primary_key_ = INVALID_PK;
pair2.primary_key_ = 1;
ASSERT_EQ(pair1 > pair2, false);
pair1.primary_key_ = 0;
pair2.primary_key_ = INVALID_PK;
pair1.advance();
ASSERT_EQ(pair1 > pair2, true);
pair1.primary_key_ = INVALID_PK;
pair2.primary_key_ = INVALID_PK;
ASSERT_EQ(pair1 > pair2, false);
ASSERT_EQ(pair1.primary_key_, INVALID_PK);
}