milvus/internal/core/src/query/SubQueryResult.cpp
neza2017 4015d7245d Merge operation
Signed-off-by: neza2017 <yefu.chen@zilliz.com>
2021-01-06 14:45:50 +08:00

78 lines
2.6 KiB
C++

// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include "utils/EasyAssert.h"
#include "query/SubQueryResult.h"
#include "segcore/Reduce.h"
namespace milvus::query {
template <bool is_desc>
void
SubQueryResult::merge_impl(const SubQueryResult& right) {
Assert(num_queries_ == right.num_queries_);
Assert(topk_ == right.topk_);
Assert(metric_type_ == right.metric_type_);
Assert(is_desc == is_descending(metric_type_));
for (int64_t qn = 0; qn < num_queries_; ++qn) {
auto offset = qn * topk_;
int64_t* __restrict__ left_labels = this->get_labels() + offset;
float* __restrict__ left_values = this->get_values() + offset;
auto right_labels = right.get_labels() + offset;
auto right_values = right.get_values() + offset;
std::vector<float> buf_values(topk_);
std::vector<int64_t> buf_labels(topk_);
auto lit = 0; // left iter
auto rit = 0; // right iter
for (auto buf_iter = 0; buf_iter < topk_; ++buf_iter) {
auto left_v = left_values[lit];
auto right_v = right_values[rit];
// optimize out at compiling
if (is_desc ? (left_v >= right_v) : (left_v <= right_v)) {
buf_values[buf_iter] = left_values[lit];
buf_labels[buf_iter] = left_labels[lit];
++lit;
} else {
buf_values[buf_iter] = right_values[rit];
buf_labels[buf_iter] = right_labels[rit];
++rit;
}
}
std::copy_n(buf_values.data(), topk_, left_values);
std::copy_n(buf_labels.data(), topk_, left_labels);
}
}
void
SubQueryResult::merge(const SubQueryResult& sub_result) {
Assert(metric_type_ == sub_result.metric_type_);
if (is_descending(metric_type_)) {
this->merge_impl<true>(sub_result);
} else {
this->merge_impl<false>(sub_result);
}
}
SubQueryResult
SubQueryResult::merge(const SubQueryResult& left, const SubQueryResult& right) {
auto left_copy = left;
left_copy.merge(right);
return left_copy;
}
} // namespace milvus::query