mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
362 lines
12 KiB
C++
362 lines
12 KiB
C++
// Copyright (C) 2019-2024 Zilliz. All rights reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
|
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
|
// or implied. See the License for the specific language governing permissions and limitations under the License
|
|
|
|
#include "query/CachedSearchIterator.h"
|
|
#include "query/SearchBruteForce.h"
|
|
#include <algorithm>
|
|
|
|
namespace milvus::query {
|
|
|
|
CachedSearchIterator::CachedSearchIterator(
|
|
const milvus::index::VectorIndex& index,
|
|
const knowhere::DataSetPtr& query_ds,
|
|
const SearchInfo& search_info,
|
|
const BitsetView& bitset) {
|
|
if (query_ds == nullptr) {
|
|
PanicInfo(ErrorCode::UnexpectedError,
|
|
"Query dataset is nullptr, cannot initialize iterator");
|
|
}
|
|
nq_ = query_ds->GetRows();
|
|
Init(search_info);
|
|
|
|
auto search_json = index.PrepareSearchParams(search_info);
|
|
index::CheckAndUpdateKnowhereRangeSearchParam(
|
|
search_info, batch_size_, index.GetMetricType(), search_json);
|
|
|
|
auto expected_iterators =
|
|
index.VectorIterators(query_ds, search_json, bitset);
|
|
if (expected_iterators.has_value()) {
|
|
iterators_ = std::move(expected_iterators.value());
|
|
} else {
|
|
PanicInfo(ErrorCode::UnexpectedError,
|
|
"Failed to create iterators from index");
|
|
}
|
|
}
|
|
|
|
CachedSearchIterator::CachedSearchIterator(
|
|
const dataset::SearchDataset& query_ds,
|
|
const dataset::RawDataset& raw_ds,
|
|
const SearchInfo& search_info,
|
|
const std::map<std::string, std::string>& index_info,
|
|
const BitsetView& bitset,
|
|
const milvus::DataType& data_type) {
|
|
nq_ = query_ds.num_queries;
|
|
Init(search_info);
|
|
|
|
auto expected_iterators = GetBruteForceSearchIterators(
|
|
query_ds, raw_ds, search_info, index_info, bitset, data_type);
|
|
if (expected_iterators.has_value()) {
|
|
iterators_ = std::move(expected_iterators.value());
|
|
} else {
|
|
PanicInfo(ErrorCode::UnexpectedError,
|
|
"Failed to create iterators from index");
|
|
}
|
|
}
|
|
|
|
void
|
|
CachedSearchIterator::InitializeChunkedIterators(
|
|
const dataset::SearchDataset& query_ds,
|
|
const SearchInfo& search_info,
|
|
const std::map<std::string, std::string>& index_info,
|
|
const BitsetView& bitset,
|
|
const milvus::DataType& data_type,
|
|
const GetChunkDataFunc& get_chunk_data) {
|
|
int64_t offset = 0;
|
|
chunked_heaps_.resize(nq_);
|
|
for (int64_t chunk_id = 0; chunk_id < num_chunks_; ++chunk_id) {
|
|
auto [chunk_data, chunk_size] = get_chunk_data(chunk_id);
|
|
auto sub_data = query::dataset::RawDataset{
|
|
offset, query_ds.dim, chunk_size, chunk_data};
|
|
|
|
auto expected_iterators = GetBruteForceSearchIterators(
|
|
query_ds, sub_data, search_info, index_info, bitset, data_type);
|
|
if (expected_iterators.has_value()) {
|
|
auto& chunk_iterators = expected_iterators.value();
|
|
iterators_.insert(iterators_.end(),
|
|
std::make_move_iterator(chunk_iterators.begin()),
|
|
std::make_move_iterator(chunk_iterators.end()));
|
|
} else {
|
|
PanicInfo(ErrorCode::UnexpectedError,
|
|
"Failed to create iterators from index");
|
|
}
|
|
offset += chunk_size;
|
|
}
|
|
}
|
|
|
|
CachedSearchIterator::CachedSearchIterator(
|
|
const dataset::SearchDataset& query_ds,
|
|
const segcore::VectorBase* vec_data,
|
|
const int64_t row_count,
|
|
const SearchInfo& search_info,
|
|
const std::map<std::string, std::string>& index_info,
|
|
const BitsetView& bitset,
|
|
const milvus::DataType& data_type) {
|
|
if (vec_data == nullptr) {
|
|
PanicInfo(ErrorCode::UnexpectedError,
|
|
"Vector data is nullptr, cannot initialize iterator");
|
|
}
|
|
|
|
if (row_count <= 0) {
|
|
PanicInfo(ErrorCode::UnexpectedError,
|
|
"Number of rows is 0, cannot initialize iterator");
|
|
}
|
|
|
|
const int64_t vec_size_per_chunk = vec_data->get_size_per_chunk();
|
|
num_chunks_ = upper_div(row_count, vec_size_per_chunk);
|
|
nq_ = query_ds.num_queries;
|
|
Init(search_info);
|
|
|
|
iterators_.reserve(nq_ * num_chunks_);
|
|
InitializeChunkedIterators(
|
|
query_ds,
|
|
search_info,
|
|
index_info,
|
|
bitset,
|
|
data_type,
|
|
[&vec_data, vec_size_per_chunk, row_count](
|
|
int64_t chunk_id) -> std::pair<const void*, int64_t> {
|
|
const auto chunk_data = vec_data->get_chunk_data(chunk_id);
|
|
int64_t chunk_size = std::min(
|
|
vec_size_per_chunk, row_count - chunk_id * vec_size_per_chunk);
|
|
return {chunk_data, chunk_size};
|
|
});
|
|
}
|
|
|
|
CachedSearchIterator::CachedSearchIterator(
|
|
const std::shared_ptr<ChunkedColumnBase>& column,
|
|
const dataset::SearchDataset& query_ds,
|
|
const SearchInfo& search_info,
|
|
const std::map<std::string, std::string>& index_info,
|
|
const BitsetView& bitset,
|
|
const milvus::DataType& data_type) {
|
|
if (column == nullptr) {
|
|
PanicInfo(ErrorCode::UnexpectedError,
|
|
"Column is nullptr, cannot initialize iterator");
|
|
}
|
|
|
|
num_chunks_ = column->num_chunks();
|
|
nq_ = query_ds.num_queries;
|
|
Init(search_info);
|
|
|
|
iterators_.reserve(nq_ * num_chunks_);
|
|
InitializeChunkedIterators(
|
|
query_ds,
|
|
search_info,
|
|
index_info,
|
|
bitset,
|
|
data_type,
|
|
[&column](int64_t chunk_id) {
|
|
const char* chunk_data = column->Data(chunk_id);
|
|
int64_t chunk_size = column->chunk_row_nums(chunk_id);
|
|
return std::make_pair(static_cast<const void*>(chunk_data),
|
|
chunk_size);
|
|
});
|
|
}
|
|
|
|
void
|
|
CachedSearchIterator::NextBatch(const SearchInfo& search_info,
|
|
SearchResult& search_result) {
|
|
if (iterators_.empty()) {
|
|
return;
|
|
}
|
|
|
|
if (iterators_.size() != nq_ * num_chunks_) {
|
|
PanicInfo(ErrorCode::UnexpectedError,
|
|
"Iterator size mismatch, expect %d, but got %d",
|
|
nq_ * num_chunks_,
|
|
iterators_.size());
|
|
}
|
|
|
|
ValidateSearchInfo(search_info);
|
|
|
|
search_result.total_nq_ = nq_;
|
|
search_result.unity_topK_ = batch_size_;
|
|
search_result.seg_offsets_.resize(nq_ * batch_size_);
|
|
search_result.distances_.resize(nq_ * batch_size_);
|
|
|
|
for (size_t query_idx = 0; query_idx < nq_; ++query_idx) {
|
|
auto rst = GetBatchedNextResults(query_idx, search_info);
|
|
WriteSingleQuerySearchResult(
|
|
search_result, query_idx, rst, search_info.round_decimal_);
|
|
}
|
|
}
|
|
|
|
void
|
|
CachedSearchIterator::ValidateSearchInfo(const SearchInfo& search_info) {
|
|
if (!search_info.iterator_v2_info_.has_value()) {
|
|
PanicInfo(ErrorCode::UnexpectedError,
|
|
"Iterator v2 SearchInfo is not set");
|
|
}
|
|
|
|
auto iterator_v2_info = search_info.iterator_v2_info_.value();
|
|
if (iterator_v2_info.batch_size != batch_size_) {
|
|
PanicInfo(ErrorCode::UnexpectedError,
|
|
"Batch size mismatch, expect %d, but got %d",
|
|
batch_size_,
|
|
iterator_v2_info.batch_size);
|
|
}
|
|
}
|
|
|
|
std::optional<CachedSearchIterator::DisIdPair>
|
|
CachedSearchIterator::GetNextValidResult(
|
|
const size_t iterator_idx,
|
|
const std::optional<float>& last_bound,
|
|
const std::optional<float>& radius,
|
|
const std::optional<float>& range_filter) {
|
|
auto& iterator = iterators_[iterator_idx];
|
|
while (iterator->HasNext()) {
|
|
auto result = ConvertIteratorResult(iterator->Next());
|
|
if (IsValid(result, last_bound, radius, range_filter)) {
|
|
return result;
|
|
}
|
|
}
|
|
return std::nullopt;
|
|
}
|
|
|
|
// TODO: Optimize this method
|
|
void
|
|
CachedSearchIterator::MergeChunksResults(
|
|
size_t query_idx,
|
|
const std::optional<float>& last_bound,
|
|
const std::optional<float>& radius,
|
|
const std::optional<float>& range_filter,
|
|
std::vector<DisIdPair>& rst) {
|
|
auto& heap = chunked_heaps_[query_idx];
|
|
|
|
if (heap.empty()) {
|
|
for (size_t chunk_id = 0; chunk_id < num_chunks_; ++chunk_id) {
|
|
const size_t iterator_idx = query_idx + chunk_id * nq_;
|
|
if (auto next_result = GetNextValidResult(
|
|
iterator_idx, last_bound, radius, range_filter);
|
|
next_result.has_value()) {
|
|
heap.emplace(iterator_idx, next_result.value());
|
|
}
|
|
}
|
|
}
|
|
|
|
while (!heap.empty() && rst.size() < batch_size_) {
|
|
const auto [iterator_idx, cur_rst] = heap.top();
|
|
heap.pop();
|
|
|
|
// last_bound may change between NextBatch calls, discard any invalid results
|
|
if (!IsValid(cur_rst, last_bound, radius, range_filter)) {
|
|
continue;
|
|
}
|
|
rst.emplace_back(cur_rst);
|
|
|
|
if (auto next_result = GetNextValidResult(
|
|
iterator_idx, last_bound, radius, range_filter);
|
|
next_result.has_value()) {
|
|
heap.emplace(iterator_idx, next_result.value());
|
|
}
|
|
}
|
|
}
|
|
|
|
std::vector<CachedSearchIterator::DisIdPair>
|
|
CachedSearchIterator::GetBatchedNextResults(size_t query_idx,
|
|
const SearchInfo& search_info) {
|
|
auto last_bound = ConvertIncomingDistance(
|
|
search_info.iterator_v2_info_.value().last_bound);
|
|
auto radius = ConvertIncomingDistance(
|
|
index::GetValueFromConfig<float>(search_info.search_params_, RADIUS));
|
|
auto range_filter =
|
|
ConvertIncomingDistance(index::GetValueFromConfig<float>(
|
|
search_info.search_params_, RANGE_FILTER));
|
|
|
|
std::vector<DisIdPair> rst;
|
|
rst.reserve(batch_size_);
|
|
|
|
if (num_chunks_ == 1) {
|
|
auto& iterator = iterators_[query_idx];
|
|
while (iterator->HasNext() && rst.size() < batch_size_) {
|
|
auto result = ConvertIteratorResult(iterator->Next());
|
|
if (IsValid(result, last_bound, radius, range_filter)) {
|
|
rst.emplace_back(result);
|
|
}
|
|
}
|
|
} else {
|
|
MergeChunksResults(query_idx, last_bound, radius, range_filter, rst);
|
|
}
|
|
std::sort(rst.begin(), rst.end());
|
|
if (sign_ == -1) {
|
|
std::for_each(rst.begin(), rst.end(), [this](DisIdPair& x) {
|
|
x.first = x.first * sign_;
|
|
});
|
|
}
|
|
while (rst.size() < batch_size_) {
|
|
rst.emplace_back(1.0f / 0.0f, -1);
|
|
}
|
|
return rst;
|
|
}
|
|
|
|
void
|
|
CachedSearchIterator::WriteSingleQuerySearchResult(
|
|
SearchResult& search_result,
|
|
const size_t idx,
|
|
std::vector<DisIdPair>& rst,
|
|
const int64_t round_decimal) {
|
|
const float multiplier = pow(10.0, round_decimal);
|
|
|
|
std::transform(rst.begin(),
|
|
rst.end(),
|
|
search_result.distances_.begin() + idx * batch_size_,
|
|
[multiplier, round_decimal](DisIdPair& x) {
|
|
if (round_decimal != -1) {
|
|
x.first =
|
|
std::round(x.first * multiplier) / multiplier;
|
|
}
|
|
return x.first;
|
|
});
|
|
|
|
std::transform(rst.begin(),
|
|
rst.end(),
|
|
search_result.seg_offsets_.begin() + idx * batch_size_,
|
|
[](const DisIdPair& x) { return x.second; });
|
|
}
|
|
|
|
void
|
|
CachedSearchIterator::Init(const SearchInfo& search_info) {
|
|
if (!search_info.iterator_v2_info_.has_value()) {
|
|
PanicInfo(ErrorCode::UnexpectedError,
|
|
"Iterator v2 info is not set, cannot initialize iterator");
|
|
}
|
|
|
|
auto iterator_v2_info = search_info.iterator_v2_info_.value();
|
|
if (iterator_v2_info.batch_size == 0) {
|
|
PanicInfo(ErrorCode::UnexpectedError,
|
|
"Batch size is 0, cannot initialize iterator");
|
|
}
|
|
batch_size_ = iterator_v2_info.batch_size;
|
|
|
|
if (search_info.metric_type_.empty()) {
|
|
PanicInfo(ErrorCode::UnexpectedError,
|
|
"Metric type is empty, cannot initialize iterator");
|
|
}
|
|
if (PositivelyRelated(search_info.metric_type_)) {
|
|
sign_ = -1;
|
|
} else {
|
|
sign_ = 1;
|
|
}
|
|
|
|
if (nq_ == 0) {
|
|
PanicInfo(ErrorCode::UnexpectedError,
|
|
"Number of queries is 0, cannot initialize iterator");
|
|
}
|
|
|
|
// disable multi-query for now
|
|
if (nq_ > 1) {
|
|
PanicInfo(
|
|
ErrorCode::UnexpectedError,
|
|
"Number of queries is greater than 1, cannot initialize iterator");
|
|
}
|
|
}
|
|
|
|
} // namespace milvus::query
|