enhance: speed up search iterator stage 1 (#37947)

issue: #37548

Signed-off-by: Patrick Weizhi Xu <weizhi.xu@zilliz.com>
This commit is contained in:
Patrick Weizhi Xu 2024-12-26 10:32:49 +08:00 committed by GitHub
parent f49d618382
commit 85f462be1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 1970 additions and 117 deletions

2
go.mod
View File

@ -64,6 +64,7 @@ require (
github.com/cenkalti/backoff/v4 v4.2.1 github.com/cenkalti/backoff/v4 v4.2.1
github.com/cockroachdb/redact v1.1.3 github.com/cockroachdb/redact v1.1.3
github.com/goccy/go-json v0.10.3 github.com/goccy/go-json v0.10.3
github.com/google/uuid v1.6.0
github.com/greatroar/blobloom v0.0.0-00010101000000-000000000000 github.com/greatroar/blobloom v0.0.0-00010101000000-000000000000
github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/hashicorp/golang-lru/v2 v2.0.7
github.com/jolestar/go-commons-pool/v2 v2.1.2 github.com/jolestar/go-commons-pool/v2 v2.1.2
@ -144,7 +145,6 @@ require (
github.com/golang/snappy v0.0.4 // indirect github.com/golang/snappy v0.0.4 // indirect
github.com/google/flatbuffers v2.0.8+incompatible // indirect github.com/google/flatbuffers v2.0.8+incompatible // indirect
github.com/google/s2a-go v0.1.7 // indirect github.com/google/s2a-go v0.1.7 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.5 // indirect github.com/googleapis/gax-go/v2 v2.12.5 // indirect
github.com/gorilla/websocket v1.4.2 // indirect github.com/gorilla/websocket v1.4.2 // indirect

View File

@ -24,6 +24,12 @@
namespace milvus { namespace milvus {
struct SearchIteratorV2Info {
std::string token = "";
uint32_t batch_size = 0;
std::optional<float> last_bound = std::nullopt;
};
struct SearchInfo { struct SearchInfo {
int64_t topk_{0}; int64_t topk_{0};
int64_t group_size_{1}; int64_t group_size_{1};
@ -36,6 +42,7 @@ struct SearchInfo {
tracer::TraceContext trace_ctx_; tracer::TraceContext trace_ctx_;
bool materialized_view_involved = false; bool materialized_view_involved = false;
bool iterative_filter_execution = false; bool iterative_filter_execution = false;
std::optional<SearchIteratorV2Info> iterator_v2_info_ = std::nullopt;
}; };
using SearchInfoPtr = std::shared_ptr<SearchInfo>; using SearchInfoPtr = std::shared_ptr<SearchInfo>;

View File

@ -362,4 +362,38 @@ ReadDataFromFD(int fd, void* buf, size_t size, size_t chunk_size) {
} }
} }
bool
CheckAndUpdateKnowhereRangeSearchParam(const SearchInfo& search_info,
const int64_t topk,
const MetricType& metric_type,
knowhere::Json& search_config) {
const auto radius =
index::GetValueFromConfig<float>(search_info.search_params_, RADIUS);
if (!radius.has_value()) {
return false;
}
search_config[RADIUS] = radius.value();
// `range_search_k` is only used as one of the conditions for iterator early termination.
// not gurantee to return exactly `range_search_k` results, which may be more or less.
// set it to -1 will return all results in the range.
search_config[knowhere::meta::RANGE_SEARCH_K] = topk;
const auto range_filter =
GetValueFromConfig<float>(search_info.search_params_, RANGE_FILTER);
if (range_filter.has_value()) {
search_config[RANGE_FILTER] = range_filter.value();
CheckRangeSearchParam(
search_config[RADIUS], search_config[RANGE_FILTER], metric_type);
}
const auto page_retain_order =
GetValueFromConfig<bool>(search_info.search_params_, PAGE_RETAIN_ORDER);
if (page_retain_order.has_value()) {
search_config[knowhere::meta::RETAIN_ITERATOR_ORDER] =
page_retain_order.value();
}
return true;
}
} // namespace milvus::index } // namespace milvus::index

View File

@ -30,6 +30,8 @@
#include "common/Types.h" #include "common/Types.h"
#include "common/FieldData.h" #include "common/FieldData.h"
#include "common/QueryInfo.h"
#include "common/RangeSearchHelper.h"
#include "index/IndexInfo.h" #include "index/IndexInfo.h"
#include "storage/Types.h" #include "storage/Types.h"
@ -147,4 +149,10 @@ AssembleIndexDatas(std::map<std::string, FieldDataChannelPtr>& index_datas,
void void
ReadDataFromFD(int fd, void* buf, size_t size, size_t chunk_size = 0x7ffff000); ReadDataFromFD(int fd, void* buf, size_t size, size_t chunk_size = 0x7ffff000);
bool
CheckAndUpdateKnowhereRangeSearchParam(const SearchInfo& search_info,
const int64_t topk,
const MetricType& metric_type,
knowhere::Json& search_config);
} // namespace milvus::index } // namespace milvus::index

View File

@ -266,32 +266,9 @@ VectorDiskAnnIndex<T>::Query(const DatasetPtr dataset,
search_config[DISK_ANN_PREFIX_PATH] = local_index_path_prefix; search_config[DISK_ANN_PREFIX_PATH] = local_index_path_prefix;
auto final = [&] { auto final = [&] {
auto radius = if (CheckAndUpdateKnowhereRangeSearchParam(
GetValueFromConfig<float>(search_info.search_params_, RADIUS); search_info, topk, GetMetricType(), search_config)) {
if (radius.has_value()) {
search_config[RADIUS] = radius.value();
// `range_search_k` is only used as one of the conditions for iterator early termination.
// not gurantee to return exactly `range_search_k` results, which may be more or less.
// set it to -1 will return all results in the range.
search_config[knowhere::meta::RANGE_SEARCH_K] = topk;
auto range_filter = GetValueFromConfig<float>(
search_info.search_params_, RANGE_FILTER);
if (range_filter.has_value()) {
search_config[RANGE_FILTER] = range_filter.value();
CheckRangeSearchParam(search_config[RADIUS],
search_config[RANGE_FILTER],
GetMetricType());
}
auto page_retain_order = GetValueFromConfig<bool>(
search_info.search_params_, PAGE_RETAIN_ORDER);
if (page_retain_order.has_value()) {
search_config[knowhere::meta::RETAIN_ITERATOR_ORDER] =
page_retain_order.value();
}
auto res = index_.RangeSearch(dataset, search_config, bitset); auto res = index_.RangeSearch(dataset, search_config, bitset);
if (!res.has_value()) { if (!res.has_value()) {
PanicInfo(ErrorCode::UnexpectedError, PanicInfo(ErrorCode::UnexpectedError,
fmt::format("failed to range search: {}: {}", fmt::format("failed to range search: {}: {}",

View File

@ -380,16 +380,8 @@ VectorMemIndex<T>::Query(const DatasetPtr dataset,
// TODO :: check dim of search data // TODO :: check dim of search data
auto final = [&] { auto final = [&] {
auto index_type = GetIndexType(); auto index_type = GetIndexType();
if (CheckKeyInConfig(search_conf, RADIUS)) { if (CheckAndUpdateKnowhereRangeSearchParam(
if (CheckKeyInConfig(search_conf, RANGE_FILTER)) { search_info, topk, GetMetricType(), search_conf)) {
CheckRangeSearchParam(search_conf[RADIUS],
search_conf[RANGE_FILTER],
GetMetricType());
}
// `range_search_k` is only used as one of the conditions for iterator early termination.
// not gurantee to return exactly `range_search_k` results, which may be more or less.
// set it to -1 will return all results in the range.
search_conf[knowhere::meta::RANGE_SEARCH_K] = topk;
milvus::tracer::AddEvent("start_knowhere_index_range_search"); milvus::tracer::AddEvent("start_knowhere_index_range_search");
auto res = index_.RangeSearch(dataset, search_conf, bitset); auto res = index_.RangeSearch(dataset, search_conf, bitset);
milvus::tracer::AddEvent("finish_knowhere_index_range_search"); milvus::tracer::AddEvent("finish_knowhere_index_range_search");

View File

@ -0,0 +1,362 @@
// Copyright (C) 2019-2024 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include "query/CachedSearchIterator.h"
#include "query/SearchBruteForce.h"
#include <algorithm>
namespace milvus::query {
CachedSearchIterator::CachedSearchIterator(
const milvus::index::VectorIndex& index,
const knowhere::DataSetPtr& query_ds,
const SearchInfo& search_info,
const BitsetView& bitset) {
if (query_ds == nullptr) {
PanicInfo(ErrorCode::UnexpectedError,
"Query dataset is nullptr, cannot initialize iterator");
}
nq_ = query_ds->GetRows();
Init(search_info);
auto search_json = index.PrepareSearchParams(search_info);
index::CheckAndUpdateKnowhereRangeSearchParam(
search_info, batch_size_, index.GetMetricType(), search_json);
auto expected_iterators =
index.VectorIterators(query_ds, search_json, bitset);
if (expected_iterators.has_value()) {
iterators_ = std::move(expected_iterators.value());
} else {
PanicInfo(ErrorCode::UnexpectedError,
"Failed to create iterators from index");
}
}
CachedSearchIterator::CachedSearchIterator(
const dataset::SearchDataset& query_ds,
const dataset::RawDataset& raw_ds,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
const milvus::DataType& data_type) {
nq_ = query_ds.num_queries;
Init(search_info);
auto expected_iterators = GetBruteForceSearchIterators(
query_ds, raw_ds, search_info, index_info, bitset, data_type);
if (expected_iterators.has_value()) {
iterators_ = std::move(expected_iterators.value());
} else {
PanicInfo(ErrorCode::UnexpectedError,
"Failed to create iterators from index");
}
}
void
CachedSearchIterator::InitializeChunkedIterators(
const dataset::SearchDataset& query_ds,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
const milvus::DataType& data_type,
const GetChunkDataFunc& get_chunk_data) {
int64_t offset = 0;
chunked_heaps_.resize(nq_);
for (int64_t chunk_id = 0; chunk_id < num_chunks_; ++chunk_id) {
auto [chunk_data, chunk_size] = get_chunk_data(chunk_id);
auto sub_data = query::dataset::RawDataset{
offset, query_ds.dim, chunk_size, chunk_data};
auto expected_iterators = GetBruteForceSearchIterators(
query_ds, sub_data, search_info, index_info, bitset, data_type);
if (expected_iterators.has_value()) {
auto& chunk_iterators = expected_iterators.value();
iterators_.insert(iterators_.end(),
std::make_move_iterator(chunk_iterators.begin()),
std::make_move_iterator(chunk_iterators.end()));
} else {
PanicInfo(ErrorCode::UnexpectedError,
"Failed to create iterators from index");
}
offset += chunk_size;
}
}
CachedSearchIterator::CachedSearchIterator(
const dataset::SearchDataset& query_ds,
const segcore::VectorBase* vec_data,
const int64_t row_count,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
const milvus::DataType& data_type) {
if (vec_data == nullptr) {
PanicInfo(ErrorCode::UnexpectedError,
"Vector data is nullptr, cannot initialize iterator");
}
if (row_count <= 0) {
PanicInfo(ErrorCode::UnexpectedError,
"Number of rows is 0, cannot initialize iterator");
}
const int64_t vec_size_per_chunk = vec_data->get_size_per_chunk();
num_chunks_ = upper_div(row_count, vec_size_per_chunk);
nq_ = query_ds.num_queries;
Init(search_info);
iterators_.reserve(nq_ * num_chunks_);
InitializeChunkedIterators(
query_ds,
search_info,
index_info,
bitset,
data_type,
[&vec_data, vec_size_per_chunk, row_count](
int64_t chunk_id) -> std::pair<const void*, int64_t> {
const auto chunk_data = vec_data->get_chunk_data(chunk_id);
int64_t chunk_size = std::min(
vec_size_per_chunk, row_count - chunk_id * vec_size_per_chunk);
return {chunk_data, chunk_size};
});
}
CachedSearchIterator::CachedSearchIterator(
const std::shared_ptr<ChunkedColumnBase>& column,
const dataset::SearchDataset& query_ds,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
const milvus::DataType& data_type) {
if (column == nullptr) {
PanicInfo(ErrorCode::UnexpectedError,
"Column is nullptr, cannot initialize iterator");
}
num_chunks_ = column->num_chunks();
nq_ = query_ds.num_queries;
Init(search_info);
iterators_.reserve(nq_ * num_chunks_);
InitializeChunkedIterators(
query_ds,
search_info,
index_info,
bitset,
data_type,
[&column](int64_t chunk_id) {
const char* chunk_data = column->Data(chunk_id);
int64_t chunk_size = column->chunk_row_nums(chunk_id);
return std::make_pair(static_cast<const void*>(chunk_data),
chunk_size);
});
}
void
CachedSearchIterator::NextBatch(const SearchInfo& search_info,
SearchResult& search_result) {
if (iterators_.empty()) {
return;
}
if (iterators_.size() != nq_ * num_chunks_) {
PanicInfo(ErrorCode::UnexpectedError,
"Iterator size mismatch, expect %d, but got %d",
nq_ * num_chunks_,
iterators_.size());
}
ValidateSearchInfo(search_info);
search_result.total_nq_ = nq_;
search_result.unity_topK_ = batch_size_;
search_result.seg_offsets_.resize(nq_ * batch_size_);
search_result.distances_.resize(nq_ * batch_size_);
for (size_t query_idx = 0; query_idx < nq_; ++query_idx) {
auto rst = GetBatchedNextResults(query_idx, search_info);
WriteSingleQuerySearchResult(
search_result, query_idx, rst, search_info.round_decimal_);
}
}
void
CachedSearchIterator::ValidateSearchInfo(const SearchInfo& search_info) {
if (!search_info.iterator_v2_info_.has_value()) {
PanicInfo(ErrorCode::UnexpectedError,
"Iterator v2 SearchInfo is not set");
}
auto iterator_v2_info = search_info.iterator_v2_info_.value();
if (iterator_v2_info.batch_size != batch_size_) {
PanicInfo(ErrorCode::UnexpectedError,
"Batch size mismatch, expect %d, but got %d",
batch_size_,
iterator_v2_info.batch_size);
}
}
std::optional<CachedSearchIterator::DisIdPair>
CachedSearchIterator::GetNextValidResult(
const size_t iterator_idx,
const std::optional<float>& last_bound,
const std::optional<float>& radius,
const std::optional<float>& range_filter) {
auto& iterator = iterators_[iterator_idx];
while (iterator->HasNext()) {
auto result = ConvertIteratorResult(iterator->Next());
if (IsValid(result, last_bound, radius, range_filter)) {
return result;
}
}
return std::nullopt;
}
// TODO: Optimize this method
void
CachedSearchIterator::MergeChunksResults(
size_t query_idx,
const std::optional<float>& last_bound,
const std::optional<float>& radius,
const std::optional<float>& range_filter,
std::vector<DisIdPair>& rst) {
auto& heap = chunked_heaps_[query_idx];
if (heap.empty()) {
for (size_t chunk_id = 0; chunk_id < num_chunks_; ++chunk_id) {
const size_t iterator_idx = query_idx + chunk_id * nq_;
if (auto next_result = GetNextValidResult(
iterator_idx, last_bound, radius, range_filter);
next_result.has_value()) {
heap.emplace(iterator_idx, next_result.value());
}
}
}
while (!heap.empty() && rst.size() < batch_size_) {
const auto [iterator_idx, cur_rst] = heap.top();
heap.pop();
// last_bound may change between NextBatch calls, discard any invalid results
if (!IsValid(cur_rst, last_bound, radius, range_filter)) {
continue;
}
rst.emplace_back(cur_rst);
if (auto next_result = GetNextValidResult(
iterator_idx, last_bound, radius, range_filter);
next_result.has_value()) {
heap.emplace(iterator_idx, next_result.value());
}
}
}
std::vector<CachedSearchIterator::DisIdPair>
CachedSearchIterator::GetBatchedNextResults(size_t query_idx,
const SearchInfo& search_info) {
auto last_bound = ConvertIncomingDistance(
search_info.iterator_v2_info_.value().last_bound);
auto radius = ConvertIncomingDistance(
index::GetValueFromConfig<float>(search_info.search_params_, RADIUS));
auto range_filter =
ConvertIncomingDistance(index::GetValueFromConfig<float>(
search_info.search_params_, RANGE_FILTER));
std::vector<DisIdPair> rst;
rst.reserve(batch_size_);
if (num_chunks_ == 1) {
auto& iterator = iterators_[query_idx];
while (iterator->HasNext() && rst.size() < batch_size_) {
auto result = ConvertIteratorResult(iterator->Next());
if (IsValid(result, last_bound, radius, range_filter)) {
rst.emplace_back(result);
}
}
} else {
MergeChunksResults(query_idx, last_bound, radius, range_filter, rst);
}
std::sort(rst.begin(), rst.end());
if (sign_ == -1) {
std::for_each(rst.begin(), rst.end(), [this](DisIdPair& x) {
x.first = x.first * sign_;
});
}
while (rst.size() < batch_size_) {
rst.emplace_back(1.0f / 0.0f, -1);
}
return rst;
}
void
CachedSearchIterator::WriteSingleQuerySearchResult(
SearchResult& search_result,
const size_t idx,
std::vector<DisIdPair>& rst,
const int64_t round_decimal) {
const float multiplier = pow(10.0, round_decimal);
std::transform(rst.begin(),
rst.end(),
search_result.distances_.begin() + idx * batch_size_,
[multiplier, round_decimal](DisIdPair& x) {
if (round_decimal != -1) {
x.first =
std::round(x.first * multiplier) / multiplier;
}
return x.first;
});
std::transform(rst.begin(),
rst.end(),
search_result.seg_offsets_.begin() + idx * batch_size_,
[](const DisIdPair& x) { return x.second; });
}
void
CachedSearchIterator::Init(const SearchInfo& search_info) {
if (!search_info.iterator_v2_info_.has_value()) {
PanicInfo(ErrorCode::UnexpectedError,
"Iterator v2 info is not set, cannot initialize iterator");
}
auto iterator_v2_info = search_info.iterator_v2_info_.value();
if (iterator_v2_info.batch_size == 0) {
PanicInfo(ErrorCode::UnexpectedError,
"Batch size is 0, cannot initialize iterator");
}
batch_size_ = iterator_v2_info.batch_size;
if (search_info.metric_type_.empty()) {
PanicInfo(ErrorCode::UnexpectedError,
"Metric type is empty, cannot initialize iterator");
}
if (PositivelyRelated(search_info.metric_type_)) {
sign_ = -1;
} else {
sign_ = 1;
}
if (nq_ == 0) {
PanicInfo(ErrorCode::UnexpectedError,
"Number of queries is 0, cannot initialize iterator");
}
// disable multi-query for now
if (nq_ > 1) {
PanicInfo(
ErrorCode::UnexpectedError,
"Number of queries is greater than 1, cannot initialize iterator");
}
}
} // namespace milvus::query

View File

@ -0,0 +1,182 @@
// Copyright (C) 2019-2024 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include <utility>
#include "common/BitsetView.h"
#include "common/QueryInfo.h"
#include "common/QueryResult.h"
#include "query/helper.h"
#include "segcore/ConcurrentVector.h"
#include "index/VectorIndex.h"
namespace milvus::query {
// This class is used to cache the search results from Knowhere
// search iterators and filter the results based on the last_bound,
// radius and range_filter.
// It provides a number of constructors to support different scenarios,
// including growing/sealed, chunked/non-chunked.
//
// It does not care about TopK in search_info
// The topk in SearchResult will be set to the batch_size for compatibility
//
// TODO: introduce the pool of results in the near future
// TODO: replace VectorIterator class
class CachedSearchIterator {
public:
// For sealed segment with vector index
CachedSearchIterator(const milvus::index::VectorIndex& index,
const knowhere::DataSetPtr& dataset,
const SearchInfo& search_info,
const BitsetView& bitset);
// For sealed segment, BF
CachedSearchIterator(const dataset::SearchDataset& dataset,
const dataset::RawDataset& raw_ds,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
const milvus::DataType& data_type);
// For growing segment with chunked data, BF
CachedSearchIterator(const dataset::SearchDataset& dataset,
const segcore::VectorBase* vec_data,
const int64_t row_count,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
const milvus::DataType& data_type);
// For sealed segment with chunked data, BF
CachedSearchIterator(const std::shared_ptr<ChunkedColumnBase>& column,
const dataset::SearchDataset& dataset,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
const milvus::DataType& data_type);
// This method fetches the next batch of search results based on the provided search information
// and updates the search_result object with the new batch of results.
void
NextBatch(const SearchInfo& search_info, SearchResult& search_result);
// Disable copy and move
CachedSearchIterator(const CachedSearchIterator&) = delete;
CachedSearchIterator&
operator=(const CachedSearchIterator&) = delete;
CachedSearchIterator(CachedSearchIterator&&) = delete;
CachedSearchIterator&
operator=(CachedSearchIterator&&) = delete;
private:
using DisIdPair = std::pair<float, int64_t>;
using IterIdx = size_t;
using IterIdDisIdPair = std::pair<IterIdx, DisIdPair>;
using GetChunkDataFunc =
std::function<std::pair<const void*, int64_t>(int64_t)>;
int64_t batch_size_ = 0;
std::vector<knowhere::IndexNode::IteratorPtr> iterators_;
int8_t sign_ = 1;
size_t num_chunks_ = 1;
size_t nq_ = 0;
struct IterIdDisIdPairComparator {
bool
operator()(const IterIdDisIdPair& lhs, const IterIdDisIdPair& rhs) {
if (lhs.second.first == rhs.second.first) {
return lhs.second.second > rhs.second.second;
}
return lhs.second.first > rhs.second.first;
}
};
std::vector<std::priority_queue<IterIdDisIdPair,
std::vector<IterIdDisIdPair>,
IterIdDisIdPairComparator>>
chunked_heaps_;
inline bool
IsValid(const DisIdPair& result,
const std::optional<float>& last_bound,
const std::optional<float>& radius,
const std::optional<float>& range_filter) {
const float dist = result.first;
const bool is_valid =
!last_bound.has_value() || dist > last_bound.value();
if (!radius.has_value()) {
return is_valid;
}
if (!range_filter.has_value()) {
return is_valid && dist < radius.value();
}
return is_valid && dist < radius.value() &&
dist >= range_filter.value();
}
inline DisIdPair
ConvertIteratorResult(const std::pair<int64_t, float>& iter_rst) {
DisIdPair rst;
rst.first = iter_rst.second * sign_;
rst.second = iter_rst.first;
return rst;
}
inline std::optional<float>
ConvertIncomingDistance(std::optional<float> dist) {
if (dist.has_value()) {
dist = dist.value() * sign_;
}
return dist;
}
std::optional<DisIdPair>
GetNextValidResult(size_t iterator_idx,
const std::optional<float>& last_bound,
const std::optional<float>& radius,
const std::optional<float>& range_filter);
void
MergeChunksResults(size_t query_idx,
const std::optional<float>& last_bound,
const std::optional<float>& radius,
const std::optional<float>& range_filter,
std::vector<DisIdPair>& rst);
void
ValidateSearchInfo(const SearchInfo& search_info);
std::vector<DisIdPair>
GetBatchedNextResults(size_t query_idx, const SearchInfo& search_info);
void
WriteSingleQuerySearchResult(SearchResult& search_result,
const size_t idx,
std::vector<DisIdPair>& rst,
const int64_t round_decimal);
void
Init(const SearchInfo& search_info);
void
InitializeChunkedIterators(
const dataset::SearchDataset& dataset,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
const milvus::DataType& data_type,
const GetChunkDataFunc& get_chunk_data);
};
} // namespace milvus::query

View File

@ -93,6 +93,20 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
search_info.strict_group_size_ = search_info.strict_group_size_ =
query_info_proto.strict_group_size(); query_info_proto.strict_group_size();
} }
if (query_info_proto.has_search_iterator_v2_info()) {
auto& iterator_v2_info_proto =
query_info_proto.search_iterator_v2_info();
search_info.iterator_v2_info_ = SearchIteratorV2Info{
.token = iterator_v2_info_proto.token(),
.batch_size = iterator_v2_info_proto.batch_size(),
};
if (iterator_v2_info_proto.has_last_bound()) {
search_info.iterator_v2_info_->last_bound =
iterator_v2_info_proto.last_bound();
}
}
return search_info; return search_info;
}; };

View File

@ -226,45 +226,66 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
return sub_result; return sub_result;
} }
SubSearchResult knowhere::expected<std::vector<knowhere::IndexNode::IteratorPtr>>
BruteForceSearchIterators(const dataset::SearchDataset& query_ds, DispatchBruteForceIteratorByDataType(const knowhere::DataSetPtr& base_dataset,
const dataset::RawDataset& raw_ds, const knowhere::DataSetPtr& query_dataset,
const SearchInfo& search_info, const knowhere::Json& config,
const std::map<std::string, std::string>& index_info, const BitsetView& bitset,
const BitsetView& bitset, const milvus::DataType& data_type) {
DataType data_type) {
auto nq = query_ds.num_queries;
auto [query_dataset, base_dataset] =
PrepareBFDataSet(query_ds, raw_ds, data_type);
auto search_cfg = PrepareBFSearchParams(search_info, index_info);
knowhere::expected<std::vector<knowhere::IndexNode::IteratorPtr>>
iterators_val;
switch (data_type) { switch (data_type) {
case DataType::VECTOR_FLOAT: case DataType::VECTOR_FLOAT:
iterators_val = knowhere::BruteForce::AnnIterator<float>( return knowhere::BruteForce::AnnIterator<float>(
base_dataset, query_dataset, search_cfg, bitset); base_dataset, query_dataset, config, bitset);
break; break;
case DataType::VECTOR_FLOAT16: case DataType::VECTOR_FLOAT16:
//todo: if knowhere support real fp16/bf16 bf, change it //todo: if knowhere support real fp16/bf16 bf, change it
iterators_val = knowhere::BruteForce::AnnIterator<float>( return knowhere::BruteForce::AnnIterator<float>(
base_dataset, query_dataset, search_cfg, bitset); base_dataset, query_dataset, config, bitset);
break; break;
case DataType::VECTOR_BFLOAT16: case DataType::VECTOR_BFLOAT16:
//todo: if knowhere support real fp16/bf16 bf, change it //todo: if knowhere support real fp16/bf16 bf, change it
iterators_val = knowhere::BruteForce::AnnIterator<float>( return knowhere::BruteForce::AnnIterator<float>(
base_dataset, query_dataset, search_cfg, bitset); base_dataset, query_dataset, config, bitset);
break; break;
case DataType::VECTOR_SPARSE_FLOAT: case DataType::VECTOR_SPARSE_FLOAT:
iterators_val = knowhere::BruteForce::AnnIterator< return knowhere::BruteForce::AnnIterator<
knowhere::sparse::SparseRow<float>>( knowhere::sparse::SparseRow<float>>(
base_dataset, query_dataset, search_cfg, bitset); base_dataset, query_dataset, config, bitset);
break; break;
default: default:
PanicInfo(ErrorCode::Unsupported, PanicInfo(ErrorCode::Unsupported,
"Unsupported dataType for chunk brute force iterator:{}", "Unsupported dataType for chunk brute force iterator:{}",
data_type); data_type);
} }
}
knowhere::expected<std::vector<knowhere::IndexNode::IteratorPtr>>
GetBruteForceSearchIterators(
const dataset::SearchDataset& query_ds,
const dataset::RawDataset& raw_ds,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
DataType data_type) {
auto nq = query_ds.num_queries;
auto [query_dataset, base_dataset] =
PrepareBFDataSet(query_ds, raw_ds, data_type);
auto search_cfg = PrepareBFSearchParams(search_info, index_info);
return DispatchBruteForceIteratorByDataType(
base_dataset, query_dataset, search_cfg, bitset, data_type);
}
SubSearchResult
PackBruteForceSearchIteratorsIntoSubResult(
const dataset::SearchDataset& query_ds,
const dataset::RawDataset& raw_ds,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
DataType data_type) {
auto nq = query_ds.num_queries;
auto iterators_val = GetBruteForceSearchIterators(
query_ds, raw_ds, search_info, index_info, bitset, data_type);
if (iterators_val.has_value()) { if (iterators_val.has_value()) {
AssertInfo( AssertInfo(
iterators_val.value().size() == nq, iterators_val.value().size() == nq,

View File

@ -31,12 +31,22 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
const BitsetView& bitset, const BitsetView& bitset,
DataType data_type); DataType data_type);
knowhere::expected<std::vector<knowhere::IndexNode::IteratorPtr>>
GetBruteForceSearchIterators(
const dataset::SearchDataset& query_ds,
const dataset::RawDataset& raw_ds,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
DataType data_type);
SubSearchResult SubSearchResult
BruteForceSearchIterators(const dataset::SearchDataset& query_ds, PackBruteForceSearchIteratorsIntoSubResult(
const dataset::RawDataset& raw_ds, const dataset::SearchDataset& query_ds,
const SearchInfo& search_info, const dataset::RawDataset& raw_ds,
const std::map<std::string, std::string>& index_info, const SearchInfo& search_info,
const BitsetView& bitset, const std::map<std::string, std::string>& index_info,
DataType data_type); const BitsetView& bitset,
DataType data_type);
} // namespace milvus::query } // namespace milvus::query

View File

@ -18,6 +18,7 @@
#include "knowhere/comp/index_param.h" #include "knowhere/comp/index_param.h"
#include "knowhere/config.h" #include "knowhere/config.h"
#include "log/Log.h" #include "log/Log.h"
#include "query/CachedSearchIterator.h"
#include "query/SearchBruteForce.h" #include "query/SearchBruteForce.h"
#include "query/SearchOnIndex.h" #include "query/SearchOnIndex.h"
#include "exec/operator/Utils.h" #include "exec/operator/Utils.h"
@ -125,6 +126,19 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
// step 3: brute force search where small indexing is unavailable // step 3: brute force search where small indexing is unavailable
auto vec_ptr = record.get_data_base(vecfield_id); auto vec_ptr = record.get_data_base(vecfield_id);
if (info.iterator_v2_info_.has_value()) {
CachedSearchIterator cached_iter(search_dataset,
vec_ptr,
active_count,
info,
index_info,
bitset,
data_type);
cached_iter.NextBatch(info, search_result);
return;
}
auto vec_size_per_chunk = vec_ptr->get_size_per_chunk(); auto vec_size_per_chunk = vec_ptr->get_size_per_chunk();
auto max_chunk = upper_div(active_count, vec_size_per_chunk); auto max_chunk = upper_div(active_count, vec_size_per_chunk);
@ -140,12 +154,13 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
auto sub_data = query::dataset::RawDataset{ auto sub_data = query::dataset::RawDataset{
element_begin, dim, size_per_chunk, chunk_data}; element_begin, dim, size_per_chunk, chunk_data};
if (milvus::exec::UseVectorIterator(info)) { if (milvus::exec::UseVectorIterator(info)) {
auto sub_qr = BruteForceSearchIterators(search_dataset, auto sub_qr =
sub_data, PackBruteForceSearchIteratorsIntoSubResult(search_dataset,
info, sub_data,
index_info, info,
bitset, index_info,
data_type); bitset,
data_type);
final_qr.merge(sub_qr); final_qr.merge(sub_qr);
} else { } else {
auto sub_qr = BruteForceSearch(search_dataset, auto sub_qr = BruteForceSearch(search_dataset,

View File

@ -11,6 +11,7 @@
#include "SearchOnIndex.h" #include "SearchOnIndex.h"
#include "exec/operator/Utils.h" #include "exec/operator/Utils.h"
#include "CachedSearchIterator.h"
namespace milvus::query { namespace milvus::query {
void void
@ -26,14 +27,23 @@ SearchOnIndex(const dataset::SearchDataset& search_dataset,
auto dataset = auto dataset =
knowhere::GenDataSet(num_queries, dim, search_dataset.query_data); knowhere::GenDataSet(num_queries, dim, search_dataset.query_data);
dataset->SetIsSparse(is_sparse); dataset->SetIsSparse(is_sparse);
if (!milvus::exec::PrepareVectorIteratorsFromIndex(search_conf, if (milvus::exec::PrepareVectorIteratorsFromIndex(search_conf,
num_queries, num_queries,
dataset, dataset,
search_result, search_result,
bitset, bitset,
indexing)) { indexing)) {
indexing.Query(dataset, search_conf, bitset, search_result); return;
} }
if (search_conf.iterator_v2_info_.has_value()) {
auto iter =
CachedSearchIterator(indexing, dataset, search_conf, bitset);
iter.NextBatch(search_conf, search_result);
return;
}
indexing.Query(dataset, search_conf, bitset, search_result);
} }
} // namespace milvus::query } // namespace milvus::query

View File

@ -18,6 +18,7 @@
#include "common/QueryInfo.h" #include "common/QueryInfo.h"
#include "common/Types.h" #include "common/Types.h"
#include "mmap/Column.h" #include "mmap/Column.h"
#include "query/CachedSearchIterator.h"
#include "query/SearchBruteForce.h" #include "query/SearchBruteForce.h"
#include "query/SearchOnSealed.h" #include "query/SearchOnSealed.h"
#include "query/helper.h" #include "query/helper.h"
@ -55,13 +56,20 @@ SearchOnSealedIndex(const Schema& schema,
dataset->SetIsSparse(is_sparse); dataset->SetIsSparse(is_sparse);
auto vec_index = auto vec_index =
dynamic_cast<index::VectorIndex*>(field_indexing->indexing_.get()); dynamic_cast<index::VectorIndex*>(field_indexing->indexing_.get());
if (search_info.iterator_v2_info_.has_value()) {
CachedSearchIterator cached_iter(
*vec_index, dataset, search_info, bitset);
cached_iter.NextBatch(search_info, search_result);
return;
}
if (!milvus::exec::PrepareVectorIteratorsFromIndex(search_info, if (!milvus::exec::PrepareVectorIteratorsFromIndex(search_info,
num_queries, num_queries,
dataset, dataset,
search_result, search_result,
bitset, bitset,
*vec_index)) { *vec_index)) {
auto index_type = vec_index->GetIndexType();
vec_index->Query(dataset, search_info, bitset, search_result); vec_index->Query(dataset, search_info, bitset, search_result);
float* distances = search_result.distances_.data(); float* distances = search_result.distances_.data();
auto total_num = num_queries * topK; auto total_num = num_queries * topK;
@ -104,6 +112,14 @@ SearchOnSealed(const Schema& schema,
auto data_type = field.get_data_type(); auto data_type = field.get_data_type();
CheckBruteForceSearchParam(field, search_info); CheckBruteForceSearchParam(field, search_info);
if (search_info.iterator_v2_info_.has_value()) {
CachedSearchIterator cached_iter(
column, query_dataset, search_info, index_info, bitview, data_type);
cached_iter.NextBatch(search_info, result);
return;
}
auto num_chunk = column->num_chunks(); auto num_chunk = column->num_chunks();
SubSearchResult final_qr(num_queries, SubSearchResult final_qr(num_queries,
@ -115,17 +131,16 @@ SearchOnSealed(const Schema& schema,
for (int i = 0; i < num_chunk; ++i) { for (int i = 0; i < num_chunk; ++i) {
auto vec_data = column->Data(i); auto vec_data = column->Data(i);
auto chunk_size = column->chunk_row_nums(i); auto chunk_size = column->chunk_row_nums(i);
const uint8_t* bitset_ptr = nullptr;
auto data_id = offset;
auto raw_dataset = auto raw_dataset =
query::dataset::RawDataset{offset, dim, chunk_size, vec_data}; query::dataset::RawDataset{offset, dim, chunk_size, vec_data};
if (milvus::exec::UseVectorIterator(search_info)) { if (milvus::exec::UseVectorIterator(search_info)) {
auto sub_qr = BruteForceSearchIterators(query_dataset, auto sub_qr =
raw_dataset, PackBruteForceSearchIteratorsIntoSubResult(query_dataset,
search_info, raw_dataset,
index_info, search_info,
bitview, index_info,
data_type); bitview,
data_type);
final_qr.merge(sub_qr); final_qr.merge(sub_qr);
} else { } else {
auto sub_qr = BruteForceSearch(query_dataset, auto sub_qr = BruteForceSearch(query_dataset,
@ -136,7 +151,6 @@ SearchOnSealed(const Schema& schema,
data_type); data_type);
final_qr.merge(sub_qr); final_qr.merge(sub_qr);
} }
offset += chunk_size; offset += chunk_size;
} }
if (milvus::exec::UseVectorIterator(search_info)) { if (milvus::exec::UseVectorIterator(search_info)) {
@ -181,14 +195,23 @@ SearchOnSealed(const Schema& schema,
CheckBruteForceSearchParam(field, search_info); CheckBruteForceSearchParam(field, search_info);
auto raw_dataset = query::dataset::RawDataset{0, dim, row_count, vec_data}; auto raw_dataset = query::dataset::RawDataset{0, dim, row_count, vec_data};
if (milvus::exec::UseVectorIterator(search_info)) { if (milvus::exec::UseVectorIterator(search_info)) {
auto sub_qr = BruteForceSearchIterators(query_dataset, auto sub_qr = PackBruteForceSearchIteratorsIntoSubResult(query_dataset,
raw_dataset, raw_dataset,
search_info, search_info,
index_info, index_info,
bitset, bitset,
data_type); data_type);
result.AssembleChunkVectorIterators( result.AssembleChunkVectorIterators(
num_queries, 1, {0}, sub_qr.chunk_iterators()); num_queries, 1, {0}, sub_qr.chunk_iterators());
} else if (search_info.iterator_v2_info_.has_value()) {
CachedSearchIterator cached_iter(query_dataset,
raw_dataset,
search_info,
index_info,
bitset,
data_type);
cached_iter.NextBatch(search_info, result);
return;
} else { } else {
auto sub_qr = BruteForceSearch(query_dataset, auto sub_qr = BruteForceSearch(query_dataset,
raw_dataset, raw_dataset,

View File

@ -89,6 +89,7 @@ set(MILVUS_TEST_FILES
test_chunked_segment.cpp test_chunked_segment.cpp
test_chunked_column.cpp test_chunked_column.cpp
test_rust_result.cpp test_rust_result.cpp
test_cached_search_iterator.cpp
) )
if ( INDEX_ENGINE STREQUAL "cardinal" ) if ( INDEX_ENGINE STREQUAL "cardinal" )

View File

@ -143,12 +143,13 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
AssertMatch(ref, ans); AssertMatch(ref, ans);
} }
auto result3 = BruteForceSearchIterators(query_dataset, auto result3 = PackBruteForceSearchIteratorsIntoSubResult(
raw_dataset, query_dataset,
search_info, raw_dataset,
index_info, search_info,
bitset_view, index_info,
DataType::VECTOR_SPARSE_FLOAT); bitset_view,
DataType::VECTOR_SPARSE_FLOAT);
auto iterators = result3.chunk_iterators(); auto iterators = result3.chunk_iterators();
for (int i = 0; i < nq; i++) { for (int i = 0; i < nq; i++) {
auto it = iterators[i]; auto it = iterators[i];

View File

@ -0,0 +1,797 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <gtest/gtest.h>
#include <memory>
#include <random>
#include <unordered_set>
#include "common/BitsetView.h"
#include "common/QueryInfo.h"
#include "common/QueryResult.h"
#include "common/Utils.h"
#include "index/Index.h"
#include "knowhere/comp/index_param.h"
#include "query/CachedSearchIterator.h"
#include "index/VectorIndex.h"
#include "index/IndexFactory.h"
#include "knowhere/dataset.h"
#include "query/helper.h"
#include "segcore/ConcurrentVector.h"
#include "segcore/InsertRecord.h"
#include "mmap/ChunkedColumn.h"
#include "test_utils/DataGen.h"
using namespace milvus;
using namespace milvus::query;
using namespace milvus::segcore;
using namespace milvus::index;
namespace {
constexpr int64_t kDim = 16;
constexpr int64_t kNumVectors = 1000;
constexpr int64_t kNumQueries = 1;
constexpr int64_t kBatchSize = 100;
constexpr size_t kSizePerChunk = 128;
constexpr size_t kHnswM = 24;
constexpr size_t kHnswEfConstruction = 360;
constexpr size_t kHnswEf = 128;
const MetricType kMetricType = knowhere::metric::L2;
} // namespace
enum class ConstructorType {
VectorIndex = 0,
RawData,
VectorBase,
ChunkedColumn
};
static const std::vector<ConstructorType> kConstructorTypes = {
ConstructorType::VectorIndex,
ConstructorType::RawData,
ConstructorType::VectorBase,
ConstructorType::ChunkedColumn,
};
static const std::vector<MetricType> kMetricTypes = {
knowhere::metric::L2,
knowhere::metric::IP,
knowhere::metric::COSINE,
};
// this class does not support test concurrently
class CachedSearchIteratorTest
: public ::testing::TestWithParam<std::tuple<ConstructorType, MetricType>> {
private:
protected:
SearchInfo
GetDefaultNormalSearchInfo() {
return SearchInfo{
.topk_ = kBatchSize,
.round_decimal_ = -1,
.metric_type_ = std::get<1>(GetParam()),
.search_params_ =
{
{knowhere::indexparam::EF, std::to_string(kHnswEf)},
},
.iterator_v2_info_ =
SearchIteratorV2Info{
.batch_size = kBatchSize,
},
};
}
static DataType data_type_;
static int64_t dim_;
static int64_t nb_;
static int64_t nq_;
static FixedVector<float> base_dataset_;
static FixedVector<float> query_dataset_;
static IndexBasePtr index_hnsw_l2_;
static IndexBasePtr index_hnsw_ip_;
static IndexBasePtr index_hnsw_cos_;
static knowhere::DataSetPtr knowhere_query_dataset_;
static dataset::SearchDataset search_dataset_;
static std::unique_ptr<ConcurrentVector<milvus::FloatVector>> vector_base_;
static std::shared_ptr<ChunkedColumn> column_;
static std::vector<std::vector<char>> column_data_;
IndexBase* index_hnsw_ = nullptr;
MetricType metric_type_ = kMetricType;
std::unique_ptr<CachedSearchIterator>
DispatchIterator(const ConstructorType& constructor_type,
const SearchInfo& search_info,
const BitsetView& bitset) {
switch (constructor_type) {
case ConstructorType::VectorIndex:
return std::make_unique<CachedSearchIterator>(
dynamic_cast<const VectorIndex&>(*index_hnsw_),
knowhere_query_dataset_,
search_info,
bitset);
case ConstructorType::RawData:
return std::make_unique<CachedSearchIterator>(
search_dataset_,
dataset::RawDataset{0, dim_, nb_, base_dataset_.data()},
search_info,
std::map<std::string, std::string>{},
bitset,
data_type_);
case ConstructorType::VectorBase:
return std::make_unique<CachedSearchIterator>(
search_dataset_,
vector_base_.get(),
nb_,
search_info,
std::map<std::string, std::string>{},
bitset,
data_type_);
case ConstructorType::ChunkedColumn:
return std::make_unique<CachedSearchIterator>(
column_,
search_dataset_,
search_info,
std::map<std::string, std::string>{},
bitset,
data_type_);
default:
return nullptr;
}
}
// use last distance of the first batch as range_filter
// use first distance of the last batch as radius
std::pair<float, float>
GetRadiusAndRangeFilter() {
const size_t num_rnds = (nb_ + kBatchSize - 1) / kBatchSize;
SearchResult search_result;
float radius, range_filter;
bool get_radius_success = false;
bool get_range_filter_sucess = false;
SearchInfo search_info = GetDefaultNormalSearchInfo();
auto iterator =
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
for (size_t rnd = 0; rnd < num_rnds; ++rnd) {
iterator->NextBatch(search_info, search_result);
if (rnd == 0) {
for (size_t i = kBatchSize - 1; i >= 0; --i) {
if (search_result.seg_offsets_[i] != -1) {
range_filter = search_result.distances_[i];
get_range_filter_sucess = true;
break;
}
}
} else {
for (size_t i = 0; i < kBatchSize; ++i) {
if (search_result.seg_offsets_[i] != -1) {
radius = search_result.distances_[i];
get_radius_success = true;
break;
}
}
}
}
if (!get_radius_success || !get_range_filter_sucess) {
throw std::runtime_error("Failed to get radius and range filter");
}
return {radius, range_filter};
}
static void
BuildIndex() {
auto dataset = knowhere::GenDataSet(nb_, dim_, base_dataset_.data());
for (const auto& metric_type : kMetricTypes) {
milvus::index::CreateIndexInfo create_index_info;
create_index_info.field_type = data_type_;
create_index_info.metric_type = metric_type;
create_index_info.index_engine_version =
knowhere::Version::GetCurrentVersion().VersionNumber();
auto build_conf = knowhere::Json{
{knowhere::meta::METRIC_TYPE, knowhere::metric::L2},
{knowhere::meta::DIM, std::to_string(dim_)},
{knowhere::indexparam::M, std::to_string(kHnswM)},
{knowhere::indexparam::EFCONSTRUCTION,
std::to_string(kHnswEfConstruction)}};
create_index_info.index_type = knowhere::IndexEnum::INDEX_HNSW;
if (metric_type == knowhere::metric::L2) {
index_hnsw_l2_ =
milvus::index::IndexFactory::GetInstance().CreateIndex(
create_index_info,
milvus::storage::FileManagerContext());
index_hnsw_l2_->BuildWithDataset(dataset, build_conf);
ASSERT_EQ(index_hnsw_l2_->Count(), nb_);
} else if (metric_type == knowhere::metric::IP) {
index_hnsw_ip_ =
milvus::index::IndexFactory::GetInstance().CreateIndex(
create_index_info,
milvus::storage::FileManagerContext());
index_hnsw_ip_->BuildWithDataset(dataset, build_conf);
ASSERT_EQ(index_hnsw_ip_->Count(), nb_);
} else if (metric_type == knowhere::metric::COSINE) {
index_hnsw_cos_ =
milvus::index::IndexFactory::GetInstance().CreateIndex(
create_index_info,
milvus::storage::FileManagerContext());
index_hnsw_cos_->BuildWithDataset(dataset, build_conf);
ASSERT_EQ(index_hnsw_cos_->Count(), nb_);
} else {
FAIL() << "Unsupported metric type: " << metric_type;
}
}
}
static void
SetUpVectorBase() {
vector_base_ = std::make_unique<ConcurrentVector<milvus::FloatVector>>(
dim_, kSizePerChunk);
vector_base_->set_data_raw(0, base_dataset_.data(), nb_);
ASSERT_EQ(vector_base_->num_chunk(),
(nb_ + kSizePerChunk - 1) / kSizePerChunk);
}
static void
SetUpChunkedColumn() {
column_ = std::make_unique<ChunkedColumn>();
const size_t num_chunks_ = (nb_ + kSizePerChunk - 1) / kSizePerChunk;
column_data_.resize(num_chunks_);
size_t offset = 0;
for (size_t i = 0; i < num_chunks_; ++i) {
const size_t rows = std::min(nb_ - offset, kSizePerChunk);
const size_t chunk_bitset_size = (rows + 7) / 8;
const size_t buf_size =
chunk_bitset_size + rows * dim_ * sizeof(float);
auto& chunk_data = column_data_[i];
chunk_data.resize(buf_size);
memcpy(chunk_data.data() + chunk_bitset_size,
base_dataset_.cbegin() + offset * dim_,
rows * dim_ * sizeof(float));
column_->AddChunk(std::make_shared<FixedWidthChunk>(
rows, dim_, chunk_data.data(), buf_size, sizeof(float), false));
offset += rows;
}
}
static void
SetUpTestSuite() {
auto schema = std::make_shared<Schema>();
auto fakevec_id = schema->AddDebugField(
"fakevec", DataType::VECTOR_FLOAT, dim_, kMetricType);
// generate base dataset
base_dataset_ =
segcore::DataGen(schema, nb_).get_col<float>(fakevec_id);
// generate query dataset
query_dataset_ = {base_dataset_.cbegin(),
base_dataset_.cbegin() + nq_ * dim_};
knowhere_query_dataset_ =
knowhere::GenDataSet(nq_, dim_, query_dataset_.data());
search_dataset_ = dataset::SearchDataset{
.metric_type = kMetricType,
.num_queries = nq_,
.topk = kBatchSize,
.round_decimal = -1,
.dim = dim_,
.query_data = query_dataset_.data(),
};
BuildIndex();
SetUpVectorBase();
SetUpChunkedColumn();
}
static void
TearDownTestSuite() {
base_dataset_.clear();
query_dataset_.clear();
index_hnsw_l2_.reset();
index_hnsw_ip_.reset();
index_hnsw_cos_.reset();
knowhere_query_dataset_.reset();
vector_base_.reset();
column_.reset();
}
void
SetUp() override {
auto metric_type = std::get<1>(GetParam());
if (metric_type == knowhere::metric::L2) {
metric_type_ = knowhere::metric::L2;
search_dataset_.metric_type = knowhere::metric::L2;
index_hnsw_ = index_hnsw_l2_.get();
} else if (metric_type == knowhere::metric::IP) {
metric_type_ = knowhere::metric::IP;
search_dataset_.metric_type = knowhere::metric::IP;
index_hnsw_ = index_hnsw_ip_.get();
} else if (metric_type == knowhere::metric::COSINE) {
metric_type_ = knowhere::metric::COSINE;
search_dataset_.metric_type = knowhere::metric::COSINE;
index_hnsw_ = index_hnsw_cos_.get();
} else {
FAIL() << "Unsupported metric type: " << metric_type;
}
}
void
TearDown() override {
}
};
// initialize static variables
DataType CachedSearchIteratorTest::data_type_ = DataType::VECTOR_FLOAT;
int64_t CachedSearchIteratorTest::dim_ = kDim;
int64_t CachedSearchIteratorTest::nb_ = kNumVectors;
int64_t CachedSearchIteratorTest::nq_ = kNumQueries;
IndexBasePtr CachedSearchIteratorTest::index_hnsw_l2_ = nullptr;
IndexBasePtr CachedSearchIteratorTest::index_hnsw_ip_ = nullptr;
IndexBasePtr CachedSearchIteratorTest::index_hnsw_cos_ = nullptr;
knowhere::DataSetPtr CachedSearchIteratorTest::knowhere_query_dataset_ =
nullptr;
dataset::SearchDataset CachedSearchIteratorTest::search_dataset_;
FixedVector<float> CachedSearchIteratorTest::base_dataset_;
FixedVector<float> CachedSearchIteratorTest::query_dataset_;
std::unique_ptr<ConcurrentVector<milvus::FloatVector>>
CachedSearchIteratorTest::vector_base_ = nullptr;
std::shared_ptr<ChunkedColumn> CachedSearchIteratorTest::column_ = nullptr;
std::vector<std::vector<char>> CachedSearchIteratorTest::column_data_;
/********* Testcases Start **********/
TEST_P(CachedSearchIteratorTest, NextBatchNormal) {
SearchInfo search_info = GetDefaultNormalSearchInfo();
const std::vector<size_t> kBatchSizes = {
1, 7, 43, 99, 100, 101, 1000, 1005};
for (size_t batch_size : kBatchSizes) {
std::cout << "batch_size: " << batch_size << std::endl;
search_info.iterator_v2_info_->batch_size = batch_size;
auto iterator =
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
SearchResult search_result;
iterator->NextBatch(search_info, search_result);
for (size_t i = 0; i < nq_; ++i) {
std::unordered_set<int64_t> seg_offsets;
size_t cnt = 0;
for (size_t j = 0; j < batch_size; ++j) {
if (search_result.seg_offsets_[i * batch_size + j] == -1) {
break;
}
++cnt;
seg_offsets.insert(
search_result.seg_offsets_[i * batch_size + j]);
}
EXPECT_EQ(seg_offsets.size(), cnt);
if (metric_type_ == knowhere::metric::L2) {
EXPECT_EQ(search_result.distances_[i * batch_size], 0);
}
}
EXPECT_EQ(search_result.unity_topK_, batch_size);
EXPECT_EQ(search_result.total_nq_, nq_);
EXPECT_EQ(search_result.seg_offsets_.size(), nq_ * batch_size);
EXPECT_EQ(search_result.distances_.size(), nq_ * batch_size);
}
}
TEST_P(CachedSearchIteratorTest, NextBatchDistBound) {
SearchInfo search_info = GetDefaultNormalSearchInfo();
const size_t batch_size = kBatchSize;
const float dist_bound_factor = PositivelyRelated(metric_type_) ? 0.5 : 1.5;
float dist_bound = 0;
{
auto iterator =
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
SearchResult search_result;
iterator->NextBatch(search_info, search_result);
bool found_dist_bound = false;
// use the last distance of the first query * factor as the dist bound
for (size_t j = batch_size - 1; j >= 0; --j) {
if (search_result.seg_offsets_[j] != -1) {
dist_bound = search_result.distances_[j] * dist_bound_factor;
found_dist_bound = true;
break;
}
}
ASSERT_TRUE(found_dist_bound);
search_info.iterator_v2_info_->last_bound = dist_bound;
for (size_t rnd = 1; rnd < (nb_ + batch_size - 1) / batch_size; ++rnd) {
iterator->NextBatch(search_info, search_result);
for (size_t i = 0; i < nq_; ++i) {
for (size_t j = 0; j < batch_size; ++j) {
if (search_result.seg_offsets_[i * batch_size + j] == -1) {
break;
}
if (PositivelyRelated(metric_type_)) {
EXPECT_LT(search_result.distances_[i * batch_size + j],
dist_bound);
} else {
EXPECT_GT(search_result.distances_[i * batch_size + j],
dist_bound);
}
}
}
}
}
}
TEST_P(CachedSearchIteratorTest, NextBatchDistBoundEmptyResults) {
SearchInfo search_info = GetDefaultNormalSearchInfo();
const size_t batch_size = kBatchSize;
const float dist_bound = PositivelyRelated(metric_type_)
? -std::numeric_limits<float>::max()
: std::numeric_limits<float>::max();
auto iterator =
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
SearchResult search_result;
search_info.iterator_v2_info_->last_bound = dist_bound;
size_t total_cnt = 0;
for (size_t rnd = 0; rnd < (nb_ + batch_size - 1) / batch_size; ++rnd) {
iterator->NextBatch(search_info, search_result);
for (size_t i = 0; i < nq_; ++i) {
for (size_t j = 0; j < batch_size; ++j) {
if (search_result.seg_offsets_[i * batch_size + j] == -1) {
break;
}
++total_cnt;
}
}
}
EXPECT_EQ(total_cnt, 0);
}
TEST_P(CachedSearchIteratorTest, NextBatchRangeSearchRadius) {
const size_t num_rnds = (nb_ + kBatchSize - 1) / kBatchSize;
const auto [radius, range_filter] = GetRadiusAndRangeFilter();
SearchResult search_result;
SearchInfo search_info = GetDefaultNormalSearchInfo();
search_info.search_params_[knowhere::meta::RADIUS] = radius;
auto iterator =
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
for (size_t rnd = 0; rnd < num_rnds; ++rnd) {
iterator->NextBatch(search_info, search_result);
for (size_t i = 0; i < nq_; ++i) {
for (size_t j = 0; j < kBatchSize; ++j) {
if (search_result.seg_offsets_[i * kBatchSize + j] == -1) {
break;
}
float dist = search_result.distances_[i * kBatchSize + j];
if (PositivelyRelated(metric_type_)) {
ASSERT_GT(dist, radius);
} else {
ASSERT_LT(dist, radius);
}
}
}
}
}
TEST_P(CachedSearchIteratorTest, NextBatchRangeSearchRadiusAndRangeFilter) {
const size_t num_rnds = (nb_ + kBatchSize - 1) / kBatchSize;
const auto [radius, range_filter] = GetRadiusAndRangeFilter();
SearchResult search_result;
SearchInfo search_info = GetDefaultNormalSearchInfo();
search_info.search_params_[knowhere::meta::RADIUS] = radius;
search_info.search_params_[knowhere::meta::RANGE_FILTER] = range_filter;
auto iterator =
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
for (size_t rnd = 0; rnd < num_rnds; ++rnd) {
iterator->NextBatch(search_info, search_result);
for (size_t i = 0; i < nq_; ++i) {
for (size_t j = 0; j < kBatchSize; ++j) {
if (search_result.seg_offsets_[i * kBatchSize + j] == -1) {
break;
}
float dist = search_result.distances_[i * kBatchSize + j];
if (PositivelyRelated(metric_type_)) {
ASSERT_GT(dist, radius);
ASSERT_LE(dist, range_filter);
} else {
ASSERT_LT(dist, radius);
ASSERT_GE(dist, range_filter);
}
}
}
}
}
TEST_P(CachedSearchIteratorTest,
NextBatchRangeSearchLastBoundRadiusRangeFilter) {
const size_t num_rnds = (nb_ + kBatchSize - 1) / kBatchSize;
const auto [radius, range_filter] = GetRadiusAndRangeFilter();
SearchResult search_result;
const float diff = (radius + range_filter) / 2;
const std::vector<float> last_bounds = {radius - diff,
radius,
radius + diff,
range_filter,
range_filter + diff};
SearchInfo search_info = GetDefaultNormalSearchInfo();
search_info.search_params_[knowhere::meta::RADIUS] = radius;
search_info.search_params_[knowhere::meta::RANGE_FILTER] = range_filter;
for (float last_bound : last_bounds) {
search_info.iterator_v2_info_->last_bound = last_bound;
auto iterator =
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
for (size_t rnd = 0; rnd < num_rnds; ++rnd) {
iterator->NextBatch(search_info, search_result);
for (size_t i = 0; i < nq_; ++i) {
for (size_t j = 0; j < kBatchSize; ++j) {
if (search_result.seg_offsets_[i * kBatchSize + j] == -1) {
break;
}
float dist = search_result.distances_[i * kBatchSize + j];
if (PositivelyRelated(metric_type_)) {
ASSERT_LE(dist, last_bound);
ASSERT_GT(dist, radius);
ASSERT_LE(dist, range_filter);
} else {
ASSERT_GT(dist, last_bound);
ASSERT_LT(dist, radius);
ASSERT_GE(dist, range_filter);
}
}
}
}
}
}
TEST_P(CachedSearchIteratorTest, NextBatchZeroBatchSize) {
SearchInfo search_info = GetDefaultNormalSearchInfo();
auto iterator =
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
SearchResult search_result;
search_info.iterator_v2_info_->batch_size = 0;
EXPECT_THROW(iterator->NextBatch(search_info, search_result), SegcoreError);
}
TEST_P(CachedSearchIteratorTest, NextBatchDiffBatchSizeComparedToInit) {
SearchInfo search_info = GetDefaultNormalSearchInfo();
auto iterator =
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
SearchResult search_result;
search_info.iterator_v2_info_->batch_size = kBatchSize + 1;
EXPECT_THROW(iterator->NextBatch(search_info, search_result), SegcoreError);
}
TEST_P(CachedSearchIteratorTest, NextBatchEmptySearchInfo) {
SearchInfo search_info = GetDefaultNormalSearchInfo();
auto iterator =
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
SearchResult search_result;
SearchInfo empty_search_info;
EXPECT_THROW(iterator->NextBatch(empty_search_info, search_result),
SegcoreError);
}
TEST_P(CachedSearchIteratorTest, NextBatchEmptyIteratorV2Info) {
SearchInfo search_info = GetDefaultNormalSearchInfo();
auto iterator =
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
SearchResult search_result;
search_info.iterator_v2_info_ = std::nullopt;
EXPECT_THROW(iterator->NextBatch(search_info, search_result), SegcoreError);
}
TEST_P(CachedSearchIteratorTest, NextBatchtAllBatchesNormal) {
SearchInfo search_info = GetDefaultNormalSearchInfo();
const std::vector<size_t> kBatchSizes = {
1, 7, 43, 99, 100, 101, 1000, 1005};
// const std::vector<size_t> kBatchSizes = {1005};
for (size_t batch_size : kBatchSizes) {
search_info.iterator_v2_info_->batch_size = batch_size;
auto iterator =
DispatchIterator(std::get<0>(GetParam()), search_info, nullptr);
size_t total_cnt = 0;
for (size_t rnd = 0; rnd < (nb_ + batch_size - 1) / batch_size; ++rnd) {
SearchResult search_result;
iterator->NextBatch(search_info, search_result);
for (size_t i = 0; i < nq_; ++i) {
std::unordered_set<int64_t> seg_offsets;
size_t cnt = 0;
for (size_t j = 0; j < batch_size; ++j) {
if (search_result.seg_offsets_[i * batch_size + j] == -1) {
break;
}
++cnt;
seg_offsets.insert(
search_result.seg_offsets_[i * batch_size + j]);
}
total_cnt += cnt;
// check no duplicate
EXPECT_EQ(seg_offsets.size(), cnt);
// only check if the first distance of the first batch is 0
if (rnd == 0 && metric_type_ == knowhere::metric::L2) {
EXPECT_EQ(search_result.distances_[i * batch_size], 0);
}
}
EXPECT_EQ(search_result.unity_topK_, batch_size);
EXPECT_EQ(search_result.total_nq_, nq_);
EXPECT_EQ(search_result.seg_offsets_.size(), nq_ * batch_size);
EXPECT_EQ(search_result.distances_.size(), nq_ * batch_size);
}
if (std::get<0>(GetParam()) == ConstructorType::VectorIndex) {
EXPECT_GE(total_cnt, nb_ * nq_ * 0.9);
} else {
EXPECT_EQ(total_cnt, nb_ * nq_);
}
}
}
TEST_P(CachedSearchIteratorTest, ConstructorWithInvalidSearchInfo) {
EXPECT_THROW(
DispatchIterator(std::get<0>(GetParam()), SearchInfo{}, nullptr),
SegcoreError);
EXPECT_THROW(
DispatchIterator(
std::get<0>(GetParam()), SearchInfo{.metric_type_ = ""}, nullptr),
SegcoreError);
EXPECT_THROW(DispatchIterator(std::get<0>(GetParam()),
SearchInfo{.metric_type_ = metric_type_},
nullptr),
SegcoreError);
EXPECT_THROW(DispatchIterator(std::get<0>(GetParam()),
SearchInfo{.metric_type_ = metric_type_,
.iterator_v2_info_ = {}},
nullptr),
SegcoreError);
EXPECT_THROW(
DispatchIterator(std::get<0>(GetParam()),
SearchInfo{.metric_type_ = metric_type_,
.iterator_v2_info_ =
SearchIteratorV2Info{.batch_size = 0}},
nullptr),
SegcoreError);
}
TEST_P(CachedSearchIteratorTest, ConstructorWithInvalidParams) {
SearchInfo search_info = GetDefaultNormalSearchInfo();
if (std::get<0>(GetParam()) == ConstructorType::VectorIndex) {
EXPECT_THROW(auto iterator = std::make_unique<CachedSearchIterator>(
dynamic_cast<const VectorIndex&>(*index_hnsw_),
nullptr,
search_info,
nullptr),
SegcoreError);
EXPECT_THROW(auto iterator = std::make_unique<CachedSearchIterator>(
dynamic_cast<const VectorIndex&>(*index_hnsw_),
std::make_shared<knowhere::DataSet>(),
search_info,
nullptr),
SegcoreError);
} else if (std::get<0>(GetParam()) == ConstructorType::RawData) {
EXPECT_THROW(
auto iterator = std::make_unique<CachedSearchIterator>(
dataset::SearchDataset{},
dataset::RawDataset{0, dim_, nb_, base_dataset_.data()},
search_info,
std::map<std::string, std::string>{},
nullptr,
data_type_),
SegcoreError);
} else if (std::get<0>(GetParam()) == ConstructorType::VectorBase) {
EXPECT_THROW(auto iterator = std::make_unique<CachedSearchIterator>(
dataset::SearchDataset{},
vector_base_.get(),
nb_,
search_info,
std::map<std::string, std::string>{},
nullptr,
data_type_),
SegcoreError);
EXPECT_THROW(auto iterator = std::make_unique<CachedSearchIterator>(
search_dataset_,
nullptr,
nb_,
search_info,
std::map<std::string, std::string>{},
nullptr,
data_type_),
SegcoreError);
EXPECT_THROW(auto iterator = std::make_unique<CachedSearchIterator>(
search_dataset_,
vector_base_.get(),
0,
search_info,
std::map<std::string, std::string>{},
nullptr,
data_type_),
SegcoreError);
} else if (std::get<0>(GetParam()) == ConstructorType::ChunkedColumn) {
EXPECT_THROW(auto iterator = std::make_unique<CachedSearchIterator>(
nullptr,
search_dataset_,
search_info,
std::map<std::string, std::string>{},
nullptr,
data_type_),
SegcoreError);
EXPECT_THROW(auto iterator = std::make_unique<CachedSearchIterator>(
column_,
dataset::SearchDataset{},
search_info,
std::map<std::string, std::string>{},
nullptr,
data_type_),
SegcoreError);
}
}
/********* Testcases End **********/
INSTANTIATE_TEST_SUITE_P(
CachedSearchIteratorTests,
CachedSearchIteratorTest,
::testing::Combine(::testing::ValuesIn(kConstructorTypes),
::testing::ValuesIn(kMetricTypes)),
[](const testing::TestParamInfo<std::tuple<ConstructorType, MetricType>>&
info) {
std::string constructor_type_str;
ConstructorType constructor_type = std::get<0>(info.param);
MetricType metric_type = std::get<1>(info.param);
switch (constructor_type) {
case ConstructorType::VectorIndex:
constructor_type_str = "VectorIndex";
break;
case ConstructorType::RawData:
constructor_type_str = "RawData";
break;
case ConstructorType::VectorBase:
constructor_type_str = "VectorBase";
break;
case ConstructorType::ChunkedColumn:
constructor_type_str = "ChunkedColumn";
break;
default:
constructor_type_str = "Unknown constructor type";
};
if (metric_type == knowhere::metric::L2) {
constructor_type_str += "_L2";
} else if (metric_type == knowhere::metric::IP) {
constructor_type_str += "_IP";
} else if (metric_type == knowhere::metric::COSINE) {
constructor_type_str += "_COSINE";
} else {
constructor_type_str += "_Unknown";
}
return constructor_type_str;
});

View File

@ -55,6 +55,12 @@ message Array {
schema.DataType element_type = 3; schema.DataType element_type = 3;
} }
message SearchIteratorV2Info {
string token = 1;
uint32 batch_size = 2;
optional float last_bound = 3;
}
message QueryInfo { message QueryInfo {
int64 topk = 1; int64 topk = 1;
string metric_type = 3; string metric_type = 3;
@ -67,6 +73,7 @@ message QueryInfo {
double bm25_avgdl = 10; double bm25_avgdl = 10;
int64 query_field_id =11; int64 query_field_id =11;
string hints = 12; string hints = 12;
optional SearchIteratorV2Info search_iterator_v2_info = 13;
} }
message ColumnInfo { message ColumnInfo {

View File

@ -26,6 +26,7 @@ import (
"time" "time"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/google/uuid"
"github.com/hashicorp/golang-lru/v2/expirable" "github.com/hashicorp/golang-lru/v2/expirable"
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/atomic" "go.uber.org/atomic"
@ -304,6 +305,13 @@ func (node *Proxy) Init() error {
node.enableMaterializedView = Params.CommonCfg.EnableMaterializedView.GetAsBool() node.enableMaterializedView = Params.CommonCfg.EnableMaterializedView.GetAsBool()
// Enable internal rand pool for UUIDv4 generation
// This is NOT thread-safe and should only be called before the service starts and
// there is no possibility that New or any other UUID V4 generation function will be called concurrently
// Only proxy generates UUID for now, and one Milvus process only has one proxy
uuid.EnableRandPool()
log.Debug("enable rand pool for UUIDv4 generation")
log.Info("init proxy done", zap.Int64("nodeID", paramtable.GetNodeID()), zap.String("Address", node.address)) log.Info("init proxy done", zap.Int64("nodeID", paramtable.GetNodeID()), zap.String("Address", node.address))
return nil return nil
} }

View File

@ -8,6 +8,7 @@ import (
"strings" "strings"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/google/uuid"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
@ -82,6 +83,81 @@ type SearchInfo struct {
isIterator bool isIterator bool
} }
func parseSearchIteratorV2Info(searchParamsPair []*commonpb.KeyValuePair, groupByFieldId int64, isIterator bool, offset int64, queryTopK *int64) (*planpb.SearchIteratorV2Info, error) {
isIteratorV2Str, _ := funcutil.GetAttrByKeyFromRepeatedKV(SearchIterV2Key, searchParamsPair)
isIteratorV2, _ := strconv.ParseBool(isIteratorV2Str)
if !isIteratorV2 {
return nil, nil
}
// iteratorV1 and iteratorV2 should be set together for compatibility
if !isIterator {
return nil, fmt.Errorf("both %s and %s must be set in the SDK", IteratorField, SearchIterV2Key)
}
// disable groupBy when doing iteratorV2
// same behavior with V1
if isIteratorV2 && groupByFieldId > 0 {
return nil, merr.WrapErrParameterInvalid("", "",
"GroupBy is not permitted when using a search iterator")
}
// disable offset when doing iteratorV2
if isIteratorV2 && offset > 0 {
return nil, merr.WrapErrParameterInvalid("", "",
"Setting an offset is not permitted when using a search iterator v2")
}
// parse token, generate if not exist
token, _ := funcutil.GetAttrByKeyFromRepeatedKV(SearchIterIdKey, searchParamsPair)
if token == "" {
generatedToken, err := uuid.NewRandom()
if err != nil {
return nil, err
}
token = generatedToken.String()
} else {
// Validate existing token is a valid UUID
if _, err := uuid.Parse(token); err != nil {
return nil, fmt.Errorf("invalid token format")
}
}
// parse batch size, required non-zero value
batchSizeStr, _ := funcutil.GetAttrByKeyFromRepeatedKV(SearchIterBatchSizeKey, searchParamsPair)
if batchSizeStr == "" {
return nil, fmt.Errorf("batch size is required")
}
batchSize, err := strconv.ParseInt(batchSizeStr, 0, 64)
if err != nil {
return nil, fmt.Errorf("batch size is invalid, %w", err)
}
// use the same validation logic as topk
if err := validateLimit(batchSize); err != nil {
return nil, fmt.Errorf("batch size is invalid, %w", err)
}
*queryTopK = batchSize // for compatibility
// prepare plan iterator v2 info proto
planIteratorV2Info := &planpb.SearchIteratorV2Info{
Token: token,
BatchSize: uint32(batchSize),
}
// append optional last bound if applicable
lastBoundStr, _ := funcutil.GetAttrByKeyFromRepeatedKV(SearchIterLastBoundKey, searchParamsPair)
if lastBoundStr != "" {
lastBound, err := strconv.ParseFloat(lastBoundStr, 32)
if err != nil {
return nil, fmt.Errorf("failed to parse input last bound, %w", err)
}
lastBoundFloat32 := float32(lastBound)
planIteratorV2Info.LastBound = &lastBoundFloat32 // escape pointer
}
return planIteratorV2Info, nil
}
// parseSearchInfo returns QueryInfo and offset // parseSearchInfo returns QueryInfo and offset
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, rankParams *rankParams) *SearchInfo { func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, rankParams *rankParams) *SearchInfo {
var topK int64 var topK int64
@ -196,16 +272,22 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
"Not allowed to do range-search when doing search-group-by")} "Not allowed to do range-search when doing search-group-by")}
} }
planSearchIteratorV2Info, err := parseSearchIteratorV2Info(searchParamsPair, groupByFieldId, isIterator, offset, &queryTopK)
if err != nil {
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("parse iterator v2 info failed: %w", err)}
}
return &SearchInfo{ return &SearchInfo{
planInfo: &planpb.QueryInfo{ planInfo: &planpb.QueryInfo{
Topk: queryTopK, Topk: queryTopK,
MetricType: metricType, MetricType: metricType,
SearchParams: searchParamStr, SearchParams: searchParamStr,
RoundDecimal: roundDecimal, RoundDecimal: roundDecimal,
GroupByFieldId: groupByFieldId, GroupByFieldId: groupByFieldId,
GroupSize: groupSize, GroupSize: groupSize,
StrictGroupSize: strictGroupSize, StrictGroupSize: strictGroupSize,
Hints: hints, Hints: hints,
SearchIteratorV2Info: planSearchIteratorV2Info,
}, },
offset: offset, offset: offset,
isIterator: isIterator, isIterator: isIterator,

View File

@ -69,6 +69,11 @@ const (
OffsetKey = "offset" OffsetKey = "offset"
LimitKey = "limit" LimitKey = "limit"
SearchIterV2Key = "search_iter_v2"
SearchIterBatchSizeKey = "search_iter_batch_size"
SearchIterLastBoundKey = "search_iter_last_bound"
SearchIterIdKey = "search_iter_id"
InsertTaskName = "InsertTask" InsertTaskName = "InsertTask"
CreateCollectionTaskName = "CreateCollectionTask" CreateCollectionTaskName = "CreateCollectionTask"
DropCollectionTaskName = "DropCollectionTask" DropCollectionTaskName = "DropCollectionTask"

View File

@ -28,6 +28,7 @@ import (
"github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord" "github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/tsoutil"
@ -590,12 +591,15 @@ func (t *searchTask) Execute(ctx context.Context) error {
return nil return nil
} }
func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK int64, offset int64, queryInfo *planpb.QueryInfo, isAdvance bool) (*milvuspb.SearchResults, error) { func getMetricType(toReduceResults []*internalpb.SearchResults) string {
metricType := "" metricType := ""
if len(toReduceResults) >= 1 { if len(toReduceResults) >= 1 {
metricType = toReduceResults[0].GetMetricType() metricType = toReduceResults[0].GetMetricType()
} }
return metricType
}
func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK int64, offset int64, metricType string, queryInfo *planpb.QueryInfo, isAdvance bool) (*milvuspb.SearchResults, error) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "reduceResults") ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "reduceResults")
defer sp.End() defer sp.End()
@ -631,6 +635,24 @@ func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*inter
return result, nil return result, nil
} }
// find the last bound based on reduced results and metric type
// only support nq == 1, for search iterator v2
func getLastBound(result *milvuspb.SearchResults, incomingLastBound *float32, metricType string) float32 {
len := len(result.Results.Scores)
if len > 0 && result.GetResults().GetNumQueries() == 1 {
return result.Results.Scores[len-1]
}
// if no results found and incoming last bound is not nil, return it
if incomingLastBound != nil {
return *incomingLastBound
}
// if no results found and it is the first call, return the closest bound
if metric.PositivelyRelated(metricType) {
return math.MaxFloat32
}
return -math.MaxFloat32
}
func (t *searchTask) PostExecute(ctx context.Context) error { func (t *searchTask) PostExecute(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search-PostExecute") ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search-PostExecute")
defer sp.End() defer sp.End()
@ -670,6 +692,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
return err return err
} }
metricType := getMetricType(toReduceResults)
// reduce // reduce
if t.SearchRequest.GetIsAdvanced() { if t.SearchRequest.GetIsAdvanced() {
multipleInternalResults := make([][]*internalpb.SearchResults, len(t.SearchRequest.GetSubReqs())) multipleInternalResults := make([][]*internalpb.SearchResults, len(t.SearchRequest.GetSubReqs()))
@ -696,16 +719,12 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
multipleMilvusResults := make([]*milvuspb.SearchResults, len(t.SearchRequest.GetSubReqs())) multipleMilvusResults := make([]*milvuspb.SearchResults, len(t.SearchRequest.GetSubReqs()))
for index, internalResults := range multipleInternalResults { for index, internalResults := range multipleInternalResults {
subReq := t.SearchRequest.GetSubReqs()[index] subReq := t.SearchRequest.GetSubReqs()[index]
subMetricType := getMetricType(internalResults)
metricType := "" result, err := t.reduceResults(t.ctx, internalResults, subReq.GetNq(), subReq.GetTopk(), subReq.GetOffset(), subMetricType, t.queryInfos[index], true)
if len(internalResults) >= 1 {
metricType = internalResults[0].GetMetricType()
}
result, err := t.reduceResults(t.ctx, internalResults, subReq.GetNq(), subReq.GetTopk(), subReq.GetOffset(), t.queryInfos[index], true)
if err != nil { if err != nil {
return err return err
} }
t.reScorers[index].setMetricType(metricType) t.reScorers[index].setMetricType(subMetricType)
t.reScorers[index].reScore(result) t.reScorers[index].reScore(result)
multipleMilvusResults[index] = result multipleMilvusResults[index] = result
} }
@ -721,7 +740,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
return err return err
} }
} else { } else {
t.result, err = t.reduceResults(t.ctx, toReduceResults, t.SearchRequest.GetNq(), t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), t.queryInfos[0], false) t.result, err = t.reduceResults(t.ctx, toReduceResults, t.SearchRequest.GetNq(), t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), metricType, t.queryInfos[0], false)
if err != nil { if err != nil {
return err return err
} }
@ -751,6 +770,14 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
} }
t.result.Results.OutputFields = t.userOutputFields t.result.Results.OutputFields = t.userOutputFields
t.result.CollectionName = t.request.GetCollectionName() t.result.CollectionName = t.request.GetCollectionName()
if t.isIterator && len(t.queryInfos) == 1 && t.queryInfos[0] != nil {
if iterInfo := t.queryInfos[0].GetSearchIteratorV2Info(); iterInfo != nil {
t.result.Results.SearchIteratorV2Results = &schemapb.SearchIteratorV2Results{
Token: iterInfo.GetToken(),
LastBound: getLastBound(t.result, iterInfo.LastBound, metricType),
}
}
}
if t.isIterator && t.request.GetGuaranteeTimestamp() == 0 { if t.isIterator && t.request.GetGuaranteeTimestamp() == 0 {
// first page for iteration, need to set up sessionTs for iterator // first page for iteration, need to set up sessionTs for iterator
t.result.SessionTs = getMaxMvccTsFromChannels(t.queryChannelsTs, t.BeginTs()) t.result.SessionTs = getMaxMvccTsFromChannels(t.queryChannelsTs, t.BeginTs())

View File

@ -18,12 +18,14 @@ package proxy
import ( import (
"context" "context"
"fmt" "fmt"
"math"
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -103,9 +105,124 @@ func TestSearchTask_PostExecute(t *testing.T) {
assert.Equal(t, qt.resultSizeInsufficient, true) assert.Equal(t, qt.resultSizeInsufficient, true)
assert.Equal(t, qt.isTopkReduce, false) assert.Equal(t, qt.isTopkReduce, false)
}) })
t.Run("test search iterator v2", func(t *testing.T) {
const (
kRows = 10
kToken = "test-token"
)
collName := "test_collection_search_iterator_v2" + funcutil.GenRandomStr()
collSchema := createColl(t, collName, rc)
createIteratorSearchTask := func(t *testing.T, metricType string, rows int) *searchTask {
ids := make([]int64, rows)
for i := range ids {
ids[i] = int64(i)
}
resultIDs := &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: ids,
},
},
}
scores := make([]float32, rows)
// proxy needs to reverse the score for negatively related metrics
for i := range scores {
if metric.PositivelyRelated(metricType) {
scores[i] = float32(len(scores) - i)
} else {
scores[i] = -float32(i + 1)
}
}
resultData := &schemapb.SearchResultData{
Ids: resultIDs,
Scores: scores,
NumQueries: 1,
}
qt := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Search,
SourceID: paramtable.GetNodeID(),
},
Nq: 1,
},
schema: newSchemaInfo(collSchema),
request: &milvuspb.SearchRequest{
CollectionName: collName,
},
queryInfos: []*planpb.QueryInfo{{
SearchIteratorV2Info: &planpb.SearchIteratorV2Info{
Token: kToken,
BatchSize: 1,
},
}},
result: &milvuspb.SearchResults{
Results: resultData,
},
resultBuf: typeutil.NewConcurrentSet[*internalpb.SearchResults](),
tr: timerecord.NewTimeRecorder("search"),
isIterator: true,
}
bytes, err := proto.Marshal(resultData)
assert.NoError(t, err)
qt.resultBuf.Insert(&internalpb.SearchResults{
MetricType: metricType,
SlicedBlob: bytes,
})
return qt
}
t.Run("test search iterator v2", func(t *testing.T) {
metrics := []string{metric.L2, metric.IP, metric.COSINE, metric.BM25}
for _, metricType := range metrics {
qt := createIteratorSearchTask(t, metricType, kRows)
err = qt.PostExecute(ctx)
assert.NoError(t, err)
assert.Equal(t, kToken, qt.result.Results.SearchIteratorV2Results.Token)
if metric.PositivelyRelated(metricType) {
assert.Equal(t, float32(1), qt.result.Results.SearchIteratorV2Results.LastBound)
} else {
assert.Equal(t, float32(kRows), qt.result.Results.SearchIteratorV2Results.LastBound)
}
}
})
t.Run("test search iterator v2 with empty result", func(t *testing.T) {
metrics := []string{metric.L2, metric.IP, metric.COSINE, metric.BM25}
for _, metricType := range metrics {
qt := createIteratorSearchTask(t, metricType, 0)
err = qt.PostExecute(ctx)
assert.NoError(t, err)
assert.Equal(t, kToken, qt.result.Results.SearchIteratorV2Results.Token)
if metric.PositivelyRelated(metricType) {
assert.Equal(t, float32(math.MaxFloat32), qt.result.Results.SearchIteratorV2Results.LastBound)
} else {
assert.Equal(t, float32(-math.MaxFloat32), qt.result.Results.SearchIteratorV2Results.LastBound)
}
}
})
t.Run("test search iterator v2 with empty result and incoming last bound", func(t *testing.T) {
metrics := []string{metric.L2, metric.IP, metric.COSINE, metric.BM25}
kLastBound := float32(10)
for _, metricType := range metrics {
qt := createIteratorSearchTask(t, metricType, 0)
qt.queryInfos[0].SearchIteratorV2Info.LastBound = &kLastBound
err = qt.PostExecute(ctx)
assert.NoError(t, err)
assert.Equal(t, kToken, qt.result.Results.SearchIteratorV2Results.Token)
assert.Equal(t, kLastBound, qt.result.Results.SearchIteratorV2Results.LastBound)
}
})
})
} }
func createColl(t *testing.T, name string, rc types.RootCoordClient) { func createColl(t *testing.T, name string, rc types.RootCoordClient) *schemapb.CollectionSchema {
schema := constructCollectionSchema(testInt64Field, testFloatVecField, testVecDim, name) schema := constructCollectionSchema(testInt64Field, testFloatVecField, testVecDim, name)
marshaledSchema, err := proto.Marshal(schema) marshaledSchema, err := proto.Marshal(schema)
require.NoError(t, err) require.NoError(t, err)
@ -126,6 +243,8 @@ func createColl(t *testing.T, name string, rc types.RootCoordClient) {
require.NoError(t, createColT.PreExecute(ctx)) require.NoError(t, createColT.PreExecute(ctx))
require.NoError(t, createColT.Execute(ctx)) require.NoError(t, createColT.Execute(ctx))
require.NoError(t, createColT.PostExecute(ctx)) require.NoError(t, createColT.PostExecute(ctx))
return schema
} }
func getBaseSearchParams() []*commonpb.KeyValuePair { func getBaseSearchParams() []*commonpb.KeyValuePair {
@ -2599,6 +2718,157 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
assert.True(t, strings.Contains(searchInfo.parseError.Error(), "failed to parse input group size")) assert.True(t, strings.Contains(searchInfo.parseError.Error(), "failed to parse input group size"))
} }
}) })
t.Run("check search iterator v2", func(t *testing.T) {
kBatchSize := uint32(10)
generateValidParamsForSearchIteratorV2 := func() []*commonpb.KeyValuePair {
param := getValidSearchParams()
return append(param,
&commonpb.KeyValuePair{
Key: SearchIterV2Key,
Value: "True",
},
&commonpb.KeyValuePair{
Key: IteratorField,
Value: "True",
},
&commonpb.KeyValuePair{
Key: SearchIterBatchSizeKey,
Value: fmt.Sprintf("%d", kBatchSize),
},
)
}
t.Run("iteratorV2 normal", func(t *testing.T) {
param := generateValidParamsForSearchIteratorV2()
searchInfo := parseSearchInfo(param, nil, nil)
assert.NoError(t, searchInfo.parseError)
assert.NotNil(t, searchInfo.planInfo)
assert.NotEmpty(t, searchInfo.planInfo.SearchIteratorV2Info.Token)
assert.Equal(t, kBatchSize, searchInfo.planInfo.SearchIteratorV2Info.BatchSize)
assert.Len(t, searchInfo.planInfo.SearchIteratorV2Info.Token, 36)
assert.Equal(t, int64(kBatchSize), searchInfo.planInfo.GetTopk()) // compatibility
})
t.Run("iteratorV2 without isIterator", func(t *testing.T) {
param := generateValidParamsForSearchIteratorV2()
resetSearchParamsValue(param, IteratorField, "False")
searchInfo := parseSearchInfo(param, nil, nil)
assert.Error(t, searchInfo.parseError)
assert.ErrorContains(t, searchInfo.parseError, "both")
})
t.Run("iteratorV2 with groupBy", func(t *testing.T) {
param := generateValidParamsForSearchIteratorV2()
param = append(param, &commonpb.KeyValuePair{
Key: GroupByFieldKey,
Value: "string_field",
})
fields := make([]*schemapb.FieldSchema, 0)
fields = append(fields, &schemapb.FieldSchema{
FieldID: int64(101),
Name: "string_field",
})
schema := &schemapb.CollectionSchema{
Fields: fields,
}
searchInfo := parseSearchInfo(param, schema, nil)
assert.Error(t, searchInfo.parseError)
assert.ErrorContains(t, searchInfo.parseError, "roupBy")
})
t.Run("iteratorV2 with offset", func(t *testing.T) {
param := generateValidParamsForSearchIteratorV2()
param = append(param, &commonpb.KeyValuePair{
Key: OffsetKey,
Value: "10",
})
searchInfo := parseSearchInfo(param, nil, nil)
assert.Error(t, searchInfo.parseError)
assert.ErrorContains(t, searchInfo.parseError, "offset")
})
t.Run("iteratorV2 invalid token", func(t *testing.T) {
param := generateValidParamsForSearchIteratorV2()
param = append(param, &commonpb.KeyValuePair{
Key: SearchIterIdKey,
Value: "invalid_token",
})
searchInfo := parseSearchInfo(param, nil, nil)
assert.Error(t, searchInfo.parseError)
assert.ErrorContains(t, searchInfo.parseError, "invalid token format")
})
t.Run("iteratorV2 passed token must be same", func(t *testing.T) {
token, err := uuid.NewRandom()
assert.NoError(t, err)
param := generateValidParamsForSearchIteratorV2()
param = append(param, &commonpb.KeyValuePair{
Key: SearchIterIdKey,
Value: token.String(),
})
searchInfo := parseSearchInfo(param, nil, nil)
assert.NoError(t, searchInfo.parseError)
assert.NotEmpty(t, searchInfo.planInfo.SearchIteratorV2Info.Token)
assert.Equal(t, token.String(), searchInfo.planInfo.SearchIteratorV2Info.Token)
})
t.Run("iteratorV2 batch size", func(t *testing.T) {
param := generateValidParamsForSearchIteratorV2()
resetSearchParamsValue(param, SearchIterBatchSizeKey, "1.123")
searchInfo := parseSearchInfo(param, nil, nil)
assert.Error(t, searchInfo.parseError)
assert.ErrorContains(t, searchInfo.parseError, "batch size is invalid")
})
t.Run("iteratorV2 batch size", func(t *testing.T) {
param := generateValidParamsForSearchIteratorV2()
resetSearchParamsValue(param, SearchIterBatchSizeKey, "")
searchInfo := parseSearchInfo(param, nil, nil)
assert.Error(t, searchInfo.parseError)
assert.ErrorContains(t, searchInfo.parseError, "batch size is required")
})
t.Run("iteratorV2 batch size negative", func(t *testing.T) {
param := generateValidParamsForSearchIteratorV2()
resetSearchParamsValue(param, SearchIterBatchSizeKey, "-1")
searchInfo := parseSearchInfo(param, nil, nil)
assert.Error(t, searchInfo.parseError)
assert.ErrorContains(t, searchInfo.parseError, "batch size is invalid")
})
t.Run("iteratorV2 batch size too large", func(t *testing.T) {
param := generateValidParamsForSearchIteratorV2()
resetSearchParamsValue(param, SearchIterBatchSizeKey, fmt.Sprintf("%d", Params.QuotaConfig.TopKLimit.GetAsInt64()+1))
searchInfo := parseSearchInfo(param, nil, nil)
assert.Error(t, searchInfo.parseError)
assert.ErrorContains(t, searchInfo.parseError, "batch size is invalid")
})
t.Run("iteratorV2 last bound", func(t *testing.T) {
kLastBound := float32(1.123)
param := generateValidParamsForSearchIteratorV2()
param = append(param, &commonpb.KeyValuePair{
Key: SearchIterLastBoundKey,
Value: fmt.Sprintf("%f", kLastBound),
})
searchInfo := parseSearchInfo(param, nil, nil)
assert.NoError(t, searchInfo.parseError)
assert.NotNil(t, searchInfo.planInfo)
assert.Equal(t, kLastBound, *searchInfo.planInfo.SearchIteratorV2Info.LastBound)
})
t.Run("iteratorV2 invalid last bound", func(t *testing.T) {
param := generateValidParamsForSearchIteratorV2()
param = append(param, &commonpb.KeyValuePair{
Key: SearchIterLastBoundKey,
Value: "xxx",
})
searchInfo := parseSearchInfo(param, nil, nil)
assert.Error(t, searchInfo.parseError)
assert.ErrorContains(t, searchInfo.parseError, "failed to parse input last bound")
})
})
} }
func getSearchResultData(nq, topk int64) *schemapb.SearchResultData { func getSearchResultData(nq, topk int64) *schemapb.SearchResultData {