mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-29 06:55:27 +08:00
89 lines
3.2 KiB
C++
89 lines
3.2 KiB
C++
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
|
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
|
// or implied. See the License for the specific language governing permissions and limitations under the License
|
|
|
|
#include <cmath>
|
|
|
|
#include "exceptions/EasyAssert.h"
|
|
#include "query/SubSearchResult.h"
|
|
|
|
namespace milvus::query {
|
|
|
|
template <bool is_desc>
|
|
void
|
|
SubSearchResult::merge_impl(const SubSearchResult& right) {
|
|
AssertInfo(num_queries_ == right.num_queries_, "[SubSearchResult]Nq check failed");
|
|
AssertInfo(topk_ == right.topk_, "[SubSearchResult]Topk check failed");
|
|
AssertInfo(metric_type_ == right.metric_type_, "[SubSearchResult]Metric type check failed");
|
|
AssertInfo(is_desc == is_descending(metric_type_), "[SubSearchResult]Metric type isn't desc");
|
|
|
|
for (int64_t qn = 0; qn < num_queries_; ++qn) {
|
|
auto offset = qn * topk_;
|
|
|
|
int64_t* __restrict__ left_ids = this->get_ids() + offset;
|
|
float* __restrict__ left_distances = this->get_distances() + offset;
|
|
|
|
auto right_ids = right.get_ids() + offset;
|
|
auto right_distances = right.get_distances() + offset;
|
|
|
|
std::vector<float> buf_distances(topk_);
|
|
std::vector<int64_t> buf_ids(topk_);
|
|
|
|
auto lit = 0; // left iter
|
|
auto rit = 0; // right iter
|
|
|
|
for (auto buf_iter = 0; buf_iter < topk_; ++buf_iter) {
|
|
auto left_v = left_distances[lit];
|
|
auto right_v = right_distances[rit];
|
|
// optimize out at compiling
|
|
if (is_desc ? (left_v >= right_v) : (left_v <= right_v)) {
|
|
buf_distances[buf_iter] = left_distances[lit];
|
|
buf_ids[buf_iter] = left_ids[lit];
|
|
++lit;
|
|
} else {
|
|
buf_distances[buf_iter] = right_distances[rit];
|
|
buf_ids[buf_iter] = right_ids[rit];
|
|
++rit;
|
|
}
|
|
}
|
|
std::copy_n(buf_distances.data(), topk_, left_distances);
|
|
std::copy_n(buf_ids.data(), topk_, left_ids);
|
|
}
|
|
}
|
|
|
|
void
|
|
SubSearchResult::merge(const SubSearchResult& sub_result) {
|
|
AssertInfo(metric_type_ == sub_result.metric_type_, "[SubSearchResult]Metric type check failed when merge");
|
|
if (is_descending(metric_type_)) {
|
|
this->merge_impl<true>(sub_result);
|
|
} else {
|
|
this->merge_impl<false>(sub_result);
|
|
}
|
|
}
|
|
|
|
SubSearchResult
|
|
SubSearchResult::merge(const SubSearchResult& left, const SubSearchResult& right) {
|
|
auto left_copy = left;
|
|
left_copy.merge(right);
|
|
return left_copy;
|
|
}
|
|
|
|
void
|
|
SubSearchResult::round_values() {
|
|
if (round_decimal_ == -1)
|
|
return;
|
|
const float multiplier = pow(10.0, round_decimal_);
|
|
for (auto it = this->distances_.begin(); it != this->distances_.end(); it++) {
|
|
*it = round(*it * multiplier) / multiplier;
|
|
}
|
|
}
|
|
|
|
} // namespace milvus::query
|