yah01 836773f1a0
Optimize performance of reducing segments (#21722)
- Improve the performance of reducing from O(knlogn) to O(nlogk)

Signed-off-by: yah01 <yang.cen@zilliz.com>
2023-01-16 15:25:42 +08:00

109 lines
2.9 KiB
C++

// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <algorithm>
#include <cstdint>
#include <memory>
#include <vector>
#include <queue>
#include "utils/Status.h"
#include "common/type_c.h"
#include "common/QueryResult.h"
#include "query/PlanImpl.h"
#include "ReduceStructure.h"
namespace milvus::segcore {
// SearchResultDataBlobs contains the marshal blobs of many `milvus::proto::schema::SearchResultData`
struct SearchResultDataBlobs {
std::vector<std::vector<char>> blobs;
};
class ReduceHelper {
public:
explicit ReduceHelper(std::vector<SearchResult*>& search_results,
milvus::query::Plan* plan,
int64_t* slice_nqs,
int64_t* slice_topKs,
int64_t slice_num)
: search_results_(search_results),
plan_(plan),
slice_nqs_(slice_nqs, slice_nqs + slice_num),
slice_topKs_(slice_topKs, slice_topKs + slice_num) {
Initialize();
}
void
Reduce();
void
Marshal();
void*
GetSearchResultDataBlobs() {
return search_result_data_blobs_.release();
}
private:
void
Initialize();
void
FilterInvalidSearchResult(SearchResult* search_result);
void
FillPrimaryKey();
void
RefreshSearchResult();
void
FillEntryData();
int64_t
ReduceSearchResultForOneNQ(int64_t qi, int64_t topk, int64_t& result_offset);
void
ReduceResultData();
std::vector<char>
GetSearchResultDataSlice(int slice_index_);
private:
std::vector<int64_t> slice_topKs_;
std::vector<int64_t> slice_nqs_;
int64_t total_nq_;
int64_t num_segments_;
int64_t num_slices_;
milvus::query::Plan* plan_;
std::vector<SearchResult*>& search_results_;
std::vector<int64_t> slice_nqs_prefix_sum_;
// dim0: num_segments_; dim1: total_nq_; dim2: offset
std::vector<std::vector<std::vector<int64_t>>> final_search_records_;
// output
std::unique_ptr<SearchResultDataBlobs> search_result_data_blobs_;
// Used for merge results,
// define these here to avoid allocating them for each query
std::vector<SearchResultPair> pairs_;
std::priority_queue<SearchResultPair*, std::vector<SearchResultPair*>, SearchResultPairComparator> heap_;
std::unordered_set<milvus::PkType> pk_set_;
};
} // namespace milvus::segcore