mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-01 08:28:10 +08:00
103 lines
3.9 KiB
C++
103 lines
3.9 KiB
C++
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
|
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
|
// or implied. See the License for the specific language governing permissions and limitations under the License
|
|
|
|
#include <cmath>
|
|
|
|
#include "common/QueryInfo.h"
|
|
#include "query/SearchBruteForce.h"
|
|
#include "query/SearchOnSealed.h"
|
|
#include "query/helper.h"
|
|
|
|
namespace milvus::query {
|
|
|
|
void
|
|
SearchOnSealedIndex(const Schema& schema,
|
|
const segcore::SealedIndexingRecord& record,
|
|
const SearchInfo& search_info,
|
|
const void* query_data,
|
|
int64_t num_queries,
|
|
const BitsetView& bitset,
|
|
SearchResult& result) {
|
|
auto topk = search_info.topk_;
|
|
auto round_decimal = search_info.round_decimal_;
|
|
|
|
auto field_id = search_info.field_id_;
|
|
auto& field = schema[field_id];
|
|
// Assert(field.get_data_type() == DataType::VECTOR_FLOAT);
|
|
auto dim = field.get_dim();
|
|
|
|
AssertInfo(record.is_ready(field_id), "[SearchOnSealed]Record isn't ready");
|
|
auto field_indexing = record.get_field_indexing(field_id);
|
|
AssertInfo(field_indexing->metric_type_ == search_info.metric_type_,
|
|
"Metric type of field index isn't the same with search info");
|
|
|
|
auto final = [&] {
|
|
auto ds = knowhere::GenDataSet(num_queries, dim, query_data);
|
|
|
|
auto conf = search_info.search_params_;
|
|
conf[knowhere::meta::TOPK] = search_info.topk_;
|
|
conf[knowhere::meta::METRIC_TYPE] = field_indexing->metric_type_;
|
|
auto vec_index =
|
|
dynamic_cast<index::VectorIndex*>(field_indexing->indexing_.get());
|
|
auto index_type = vec_index->GetIndexType();
|
|
return vec_index->Query(ds, search_info, bitset);
|
|
}();
|
|
|
|
auto ids = final->seg_offsets_.data();
|
|
float* distances = final->distances_.data();
|
|
|
|
auto total_num = num_queries * topk;
|
|
if (round_decimal != -1) {
|
|
const float multiplier = pow(10.0, round_decimal);
|
|
for (int i = 0; i < total_num; i++) {
|
|
distances[i] = std::round(distances[i] * multiplier) / multiplier;
|
|
}
|
|
}
|
|
result.seg_offsets_.resize(total_num);
|
|
result.distances_.resize(total_num);
|
|
result.total_nq_ = num_queries;
|
|
result.unity_topK_ = topk;
|
|
|
|
std::copy_n(ids, total_num, result.seg_offsets_.data());
|
|
std::copy_n(distances, total_num, result.distances_.data());
|
|
}
|
|
|
|
void
|
|
SearchOnSealed(const Schema& schema,
|
|
const void* vec_data,
|
|
const SearchInfo& search_info,
|
|
const void* query_data,
|
|
int64_t num_queries,
|
|
int64_t row_count,
|
|
const BitsetView& bitset,
|
|
SearchResult& result) {
|
|
auto field_id = search_info.field_id_;
|
|
auto& field = schema[field_id];
|
|
|
|
query::dataset::SearchDataset dataset{search_info.metric_type_,
|
|
num_queries,
|
|
search_info.topk_,
|
|
search_info.round_decimal_,
|
|
field.get_dim(),
|
|
query_data};
|
|
|
|
CheckBruteForceSearchParam(field, search_info);
|
|
auto sub_qr = BruteForceSearch(
|
|
dataset, vec_data, row_count, search_info.search_params_, bitset);
|
|
|
|
result.distances_ = std::move(sub_qr.mutable_distances());
|
|
result.seg_offsets_ = std::move(sub_qr.mutable_seg_offsets());
|
|
result.unity_topK_ = dataset.topk;
|
|
result.total_nq_ = dataset.num_queries;
|
|
}
|
|
|
|
} // namespace milvus::query
|