milvus/cpp/unittest/db/test_search.cpp
starlord defd24b2c6 format code
Former-commit-id: 648cc3be7f2dc4cca5c1394ff83c2d383c7b70b7
2019-10-10 16:03:29 +08:00

209 lines
8.1 KiB
C++

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <gtest/gtest.h>
#include <cmath>
#include <vector>
#include "scheduler/task/SearchTask.h"
#include "utils/TimeRecorder.h"
namespace {
namespace ms = milvus::scheduler;
void
BuildResult(uint64_t nq,
uint64_t topk,
bool ascending,
std::vector<int64_t>& output_ids,
std::vector<float>& output_distence) {
output_ids.clear();
output_ids.resize(nq * topk);
output_distence.clear();
output_distence.resize(nq * topk);
for (uint64_t i = 0; i < nq; i++) {
for (uint64_t j = 0; j < topk; j++) {
output_ids[i * topk + j] = (int64_t)(drand48() * 100000);
output_distence[i * topk + j] = ascending ? (j + drand48()) : ((topk - j) + drand48());
}
}
}
void
CheckTopkResult(const std::vector<int64_t>& input_ids_1,
const std::vector<float>& input_distance_1,
const std::vector<int64_t>& input_ids_2,
const std::vector<float>& input_distance_2,
uint64_t nq,
uint64_t topk,
bool ascending,
const milvus::scheduler::ResultSet& result) {
ASSERT_EQ(result.size(), nq);
ASSERT_EQ(input_ids_1.size(), input_distance_1.size());
ASSERT_EQ(input_ids_2.size(), input_distance_2.size());
uint64_t input_k1 = input_ids_1.size() / nq;
uint64_t input_k2 = input_ids_2.size() / nq;
for (int64_t i = 0; i < nq; i++) {
std::vector<float>
src_vec(input_distance_1.begin() + i * input_k1, input_distance_1.begin() + (i + 1) * input_k1);
src_vec.insert(src_vec.end(),
input_distance_2.begin() + i * input_k2,
input_distance_2.begin() + (i + 1) * input_k2);
if (ascending) {
std::sort(src_vec.begin(), src_vec.end());
} else {
std::sort(src_vec.begin(), src_vec.end(), std::greater<float>());
}
uint64_t n = std::min(topk, input_k1 + input_k2);
for (uint64_t j = 0; j < n; j++) {
if (src_vec[j] != result[i][j].second) {
std::cout << src_vec[j] << " " << result[i][j].second << std::endl;
}
ASSERT_TRUE(src_vec[j] == result[i][j].second);
}
}
}
} // namespace
TEST(DBSearchTest, TOPK_TEST) {
uint64_t NQ = 15;
uint64_t TOP_K = 64;
bool ascending;
std::vector<int64_t> ids1, ids2;
std::vector<float> dist1, dist2;
milvus::scheduler::ResultSet result;
milvus::Status status;
/* test1, id1/dist1 valid, id2/dist2 empty */
ascending = true;
BuildResult(NQ, TOP_K, ascending, ids1, dist1);
status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result);
ASSERT_TRUE(status.ok());
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
/* test2, id1/dist1 valid, id2/dist2 valid */
BuildResult(NQ, TOP_K, ascending, ids2, dist2);
status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
ASSERT_TRUE(status.ok());
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
/* test3, id1/dist1 small topk */
ids1.clear();
dist1.clear();
result.clear();
BuildResult(NQ, TOP_K / 2, ascending, ids1, dist1);
status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result);
ASSERT_TRUE(status.ok());
status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
ASSERT_TRUE(status.ok());
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
/* test4, id1/dist1 small topk, id2/dist2 small topk */
ids2.clear();
dist2.clear();
result.clear();
BuildResult(NQ, TOP_K / 3, ascending, ids2, dist2);
status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result);
ASSERT_TRUE(status.ok());
status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K / 3, NQ, TOP_K, ascending, result);
ASSERT_TRUE(status.ok());
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
/////////////////////////////////////////////////////////////////////////////////////////
ascending = false;
ids1.clear();
dist1.clear();
ids2.clear();
dist2.clear();
result.clear();
/* test1, id1/dist1 valid, id2/dist2 empty */
BuildResult(NQ, TOP_K, ascending, ids1, dist1);
status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result);
ASSERT_TRUE(status.ok());
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
/* test2, id1/dist1 valid, id2/dist2 valid */
BuildResult(NQ, TOP_K, ascending, ids2, dist2);
status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
ASSERT_TRUE(status.ok());
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
/* test3, id1/dist1 small topk */
ids1.clear();
dist1.clear();
result.clear();
BuildResult(NQ, TOP_K / 2, ascending, ids1, dist1);
status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result);
ASSERT_TRUE(status.ok());
status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
ASSERT_TRUE(status.ok());
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
/* test4, id1/dist1 small topk, id2/dist2 small topk */
ids2.clear();
dist2.clear();
result.clear();
BuildResult(NQ, TOP_K / 3, ascending, ids2, dist2);
status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result);
ASSERT_TRUE(status.ok());
status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K / 3, NQ, TOP_K, ascending, result);
ASSERT_TRUE(status.ok());
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
}
TEST(DBSearchTest, REDUCE_PERF_TEST) {
int32_t nq = 100;
int32_t top_k = 1000;
int32_t index_file_num = 478; /* sift1B dataset, index files num */
bool ascending = true;
std::vector<int64_t> input_ids;
std::vector<float> input_distance;
milvus::scheduler::ResultSet final_result;
milvus::Status status;
double span, reduce_cost = 0.0;
milvus::TimeRecorder rc("");
for (int32_t i = 0; i < index_file_num; i++) {
BuildResult(nq, top_k, ascending, input_ids, input_distance);
rc.RecordSection("do search for context: " + std::to_string(i));
// pick up topk result
status = milvus::scheduler::XSearchTask::TopkResult(input_ids,
input_distance,
top_k,
nq,
top_k,
ascending,
final_result);
ASSERT_TRUE(status.ok());
ASSERT_EQ(final_result.size(), nq);
span = rc.RecordSection("reduce topk for context: " + std::to_string(i));
reduce_cost += span;
}
std::cout << "total reduce time: " << reduce_cost / 1000 << " ms" << std::endl;
}