// Copyright (C) 2019-2024 Zilliz. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software distributed under the License // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License #pragma once #include #include "cachinglayer/CacheSlot.h" #include "common/BitsetView.h" #include "common/QueryInfo.h" #include "common/QueryResult.h" #include "query/helper.h" #include "segcore/ConcurrentVector.h" #include "index/VectorIndex.h" #include "mmap/ChunkedColumn.h" namespace milvus::query { // This class is used to cache the search results from Knowhere // search iterators and filter the results based on the last_bound, // radius and range_filter. // It provides a number of constructors to support different scenarios, // including growing/sealed, chunked/non-chunked. // // It does not care about TopK in search_info // The topk in SearchResult will be set to the batch_size for compatibility // // TODO: introduce the pool of results in the near future // TODO: replace VectorIterator class class CachedSearchIterator { public: // For sealed segment with vector index CachedSearchIterator(const milvus::index::VectorIndex& index, const knowhere::DataSetPtr& dataset, const SearchInfo& search_info, const BitsetView& bitset); // For sealed segment, BF CachedSearchIterator(const dataset::SearchDataset& dataset, const dataset::RawDataset& raw_ds, const SearchInfo& search_info, const std::map& index_info, const BitsetView& bitset, const milvus::DataType& data_type); // For growing segment with chunked data, BF CachedSearchIterator(const dataset::SearchDataset& dataset, const segcore::VectorBase* vec_data, const int64_t row_count, const SearchInfo& search_info, const std::map& index_info, const BitsetView& bitset, const milvus::DataType& data_type); // For sealed segment with chunked data, BF CachedSearchIterator(ChunkedColumnBase* column, const dataset::SearchDataset& dataset, const SearchInfo& search_info, const std::map& index_info, const BitsetView& bitset, const milvus::DataType& data_type); // This method fetches the next batch of search results based on the provided search information // and updates the search_result object with the new batch of results. void NextBatch(const SearchInfo& search_info, SearchResult& search_result); // Disable copy and move CachedSearchIterator(const CachedSearchIterator&) = delete; CachedSearchIterator& operator=(const CachedSearchIterator&) = delete; CachedSearchIterator(CachedSearchIterator&&) = delete; CachedSearchIterator& operator=(CachedSearchIterator&&) = delete; private: using DisIdPair = std::pair; using IterIdx = size_t; using IterIdDisIdPair = std::pair; using GetChunkDataFunc = std::function, int64_t>(int64_t)>; int64_t batch_size_ = 0; std::vector iterators_; int8_t sign_ = 1; size_t num_chunks_ = 1; size_t nq_ = 0; struct IterIdDisIdPairComparator { bool operator()(const IterIdDisIdPair& lhs, const IterIdDisIdPair& rhs) { if (lhs.second.first == rhs.second.first) { return lhs.second.second > rhs.second.second; } return lhs.second.first > rhs.second.first; } }; std::vector, IterIdDisIdPairComparator>> chunked_heaps_; inline bool IsValid(const DisIdPair& result, const std::optional& last_bound, const std::optional& radius, const std::optional& range_filter) { const float dist = result.first; const bool is_valid = !last_bound.has_value() || dist > last_bound.value(); if (!radius.has_value()) { return is_valid; } if (!range_filter.has_value()) { return is_valid && dist < radius.value(); } return is_valid && dist < radius.value() && dist >= range_filter.value(); } inline DisIdPair ConvertIteratorResult(const std::pair& iter_rst) { DisIdPair rst; rst.first = iter_rst.second * sign_; rst.second = iter_rst.first; return rst; } inline std::optional ConvertIncomingDistance(std::optional dist) { if (dist.has_value()) { dist = dist.value() * sign_; } return dist; } std::optional GetNextValidResult(size_t iterator_idx, const std::optional& last_bound, const std::optional& radius, const std::optional& range_filter); void MergeChunksResults(size_t query_idx, const std::optional& last_bound, const std::optional& radius, const std::optional& range_filter, std::vector& rst); void ValidateSearchInfo(const SearchInfo& search_info); std::vector GetBatchedNextResults(size_t query_idx, const SearchInfo& search_info); void WriteSingleQuerySearchResult(SearchResult& search_result, const size_t idx, std::vector& rst, const int64_t round_decimal); void Init(const SearchInfo& search_info); void InitializeChunkedIterators( const dataset::SearchDataset& dataset, const SearchInfo& search_info, const std::map& index_info, const BitsetView& bitset, const milvus::DataType& data_type, const GetChunkDataFunc& get_chunk_data); }; } // namespace milvus::query