cai.zhang a362bb1457
Support array datatype (#26369)
Signed-off-by: cai.zhang <cai.zhang@zilliz.com>
2023-09-19 14:23:23 +08:00

418 lines
16 KiB
C++

// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include "Reduce.h"
#include <log/Log.h>
#include <algorithm>
#include <cstdint>
#include <vector>
#include "SegmentInterface.h"
#include "Utils.h"
#include "common/EasyAssert.h"
#include "pkVisitor.h"
namespace milvus::segcore {
void
ReduceHelper::Initialize() {
AssertInfo(search_results_.size() > 0, "empty search result");
AssertInfo(slice_nqs_.size() > 0, "empty slice_nqs");
AssertInfo(slice_nqs_.size() == slice_topKs_.size(),
"unaligned slice_nqs and slice_topKs");
total_nq_ = search_results_[0]->total_nq_;
num_segments_ = search_results_.size();
num_slices_ = slice_nqs_.size();
// prefix sum, get slices offsets
AssertInfo(num_slices_ > 0, "empty slice_nqs is not allowed");
slice_nqs_prefix_sum_.resize(num_slices_ + 1);
std::partial_sum(slice_nqs_.begin(),
slice_nqs_.end(),
slice_nqs_prefix_sum_.begin() + 1);
AssertInfo(slice_nqs_prefix_sum_[num_slices_] == total_nq_,
"illegal req sizes, slice_nqs_prefix_sum_[last] = " +
std::to_string(slice_nqs_prefix_sum_[num_slices_]) +
", total_nq = " + std::to_string(total_nq_));
// init final_search_records and final_read_topKs
final_search_records_.resize(num_segments_);
for (auto& search_record : final_search_records_) {
search_record.resize(total_nq_);
}
}
void
ReduceHelper::Reduce() {
FillPrimaryKey();
ReduceResultData();
RefreshSearchResult();
FillEntryData();
}
void
ReduceHelper::Marshal() {
// get search result data blobs of slices
search_result_data_blobs_ =
std::make_unique<milvus::segcore::SearchResultDataBlobs>();
search_result_data_blobs_->blobs.resize(num_slices_);
for (int i = 0; i < num_slices_; i++) {
auto proto = GetSearchResultDataSlice(i);
search_result_data_blobs_->blobs[i] = proto;
}
}
void
ReduceHelper::FilterInvalidSearchResult(SearchResult* search_result) {
auto nq = search_result->total_nq_;
auto topK = search_result->unity_topK_;
AssertInfo(search_result->seg_offsets_.size() == nq * topK,
"wrong seg offsets size, size = " +
std::to_string(search_result->seg_offsets_.size()) +
", expected size = " + std::to_string(nq * topK));
AssertInfo(search_result->distances_.size() == nq * topK,
"wrong distances size, size = " +
std::to_string(search_result->distances_.size()) +
", expected size = " + std::to_string(nq * topK));
std::vector<int64_t> real_topks(nq, 0);
uint32_t valid_index = 0;
auto segment = static_cast<SegmentInterface*>(search_result->segment_);
auto& offsets = search_result->seg_offsets_;
auto& distances = search_result->distances_;
for (auto i = 0; i < nq; ++i) {
for (auto j = 0; j < topK; ++j) {
auto index = i * topK + j;
if (offsets[index] != INVALID_SEG_OFFSET) {
AssertInfo(0 <= offsets[index] &&
offsets[index] < segment->get_row_count(),
fmt::format("invalid offset {}, segment {} with "
"rows num {}, data or index corruption",
offsets[index],
segment->get_segment_id(),
segment->get_row_count()));
real_topks[i]++;
offsets[valid_index] = offsets[index];
distances[valid_index] = distances[index];
valid_index++;
}
}
}
offsets.resize(valid_index);
distances.resize(valid_index);
search_result->topk_per_nq_prefix_sum_.resize(nq + 1);
std::partial_sum(real_topks.begin(),
real_topks.end(),
search_result->topk_per_nq_prefix_sum_.begin() + 1);
}
void
ReduceHelper::FillPrimaryKey() {
std::vector<SearchResult*> valid_search_results;
// get primary keys for duplicates removal
uint32_t valid_index = 0;
for (auto& search_result : search_results_) {
FilterInvalidSearchResult(search_result);
if (search_result->get_total_result_count() > 0) {
auto segment =
static_cast<SegmentInterface*>(search_result->segment_);
segment->FillPrimaryKeys(plan_, *search_result);
search_results_[valid_index++] = search_result;
}
}
search_results_.resize(valid_index);
num_segments_ = search_results_.size();
}
void
ReduceHelper::RefreshSearchResult() {
for (int i = 0; i < num_segments_; i++) {
std::vector<int64_t> real_topks(total_nq_, 0);
auto search_result = search_results_[i];
if (search_result->result_offsets_.size() != 0) {
uint32_t size = 0;
for (int j = 0; j < total_nq_; j++) {
size += final_search_records_[i][j].size();
}
std::vector<milvus::PkType> primary_keys(size);
std::vector<float> distances(size);
std::vector<int64_t> seg_offsets(size);
uint32_t index = 0;
for (int j = 0; j < total_nq_; j++) {
for (auto offset : final_search_records_[i][j]) {
primary_keys[index] = search_result->primary_keys_[offset];
distances[index] = search_result->distances_[offset];
seg_offsets[index] = search_result->seg_offsets_[offset];
index++;
real_topks[j]++;
}
}
search_result->primary_keys_.swap(primary_keys);
search_result->distances_.swap(distances);
search_result->seg_offsets_.swap(seg_offsets);
}
std::partial_sum(real_topks.begin(),
real_topks.end(),
search_result->topk_per_nq_prefix_sum_.begin() + 1);
}
}
void
ReduceHelper::FillEntryData() {
for (auto search_result : search_results_) {
auto segment = static_cast<milvus::segcore::SegmentInterface*>(
search_result->segment_);
segment->FillTargetEntry(plan_, *search_result);
}
}
int64_t
ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi,
int64_t topk,
int64_t& offset) {
while (!heap_.empty()) {
heap_.pop();
}
pk_set_.clear();
pairs_.clear();
pairs_.reserve(num_segments_);
for (int i = 0; i < num_segments_; i++) {
auto search_result = search_results_[i];
auto offset_beg = search_result->topk_per_nq_prefix_sum_[qi];
auto offset_end = search_result->topk_per_nq_prefix_sum_[qi + 1];
if (offset_beg == offset_end) {
continue;
}
auto primary_key = search_result->primary_keys_[offset_beg];
auto distance = search_result->distances_[offset_beg];
pairs_.emplace_back(
primary_key, distance, search_result, i, offset_beg, offset_end);
heap_.push(&pairs_.back());
}
// nq has no results for all segments
if (heap_.size() == 0) {
return 0;
}
int64_t dup_cnt = 0;
auto start = offset;
while (offset - start < topk && !heap_.empty()) {
auto pilot = heap_.top();
heap_.pop();
auto index = pilot->segment_index_;
auto pk = pilot->primary_key_;
// no valid search result for this nq, break to next
if (pk == INVALID_PK) {
break;
}
// remove duplicates
if (pk_set_.count(pk) == 0) {
pilot->search_result_->result_offsets_.push_back(offset++);
final_search_records_[index][qi].push_back(pilot->offset_);
pk_set_.insert(pk);
} else {
// skip entity with same primary key
dup_cnt++;
}
pilot->advance();
if (pilot->primary_key_ != INVALID_PK) {
heap_.push(pilot);
}
}
return dup_cnt;
}
void
ReduceHelper::ReduceResultData() {
for (int i = 0; i < num_segments_; i++) {
auto search_result = search_results_[i];
auto result_count = search_result->get_total_result_count();
AssertInfo(search_result != nullptr,
"search result must not equal to nullptr");
AssertInfo(search_result->distances_.size() == result_count,
"incorrect search result distance size");
AssertInfo(search_result->seg_offsets_.size() == result_count,
"incorrect search result seg offset size");
AssertInfo(search_result->primary_keys_.size() == result_count,
"incorrect search result primary key size");
}
int64_t skip_dup_cnt = 0;
for (int64_t slice_index = 0; slice_index < num_slices_; slice_index++) {
auto nq_begin = slice_nqs_prefix_sum_[slice_index];
auto nq_end = slice_nqs_prefix_sum_[slice_index + 1];
// reduce search results
int64_t offset = 0;
for (int64_t qi = nq_begin; qi < nq_end; qi++) {
skip_dup_cnt += ReduceSearchResultForOneNQ(
qi, slice_topKs_[slice_index], offset);
}
}
if (skip_dup_cnt > 0) {
LOG_SEGCORE_DEBUG_ << "skip duplicated search result, count = "
<< skip_dup_cnt;
}
}
std::vector<char>
ReduceHelper::GetSearchResultDataSlice(int slice_index) {
auto nq_begin = slice_nqs_prefix_sum_[slice_index];
auto nq_end = slice_nqs_prefix_sum_[slice_index + 1];
int64_t result_count = 0;
for (auto search_result : search_results_) {
AssertInfo(search_result->topk_per_nq_prefix_sum_.size() ==
search_result->total_nq_ + 1,
"incorrect topk_per_nq_prefix_sum_ size in search result");
result_count += search_result->topk_per_nq_prefix_sum_[nq_end] -
search_result->topk_per_nq_prefix_sum_[nq_begin];
}
auto search_result_data =
std::make_unique<milvus::proto::schema::SearchResultData>();
// set unify_topK and total_nq
search_result_data->set_top_k(slice_topKs_[slice_index]);
search_result_data->set_num_queries(nq_end - nq_begin);
search_result_data->mutable_topks()->Resize(nq_end - nq_begin, 0);
// `result_pairs` contains the SearchResult and result_offset info, used for filling output fields
std::vector<std::pair<SearchResult*, int64_t>> result_pairs(result_count);
// reserve space for pks
auto primary_field_id =
plan_->schema_.get_primary_field_id().value_or(milvus::FieldId(-1));
AssertInfo(primary_field_id.get() != INVALID_FIELD_ID, "Primary key is -1");
auto pk_type = plan_->schema_[primary_field_id].get_data_type();
switch (pk_type) {
case milvus::DataType::INT64: {
auto ids = std::make_unique<milvus::proto::schema::LongArray>();
ids->mutable_data()->Resize(result_count, 0);
search_result_data->mutable_ids()->set_allocated_int_id(
ids.release());
break;
}
case milvus::DataType::VARCHAR: {
auto ids = std::make_unique<milvus::proto::schema::StringArray>();
std::vector<std::string> string_pks(result_count);
// TODO: prevent mem copy
*ids->mutable_data() = {string_pks.begin(), string_pks.end()};
search_result_data->mutable_ids()->set_allocated_str_id(
ids.release());
break;
}
default: {
PanicCodeInfo(DataTypeInvalid,
fmt::format("unsupported primary key type {}",
fmt::underlying(pk_type)));
}
}
// reserve space for distances
search_result_data->mutable_scores()->Resize(result_count, 0);
// fill pks and distances
for (auto qi = nq_begin; qi < nq_end; qi++) {
int64_t topk_count = 0;
for (auto search_result : search_results_) {
AssertInfo(search_result != nullptr,
"null search result when reorganize");
if (search_result->result_offsets_.size() == 0) {
continue;
}
auto topk_start = search_result->topk_per_nq_prefix_sum_[qi];
auto topk_end = search_result->topk_per_nq_prefix_sum_[qi + 1];
topk_count += topk_end - topk_start;
for (auto ki = topk_start; ki < topk_end; ki++) {
auto loc = search_result->result_offsets_[ki];
AssertInfo(loc < result_count && loc >= 0,
"invalid loc when GetSearchResultDataSlice, loc = " +
std::to_string(loc) + ", result_count = " +
std::to_string(result_count));
// set result pks
switch (pk_type) {
case milvus::DataType::INT64: {
search_result_data->mutable_ids()
->mutable_int_id()
->mutable_data()
->Set(loc,
std::visit(Int64PKVisitor{},
search_result->primary_keys_[ki]));
break;
}
case milvus::DataType::VARCHAR: {
*search_result_data->mutable_ids()
->mutable_str_id()
->mutable_data()
->Mutable(loc) = std::visit(
StrPKVisitor{}, search_result->primary_keys_[ki]);
break;
}
default: {
PanicCodeInfo(
DataTypeInvalid,
fmt::format("unsupported primary key type {}",
fmt::underlying(pk_type)));
}
}
// set result distances
search_result_data->mutable_scores()->Set(
loc, search_result->distances_[ki]);
// set result offset to fill output fields data
result_pairs[loc] = std::make_pair(search_result, ki);
}
}
// update result topKs
search_result_data->mutable_topks()->Set(qi - nq_begin, topk_count);
}
AssertInfo(search_result_data->scores_size() == result_count,
"wrong scores size, size = " +
std::to_string(search_result_data->scores_size()) +
", expected size = " + std::to_string(result_count));
// set output fields
for (auto field_id : plan_->target_entries_) {
auto& field_meta = plan_->schema_[field_id];
auto field_data =
milvus::segcore::MergeDataArray(result_pairs, field_meta);
if (field_meta.get_data_type() == DataType::ARRAY) {
field_data->mutable_scalars()
->mutable_array_data()
->set_element_type(
proto::schema::DataType(field_meta.get_element_type()));
}
search_result_data->mutable_fields_data()->AddAllocated(
field_data.release());
}
// SearchResultData to blob
auto size = search_result_data->ByteSizeLong();
auto buffer = std::vector<char>(size);
search_result_data->SerializePartialToArray(buffer.data(), size);
return buffer;
}
} // namespace milvus::segcore