mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-05 02:12:48 +08:00
fix conflict
Former-commit-id: b51828bbff45e2c3eeb2eaeb72b1e4adebda0550
This commit is contained in:
commit
0a77b4547f
@ -14,15 +14,19 @@ Please mark all change in change log and use the ticket from JIRA.
|
||||
## Improvement
|
||||
- MS-552 - Add and change the easylogging library
|
||||
- MS-553 - Refine cache code
|
||||
- MS-557 - Merge Log.h
|
||||
- MS-555 - Remove old scheduler
|
||||
- MS-556 - Add Job Definition in Scheduler
|
||||
- MS-557 - Merge Log.h
|
||||
- MS-558 - Refine status code
|
||||
- MS-562 - Add JobMgr and TaskCreator in Scheduler
|
||||
- MS-566 - Refactor cmake
|
||||
- MS-555 - Remove old scheduler
|
||||
- MS-574 - Milvus configuration refactor
|
||||
- MS-578 - Make sure milvus5.0 don't crack 0.3.1 data
|
||||
- MS-585 - Update namespace in scheduler
|
||||
- MS-606 - Speed up result reduce
|
||||
- MS-608 - Update TODO names
|
||||
- MS-609 - Update task construct function
|
||||
- MS-611 - Add resources validity check in ResourceMgr
|
||||
|
||||
## New Feature
|
||||
|
||||
@ -36,6 +40,7 @@ Please mark all change in change log and use the ticket from JIRA.
|
||||
- MS-590 - Refine cmake code to support cpplint
|
||||
- MS-600 - Reconstruct unittest code
|
||||
- MS-602 - Remove zilliz namespace
|
||||
- MS-610 - Change error code base value from hex to decimal
|
||||
|
||||
# Milvus 0.4.0 (2019-09-12)
|
||||
|
||||
|
||||
@ -24,6 +24,12 @@ namespace scheduler {
|
||||
|
||||
void
|
||||
ResourceMgr::Start() {
|
||||
if (not check_resource_valid()) {
|
||||
ENGINE_LOG_ERROR << "Resources invalid, cannot start ResourceMgr.";
|
||||
ENGINE_LOG_ERROR << Dump();
|
||||
return;
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lck(resources_mutex_);
|
||||
for (auto& resource : resources_) {
|
||||
resource->Start();
|
||||
@ -60,8 +66,22 @@ ResourceMgr::Add(ResourcePtr&& resource) {
|
||||
|
||||
resource->RegisterSubscriber(std::bind(&ResourceMgr::post_event, this, std::placeholders::_1));
|
||||
|
||||
if (resource->type() == ResourceType::DISK) {
|
||||
disk_resources_.emplace_back(ResourceWPtr(resource));
|
||||
switch (resource->type()) {
|
||||
case ResourceType::DISK: {
|
||||
disk_resources_.emplace_back(ResourceWPtr(resource));
|
||||
break;
|
||||
}
|
||||
case ResourceType::CPU: {
|
||||
cpu_resources_.emplace_back(ResourceWPtr(resource));
|
||||
break;
|
||||
}
|
||||
case ResourceType::GPU: {
|
||||
gpu_resources_.emplace_back(ResourceWPtr(resource));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
break;
|
||||
}
|
||||
}
|
||||
resources_.emplace_back(resource);
|
||||
|
||||
@ -74,7 +94,7 @@ ResourceMgr::Connect(const std::string& name1, const std::string& name2, Connect
|
||||
auto res2 = GetResource(name2);
|
||||
if (res1 && res2) {
|
||||
res1->AddNeighbour(std::static_pointer_cast<Node>(res2), connection);
|
||||
// TODO(wxy): enable when task balance supported
|
||||
// TODO(wxyu): enable when task balance supported
|
||||
// res2->AddNeighbour(std::static_pointer_cast<Node>(res1), connection);
|
||||
return true;
|
||||
}
|
||||
@ -85,6 +105,8 @@ void
|
||||
ResourceMgr::Clear() {
|
||||
std::lock_guard<std::mutex> lck(resources_mutex_);
|
||||
disk_resources_.clear();
|
||||
cpu_resources_.clear();
|
||||
gpu_resources_.clear();
|
||||
resources_.clear();
|
||||
}
|
||||
|
||||
@ -148,14 +170,14 @@ ResourceMgr::GetNumGpuResource() const {
|
||||
|
||||
std::string
|
||||
ResourceMgr::Dump() {
|
||||
std::string str = "ResourceMgr contains " + std::to_string(resources_.size()) + " resources.\n";
|
||||
std::stringstream ss;
|
||||
ss << "ResourceMgr contains " << resources_.size() << " resources." << std::endl;
|
||||
|
||||
for (uint64_t i = 0; i < resources_.size(); ++i) {
|
||||
str += "Resource No." + std::to_string(i) + ":\n";
|
||||
// str += resources_[i]->Dump();
|
||||
for (auto& res : resources_) {
|
||||
ss << res->Dump();
|
||||
}
|
||||
|
||||
return str;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
std::string
|
||||
@ -170,6 +192,34 @@ ResourceMgr::DumpTaskTables() {
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
bool
|
||||
ResourceMgr::check_resource_valid() {
|
||||
{
|
||||
// TODO: check one disk-resource, one cpu-resource, zero or more gpu-resource;
|
||||
if (GetDiskResources().size() != 1) return false;
|
||||
if (GetCpuResources().size() != 1) return false;
|
||||
}
|
||||
|
||||
{
|
||||
// TODO: one compute-resource at least;
|
||||
if (GetNumOfComputeResource() < 1) return false;
|
||||
}
|
||||
|
||||
{
|
||||
// TODO: check disk only connect with cpu
|
||||
}
|
||||
|
||||
{
|
||||
// TODO: check gpu only connect with cpu
|
||||
}
|
||||
|
||||
{
|
||||
// TODO: check if exists isolated node
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void
|
||||
ResourceMgr::post_event(const EventPtr& event) {
|
||||
{
|
||||
@ -183,7 +233,9 @@ void
|
||||
ResourceMgr::event_process() {
|
||||
while (running_) {
|
||||
std::unique_lock<std::mutex> lock(event_mutex_);
|
||||
event_cv_.wait(lock, [this] { return !queue_.empty(); });
|
||||
event_cv_.wait(lock, [this] {
|
||||
return !queue_.empty();
|
||||
});
|
||||
|
||||
auto event = queue_.front();
|
||||
queue_.pop();
|
||||
|
||||
@ -64,7 +64,17 @@ class ResourceMgr {
|
||||
return disk_resources_;
|
||||
}
|
||||
|
||||
// TODO(wxy): why return shared pointer
|
||||
inline std::vector<ResourceWPtr>&
|
||||
GetCpuResources() {
|
||||
return cpu_resources_;
|
||||
}
|
||||
|
||||
inline std::vector<ResourceWPtr>&
|
||||
GetGpuResources() {
|
||||
return gpu_resources_;
|
||||
}
|
||||
|
||||
// TODO(wxyu): why return shared pointer
|
||||
inline std::vector<ResourcePtr>
|
||||
GetAllResources() {
|
||||
return resources_;
|
||||
@ -89,7 +99,7 @@ class ResourceMgr {
|
||||
GetNumGpuResource() const;
|
||||
|
||||
public:
|
||||
// TODO(wxy): add stats interface(low)
|
||||
// TODO(wxyu): add stats interface(low)
|
||||
|
||||
public:
|
||||
/******** Utility Functions ********/
|
||||
@ -100,6 +110,9 @@ class ResourceMgr {
|
||||
DumpTaskTables();
|
||||
|
||||
private:
|
||||
bool
|
||||
check_resource_valid();
|
||||
|
||||
void
|
||||
post_event(const EventPtr& event);
|
||||
|
||||
@ -110,6 +123,8 @@ class ResourceMgr {
|
||||
bool running_ = false;
|
||||
|
||||
std::vector<ResourceWPtr> disk_resources_;
|
||||
std::vector<ResourceWPtr> cpu_resources_;
|
||||
std::vector<ResourceWPtr> gpu_resources_;
|
||||
std::vector<ResourcePtr> resources_;
|
||||
mutable std::mutex resources_mutex_;
|
||||
|
||||
|
||||
@ -146,7 +146,7 @@ load_advance_config() {
|
||||
// }
|
||||
// } catch (const char *msg) {
|
||||
// SERVER_LOG_ERROR << msg;
|
||||
// // TODO(wxy): throw exception instead
|
||||
// // TODO(wxyu): throw exception instead
|
||||
// exit(-1);
|
||||
//// throw std::exception();
|
||||
// }
|
||||
|
||||
@ -92,7 +92,7 @@ Scheduler::Process(const EventPtr& event) {
|
||||
process_event(event);
|
||||
}
|
||||
|
||||
// TODO(wxy): refactor the function
|
||||
// TODO(wxyu): refactor the function
|
||||
void
|
||||
Scheduler::OnLoadCompleted(const EventPtr& event) {
|
||||
auto load_completed_event = std::static_pointer_cast<LoadCompletedEvent>(event);
|
||||
|
||||
@ -31,7 +31,7 @@
|
||||
namespace milvus {
|
||||
namespace scheduler {
|
||||
|
||||
// TODO(wxy): refactor, not friendly to unittest, logical in framework code
|
||||
// TODO(wxyu): refactor, not friendly to unittest, logical in framework code
|
||||
class Scheduler {
|
||||
public:
|
||||
explicit Scheduler(ResourceMgrWPtr res_mgr);
|
||||
|
||||
@ -38,7 +38,7 @@ TaskCreator::Create(const JobPtr &job) {
|
||||
return Create(std::static_pointer_cast<BuildIndexJob>(job));
|
||||
}
|
||||
default: {
|
||||
// TODO(wxy): error
|
||||
// TODO(wxyu): error
|
||||
return std::vector<TaskPtr>();
|
||||
}
|
||||
}
|
||||
@ -47,9 +47,9 @@ TaskCreator::Create(const JobPtr &job) {
|
||||
std::vector<TaskPtr>
|
||||
TaskCreator::Create(const SearchJobPtr &job) {
|
||||
std::vector<TaskPtr> tasks;
|
||||
for (auto &index_file : job->index_files()) {
|
||||
auto task = std::make_shared<XSearchTask>(index_file.second);
|
||||
task->label() = std::make_shared<DefaultLabel>();
|
||||
for (auto& index_file : job->index_files()) {
|
||||
auto label = std::make_shared<DefaultLabel>();
|
||||
auto task = std::make_shared<XSearchTask>(index_file.second, label);
|
||||
task->job_ = job;
|
||||
tasks.emplace_back(task);
|
||||
}
|
||||
@ -60,8 +60,8 @@ TaskCreator::Create(const SearchJobPtr &job) {
|
||||
std::vector<TaskPtr>
|
||||
TaskCreator::Create(const DeleteJobPtr &job) {
|
||||
std::vector<TaskPtr> tasks;
|
||||
auto task = std::make_shared<XDeleteTask>(job);
|
||||
task->label() = std::make_shared<BroadcastLabel>();
|
||||
auto label = std::make_shared<BroadcastLabel>();
|
||||
auto task = std::make_shared<XDeleteTask>(job, label);
|
||||
task->job_ = job;
|
||||
tasks.emplace_back(task);
|
||||
|
||||
|
||||
@ -125,7 +125,7 @@ class TaskTable {
|
||||
Get(uint64_t index);
|
||||
|
||||
/*
|
||||
* TODO(wxy): BIG GC
|
||||
* TODO(wxyu): BIG GC
|
||||
* Remove sequence task which is DONE or MOVED from front;
|
||||
* Called by ?
|
||||
*/
|
||||
@ -173,7 +173,7 @@ class TaskTable {
|
||||
public:
|
||||
/******** Action ********/
|
||||
|
||||
// TODO(wxy): bool to Status
|
||||
// TODO(wxyu): bool to Status
|
||||
/*
|
||||
* Load a task;
|
||||
* Set state loading;
|
||||
|
||||
@ -84,7 +84,7 @@ Action::PushTaskToNeighbourRandomly(const TaskPtr &task, const ResourcePtr &self
|
||||
}
|
||||
|
||||
} else {
|
||||
// TODO(wxy): process
|
||||
// TODO(wxyu): process
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -37,8 +37,9 @@ namespace scheduler {
|
||||
using engine::meta::TableFileSchemaPtr;
|
||||
|
||||
using Id2IndexMap = std::unordered_map<size_t, TableFileSchemaPtr>;
|
||||
using Id2DistanceMap = std::vector<std::pair<int64_t, double>>;
|
||||
using ResultSet = std::vector<Id2DistanceMap>;
|
||||
using IdDistPair = std::pair<int64_t, double>;
|
||||
using Id2DistVec = std::vector<IdDistPair>;
|
||||
using ResultSet = std::vector<Id2DistVec>;
|
||||
|
||||
class SearchJob : public Job {
|
||||
public:
|
||||
|
||||
@ -38,7 +38,7 @@
|
||||
namespace milvus {
|
||||
namespace scheduler {
|
||||
|
||||
// TODO(wxy): Storage, Route, Executor
|
||||
// TODO(wxyu): Storage, Route, Executor
|
||||
enum class ResourceType {
|
||||
DISK = 0,
|
||||
CPU = 1,
|
||||
@ -114,11 +114,11 @@ class Resource : public Node, public std::enable_shared_from_this<Resource> {
|
||||
return enable_executor_;
|
||||
}
|
||||
|
||||
// TODO(wxy): const
|
||||
// TODO(wxyu): const
|
||||
uint64_t
|
||||
NumOfTaskToExec();
|
||||
|
||||
// TODO(wxy): need double ?
|
||||
// TODO(wxyu): need double ?
|
||||
inline uint64_t
|
||||
TaskAvgCost() const {
|
||||
return total_cost_ / total_task_;
|
||||
|
||||
@ -17,11 +17,13 @@
|
||||
|
||||
#include "scheduler/task/DeleteTask.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
namespace milvus {
|
||||
namespace scheduler {
|
||||
|
||||
XDeleteTask::XDeleteTask(const scheduler::DeleteJobPtr& delete_job)
|
||||
: Task(TaskType::DeleteTask), delete_job_(delete_job) {
|
||||
XDeleteTask::XDeleteTask(const scheduler::DeleteJobPtr& delete_job, TaskLabelPtr label)
|
||||
: Task(TaskType::DeleteTask, std::move(label)), delete_job_(delete_job) {
|
||||
}
|
||||
|
||||
void
|
||||
|
||||
@ -25,7 +25,7 @@ namespace scheduler {
|
||||
|
||||
class XDeleteTask : public Task {
|
||||
public:
|
||||
explicit XDeleteTask(const scheduler::DeleteJobPtr& delete_job);
|
||||
explicit XDeleteTask(const scheduler::DeleteJobPtr& delete_job, TaskLabelPtr label);
|
||||
|
||||
void
|
||||
Load(LoadType type, uint8_t device_id) override;
|
||||
|
||||
@ -78,24 +78,26 @@ std::mutex XSearchTask::merge_mutex_;
|
||||
|
||||
void
|
||||
CollectFileMetrics(int file_type, size_t file_size) {
|
||||
server::MetricsBase& inst = server::Metrics::GetInstance();
|
||||
switch (file_type) {
|
||||
case TableFileSchema::RAW:
|
||||
case TableFileSchema::TO_INDEX: {
|
||||
server::Metrics::GetInstance().RawFileSizeHistogramObserve(file_size);
|
||||
server::Metrics::GetInstance().RawFileSizeTotalIncrement(file_size);
|
||||
server::Metrics::GetInstance().RawFileSizeGaugeSet(file_size);
|
||||
inst.RawFileSizeHistogramObserve(file_size);
|
||||
inst.RawFileSizeTotalIncrement(file_size);
|
||||
inst.RawFileSizeGaugeSet(file_size);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
server::Metrics::GetInstance().IndexFileSizeHistogramObserve(file_size);
|
||||
server::Metrics::GetInstance().IndexFileSizeTotalIncrement(file_size);
|
||||
server::Metrics::GetInstance().IndexFileSizeGaugeSet(file_size);
|
||||
inst.IndexFileSizeHistogramObserve(file_size);
|
||||
inst.IndexFileSizeTotalIncrement(file_size);
|
||||
inst.IndexFileSizeGaugeSet(file_size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
XSearchTask::XSearchTask(TableFileSchemaPtr file) : Task(TaskType::SearchTask), file_(file) {
|
||||
XSearchTask::XSearchTask(TableFileSchemaPtr file, TaskLabelPtr label)
|
||||
: Task(TaskType::SearchTask, std::move(label)), file_(file) {
|
||||
if (file_) {
|
||||
if (file_->metric_type_ != static_cast<int>(MetricType::L2)) {
|
||||
metric_l2 = false;
|
||||
@ -205,16 +207,9 @@ XSearchTask::Execute() {
|
||||
double span = rc.RecordSection(hdr + ", do search");
|
||||
// search_job->AccumSearchCost(span);
|
||||
|
||||
// step 3: cluster result
|
||||
scheduler::ResultSet result_set;
|
||||
// step 3: pick up topk result
|
||||
auto spec_k = index_engine_->Count() < topk ? index_engine_->Count() : topk;
|
||||
XSearchTask::ClusterResult(output_ids, output_distance, nq, spec_k, result_set);
|
||||
|
||||
span = rc.RecordSection(hdr + ", cluster result");
|
||||
// search_job->AccumReduceCost(span);
|
||||
|
||||
// step 4: pick up topk result
|
||||
XSearchTask::TopkResult(result_set, topk, metric_l2, search_job->GetResult());
|
||||
XSearchTask::TopkResult(output_ids, output_distance, spec_k, nq, topk, metric_l2, search_job->GetResult());
|
||||
|
||||
span = rc.RecordSection(hdr + ", reduce topk");
|
||||
// search_job->AccumReduceCost(span);
|
||||
@ -234,142 +229,75 @@ XSearchTask::Execute() {
|
||||
}
|
||||
|
||||
Status
|
||||
XSearchTask::ClusterResult(const std::vector<int64_t>& output_ids, const std::vector<float>& output_distance,
|
||||
uint64_t nq, uint64_t topk, scheduler::ResultSet& result_set) {
|
||||
if (output_ids.size() < nq * topk || output_distance.size() < nq * topk) {
|
||||
std::string msg = "Invalid id array size: " + std::to_string(output_ids.size()) + " distance array size: " +
|
||||
std::to_string(output_distance.size());
|
||||
ENGINE_LOG_ERROR << msg;
|
||||
return Status(DB_ERROR, msg);
|
||||
}
|
||||
XSearchTask::TopkResult(const std::vector<long> &input_ids,
|
||||
const std::vector<float> &input_distance,
|
||||
uint64_t input_k,
|
||||
uint64_t nq,
|
||||
uint64_t topk,
|
||||
bool ascending,
|
||||
scheduler::ResultSet &result) {
|
||||
scheduler::ResultSet result_buf;
|
||||
|
||||
result_set.clear();
|
||||
result_set.resize(nq);
|
||||
|
||||
std::function<void(size_t, size_t)> reduce_worker = [&](size_t from_index, size_t to_index) {
|
||||
for (auto i = from_index; i < to_index; i++) {
|
||||
scheduler::Id2DistanceMap id_distance;
|
||||
id_distance.reserve(topk);
|
||||
for (auto k = 0; k < topk; k++) {
|
||||
uint64_t index = i * topk + k;
|
||||
if (output_ids[index] < 0) {
|
||||
continue;
|
||||
if (result.empty()) {
|
||||
result_buf.resize(nq, scheduler::Id2DistVec(input_k, scheduler::IdDistPair(-1, 0.0)));
|
||||
for (auto i = 0; i < nq; ++i) {
|
||||
auto& result_buf_i = result_buf[i];
|
||||
uint64_t input_k_multi_i = input_k * i;
|
||||
for (auto k = 0; k < input_k; ++k) {
|
||||
uint64_t idx = input_k_multi_i + k;
|
||||
auto& result_buf_item = result_buf_i[k];
|
||||
result_buf_item.first = input_ids[idx];
|
||||
result_buf_item.second = input_distance[idx];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
size_t tar_size = result[0].size();
|
||||
uint64_t output_k = std::min(topk, input_k + tar_size);
|
||||
result_buf.resize(nq, scheduler::Id2DistVec(output_k, scheduler::IdDistPair(-1, 0.0)));
|
||||
for (auto i = 0; i < nq; ++i) {
|
||||
size_t buf_k = 0, src_k = 0, tar_k = 0;
|
||||
uint64_t src_idx;
|
||||
auto& result_i = result[i];
|
||||
auto& result_buf_i = result_buf[i];
|
||||
uint64_t input_k_multi_i = input_k * i;
|
||||
while (buf_k < output_k && src_k < input_k && tar_k < tar_size) {
|
||||
src_idx = input_k_multi_i + src_k;
|
||||
auto& result_buf_item = result_buf_i[buf_k];
|
||||
auto& result_item = result_i[tar_k];
|
||||
if ((ascending && input_distance[src_idx] < result_item.second) ||
|
||||
(!ascending && input_distance[src_idx] > result_item.second)) {
|
||||
result_buf_item.first = input_ids[src_idx];
|
||||
result_buf_item.second = input_distance[src_idx];
|
||||
src_k++;
|
||||
} else {
|
||||
result_buf_item = result_item;
|
||||
tar_k++;
|
||||
}
|
||||
id_distance.push_back(std::make_pair(output_ids[index], output_distance[index]));
|
||||
buf_k++;
|
||||
}
|
||||
result_set[i] = id_distance;
|
||||
}
|
||||
};
|
||||
|
||||
// if (NeedParallelReduce(nq, topk)) {
|
||||
// ParallelReduce(reduce_worker, nq);
|
||||
// } else {
|
||||
reduce_worker(0, nq);
|
||||
// }
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
XSearchTask::MergeResult(scheduler::Id2DistanceMap& distance_src, scheduler::Id2DistanceMap& distance_target,
|
||||
uint64_t topk, bool ascending) {
|
||||
// Note: the score_src and score_target are already arranged by score in ascending order
|
||||
if (distance_src.empty()) {
|
||||
ENGINE_LOG_WARNING << "Empty distance source array";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::unique_lock<std::mutex> lock(merge_mutex_);
|
||||
if (distance_target.empty()) {
|
||||
distance_target.swap(distance_src);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
size_t src_count = distance_src.size();
|
||||
size_t target_count = distance_target.size();
|
||||
scheduler::Id2DistanceMap distance_merged;
|
||||
distance_merged.reserve(topk);
|
||||
size_t src_index = 0, target_index = 0;
|
||||
while (true) {
|
||||
// all score_src items are merged, if score_merged.size() still less than topk
|
||||
// move items from score_target to score_merged until score_merged.size() equal topk
|
||||
if (src_index >= src_count) {
|
||||
for (size_t i = target_index; i < target_count && distance_merged.size() < topk; ++i) {
|
||||
distance_merged.push_back(distance_target[i]);
|
||||
if (buf_k < topk) {
|
||||
if (src_k < input_k) {
|
||||
while (buf_k < output_k && src_k < input_k) {
|
||||
src_idx = input_k_multi_i + src_k;
|
||||
auto& result_buf_item = result_buf_i[buf_k];
|
||||
result_buf_item.first = input_ids[src_idx];
|
||||
result_buf_item.second = input_distance[src_idx];
|
||||
src_k++;
|
||||
buf_k++;
|
||||
}
|
||||
} else {
|
||||
while (buf_k < output_k && tar_k < tar_size) {
|
||||
result_buf_i[buf_k] = result_i[tar_k];
|
||||
tar_k++;
|
||||
buf_k++;
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// all score_target items are merged, if score_merged.size() still less than topk
|
||||
// move items from score_src to score_merged until score_merged.size() equal topk
|
||||
if (target_index >= target_count) {
|
||||
for (size_t i = src_index; i < src_count && distance_merged.size() < topk; ++i) {
|
||||
distance_merged.push_back(distance_src[i]);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// compare score,
|
||||
// if ascending = true, put smallest score to score_merged one by one
|
||||
// else, put largest score to score_merged one by one
|
||||
auto& src_pair = distance_src[src_index];
|
||||
auto& target_pair = distance_target[target_index];
|
||||
if (ascending) {
|
||||
if (src_pair.second > target_pair.second) {
|
||||
distance_merged.push_back(target_pair);
|
||||
target_index++;
|
||||
} else {
|
||||
distance_merged.push_back(src_pair);
|
||||
src_index++;
|
||||
}
|
||||
} else {
|
||||
if (src_pair.second < target_pair.second) {
|
||||
distance_merged.push_back(target_pair);
|
||||
target_index++;
|
||||
} else {
|
||||
distance_merged.push_back(src_pair);
|
||||
src_index++;
|
||||
}
|
||||
}
|
||||
|
||||
// score_merged.size() already equal topk
|
||||
if (distance_merged.size() >= topk) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
distance_target.swap(distance_merged);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
XSearchTask::TopkResult(scheduler::ResultSet& result_src, uint64_t topk, bool ascending,
|
||||
scheduler::ResultSet& result_target) {
|
||||
if (result_target.empty()) {
|
||||
result_target.swap(result_src);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (result_src.size() != result_target.size()) {
|
||||
std::string msg = "Invalid result set size";
|
||||
ENGINE_LOG_ERROR << msg;
|
||||
return Status(DB_ERROR, msg);
|
||||
}
|
||||
|
||||
std::function<void(size_t, size_t)> ReduceWorker = [&](size_t from_index, size_t to_index) {
|
||||
for (size_t i = from_index; i < to_index; i++) {
|
||||
scheduler::Id2DistanceMap& score_src = result_src[i];
|
||||
scheduler::Id2DistanceMap& score_target = result_target[i];
|
||||
XSearchTask::MergeResult(score_src, score_target, topk, ascending);
|
||||
}
|
||||
};
|
||||
|
||||
// if (NeedParallelReduce(result_src.size(), topk)) {
|
||||
// ParallelReduce(ReduceWorker, result_src.size());
|
||||
// } else {
|
||||
ReduceWorker(0, result_src.size());
|
||||
// }
|
||||
result.swap(result_buf);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -26,10 +26,10 @@
|
||||
namespace milvus {
|
||||
namespace scheduler {
|
||||
|
||||
// TODO(wxy): rewrite
|
||||
// TODO(wxyu): rewrite
|
||||
class XSearchTask : public Task {
|
||||
public:
|
||||
explicit XSearchTask(TableFileSchemaPtr file);
|
||||
explicit XSearchTask(TableFileSchemaPtr file, TaskLabelPtr label);
|
||||
|
||||
void
|
||||
Load(LoadType type, uint8_t device_id) override;
|
||||
@ -39,15 +39,13 @@ class XSearchTask : public Task {
|
||||
|
||||
public:
|
||||
static Status
|
||||
ClusterResult(const std::vector<int64_t>& output_ids, const std::vector<float>& output_distence, uint64_t nq,
|
||||
uint64_t topk, scheduler::ResultSet& result_set);
|
||||
|
||||
static Status
|
||||
MergeResult(scheduler::Id2DistanceMap& distance_src, scheduler::Id2DistanceMap& distance_target, uint64_t topk,
|
||||
bool ascending);
|
||||
|
||||
static Status
|
||||
TopkResult(scheduler::ResultSet& result_src, uint64_t topk, bool ascending, scheduler::ResultSet& result_target);
|
||||
TopkResult(const std::vector<long> &input_ids,
|
||||
const std::vector<float> &input_distance,
|
||||
uint64_t input_k,
|
||||
uint64_t nq,
|
||||
uint64_t topk,
|
||||
bool ascending,
|
||||
scheduler::ResultSet &result);
|
||||
|
||||
public:
|
||||
TableFileSchemaPtr file_;
|
||||
|
||||
@ -24,6 +24,7 @@
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
namespace milvus {
|
||||
namespace scheduler {
|
||||
@ -49,7 +50,7 @@ using TaskPtr = std::shared_ptr<Task>;
|
||||
// TODO: re-design
|
||||
class Task {
|
||||
public:
|
||||
explicit Task(TaskType type) : type_(type) {
|
||||
explicit Task(TaskType type, TaskLabelPtr label) : type_(type), label_(std::move(label)) {
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
@ -18,10 +18,12 @@
|
||||
#include "scheduler/task/TestTask.h"
|
||||
#include "cache/GpuCacheMgr.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
namespace milvus {
|
||||
namespace scheduler {
|
||||
|
||||
TestTask::TestTask(TableFileSchemaPtr& file) : XSearchTask(file) {
|
||||
TestTask::TestTask(TableFileSchemaPtr& file, TaskLabelPtr label) : XSearchTask(file, std::move(label)) {
|
||||
}
|
||||
|
||||
void
|
||||
|
||||
@ -24,7 +24,7 @@ namespace scheduler {
|
||||
|
||||
class TestTask : public XSearchTask {
|
||||
public:
|
||||
explicit TestTask(TableFileSchemaPtr& file);
|
||||
explicit TestTask(TableFileSchemaPtr& file, TaskLabelPtr label);
|
||||
|
||||
public:
|
||||
void
|
||||
|
||||
@ -26,7 +26,7 @@ namespace milvus {
|
||||
using ErrorCode = int32_t;
|
||||
|
||||
constexpr ErrorCode SERVER_SUCCESS = 0;
|
||||
constexpr ErrorCode SERVER_ERROR_CODE_BASE = 0x30000;
|
||||
constexpr ErrorCode SERVER_ERROR_CODE_BASE = 30000;
|
||||
|
||||
constexpr ErrorCode
|
||||
ToServerErrorCode(const ErrorCode error_code) {
|
||||
@ -34,7 +34,7 @@ ToServerErrorCode(const ErrorCode error_code) {
|
||||
}
|
||||
|
||||
constexpr ErrorCode DB_SUCCESS = 0;
|
||||
constexpr ErrorCode DB_ERROR_CODE_BASE = 0x40000;
|
||||
constexpr ErrorCode DB_ERROR_CODE_BASE = 40000;
|
||||
|
||||
constexpr ErrorCode
|
||||
ToDbErrorCode(const ErrorCode error_code) {
|
||||
@ -42,7 +42,7 @@ ToDbErrorCode(const ErrorCode error_code) {
|
||||
}
|
||||
|
||||
constexpr ErrorCode KNOWHERE_SUCCESS = 0;
|
||||
constexpr ErrorCode KNOWHERE_ERROR_CODE_BASE = 0x50000;
|
||||
constexpr ErrorCode KNOWHERE_ERROR_CODE_BASE = 50000;
|
||||
|
||||
constexpr ErrorCode
|
||||
ToKnowhereErrorCode(const ErrorCode error_code) {
|
||||
|
||||
@ -20,6 +20,7 @@
|
||||
#include "ConfAdapter.h"
|
||||
#include "VecIndex.h"
|
||||
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
||||
|
||||
@ -22,13 +22,10 @@
|
||||
#include "scheduler/task/SearchTask.h"
|
||||
#include "utils/TimeRecorder.h"
|
||||
|
||||
using namespace milvus::scheduler;
|
||||
|
||||
namespace {
|
||||
|
||||
namespace ms = milvus;
|
||||
|
||||
static constexpr uint64_t NQ = 15;
|
||||
static constexpr uint64_t TOP_K = 64;
|
||||
|
||||
void
|
||||
BuildResult(uint64_t nq,
|
||||
uint64_t topk,
|
||||
@ -48,76 +45,36 @@ BuildResult(uint64_t nq,
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
CheckResult(const ms::scheduler::Id2DistanceMap &src_1,
|
||||
const ms::scheduler::Id2DistanceMap &src_2,
|
||||
const ms::scheduler::Id2DistanceMap &target,
|
||||
bool ascending) {
|
||||
for (uint64_t i = 0; i < target.size() - 1; i++) {
|
||||
void CheckTopkResult(const std::vector<long> &input_ids_1,
|
||||
const std::vector<float> &input_distance_1,
|
||||
const std::vector<long> &input_ids_2,
|
||||
const std::vector<float> &input_distance_2,
|
||||
uint64_t nq,
|
||||
uint64_t topk,
|
||||
bool ascending,
|
||||
const ResultSet& result) {
|
||||
ASSERT_EQ(result.size(), nq);
|
||||
ASSERT_EQ(input_ids_1.size(), input_distance_1.size());
|
||||
ASSERT_EQ(input_ids_2.size(), input_distance_2.size());
|
||||
|
||||
uint64_t input_k1 = input_ids_1.size() / nq;
|
||||
uint64_t input_k2 = input_ids_2.size() / nq;
|
||||
|
||||
for (int64_t i = 0; i < nq; i++) {
|
||||
std::vector<float> src_vec(input_distance_1.begin()+i*input_k1, input_distance_1.begin()+(i+1)*input_k1);
|
||||
src_vec.insert(src_vec.end(), input_distance_2.begin()+i*input_k2, input_distance_2.begin()+(i+1)*input_k2);
|
||||
if (ascending) {
|
||||
ASSERT_LE(target[i].second, target[i + 1].second);
|
||||
std::sort(src_vec.begin(), src_vec.end());
|
||||
} else {
|
||||
ASSERT_GE(target[i].second, target[i + 1].second);
|
||||
}
|
||||
}
|
||||
|
||||
using ID2DistMap = std::map<int64_t, float>;
|
||||
ID2DistMap src_map_1, src_map_2;
|
||||
for (const auto &pair : src_1) {
|
||||
src_map_1.insert(pair);
|
||||
}
|
||||
for (const auto &pair : src_2) {
|
||||
src_map_2.insert(pair);
|
||||
}
|
||||
|
||||
for (const auto &pair : target) {
|
||||
ASSERT_TRUE(src_map_1.find(pair.first) != src_map_1.end() || src_map_2.find(pair.first) != src_map_2.end());
|
||||
|
||||
float dist = src_map_1.find(pair.first) != src_map_1.end() ? src_map_1[pair.first] : src_map_2[pair.first];
|
||||
ASSERT_LT(fabs(pair.second - dist), std::numeric_limits<float>::epsilon());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
CheckCluster(const std::vector<int64_t> &target_ids,
|
||||
const std::vector<float> &target_distence,
|
||||
const ms::scheduler::ResultSet &src_result,
|
||||
int64_t nq,
|
||||
int64_t topk) {
|
||||
ASSERT_EQ(src_result.size(), nq);
|
||||
for (int64_t i = 0; i < nq; i++) {
|
||||
auto &res = src_result[i];
|
||||
ASSERT_EQ(res.size(), topk);
|
||||
|
||||
if (res.empty()) {
|
||||
continue;
|
||||
std::sort(src_vec.begin(), src_vec.end(), std::greater<float>());
|
||||
}
|
||||
|
||||
ASSERT_EQ(res[0].first, target_ids[i * topk]);
|
||||
ASSERT_EQ(res[topk - 1].first, target_ids[i * topk + topk - 1]);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
CheckTopkResult(const ms::scheduler::ResultSet &src_result,
|
||||
bool ascending,
|
||||
int64_t nq,
|
||||
int64_t topk) {
|
||||
ASSERT_EQ(src_result.size(), nq);
|
||||
for (int64_t i = 0; i < nq; i++) {
|
||||
auto &res = src_result[i];
|
||||
ASSERT_EQ(res.size(), topk);
|
||||
|
||||
if (res.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int64_t k = 0; k < topk - 1; k++) {
|
||||
if (ascending) {
|
||||
ASSERT_LE(res[k].second, res[k + 1].second);
|
||||
} else {
|
||||
ASSERT_GE(res[k].second, res[k + 1].second);
|
||||
uint64_t n = std::min(topk, input_k1+input_k2);
|
||||
for (uint64_t j = 0; j < n; j++) {
|
||||
if (src_vec[j] != result[i][j].second) {
|
||||
std::cout << src_vec[j] << " " << result[i][j].second << std::endl;
|
||||
}
|
||||
ASSERT_TRUE(src_vec[j] == result[i][j].second);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -125,179 +82,117 @@ CheckTopkResult(const ms::scheduler::ResultSet &src_result,
|
||||
} // namespace
|
||||
|
||||
TEST(DBSearchTest, TOPK_TEST) {
|
||||
uint64_t NQ = 15;
|
||||
uint64_t TOP_K = 64;
|
||||
bool ascending;
|
||||
std::vector<long> ids1, ids2;
|
||||
std::vector<float> dist1, dist2;
|
||||
ResultSet result;
|
||||
milvus::Status status;
|
||||
|
||||
/* test1, id1/dist1 valid, id2/dist2 empty */
|
||||
ascending = true;
|
||||
BuildResult(NQ, TOP_K, ascending, ids1, dist1);
|
||||
status = XSearchTask::TopkResult(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
|
||||
|
||||
/* test2, id1/dist1 valid, id2/dist2 valid */
|
||||
BuildResult(NQ, TOP_K, ascending, ids2, dist2);
|
||||
status = XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
|
||||
|
||||
/* test3, id1/dist1 small topk */
|
||||
ids1.clear();
|
||||
dist1.clear();
|
||||
result.clear();
|
||||
BuildResult(NQ, TOP_K/2, ascending, ids1, dist1);
|
||||
status = XSearchTask::TopkResult(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
status = XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
|
||||
|
||||
/* test4, id1/dist1 small topk, id2/dist2 small topk */
|
||||
ids2.clear();
|
||||
dist2.clear();
|
||||
result.clear();
|
||||
BuildResult(NQ, TOP_K/3, ascending, ids2, dist2);
|
||||
status = XSearchTask::TopkResult(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
status = XSearchTask::TopkResult(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////
|
||||
ascending = false;
|
||||
ids1.clear();
|
||||
dist1.clear();
|
||||
ids2.clear();
|
||||
dist2.clear();
|
||||
result.clear();
|
||||
|
||||
/* test1, id1/dist1 valid, id2/dist2 empty */
|
||||
BuildResult(NQ, TOP_K, ascending, ids1, dist1);
|
||||
status = XSearchTask::TopkResult(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
|
||||
|
||||
/* test2, id1/dist1 valid, id2/dist2 valid */
|
||||
BuildResult(NQ, TOP_K, ascending, ids2, dist2);
|
||||
status = XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
|
||||
|
||||
/* test3, id1/dist1 small topk */
|
||||
ids1.clear();
|
||||
dist1.clear();
|
||||
result.clear();
|
||||
BuildResult(NQ, TOP_K/2, ascending, ids1, dist1);
|
||||
status = XSearchTask::TopkResult(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
status = XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
|
||||
|
||||
/* test4, id1/dist1 small topk, id2/dist2 small topk */
|
||||
ids2.clear();
|
||||
dist2.clear();
|
||||
result.clear();
|
||||
BuildResult(NQ, TOP_K/3, ascending, ids2, dist2);
|
||||
status = XSearchTask::TopkResult(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
status = XSearchTask::TopkResult(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result);
|
||||
}
|
||||
|
||||
TEST(DBSearchTest, REDUCE_PERF_TEST) {
|
||||
int32_t nq = 100;
|
||||
int32_t top_k = 1000;
|
||||
int32_t index_file_num = 478; /* sift1B dataset, index files num */
|
||||
bool ascending = true;
|
||||
std::vector<int64_t> target_ids;
|
||||
std::vector<float> target_distence;
|
||||
ms::scheduler::ResultSet src_result;
|
||||
auto status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result);
|
||||
ASSERT_FALSE(status.ok());
|
||||
ASSERT_TRUE(src_result.empty());
|
||||
std::vector<long> input_ids;
|
||||
std::vector<float> input_distance;
|
||||
ResultSet final_result;
|
||||
milvus::Status status;
|
||||
|
||||
BuildResult(NQ, TOP_K, ascending, target_ids, target_distence);
|
||||
status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(src_result.size(), NQ);
|
||||
double span, reduce_cost = 0.0;
|
||||
milvus::TimeRecorder rc("");
|
||||
|
||||
ms::scheduler::ResultSet target_result;
|
||||
status = ms::scheduler::XSearchTask::TopkResult(target_result, TOP_K, ascending, target_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
for (int32_t i = 0; i < index_file_num; i++) {
|
||||
BuildResult(nq, top_k, ascending, input_ids, input_distance);
|
||||
|
||||
status = ms::scheduler::XSearchTask::TopkResult(target_result, TOP_K, ascending, src_result);
|
||||
ASSERT_FALSE(status.ok());
|
||||
rc.RecordSection("do search for context: " + std::to_string(i));
|
||||
|
||||
status = ms::scheduler::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_TRUE(src_result.empty());
|
||||
ASSERT_EQ(target_result.size(), NQ);
|
||||
// pick up topk result
|
||||
status = XSearchTask::TopkResult(input_ids, input_distance, top_k, nq, top_k, ascending, final_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(final_result.size(), nq);
|
||||
|
||||
std::vector<int64_t> src_ids;
|
||||
std::vector<float> src_distence;
|
||||
uint64_t wrong_topk = TOP_K - 10;
|
||||
BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence);
|
||||
|
||||
status = ms::scheduler::XSearchTask::ClusterResult(src_ids, src_distence, NQ, wrong_topk, src_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
|
||||
status = ms::scheduler::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
for (uint64_t i = 0; i < NQ; i++) {
|
||||
ASSERT_EQ(target_result[i].size(), TOP_K);
|
||||
}
|
||||
|
||||
wrong_topk = TOP_K + 10;
|
||||
BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence);
|
||||
|
||||
status = ms::scheduler::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
for (uint64_t i = 0; i < NQ; i++) {
|
||||
ASSERT_EQ(target_result[i].size(), TOP_K);
|
||||
span = rc.RecordSection("reduce topk for context: " + std::to_string(i));
|
||||
reduce_cost += span;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DBSearchTest, MERGE_TEST) {
|
||||
bool ascending = true;
|
||||
std::vector<int64_t> target_ids;
|
||||
std::vector<float> target_distence;
|
||||
std::vector<int64_t> src_ids;
|
||||
std::vector<float> src_distence;
|
||||
ms::scheduler::ResultSet src_result, target_result;
|
||||
|
||||
uint64_t src_count = 5, target_count = 8;
|
||||
BuildResult(1, src_count, ascending, src_ids, src_distence);
|
||||
BuildResult(1, target_count, ascending, target_ids, target_distence);
|
||||
auto status = ms::scheduler::XSearchTask::ClusterResult(src_ids, src_distence, 1, src_count, src_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, 1, target_count, target_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
|
||||
{
|
||||
ms::scheduler::Id2DistanceMap src = src_result[0];
|
||||
ms::scheduler::Id2DistanceMap target = target_result[0];
|
||||
status = ms::scheduler::XSearchTask::MergeResult(src, target, 10, ascending);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(target.size(), 10);
|
||||
CheckResult(src_result[0], target_result[0], target, ascending);
|
||||
}
|
||||
|
||||
{
|
||||
ms::scheduler::Id2DistanceMap src = src_result[0];
|
||||
ms::scheduler::Id2DistanceMap target;
|
||||
status = ms::scheduler::XSearchTask::MergeResult(src, target, 10, ascending);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(target.size(), src_count);
|
||||
ASSERT_TRUE(src.empty());
|
||||
CheckResult(src_result[0], target_result[0], target, ascending);
|
||||
}
|
||||
|
||||
{
|
||||
ms::scheduler::Id2DistanceMap src = src_result[0];
|
||||
ms::scheduler::Id2DistanceMap target = target_result[0];
|
||||
status = ms::scheduler::XSearchTask::MergeResult(src, target, 30, ascending);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(target.size(), src_count + target_count);
|
||||
CheckResult(src_result[0], target_result[0], target, ascending);
|
||||
}
|
||||
|
||||
{
|
||||
ms::scheduler::Id2DistanceMap target = src_result[0];
|
||||
ms::scheduler::Id2DistanceMap src = target_result[0];
|
||||
status = ms::scheduler::XSearchTask::MergeResult(src, target, 30, ascending);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(target.size(), src_count + target_count);
|
||||
CheckResult(src_result[0], target_result[0], target, ascending);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DBSearchTest, PARALLEL_CLUSTER_TEST) {
|
||||
bool ascending = true;
|
||||
std::vector<int64_t> target_ids;
|
||||
std::vector<float> target_distence;
|
||||
ms::scheduler::ResultSet src_result;
|
||||
|
||||
auto DoCluster = [&](int64_t nq, int64_t topk) {
|
||||
ms::TimeRecorder rc("DoCluster");
|
||||
src_result.clear();
|
||||
BuildResult(nq, topk, ascending, target_ids, target_distence);
|
||||
rc.RecordSection("build id/dietance map");
|
||||
|
||||
auto status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
ASSERT_EQ(src_result.size(), nq);
|
||||
|
||||
rc.RecordSection("cluster result");
|
||||
|
||||
CheckCluster(target_ids, target_distence, src_result, nq, topk);
|
||||
rc.RecordSection("check result");
|
||||
};
|
||||
|
||||
DoCluster(10000, 1000);
|
||||
DoCluster(333, 999);
|
||||
DoCluster(1, 1000);
|
||||
DoCluster(1, 1);
|
||||
DoCluster(7, 0);
|
||||
DoCluster(9999, 1);
|
||||
DoCluster(10001, 1);
|
||||
DoCluster(58273, 1234);
|
||||
}
|
||||
|
||||
TEST(DBSearchTest, PARALLEL_TOPK_TEST) {
|
||||
std::vector<int64_t> target_ids;
|
||||
std::vector<float> target_distence;
|
||||
ms::scheduler::ResultSet src_result;
|
||||
|
||||
std::vector<int64_t> insufficient_ids;
|
||||
std::vector<float> insufficient_distence;
|
||||
ms::scheduler::ResultSet insufficient_result;
|
||||
|
||||
auto DoTopk = [&](int64_t nq, int64_t topk, int64_t insufficient_topk, bool ascending) {
|
||||
src_result.clear();
|
||||
insufficient_result.clear();
|
||||
|
||||
ms::TimeRecorder rc("DoCluster");
|
||||
|
||||
BuildResult(nq, topk, ascending, target_ids, target_distence);
|
||||
auto status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result);
|
||||
rc.RecordSection("cluster result");
|
||||
|
||||
BuildResult(nq, insufficient_topk, ascending, insufficient_ids, insufficient_distence);
|
||||
status = ms::scheduler::XSearchTask::ClusterResult(target_ids,
|
||||
target_distence,
|
||||
nq,
|
||||
insufficient_topk,
|
||||
insufficient_result);
|
||||
rc.RecordSection("cluster result");
|
||||
|
||||
ms::scheduler::XSearchTask::TopkResult(insufficient_result, topk, ascending, src_result);
|
||||
ASSERT_TRUE(status.ok());
|
||||
rc.RecordSection("topk");
|
||||
|
||||
CheckTopkResult(src_result, ascending, nq, topk);
|
||||
rc.RecordSection("check result");
|
||||
};
|
||||
|
||||
DoTopk(5, 10, 4, false);
|
||||
DoTopk(20005, 998, 123, true);
|
||||
// DoTopk(9987, 12, 10, false);
|
||||
// DoTopk(77777, 1000, 1, false);
|
||||
// DoTopk(5432, 8899, 8899, true);
|
||||
std::cout << "total reduce time: " << reduce_cost/1000 << " ms" << std::endl;
|
||||
}
|
||||
|
||||
@ -123,6 +123,7 @@ DBTest::TearDown() {
|
||||
ms::scheduler::JobMgrInst::GetInstance()->Stop();
|
||||
ms::scheduler::SchedInst::GetInstance()->Stop();
|
||||
ms::scheduler::ResMgrInst::GetInstance()->Stop();
|
||||
ms::scheduler::ResMgrInst::GetInstance()->Clear();
|
||||
|
||||
BaseTest::TearDown();
|
||||
|
||||
|
||||
@ -24,7 +24,7 @@ namespace milvus {
|
||||
namespace scheduler {
|
||||
|
||||
TEST(TaskTest, INVALID_INDEX) {
|
||||
auto search_task = std::make_shared<XSearchTask>(nullptr);
|
||||
auto search_task = std::make_shared<XSearchTask>(nullptr, nullptr);
|
||||
search_task->Load(LoadType::TEST, 10);
|
||||
}
|
||||
|
||||
|
||||
@ -54,7 +54,8 @@ TEST(NormalTest, INST_TEST) {
|
||||
ASSERT_FALSE(disks.empty());
|
||||
if (auto observe = disks[0].lock()) {
|
||||
for (uint64_t i = 0; i < NUM_TASK; ++i) {
|
||||
auto task = std::make_shared<ms::TestTask>(dummy);
|
||||
auto label = std::make_shared<ms::DefaultLabel>();
|
||||
auto task = std::make_shared<ms::TestTask>(dummy, label);
|
||||
task->label() = std::make_shared<ms::DefaultLabel>();
|
||||
tasks.push_back(task);
|
||||
observe->task_table().Put(task);
|
||||
|
||||
@ -23,6 +23,7 @@
|
||||
#include "scheduler/resource/TestResource.h"
|
||||
#include "scheduler/task/Task.h"
|
||||
#include "scheduler/task/TestTask.h"
|
||||
#include "scheduler/tasklabel/DefaultLabel.h"
|
||||
#include "scheduler/ResourceFactory.h"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
@ -185,7 +186,8 @@ TEST_F(ResourceAdvanceTest, DISK_RESOURCE_TEST) {
|
||||
std::vector<std::shared_ptr<TestTask>> tasks;
|
||||
TableFileSchemaPtr dummy = nullptr;
|
||||
for (uint64_t i = 0; i < NUM; ++i) {
|
||||
auto task = std::make_shared<TestTask>(dummy);
|
||||
auto label = std::make_shared<DefaultLabel>();
|
||||
auto task = std::make_shared<TestTask>(dummy, label);
|
||||
tasks.push_back(task);
|
||||
disk_resource_->task_table().Put(task);
|
||||
}
|
||||
@ -210,7 +212,8 @@ TEST_F(ResourceAdvanceTest, CPU_RESOURCE_TEST) {
|
||||
std::vector<std::shared_ptr<TestTask>> tasks;
|
||||
TableFileSchemaPtr dummy = nullptr;
|
||||
for (uint64_t i = 0; i < NUM; ++i) {
|
||||
auto task = std::make_shared<TestTask>(dummy);
|
||||
auto label = std::make_shared<DefaultLabel>();
|
||||
auto task = std::make_shared<TestTask>(dummy, label);
|
||||
tasks.push_back(task);
|
||||
cpu_resource_->task_table().Put(task);
|
||||
}
|
||||
@ -235,7 +238,8 @@ TEST_F(ResourceAdvanceTest, GPU_RESOURCE_TEST) {
|
||||
std::vector<std::shared_ptr<TestTask>> tasks;
|
||||
TableFileSchemaPtr dummy = nullptr;
|
||||
for (uint64_t i = 0; i < NUM; ++i) {
|
||||
auto task = std::make_shared<TestTask>(dummy);
|
||||
auto label = std::make_shared<DefaultLabel>();
|
||||
auto task = std::make_shared<TestTask>(dummy, label);
|
||||
tasks.push_back(task);
|
||||
gpu_resource_->task_table().Put(task);
|
||||
}
|
||||
@ -260,7 +264,8 @@ TEST_F(ResourceAdvanceTest, TEST_RESOURCE_TEST) {
|
||||
std::vector<std::shared_ptr<TestTask>> tasks;
|
||||
TableFileSchemaPtr dummy = nullptr;
|
||||
for (uint64_t i = 0; i < NUM; ++i) {
|
||||
auto task = std::make_shared<TestTask>(dummy);
|
||||
auto label = std::make_shared<DefaultLabel>();
|
||||
auto task = std::make_shared<TestTask>(dummy, label);
|
||||
tasks.push_back(task);
|
||||
test_resource_->task_table().Put(task);
|
||||
}
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
#include "scheduler/resource/DiskResource.h"
|
||||
#include "scheduler/resource/TestResource.h"
|
||||
#include "scheduler/task/TestTask.h"
|
||||
#include "scheduler/tasklabel/DefaultLabel.h"
|
||||
#include "scheduler/ResourceMgr.h"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
@ -184,7 +185,8 @@ TEST_F(ResourceMgrAdvanceTest, REGISTER_SUBSCRIBER) {
|
||||
};
|
||||
mgr1_->RegisterSubscriber(callback);
|
||||
TableFileSchemaPtr dummy = nullptr;
|
||||
disk_res->task_table().Put(std::make_shared<TestTask>(dummy));
|
||||
auto label = std::make_shared<DefaultLabel>();
|
||||
disk_res->task_table().Put(std::make_shared<TestTask>(dummy, label));
|
||||
sleep(1);
|
||||
ASSERT_TRUE(flag);
|
||||
}
|
||||
|
||||
@ -155,7 +155,8 @@ TEST_F(SchedulerTest, ON_LOAD_COMPLETED) {
|
||||
insert_dummy_index_into_gpu_cache(1);
|
||||
|
||||
for (uint64_t i = 0; i < NUM; ++i) {
|
||||
auto task = std::make_shared<TestTask>(dummy);
|
||||
auto label = std::make_shared<DefaultLabel>();
|
||||
auto task = std::make_shared<TestTask>(dummy, label);
|
||||
task->label() = std::make_shared<DefaultLabel>();
|
||||
tasks.push_back(task);
|
||||
cpu_resource_.lock()->task_table().Put(task);
|
||||
@ -174,7 +175,8 @@ TEST_F(SchedulerTest, PUSH_TASK_TO_NEIGHBOUR_RANDOMLY_TEST) {
|
||||
tasks.clear();
|
||||
|
||||
for (uint64_t i = 0; i < NUM; ++i) {
|
||||
auto task = std::make_shared<TestTask>(dummy1);
|
||||
auto label = std::make_shared<DefaultLabel>();
|
||||
auto task = std::make_shared<TestTask>(dummy1, label);
|
||||
task->label() = std::make_shared<DefaultLabel>();
|
||||
tasks.push_back(task);
|
||||
cpu_resource_.lock()->task_table().Put(task);
|
||||
@ -242,7 +244,8 @@ TEST_F(SchedulerTest2, SPECIFIED_RESOURCE_TEST) {
|
||||
dummy->location_ = "location";
|
||||
|
||||
for (uint64_t i = 0; i < NUM; ++i) {
|
||||
std::shared_ptr<TestTask> task = std::make_shared<TestTask>(dummy);
|
||||
auto label = std::make_shared<DefaultLabel>();
|
||||
std::shared_ptr<TestTask> task = std::make_shared<TestTask>(dummy, label);
|
||||
task->label() = std::make_shared<SpecResLabel>(disk_);
|
||||
tasks.push_back(task);
|
||||
disk_.lock()->task_table().Put(task);
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
|
||||
#include "scheduler/TaskTable.h"
|
||||
#include "scheduler/task/TestTask.h"
|
||||
#include "scheduler/tasklabel/DefaultLabel.h"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace {
|
||||
@ -172,8 +173,9 @@ class TaskTableBaseTest : public ::testing::Test {
|
||||
SetUp() override {
|
||||
ms::TableFileSchemaPtr dummy = nullptr;
|
||||
invalid_task_ = nullptr;
|
||||
task1_ = std::make_shared<ms::TestTask>(dummy);
|
||||
task2_ = std::make_shared<ms::TestTask>(dummy);
|
||||
auto label = std::make_shared<ms::DefaultLabel>();
|
||||
task1_ = std::make_shared<ms::TestTask>(dummy, label);
|
||||
task2_ = std::make_shared<ms::TestTask>(dummy, label);
|
||||
}
|
||||
|
||||
ms::TaskPtr invalid_task_;
|
||||
@ -340,7 +342,8 @@ class TaskTableAdvanceTest : public ::testing::Test {
|
||||
SetUp() override {
|
||||
ms::TableFileSchemaPtr dummy = nullptr;
|
||||
for (uint64_t i = 0; i < 8; ++i) {
|
||||
auto task = std::make_shared<ms::TestTask>(dummy);
|
||||
auto label = std::make_shared<ms::DefaultLabel>();
|
||||
auto task = std::make_shared<ms::TestTask>(dummy, label);
|
||||
table1_.Put(task);
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user