milvus/core/src/server/delivery/request/SearchCombineRequest.cpp
groot f66b43b81b
#1827 Combine request target vectors exceed max nq (#1847)
Signed-off-by: groot <yihua.mo@zilliz.com>
2020-04-02 14:24:28 +08:00

417 lines
14 KiB
C++

// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "server/delivery/request/SearchCombineRequest.h"
#include "db/Utils.h"
#include "server/DBWrapper.h"
#include "server/context/Context.h"
#include "utils/CommonUtil.h"
#include "utils/Log.h"
#include "utils/TimeRecorder.h"
#include "utils/ValidationUtil.h"
#include <memory>
#include <set>
namespace milvus {
namespace server {
namespace {
constexpr int64_t MAX_TOPK_GAP = 200;
constexpr uint64_t MAX_NQ = 200;
void
GetUniqueList(const std::vector<std::string>& list, std::set<std::string>& unique_list) {
for (const std::string& item : list) {
unique_list.insert(item);
}
}
bool
IsSameList(const std::set<std::string>& left, const std::set<std::string>& right) {
if (left.size() != right.size()) {
return false;
}
std::set<std::string>::const_iterator iter_left;
std::set<std::string>::const_iterator iter_right;
for (iter_left = left.begin(), iter_right = right.begin(); iter_left != left.end(); iter_left++, iter_right++) {
if ((*iter_left) != (*iter_right)) {
return false;
}
}
return true;
}
void
FreeRequest(SearchRequestPtr& request, const Status& status) {
request->set_status(status);
request->Done();
}
class TracingContextList {
public:
TracingContextList() = default;
~TracingContextList() {
Finish();
}
void
CreateChild(std::vector<SearchRequestPtr>& requests, const std::string& operation_name) {
Finish();
for (auto& request : requests) {
auto parent_context = request->Context();
if (parent_context) {
auto child_context = request->Context()->Child(operation_name);
context_list_.emplace_back(child_context);
}
}
}
void
Finish() {
for (auto& context : context_list_) {
context->GetTraceContext()->GetSpan()->Finish();
}
context_list_.clear();
}
private:
std::vector<milvus::server::ContextPtr> context_list_;
};
} // namespace
SearchCombineRequest::SearchCombineRequest() : BaseRequest(nullptr, BaseRequest::kSearchCombine) {
}
Status
SearchCombineRequest::Combine(const SearchRequestPtr& request) {
if (request == nullptr) {
return Status(SERVER_NULL_POINTER, "");
}
// the request must be tested by CanCombine before invoke this function
// reset some parameters in necessary
if (request_list_.empty()) {
// validate first request input
auto status = ValidationUtil::ValidateCollectionName(request->CollectionName());
if (!status.ok()) {
return status;
}
status = ValidationUtil::ValidateSearchTopk(request->TopK());
if (!status.ok()) {
return status;
}
// assign base parameters
collection_name_ = request->CollectionName();
min_topk_ = request->TopK() - MAX_TOPK_GAP / 2;
if (min_topk_ < 0) {
min_topk_ = 0;
}
max_topk_ = request->TopK() + MAX_TOPK_GAP / 2;
if (max_topk_ > QUERY_MAX_TOPK) {
max_topk_ = QUERY_MAX_TOPK;
}
extra_params_ = request->ExtraParams();
GetUniqueList(request->PartitionList(), partition_list_);
GetUniqueList(request->FileIDList(), file_id_list_);
}
request_list_.push_back(request);
return Status::OK();
}
bool
SearchCombineRequest::CanCombine(const SearchRequestPtr& request) {
if (collection_name_ != request->CollectionName()) {
return false;
}
if (extra_params_ != request->ExtraParams()) {
return false;
}
// topk must within certain range
if (request->TopK() < min_topk_ || request->TopK() > max_topk_) {
return false;
}
// sum of nq must less-equal than MAX_NQ
if (vectors_data_.vector_count_ > MAX_NQ || request->VectorsData().vector_count_ > MAX_NQ) {
return false;
}
uint64_t total_nq = vectors_data_.vector_count_ + request->VectorsData().vector_count_;
if (total_nq > MAX_NQ) {
return false;
}
// partition list must be equal for each one
std::set<std::string> partition_list;
GetUniqueList(request->PartitionList(), partition_list);
if (!IsSameList(partition_list_, partition_list)) {
return false;
}
// file id list must be equal for each one
std::set<std::string> file_id_list;
GetUniqueList(request->FileIDList(), file_id_list);
if (!IsSameList(file_id_list_, file_id_list)) {
return false;
}
return true;
}
bool
SearchCombineRequest::CanCombine(const SearchRequestPtr& left, const SearchRequestPtr& right) {
if (left->CollectionName() != right->CollectionName()) {
return false;
}
if (left->ExtraParams() != right->ExtraParams()) {
return false;
}
// topk must within certain range
if (abs(left->TopK() - right->TopK() > MAX_TOPK_GAP)) {
return false;
}
// sum of nq must less-equal than MAX_NQ
if (left->VectorsData().vector_count_ > MAX_NQ || right->VectorsData().vector_count_ > MAX_NQ) {
return false;
}
uint64_t total_nq = left->VectorsData().vector_count_ + right->VectorsData().vector_count_;
if (total_nq > MAX_NQ) {
return false;
}
// partition list must be equal for each one
std::set<std::string> left_partition_list, right_partition_list;
GetUniqueList(left->PartitionList(), left_partition_list);
GetUniqueList(right->PartitionList(), right_partition_list);
if (!IsSameList(left_partition_list, right_partition_list)) {
return false;
}
// file id list must be equal for each one
std::set<std::string> left_file_id_list, right_file_id_list;
GetUniqueList(left->FileIDList(), left_file_id_list);
GetUniqueList(right->FileIDList(), right_file_id_list);
if (!IsSameList(left_file_id_list, right_file_id_list)) {
return false;
}
return true;
}
Status
SearchCombineRequest::FreeRequests(const Status& status) {
for (auto request : request_list_) {
FreeRequest(request, status);
}
return Status::OK();
}
Status
SearchCombineRequest::OnExecute() {
try {
size_t combined_request = request_list_.size();
SERVER_LOG_DEBUG << "SearchCombineRequest execute, request count=" << combined_request
<< ", extra_params=" << extra_params_.dump();
std::string hdr = "SearchCombineRequest(collection=" + collection_name_ + ")";
TimeRecorderAuto rc(hdr);
// step 1: check table existence
// only process root table, ignore partition table
engine::meta::CollectionSchema table_schema;
table_schema.collection_id_ = collection_name_;
auto status = DBWrapper::DB()->DescribeTable(table_schema);
if (!status.ok()) {
if (status.code() == DB_NOT_FOUND) {
status = Status(SERVER_TABLE_NOT_EXIST, TableNotExistMsg(collection_name_));
FreeRequests(status);
return status;
} else {
FreeRequests(status);
return status;
}
} else {
if (!table_schema.owner_table_.empty()) {
status = Status(SERVER_INVALID_TABLE_NAME, TableNotExistMsg(collection_name_));
FreeRequests(status);
return status;
}
}
// step 2: check input
size_t run_request = 0;
std::vector<SearchRequestPtr>::iterator iter = request_list_.begin();
for (; iter != request_list_.end();) {
SearchRequestPtr& request = *iter;
status = ValidationUtil::ValidateSearchTopk(request->TopK());
if (!status.ok()) {
// check failed, erase request and let it return error status
FreeRequest(request, status);
iter = request_list_.erase(iter);
continue;
}
status = ValidationUtil::ValidateSearchParams(extra_params_, table_schema, request->TopK());
if (!status.ok()) {
// check failed, erase request and let it return error status
FreeRequest(request, status);
iter = request_list_.erase(iter);
continue;
}
status = ValidationUtil::ValidateVectorData(request->VectorsData(), table_schema);
if (!status.ok()) {
// check failed, erase request and let it return error status
FreeRequest(request, status);
iter = request_list_.erase(iter);
continue;
}
status = ValidationUtil::ValidatePartitionTags(request->PartitionList());
if (!status.ok()) {
// check failed, erase request and let it return error status
FreeRequest(request, status);
iter = request_list_.erase(iter);
continue;
}
// reset topk
search_topk_ = request->TopK() > search_topk_ ? request->TopK() : search_topk_;
// next one
run_request++;
iter++;
}
// all requests are skipped
if (request_list_.empty()) {
SERVER_LOG_DEBUG << "all combined requests were skipped";
return Status::OK();
}
SERVER_LOG_DEBUG << (combined_request - run_request) << " requests were skipped";
SERVER_LOG_DEBUG << "reset topk to " << search_topk_;
rc.RecordSection("check validation");
// step 3: construct vectors_data
SearchRequestPtr& first_request = *request_list_.begin();
uint64_t total_count = 0;
for (auto& request : request_list_) {
total_count += request->VectorsData().vector_count_;
}
vectors_data_.vector_count_ = total_count;
uint16_t dimension = table_schema.dimension_;
bool is_float = true;
if (!first_request->VectorsData().float_data_.empty()) {
vectors_data_.float_data_.resize(total_count * dimension);
} else {
vectors_data_.binary_data_.resize(total_count * dimension / 8);
is_float = false;
}
int64_t offset = 0;
for (auto& request : request_list_) {
const engine::VectorsData& src = request->VectorsData();
if (is_float) {
size_t element_cnt = src.vector_count_ * dimension;
memcpy(vectors_data_.float_data_.data() + offset, src.float_data_.data(), element_cnt * sizeof(float));
offset += element_cnt;
} else {
size_t element_cnt = src.vector_count_ * dimension / 8;
memcpy(vectors_data_.binary_data_.data() + offset, src.binary_data_.data(), element_cnt);
offset += element_cnt;
}
}
SERVER_LOG_DEBUG << total_count << " query vectors combined";
rc.RecordSection("combined query vectors");
// step 4: search vectors
const std::vector<std::string>& partition_list = first_request->PartitionList();
const std::vector<std::string>& file_id_list = first_request->FileIDList();
engine::ResultIds result_ids;
engine::ResultDistances result_distances;
{
TracingContextList context_list;
context_list.CreateChild(request_list_, "Combine Query");
if (file_id_list_.empty()) {
status = DBWrapper::DB()->Query(nullptr, collection_name_, partition_list, (size_t)search_topk_,
extra_params_, vectors_data_, result_ids, result_distances);
} else {
status = DBWrapper::DB()->QueryByFileID(nullptr, file_id_list, (size_t)search_topk_, extra_params_,
vectors_data_, result_ids, result_distances);
}
}
rc.RecordSection("search vectors from engine");
if (!status.ok()) {
// let all request return
FreeRequests(status);
return status;
}
if (result_ids.empty()) {
status = Status(DB_ERROR, "no result returned for combined request");
// let all request return
FreeRequests(status);
return status;
}
// step 5: construct result array
offset = 0;
for (auto& request : request_list_) {
uint64_t count = request->VectorsData().vector_count_;
int64_t topk = request->TopK();
uint64_t element_cnt = count * topk;
TopKQueryResult& result = request->QueryResult();
result.row_num_ = count;
result.id_list_.resize(element_cnt);
result.distance_list_.resize(element_cnt);
memcpy(result.id_list_.data(), result_ids.data() + offset, element_cnt * sizeof(int64_t));
memcpy(result.distance_list_.data(), result_distances.data() + offset, element_cnt * sizeof(float));
offset += (count * search_topk_);
// let request return
FreeRequest(request, Status::OK());
}
rc.RecordSection("construct result and send");
} catch (std::exception& ex) {
Status status = Status(SERVER_UNEXPECTED_ERROR, ex.what());
FreeRequests(status);
return status;
}
return Status::OK();
}
} // namespace server
} // namespace milvus