mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-30 07:25:37 +08:00
Merge branch 'caiyd_reduce_parallel_0.5.0' into 'branch-0.5.0'
MS-606 optimize reduce, update unittest See merge request megasearch/milvus!694 Former-commit-id: 3a99483a963ebda56d26c6137d70be3511122361
This commit is contained in:
commit
0d1d85f4bd
@ -34,8 +34,6 @@ namespace scheduler {
|
||||
static constexpr size_t PARALLEL_REDUCE_THRESHOLD = 10000;
|
||||
static constexpr size_t PARALLEL_REDUCE_BATCH = 1000;
|
||||
|
||||
std::mutex XSearchTask::merge_mutex_;
|
||||
|
||||
// TODO(wxyu): remove unused code
|
||||
// bool
|
||||
// NeedParallelReduce(uint64_t nq, uint64_t topk) {
|
||||
@ -162,8 +160,8 @@ XSearchTask::Load(LoadType type, uint8_t device_id) {
|
||||
|
||||
size_t file_size = index_engine_->PhysicalSize();
|
||||
|
||||
std::string info = "Load file id:" + std::to_string(file_->id_) +
|
||||
" file type:" + std::to_string(file_->file_type_) + " size:" + std::to_string(file_size) +
|
||||
std::string info = "Load file id:" + std::to_string(file_->id_) + " file type:" +
|
||||
std::to_string(file_->file_type_) + " size:" + std::to_string(file_size) +
|
||||
" bytes from location: " + file_->location_ + " totally cost";
|
||||
double span = rc.ElapseFromBegin(info);
|
||||
// for (auto &context : search_contexts_) {
|
||||
@ -221,7 +219,8 @@ XSearchTask::Execute() {
|
||||
|
||||
// step 3: pick up topk result
|
||||
auto spec_k = index_engine_->Count() < topk ? index_engine_->Count() : topk;
|
||||
XSearchTask::TopkResult(output_ids, output_distance, spec_k, nq, topk, metric_l2, search_job->GetResult());
|
||||
XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, metric_l2,
|
||||
search_job->GetResult());
|
||||
|
||||
span = rc.RecordSection(hdr + ", reduce topk");
|
||||
// search_job->AccumReduceCost(span);
|
||||
@ -230,7 +229,7 @@ XSearchTask::Execute() {
|
||||
// search_job->IndexSearchDone(index_id_);//mark as done avoid dead lock, even search failed
|
||||
}
|
||||
|
||||
// step 5: notify to send result to client
|
||||
// step 4: notify to send result to client
|
||||
search_job->SearchDone(index_id_);
|
||||
}
|
||||
|
||||
@ -240,36 +239,37 @@ XSearchTask::Execute() {
|
||||
index_engine_ = nullptr;
|
||||
}
|
||||
|
||||
Status
|
||||
XSearchTask::TopkResult(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance,
|
||||
uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result) {
|
||||
scheduler::ResultSet result_buf;
|
||||
|
||||
void
|
||||
XSearchTask::MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance,
|
||||
uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending,
|
||||
scheduler::ResultSet& result) {
|
||||
if (result.empty()) {
|
||||
result_buf.resize(nq, scheduler::Id2DistVec(input_k, scheduler::IdDistPair(-1, 0.0)));
|
||||
for (auto i = 0; i < nq; ++i) {
|
||||
auto& result_buf_i = result_buf[i];
|
||||
result.resize(nq);
|
||||
}
|
||||
|
||||
for (uint64_t i = 0; i < nq; i++) {
|
||||
scheduler::Id2DistVec result_buf;
|
||||
auto& result_i = result[i];
|
||||
|
||||
if (result[i].empty()) {
|
||||
result_buf.resize(input_k, scheduler::IdDistPair(-1, 0.0));
|
||||
uint64_t input_k_multi_i = input_k * i;
|
||||
for (auto k = 0; k < input_k; ++k) {
|
||||
uint64_t idx = input_k_multi_i + k;
|
||||
auto& result_buf_item = result_buf_i[k];
|
||||
auto& result_buf_item = result_buf[k];
|
||||
result_buf_item.first = input_ids[idx];
|
||||
result_buf_item.second = input_distance[idx];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
size_t tar_size = result[0].size();
|
||||
uint64_t output_k = std::min(topk, input_k + tar_size);
|
||||
result_buf.resize(nq, scheduler::Id2DistVec(output_k, scheduler::IdDistPair(-1, 0.0)));
|
||||
for (auto i = 0; i < nq; ++i) {
|
||||
} else {
|
||||
size_t tar_size = result_i.size();
|
||||
uint64_t output_k = std::min(topk, input_k + tar_size);
|
||||
result_buf.resize(output_k, scheduler::IdDistPair(-1, 0.0));
|
||||
size_t buf_k = 0, src_k = 0, tar_k = 0;
|
||||
uint64_t src_idx;
|
||||
auto& result_i = result[i];
|
||||
auto& result_buf_i = result_buf[i];
|
||||
uint64_t input_k_multi_i = input_k * i;
|
||||
while (buf_k < output_k && src_k < input_k && tar_k < tar_size) {
|
||||
src_idx = input_k_multi_i + src_k;
|
||||
auto& result_buf_item = result_buf_i[buf_k];
|
||||
auto& result_buf_item = result_buf[buf_k];
|
||||
auto& result_item = result_i[tar_k];
|
||||
if ((ascending && input_distance[src_idx] < result_item.second) ||
|
||||
(!ascending && input_distance[src_idx] > result_item.second)) {
|
||||
@ -283,11 +283,11 @@ XSearchTask::TopkResult(const std::vector<int64_t>& input_ids, const std::vector
|
||||
buf_k++;
|
||||
}
|
||||
|
||||
if (buf_k < topk) {
|
||||
if (buf_k < output_k) {
|
||||
if (src_k < input_k) {
|
||||
while (buf_k < output_k && src_k < input_k) {
|
||||
src_idx = input_k_multi_i + src_k;
|
||||
auto& result_buf_item = result_buf_i[buf_k];
|
||||
auto& result_buf_item = result_buf[buf_k];
|
||||
result_buf_item.first = input_ids[src_idx];
|
||||
result_buf_item.second = input_distance[src_idx];
|
||||
src_k++;
|
||||
@ -295,18 +295,79 @@ XSearchTask::TopkResult(const std::vector<int64_t>& input_ids, const std::vector
|
||||
}
|
||||
} else {
|
||||
while (buf_k < output_k && tar_k < tar_size) {
|
||||
result_buf_i[buf_k] = result_i[tar_k];
|
||||
result_buf[buf_k] = result_i[tar_k];
|
||||
tar_k++;
|
||||
buf_k++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result_i.swap(result_buf);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
XSearchTask::MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k,
|
||||
const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance,
|
||||
uint64_t src_input_k, uint64_t nq, uint64_t topk, bool ascending) {
|
||||
if (src_ids.empty() || src_distance.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
result.swap(result_buf);
|
||||
std::vector<int64_t> id_buf(nq * topk, -1);
|
||||
std::vector<float> dist_buf(nq * topk, 0.0);
|
||||
|
||||
return Status::OK();
|
||||
uint64_t output_k = std::min(topk, tar_input_k + src_input_k);
|
||||
uint64_t buf_k, src_k, tar_k;
|
||||
uint64_t src_idx, tar_idx, buf_idx;
|
||||
uint64_t src_input_k_multi_i, tar_input_k_multi_i, buf_k_multi_i;
|
||||
|
||||
for (uint64_t i = 0; i < nq; i++) {
|
||||
src_input_k_multi_i = src_input_k * i;
|
||||
tar_input_k_multi_i = tar_input_k * i;
|
||||
buf_k_multi_i = output_k * i;
|
||||
buf_k = src_k = tar_k = 0;
|
||||
while (buf_k < output_k && src_k < src_input_k && tar_k < tar_input_k) {
|
||||
src_idx = src_input_k_multi_i + src_k;
|
||||
tar_idx = tar_input_k_multi_i + tar_k;
|
||||
buf_idx = buf_k_multi_i + buf_k;
|
||||
if ((ascending && src_distance[src_idx] < tar_distance[tar_idx]) ||
|
||||
(!ascending && src_distance[src_idx] > tar_distance[tar_idx])) {
|
||||
id_buf[buf_idx] = src_ids[src_idx];
|
||||
dist_buf[buf_idx] = src_distance[src_idx];
|
||||
src_k++;
|
||||
} else {
|
||||
id_buf[buf_idx] = tar_ids[tar_idx];
|
||||
dist_buf[buf_idx] = tar_distance[tar_idx];
|
||||
tar_k++;
|
||||
}
|
||||
buf_k++;
|
||||
}
|
||||
|
||||
if (buf_k < output_k) {
|
||||
if (src_k < src_input_k) {
|
||||
while (buf_k < output_k && src_k < src_input_k) {
|
||||
src_idx = src_input_k_multi_i + src_k;
|
||||
id_buf[buf_idx] = src_ids[src_idx];
|
||||
dist_buf[buf_idx] = src_distance[src_idx];
|
||||
src_k++;
|
||||
buf_k++;
|
||||
}
|
||||
} else {
|
||||
while (buf_k < output_k && tar_k < tar_input_k) {
|
||||
id_buf[buf_idx] = tar_ids[tar_idx];
|
||||
dist_buf[buf_idx] = tar_distance[tar_idx];
|
||||
tar_k++;
|
||||
buf_k++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tar_ids.swap(id_buf);
|
||||
tar_distance.swap(dist_buf);
|
||||
tar_input_k = output_k;
|
||||
}
|
||||
|
||||
} // namespace scheduler
|
||||
|
||||
@ -38,9 +38,14 @@ class XSearchTask : public Task {
|
||||
Execute() override;
|
||||
|
||||
public:
|
||||
static Status
|
||||
TopkResult(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance, uint64_t input_k,
|
||||
uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result);
|
||||
static void
|
||||
MergeTopkToResultSet(const std::vector<int64_t>& input_ids, const std::vector<float>& input_distance,
|
||||
uint64_t input_k, uint64_t nq, uint64_t topk, bool ascending, scheduler::ResultSet& result);
|
||||
|
||||
static void
|
||||
MergeTopkArray(std::vector<int64_t>& tar_ids, std::vector<float>& tar_distance, uint64_t& tar_input_k,
|
||||
const std::vector<int64_t>& src_ids, const std::vector<float>& src_distance, uint64_t src_input_k,
|
||||
uint64_t nq, uint64_t topk, bool ascending);
|
||||
|
||||
public:
|
||||
TableFileSchemaPtr file_;
|
||||
@ -49,8 +54,6 @@ class XSearchTask : public Task {
|
||||
int index_type_ = 0;
|
||||
ExecutionEnginePtr index_engine_ = nullptr;
|
||||
bool metric_l2 = true;
|
||||
|
||||
static std::mutex merge_mutex_;
|
||||
};
|
||||
|
||||
} // namespace scheduler
|
||||
|
||||
@ -21,26 +21,51 @@
|
||||
|
||||
#include "scheduler/task/SearchTask.h"
|
||||
#include "utils/TimeRecorder.h"
|
||||
#include "utils/ThreadPool.h"
|
||||
|
||||
namespace {
|
||||
|
||||
namespace ms = milvus::scheduler;
|
||||
|
||||
void
|
||||
BuildResult(uint64_t nq,
|
||||
BuildResult(std::vector<int64_t>& output_ids,
|
||||
std::vector<float>& output_distance,
|
||||
uint64_t topk,
|
||||
bool ascending,
|
||||
std::vector<int64_t>& output_ids,
|
||||
std::vector<float>& output_distence) {
|
||||
uint64_t nq,
|
||||
bool ascending) {
|
||||
output_ids.clear();
|
||||
output_ids.resize(nq * topk);
|
||||
output_distence.clear();
|
||||
output_distence.resize(nq * topk);
|
||||
output_distance.clear();
|
||||
output_distance.resize(nq * topk);
|
||||
|
||||
for (uint64_t i = 0; i < nq; i++) {
|
||||
for (uint64_t j = 0; j < topk; j++) {
|
||||
output_ids[i * topk + j] = (int64_t)(drand48() * 100000);
|
||||
output_distence[i * topk + j] = ascending ? (j + drand48()) : ((topk - j) + drand48());
|
||||
output_distance[i * topk + j] = ascending ? (j + drand48()) : ((topk - j) + drand48());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
CopyResult(std::vector<int64_t>& output_ids,
|
||||
std::vector<float>& output_distance,
|
||||
uint64_t output_topk,
|
||||
std::vector<int64_t>& input_ids,
|
||||
std::vector<float>& input_distance,
|
||||
uint64_t input_topk,
|
||||
uint64_t nq) {
|
||||
ASSERT_TRUE(input_ids.size() >= nq * input_topk);
|
||||
ASSERT_TRUE(input_distance.size() >= nq * input_topk);
|
||||
ASSERT_TRUE(output_topk <= input_topk);
|
||||
output_ids.clear();
|
||||
output_ids.resize(nq * output_topk);
|
||||
output_distance.clear();
|
||||
output_distance.resize(nq * output_topk);
|
||||
|
||||
for (uint64_t i = 0; i < nq; i++) {
|
||||
for (uint64_t j = 0; j < output_topk; j++) {
|
||||
output_ids[i * output_topk + j] = input_ids[i * input_topk + j];
|
||||
output_distance[i * output_topk + j] = input_distance[i * input_topk + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -50,8 +75,8 @@ CheckTopkResult(const std::vector<int64_t>& input_ids_1,
|
||||
const std::vector<float>& input_distance_1,
|
||||
const std::vector<int64_t>& input_ids_2,
|
||||
const std::vector<float>& input_distance_2,
|
||||
uint64_t nq,
|
||||
uint64_t topk,
|
||||
uint64_t nq,
|
||||
bool ascending,
|
||||
const milvus::scheduler::ResultSet& result) {
|
||||
ASSERT_EQ(result.size(), nq);
|
||||
@ -91,43 +116,36 @@ TEST(DBSearchTest, TOPK_TEST) {
|
||||
bool ascending;
|
||||
std::vector<int64_t> ids1, ids2;
|
||||
std::vector<float> dist1, dist2;
|
||||
milvus::scheduler::ResultSet result;
|
||||
milvus::Status status;
|
||||
ms::ResultSet result;
|
||||
|
||||
/* test1, id1/dist1 valid, id2/dist2 empty */
|
||||
ascending = true;
|
||||
BuildResult(NQ, TOP_K, ascending, ids1, dist1);
|
||||
status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
|
||||
BuildResult(ids1, dist1, TOP_K, NQ, ascending);
|
||||
ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result);
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);
|
||||
|
||||
/* test2, id1/dist1 valid, id2/dist2 valid */
|
||||
BuildResult(NQ, TOP_K, ascending, ids2, dist2);
|
||||
status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
|
||||
BuildResult(ids2, dist2, TOP_K, NQ, ascending);
|
||||
ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);
|
||||
|
||||
/* test3, id1/dist1 small topk */
|
||||
ids1.clear();
|
||||
dist1.clear();
|
||||
result.clear();
|
||||
BuildResult(NQ, TOP_K / 2, ascending, ids1, dist1);
|
||||
status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
|
||||
BuildResult(ids1, dist1, TOP_K/2, NQ, ascending);
|
||||
ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
|
||||
ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);
|
||||
|
||||
/* test4, id1/dist1 small topk, id2/dist2 small topk */
|
||||
ids2.clear();
|
||||
dist2.clear();
|
||||
result.clear();
|
||||
BuildResult(NQ, TOP_K / 3, ascending, ids2, dist2);
|
||||
status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K / 3, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
|
||||
BuildResult(ids2, dist2, TOP_K/3, NQ, ascending);
|
||||
ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
|
||||
ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result);
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////
|
||||
ascending = false;
|
||||
@ -138,71 +156,199 @@ TEST(DBSearchTest, TOPK_TEST) {
|
||||
result.clear();
|
||||
|
||||
/* test1, id1/dist1 valid, id2/dist2 empty */
|
||||
BuildResult(NQ, TOP_K, ascending, ids1, dist1);
|
||||
status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
|
||||
BuildResult(ids1, dist1, TOP_K, NQ, ascending);
|
||||
ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result);
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);
|
||||
|
||||
/* test2, id1/dist1 valid, id2/dist2 valid */
|
||||
BuildResult(NQ, TOP_K, ascending, ids2, dist2);
|
||||
status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
|
||||
BuildResult(ids2, dist2, TOP_K, NQ, ascending);
|
||||
ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);
|
||||
|
||||
/* test3, id1/dist1 small topk */
|
||||
ids1.clear();
|
||||
dist1.clear();
|
||||
result.clear();
|
||||
BuildResult(NQ, TOP_K / 2, ascending, ids1, dist1);
|
||||
status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
|
||||
BuildResult(ids1, dist1, TOP_K/2, NQ, ascending);
|
||||
ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
|
||||
ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);
|
||||
|
||||
/* test4, id1/dist1 small topk, id2/dist2 small topk */
|
||||
ids2.clear();
|
||||
dist2.clear();
|
||||
result.clear();
|
||||
BuildResult(NQ, TOP_K / 3, ascending, ids2, dist2);
|
||||
status = milvus::scheduler::XSearchTask::TopkResult(ids1, dist1, TOP_K / 2, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
status = milvus::scheduler::XSearchTask::TopkResult(ids2, dist2, TOP_K / 3, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
|
||||
BuildResult(ids2, dist2, TOP_K/3, NQ, ascending);
|
||||
ms::XSearchTask::MergeTopkToResultSet(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
|
||||
ms::XSearchTask::MergeTopkToResultSet(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result);
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, TOP_K, NQ, ascending, result);
|
||||
}
|
||||
|
||||
TEST(DBSearchTest, REDUCE_PERF_TEST) {
|
||||
int32_t nq = 100;
|
||||
int32_t top_k = 1000;
|
||||
int32_t index_file_num = 478; /* sift1B dataset, index files num */
|
||||
bool ascending = true;
|
||||
|
||||
std::vector<int32_t> thread_vec = {4, 8, 11};
|
||||
std::vector<int32_t> nq_vec = {1, 10, 100, 1000};
|
||||
std::vector<int32_t> topk_vec = {1, 4, 16, 64, 256, 1024};
|
||||
int32_t NQ = nq_vec[nq_vec.size()-1];
|
||||
int32_t TOPK = topk_vec[topk_vec.size()-1];
|
||||
|
||||
std::vector<std::vector<int64_t>> id_vec;
|
||||
std::vector<std::vector<float>> dist_vec;
|
||||
std::vector<int64_t> input_ids;
|
||||
std::vector<float> input_distance;
|
||||
milvus::scheduler::ResultSet final_result;
|
||||
milvus::Status status;
|
||||
int32_t i, k, step;
|
||||
|
||||
double span, reduce_cost = 0.0;
|
||||
milvus::TimeRecorder rc("");
|
||||
|
||||
for (int32_t i = 0; i < index_file_num; i++) {
|
||||
BuildResult(nq, top_k, ascending, input_ids, input_distance);
|
||||
|
||||
rc.RecordSection("do search for context: " + std::to_string(i));
|
||||
|
||||
// pick up topk result
|
||||
status = milvus::scheduler::XSearchTask::TopkResult(input_ids,
|
||||
input_distance,
|
||||
top_k,
|
||||
nq,
|
||||
top_k,
|
||||
ascending,
|
||||
final_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(final_result.size(), nq);
|
||||
|
||||
span = rc.RecordSection("reduce topk for context: " + std::to_string(i));
|
||||
reduce_cost += span;
|
||||
/* generate testing data */
|
||||
for (i = 0; i < index_file_num; i++) {
|
||||
BuildResult(input_ids, input_distance, TOPK, NQ, ascending);
|
||||
id_vec.push_back(input_ids);
|
||||
dist_vec.push_back(input_distance);
|
||||
}
|
||||
|
||||
for (int32_t max_thread_num : thread_vec) {
|
||||
milvus::ThreadPool threadPool(max_thread_num);
|
||||
std::list<std::future<void>> threads_list;
|
||||
|
||||
for (int32_t nq : nq_vec) {
|
||||
for (int32_t top_k : topk_vec) {
|
||||
ms::ResultSet final_result, final_result_2, final_result_3;
|
||||
|
||||
std::vector<std::vector<int64_t>> id_vec_1(index_file_num);
|
||||
std::vector<std::vector<float>> dist_vec_1(index_file_num);
|
||||
for (i = 0; i < index_file_num; i++) {
|
||||
CopyResult(id_vec_1[i], dist_vec_1[i], top_k, id_vec[i], dist_vec[i], TOPK, nq);
|
||||
}
|
||||
|
||||
std::string str1 = "Method-1 " + std::to_string(max_thread_num) + " " +
|
||||
std::to_string(nq) + " " + std::to_string(top_k);
|
||||
milvus::TimeRecorder rc1(str1);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////
|
||||
/* method-1 */
|
||||
for (i = 0; i < index_file_num; i++) {
|
||||
ms::XSearchTask::MergeTopkToResultSet(id_vec_1[i],
|
||||
dist_vec_1[i],
|
||||
top_k,
|
||||
nq,
|
||||
top_k,
|
||||
ascending,
|
||||
final_result);
|
||||
ASSERT_EQ(final_result.size(), nq);
|
||||
}
|
||||
|
||||
rc1.RecordSection("reduce done");
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////
|
||||
/* method-2 */
|
||||
std::vector<std::vector<int64_t>> id_vec_2(index_file_num);
|
||||
std::vector<std::vector<float>> dist_vec_2(index_file_num);
|
||||
std::vector<uint64_t> k_vec_2(index_file_num);
|
||||
for (i = 0; i < index_file_num; i++) {
|
||||
CopyResult(id_vec_2[i], dist_vec_2[i], top_k, id_vec[i], dist_vec[i], TOPK, nq);
|
||||
k_vec_2[i] = top_k;
|
||||
}
|
||||
|
||||
std::string str2 = "Method-2 " + std::to_string(max_thread_num) + " " +
|
||||
std::to_string(nq) + " " + std::to_string(top_k);
|
||||
milvus::TimeRecorder rc2(str2);
|
||||
|
||||
for (step = 1; step < index_file_num; step *= 2) {
|
||||
for (i = 0; i + step < index_file_num; i += step * 2) {
|
||||
ms::XSearchTask::MergeTopkArray(id_vec_2[i], dist_vec_2[i], k_vec_2[i],
|
||||
id_vec_2[i + step], dist_vec_2[i + step], k_vec_2[i + step],
|
||||
nq, top_k, ascending);
|
||||
}
|
||||
}
|
||||
ms::XSearchTask::MergeTopkToResultSet(id_vec_2[0],
|
||||
dist_vec_2[0],
|
||||
k_vec_2[0],
|
||||
nq,
|
||||
top_k,
|
||||
ascending,
|
||||
final_result_2);
|
||||
ASSERT_EQ(final_result_2.size(), nq);
|
||||
|
||||
rc2.RecordSection("reduce done");
|
||||
|
||||
for (i = 0; i < nq; i++) {
|
||||
ASSERT_EQ(final_result[i].size(), final_result_2[i].size());
|
||||
for (k = 0; k < final_result[i].size(); k++) {
|
||||
if (final_result[i][k].first != final_result_2[i][k].first) {
|
||||
std::cout << i << " " << k << std::endl;
|
||||
}
|
||||
ASSERT_EQ(final_result[i][k].first, final_result_2[i][k].first);
|
||||
ASSERT_EQ(final_result[i][k].second, final_result_2[i][k].second);
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////
|
||||
/* method-3 parallel */
|
||||
std::vector<std::vector<int64_t>> id_vec_3(index_file_num);
|
||||
std::vector<std::vector<float>> dist_vec_3(index_file_num);
|
||||
std::vector<uint64_t> k_vec_3(index_file_num);
|
||||
for (i = 0; i < index_file_num; i++) {
|
||||
CopyResult(id_vec_3[i], dist_vec_3[i], top_k, id_vec[i], dist_vec[i], TOPK, nq);
|
||||
k_vec_3[i] = top_k;
|
||||
}
|
||||
|
||||
std::string str3 = "Method-3 " + std::to_string(max_thread_num) + " " +
|
||||
std::to_string(nq) + " " + std::to_string(top_k);
|
||||
milvus::TimeRecorder rc3(str3);
|
||||
|
||||
for (step = 1; step < index_file_num; step *= 2) {
|
||||
for (i = 0; i + step < index_file_num; i += step * 2) {
|
||||
threads_list.push_back(
|
||||
threadPool.enqueue(ms::XSearchTask::MergeTopkArray,
|
||||
std::ref(id_vec_3[i]),
|
||||
std::ref(dist_vec_3[i]),
|
||||
std::ref(k_vec_3[i]),
|
||||
std::ref(id_vec_3[i + step]),
|
||||
std::ref(dist_vec_3[i + step]),
|
||||
std::ref(k_vec_3[i + step]),
|
||||
nq,
|
||||
top_k,
|
||||
ascending));
|
||||
}
|
||||
|
||||
while (threads_list.size() > 0) {
|
||||
int nready = 0;
|
||||
for (auto it = threads_list.begin(); it != threads_list.end(); it = it) {
|
||||
auto &p = *it;
|
||||
std::chrono::milliseconds span(0);
|
||||
if (p.wait_for(span) == std::future_status::ready) {
|
||||
threads_list.erase(it++);
|
||||
++nready;
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
if (nready == 0) {
|
||||
std::this_thread::yield();
|
||||
}
|
||||
}
|
||||
}
|
||||
ms::XSearchTask::MergeTopkToResultSet(id_vec_3[0],
|
||||
dist_vec_3[0],
|
||||
k_vec_3[0],
|
||||
nq,
|
||||
top_k,
|
||||
ascending,
|
||||
final_result_3);
|
||||
ASSERT_EQ(final_result_3.size(), nq);
|
||||
|
||||
rc3.RecordSection("reduce done");
|
||||
|
||||
for (i = 0; i < nq; i++) {
|
||||
ASSERT_EQ(final_result[i].size(), final_result_3[i].size());
|
||||
for (k = 0; k < final_result[i].size(); k++) {
|
||||
ASSERT_EQ(final_result[i][k].first, final_result_3[i][k].first);
|
||||
ASSERT_EQ(final_result[i][k].second, final_result_3[i][k].second);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
std::cout << "total reduce time: " << reduce_cost / 1000 << " ms" << std::endl;
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user