mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-30 15:35:33 +08:00
212 lines
8.8 KiB
C++
212 lines
8.8 KiB
C++
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
|
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
|
// or implied. See the License for the specific language governing permissions and limitations under the License
|
|
|
|
#include <algorithm>
|
|
#include <cmath>
|
|
#include <string>
|
|
|
|
#include "bitset/detail/element_wise.h"
|
|
#include "cachinglayer/Utils.h"
|
|
#include "common/BitsetView.h"
|
|
#include "common/QueryInfo.h"
|
|
#include "common/Types.h"
|
|
#include "query/CachedSearchIterator.h"
|
|
#include "query/SearchBruteForce.h"
|
|
#include "query/SearchOnSealed.h"
|
|
#include "query/helper.h"
|
|
#include "exec/operator/Utils.h"
|
|
|
|
namespace milvus::query {
|
|
|
|
void
|
|
SearchOnSealedIndex(const Schema& schema,
|
|
const segcore::SealedIndexingRecord& record,
|
|
const SearchInfo& search_info,
|
|
const void* query_data,
|
|
const size_t* query_offsets,
|
|
int64_t num_queries,
|
|
const BitsetView& bitset,
|
|
milvus::OpContext* op_context,
|
|
SearchResult& search_result) {
|
|
auto topK = search_info.topk_;
|
|
auto round_decimal = search_info.round_decimal_;
|
|
|
|
auto field_id = search_info.field_id_;
|
|
auto& field = schema[field_id];
|
|
auto is_sparse = field.get_data_type() == DataType::VECTOR_SPARSE_U32_F32;
|
|
// TODO(SPARSE): see todo in PlanImpl.h::PlaceHolder.
|
|
auto dim = is_sparse ? 0 : field.get_dim();
|
|
|
|
AssertInfo(record.is_ready(field_id), "[SearchOnSealed]Record isn't ready");
|
|
// Keep the field_indexing smart pointer, until all reference by raw dropped.
|
|
auto field_indexing = record.get_field_indexing(field_id);
|
|
AssertInfo(field_indexing->metric_type_ == search_info.metric_type_,
|
|
"Metric type of field index isn't the same with search info,"
|
|
"field index: {}, search info: {}",
|
|
field_indexing->metric_type_,
|
|
search_info.metric_type_);
|
|
|
|
knowhere::DataSetPtr dataset;
|
|
if (query_offsets == nullptr) {
|
|
dataset = knowhere::GenDataSet(num_queries, dim, query_data);
|
|
} else {
|
|
// Rather than non-embedding list search where num_queries equals to the number of vectors,
|
|
// in embedding list search, multiple vectors form an embedding list and the last element of query_offsets
|
|
// stands for the total number of vectors.
|
|
auto num_vectors = query_offsets[num_queries];
|
|
dataset = knowhere::GenDataSet(num_vectors, dim, query_data);
|
|
dataset->Set(knowhere::meta::EMB_LIST_OFFSET, query_offsets);
|
|
}
|
|
|
|
dataset->SetIsSparse(is_sparse);
|
|
auto accessor =
|
|
SemiInlineGet(field_indexing->indexing_->PinCells(nullptr, {0}));
|
|
auto vec_index =
|
|
dynamic_cast<index::VectorIndex*>(accessor->get_cell_of(0));
|
|
|
|
if (search_info.iterator_v2_info_.has_value()) {
|
|
CachedSearchIterator cached_iter(
|
|
*vec_index, dataset, search_info, bitset);
|
|
cached_iter.NextBatch(search_info, search_result);
|
|
return;
|
|
}
|
|
|
|
if (!milvus::exec::PrepareVectorIteratorsFromIndex(search_info,
|
|
num_queries,
|
|
dataset,
|
|
search_result,
|
|
bitset,
|
|
*vec_index)) {
|
|
vec_index->Query(
|
|
dataset, search_info, bitset, op_context, search_result);
|
|
float* distances = search_result.distances_.data();
|
|
auto total_num = num_queries * topK;
|
|
if (round_decimal != -1) {
|
|
const float multiplier = pow(10.0, round_decimal);
|
|
for (int i = 0; i < total_num; i++) {
|
|
distances[i] =
|
|
std::round(distances[i] * multiplier) / multiplier;
|
|
}
|
|
}
|
|
}
|
|
search_result.total_nq_ = num_queries;
|
|
search_result.unity_topK_ = topK;
|
|
}
|
|
|
|
void
|
|
SearchOnSealedColumn(const Schema& schema,
|
|
ChunkedColumnInterface* column,
|
|
const SearchInfo& search_info,
|
|
const std::map<std::string, std::string>& index_info,
|
|
const void* query_data,
|
|
const size_t* query_offsets,
|
|
int64_t num_queries,
|
|
int64_t row_count,
|
|
const BitsetView& bitview,
|
|
milvus::OpContext* op_context,
|
|
SearchResult& result) {
|
|
auto field_id = search_info.field_id_;
|
|
auto& field = schema[field_id];
|
|
|
|
auto data_type = field.get_data_type();
|
|
auto element_type = field.get_element_type();
|
|
// TODO(SPARSE): see todo in PlanImpl.h::PlaceHolder.
|
|
auto dim =
|
|
data_type == DataType::VECTOR_SPARSE_U32_F32 ? 0 : field.get_dim();
|
|
|
|
query::dataset::SearchDataset query_dataset{search_info.metric_type_,
|
|
num_queries,
|
|
search_info.topk_,
|
|
search_info.round_decimal_,
|
|
dim,
|
|
query_data,
|
|
query_offsets};
|
|
|
|
CheckBruteForceSearchParam(field, search_info);
|
|
|
|
if (search_info.iterator_v2_info_.has_value()) {
|
|
AssertInfo(data_type != DataType::VECTOR_ARRAY,
|
|
"vector array(embedding list) is not supported for "
|
|
"vector iterator");
|
|
|
|
CachedSearchIterator cached_iter(
|
|
column, query_dataset, search_info, index_info, bitview, data_type);
|
|
cached_iter.NextBatch(search_info, result);
|
|
return;
|
|
}
|
|
|
|
auto num_chunk = column->num_chunks();
|
|
|
|
SubSearchResult final_qr(num_queries,
|
|
search_info.topk_,
|
|
search_info.metric_type_,
|
|
search_info.round_decimal_);
|
|
|
|
auto offset = 0;
|
|
|
|
auto vector_chunks = column->GetAllChunks(op_context);
|
|
for (int i = 0; i < num_chunk; ++i) {
|
|
auto pw = vector_chunks[i];
|
|
auto vec_data = pw.get()->Data();
|
|
auto chunk_size = column->chunk_row_nums(i);
|
|
auto raw_dataset =
|
|
query::dataset::RawDataset{offset, dim, chunk_size, vec_data};
|
|
|
|
PinWrapper<const size_t*> offsets_pw;
|
|
if (data_type == DataType::VECTOR_ARRAY) {
|
|
AssertInfo(
|
|
query_offsets != nullptr,
|
|
"query_offsets is nullptr, but data_type is vector array");
|
|
|
|
offsets_pw = column->VectorArrayOffsets(op_context, i);
|
|
raw_dataset.raw_data_offsets = offsets_pw.get();
|
|
}
|
|
|
|
if (milvus::exec::UseVectorIterator(search_info)) {
|
|
AssertInfo(data_type != DataType::VECTOR_ARRAY,
|
|
"vector array(embedding list) is not supported for "
|
|
"vector iterator");
|
|
auto sub_qr =
|
|
PackBruteForceSearchIteratorsIntoSubResult(query_dataset,
|
|
raw_dataset,
|
|
search_info,
|
|
index_info,
|
|
bitview,
|
|
data_type);
|
|
final_qr.merge(sub_qr);
|
|
} else {
|
|
auto sub_qr = BruteForceSearch(query_dataset,
|
|
raw_dataset,
|
|
search_info,
|
|
index_info,
|
|
bitview,
|
|
data_type,
|
|
element_type,
|
|
op_context);
|
|
final_qr.merge(sub_qr);
|
|
}
|
|
offset += chunk_size;
|
|
}
|
|
if (milvus::exec::UseVectorIterator(search_info)) {
|
|
result.AssembleChunkVectorIterators(num_queries,
|
|
num_chunk,
|
|
column->GetNumRowsUntilChunk(),
|
|
final_qr.chunk_iterators());
|
|
} else {
|
|
result.distances_ = std::move(final_qr.mutable_distances());
|
|
result.seg_offsets_ = std::move(final_qr.mutable_seg_offsets());
|
|
}
|
|
result.unity_topK_ = query_dataset.topk;
|
|
result.total_nq_ = query_dataset.num_queries;
|
|
}
|
|
|
|
} // namespace milvus::query
|