diff --git a/cpp/CHANGELOG.md b/cpp/CHANGELOG.md index a2fb205e93..b642e12661 100644 --- a/cpp/CHANGELOG.md +++ b/cpp/CHANGELOG.md @@ -14,15 +14,19 @@ Please mark all change in change log and use the ticket from JIRA. ## Improvement - MS-552 - Add and change the easylogging library - MS-553 - Refine cache code -- MS-557 - Merge Log.h +- MS-555 - Remove old scheduler - MS-556 - Add Job Definition in Scheduler +- MS-557 - Merge Log.h - MS-558 - Refine status code - MS-562 - Add JobMgr and TaskCreator in Scheduler - MS-566 - Refactor cmake -- MS-555 - Remove old scheduler - MS-574 - Milvus configuration refactor - MS-578 - Make sure milvus5.0 don't crack 0.3.1 data - MS-585 - Update namespace in scheduler +- MS-606 - Speed up result reduce +- MS-608 - Update TODO names +- MS-609 - Update task construct function +- MS-611 - Add resources validity check in ResourceMgr ## New Feature @@ -36,6 +40,7 @@ Please mark all change in change log and use the ticket from JIRA. - MS-590 - Refine cmake code to support cpplint - MS-600 - Reconstruct unittest code - MS-602 - Remove zilliz namespace +- MS-610 - Change error code base value from hex to decimal # Milvus 0.4.0 (2019-09-12) diff --git a/cpp/src/scheduler/ResourceMgr.cpp b/cpp/src/scheduler/ResourceMgr.cpp index 9906c71bbe..3ea8a56ef8 100644 --- a/cpp/src/scheduler/ResourceMgr.cpp +++ b/cpp/src/scheduler/ResourceMgr.cpp @@ -24,6 +24,12 @@ namespace scheduler { void ResourceMgr::Start() { + if (not check_resource_valid()) { + ENGINE_LOG_ERROR << "Resources invalid, cannot start ResourceMgr."; + ENGINE_LOG_ERROR << Dump(); + return; + } + std::lock_guard lck(resources_mutex_); for (auto& resource : resources_) { resource->Start(); @@ -60,8 +66,22 @@ ResourceMgr::Add(ResourcePtr&& resource) { resource->RegisterSubscriber(std::bind(&ResourceMgr::post_event, this, std::placeholders::_1)); - if (resource->type() == ResourceType::DISK) { - disk_resources_.emplace_back(ResourceWPtr(resource)); + switch (resource->type()) { + case ResourceType::DISK: { + disk_resources_.emplace_back(ResourceWPtr(resource)); + break; + } + case ResourceType::CPU: { + cpu_resources_.emplace_back(ResourceWPtr(resource)); + break; + } + case ResourceType::GPU: { + gpu_resources_.emplace_back(ResourceWPtr(resource)); + break; + } + default: { + break; + } } resources_.emplace_back(resource); @@ -74,7 +94,7 @@ ResourceMgr::Connect(const std::string& name1, const std::string& name2, Connect auto res2 = GetResource(name2); if (res1 && res2) { res1->AddNeighbour(std::static_pointer_cast(res2), connection); - // TODO(wxy): enable when task balance supported + // TODO(wxyu): enable when task balance supported // res2->AddNeighbour(std::static_pointer_cast(res1), connection); return true; } @@ -85,6 +105,8 @@ void ResourceMgr::Clear() { std::lock_guard lck(resources_mutex_); disk_resources_.clear(); + cpu_resources_.clear(); + gpu_resources_.clear(); resources_.clear(); } @@ -148,14 +170,14 @@ ResourceMgr::GetNumGpuResource() const { std::string ResourceMgr::Dump() { - std::string str = "ResourceMgr contains " + std::to_string(resources_.size()) + " resources.\n"; + std::stringstream ss; + ss << "ResourceMgr contains " << resources_.size() << " resources." << std::endl; - for (uint64_t i = 0; i < resources_.size(); ++i) { - str += "Resource No." + std::to_string(i) + ":\n"; - // str += resources_[i]->Dump(); + for (auto& res : resources_) { + ss << res->Dump(); } - return str; + return ss.str(); } std::string @@ -170,6 +192,34 @@ ResourceMgr::DumpTaskTables() { return ss.str(); } +bool +ResourceMgr::check_resource_valid() { + { + // TODO: check one disk-resource, one cpu-resource, zero or more gpu-resource; + if (GetDiskResources().size() != 1) return false; + if (GetCpuResources().size() != 1) return false; + } + + { + // TODO: one compute-resource at least; + if (GetNumOfComputeResource() < 1) return false; + } + + { + // TODO: check disk only connect with cpu + } + + { + // TODO: check gpu only connect with cpu + } + + { + // TODO: check if exists isolated node + } + + return true; +} + void ResourceMgr::post_event(const EventPtr& event) { { @@ -183,7 +233,9 @@ void ResourceMgr::event_process() { while (running_) { std::unique_lock lock(event_mutex_); - event_cv_.wait(lock, [this] { return !queue_.empty(); }); + event_cv_.wait(lock, [this] { + return !queue_.empty(); + }); auto event = queue_.front(); queue_.pop(); diff --git a/cpp/src/scheduler/ResourceMgr.h b/cpp/src/scheduler/ResourceMgr.h index a81e6c239f..7a8e1ca4ca 100644 --- a/cpp/src/scheduler/ResourceMgr.h +++ b/cpp/src/scheduler/ResourceMgr.h @@ -64,7 +64,17 @@ class ResourceMgr { return disk_resources_; } - // TODO(wxy): why return shared pointer + inline std::vector& + GetCpuResources() { + return cpu_resources_; + } + + inline std::vector& + GetGpuResources() { + return gpu_resources_; + } + + // TODO(wxyu): why return shared pointer inline std::vector GetAllResources() { return resources_; @@ -89,7 +99,7 @@ class ResourceMgr { GetNumGpuResource() const; public: - // TODO(wxy): add stats interface(low) + // TODO(wxyu): add stats interface(low) public: /******** Utility Functions ********/ @@ -100,6 +110,9 @@ class ResourceMgr { DumpTaskTables(); private: + bool + check_resource_valid(); + void post_event(const EventPtr& event); @@ -110,6 +123,8 @@ class ResourceMgr { bool running_ = false; std::vector disk_resources_; + std::vector cpu_resources_; + std::vector gpu_resources_; std::vector resources_; mutable std::mutex resources_mutex_; diff --git a/cpp/src/scheduler/SchedInst.cpp b/cpp/src/scheduler/SchedInst.cpp index f17aceb596..b9edbca001 100644 --- a/cpp/src/scheduler/SchedInst.cpp +++ b/cpp/src/scheduler/SchedInst.cpp @@ -146,7 +146,7 @@ load_advance_config() { // } // } catch (const char *msg) { // SERVER_LOG_ERROR << msg; - // // TODO(wxy): throw exception instead + // // TODO(wxyu): throw exception instead // exit(-1); //// throw std::exception(); // } diff --git a/cpp/src/scheduler/Scheduler.cpp b/cpp/src/scheduler/Scheduler.cpp index 6963ed3c5f..3a82a1b361 100644 --- a/cpp/src/scheduler/Scheduler.cpp +++ b/cpp/src/scheduler/Scheduler.cpp @@ -92,7 +92,7 @@ Scheduler::Process(const EventPtr& event) { process_event(event); } -// TODO(wxy): refactor the function +// TODO(wxyu): refactor the function void Scheduler::OnLoadCompleted(const EventPtr& event) { auto load_completed_event = std::static_pointer_cast(event); diff --git a/cpp/src/scheduler/Scheduler.h b/cpp/src/scheduler/Scheduler.h index 1d8af9f4d4..5b222cc41a 100644 --- a/cpp/src/scheduler/Scheduler.h +++ b/cpp/src/scheduler/Scheduler.h @@ -31,7 +31,7 @@ namespace milvus { namespace scheduler { -// TODO(wxy): refactor, not friendly to unittest, logical in framework code +// TODO(wxyu): refactor, not friendly to unittest, logical in framework code class Scheduler { public: explicit Scheduler(ResourceMgrWPtr res_mgr); diff --git a/cpp/src/scheduler/TaskCreator.cpp b/cpp/src/scheduler/TaskCreator.cpp index 0a7b3f9cbb..83d112918c 100644 --- a/cpp/src/scheduler/TaskCreator.cpp +++ b/cpp/src/scheduler/TaskCreator.cpp @@ -38,7 +38,7 @@ TaskCreator::Create(const JobPtr &job) { return Create(std::static_pointer_cast(job)); } default: { - // TODO(wxy): error + // TODO(wxyu): error return std::vector(); } } @@ -47,9 +47,9 @@ TaskCreator::Create(const JobPtr &job) { std::vector TaskCreator::Create(const SearchJobPtr &job) { std::vector tasks; - for (auto &index_file : job->index_files()) { - auto task = std::make_shared(index_file.second); - task->label() = std::make_shared(); + for (auto& index_file : job->index_files()) { + auto label = std::make_shared(); + auto task = std::make_shared(index_file.second, label); task->job_ = job; tasks.emplace_back(task); } @@ -60,8 +60,8 @@ TaskCreator::Create(const SearchJobPtr &job) { std::vector TaskCreator::Create(const DeleteJobPtr &job) { std::vector tasks; - auto task = std::make_shared(job); - task->label() = std::make_shared(); + auto label = std::make_shared(); + auto task = std::make_shared(job, label); task->job_ = job; tasks.emplace_back(task); diff --git a/cpp/src/scheduler/TaskTable.h b/cpp/src/scheduler/TaskTable.h index 35becbe5f8..ad81b5d439 100644 --- a/cpp/src/scheduler/TaskTable.h +++ b/cpp/src/scheduler/TaskTable.h @@ -125,7 +125,7 @@ class TaskTable { Get(uint64_t index); /* - * TODO(wxy): BIG GC + * TODO(wxyu): BIG GC * Remove sequence task which is DONE or MOVED from front; * Called by ? */ @@ -173,7 +173,7 @@ class TaskTable { public: /******** Action ********/ - // TODO(wxy): bool to Status + // TODO(wxyu): bool to Status /* * Load a task; * Set state loading; diff --git a/cpp/src/scheduler/action/PushTaskToNeighbour.cpp b/cpp/src/scheduler/action/PushTaskToNeighbour.cpp index 828f0c71c6..f5184d4750 100644 --- a/cpp/src/scheduler/action/PushTaskToNeighbour.cpp +++ b/cpp/src/scheduler/action/PushTaskToNeighbour.cpp @@ -84,7 +84,7 @@ Action::PushTaskToNeighbourRandomly(const TaskPtr &task, const ResourcePtr &self } } else { - // TODO(wxy): process + // TODO(wxyu): process } } diff --git a/cpp/src/scheduler/job/SearchJob.h b/cpp/src/scheduler/job/SearchJob.h index aed40cd942..fb2d87d876 100644 --- a/cpp/src/scheduler/job/SearchJob.h +++ b/cpp/src/scheduler/job/SearchJob.h @@ -37,8 +37,9 @@ namespace scheduler { using engine::meta::TableFileSchemaPtr; using Id2IndexMap = std::unordered_map; -using Id2DistanceMap = std::vector>; -using ResultSet = std::vector; +using IdDistPair = std::pair; +using Id2DistVec = std::vector; +using ResultSet = std::vector; class SearchJob : public Job { public: diff --git a/cpp/src/scheduler/resource/Resource.h b/cpp/src/scheduler/resource/Resource.h index e8b50fe9af..c9026f13b6 100644 --- a/cpp/src/scheduler/resource/Resource.h +++ b/cpp/src/scheduler/resource/Resource.h @@ -38,7 +38,7 @@ namespace milvus { namespace scheduler { -// TODO(wxy): Storage, Route, Executor +// TODO(wxyu): Storage, Route, Executor enum class ResourceType { DISK = 0, CPU = 1, @@ -114,11 +114,11 @@ class Resource : public Node, public std::enable_shared_from_this { return enable_executor_; } - // TODO(wxy): const + // TODO(wxyu): const uint64_t NumOfTaskToExec(); - // TODO(wxy): need double ? + // TODO(wxyu): need double ? inline uint64_t TaskAvgCost() const { return total_cost_ / total_task_; diff --git a/cpp/src/scheduler/task/DeleteTask.cpp b/cpp/src/scheduler/task/DeleteTask.cpp index 480fb86056..bffe78cf8f 100644 --- a/cpp/src/scheduler/task/DeleteTask.cpp +++ b/cpp/src/scheduler/task/DeleteTask.cpp @@ -17,11 +17,13 @@ #include "scheduler/task/DeleteTask.h" +#include + namespace milvus { namespace scheduler { -XDeleteTask::XDeleteTask(const scheduler::DeleteJobPtr& delete_job) - : Task(TaskType::DeleteTask), delete_job_(delete_job) { +XDeleteTask::XDeleteTask(const scheduler::DeleteJobPtr& delete_job, TaskLabelPtr label) + : Task(TaskType::DeleteTask, std::move(label)), delete_job_(delete_job) { } void diff --git a/cpp/src/scheduler/task/DeleteTask.h b/cpp/src/scheduler/task/DeleteTask.h index 75f0969bff..fd5222ba4e 100644 --- a/cpp/src/scheduler/task/DeleteTask.h +++ b/cpp/src/scheduler/task/DeleteTask.h @@ -25,7 +25,7 @@ namespace scheduler { class XDeleteTask : public Task { public: - explicit XDeleteTask(const scheduler::DeleteJobPtr& delete_job); + explicit XDeleteTask(const scheduler::DeleteJobPtr& delete_job, TaskLabelPtr label); void Load(LoadType type, uint8_t device_id) override; diff --git a/cpp/src/scheduler/task/SearchTask.cpp b/cpp/src/scheduler/task/SearchTask.cpp index 2c3c6f2288..20962d8a10 100644 --- a/cpp/src/scheduler/task/SearchTask.cpp +++ b/cpp/src/scheduler/task/SearchTask.cpp @@ -78,24 +78,26 @@ std::mutex XSearchTask::merge_mutex_; void CollectFileMetrics(int file_type, size_t file_size) { + server::MetricsBase& inst = server::Metrics::GetInstance(); switch (file_type) { case TableFileSchema::RAW: case TableFileSchema::TO_INDEX: { - server::Metrics::GetInstance().RawFileSizeHistogramObserve(file_size); - server::Metrics::GetInstance().RawFileSizeTotalIncrement(file_size); - server::Metrics::GetInstance().RawFileSizeGaugeSet(file_size); + inst.RawFileSizeHistogramObserve(file_size); + inst.RawFileSizeTotalIncrement(file_size); + inst.RawFileSizeGaugeSet(file_size); break; } default: { - server::Metrics::GetInstance().IndexFileSizeHistogramObserve(file_size); - server::Metrics::GetInstance().IndexFileSizeTotalIncrement(file_size); - server::Metrics::GetInstance().IndexFileSizeGaugeSet(file_size); + inst.IndexFileSizeHistogramObserve(file_size); + inst.IndexFileSizeTotalIncrement(file_size); + inst.IndexFileSizeGaugeSet(file_size); break; } } } -XSearchTask::XSearchTask(TableFileSchemaPtr file) : Task(TaskType::SearchTask), file_(file) { +XSearchTask::XSearchTask(TableFileSchemaPtr file, TaskLabelPtr label) + : Task(TaskType::SearchTask, std::move(label)), file_(file) { if (file_) { if (file_->metric_type_ != static_cast(MetricType::L2)) { metric_l2 = false; @@ -205,16 +207,9 @@ XSearchTask::Execute() { double span = rc.RecordSection(hdr + ", do search"); // search_job->AccumSearchCost(span); - // step 3: cluster result - scheduler::ResultSet result_set; + // step 3: pick up topk result auto spec_k = index_engine_->Count() < topk ? index_engine_->Count() : topk; - XSearchTask::ClusterResult(output_ids, output_distance, nq, spec_k, result_set); - - span = rc.RecordSection(hdr + ", cluster result"); - // search_job->AccumReduceCost(span); - - // step 4: pick up topk result - XSearchTask::TopkResult(result_set, topk, metric_l2, search_job->GetResult()); + XSearchTask::TopkResult(output_ids, output_distance, spec_k, nq, topk, metric_l2, search_job->GetResult()); span = rc.RecordSection(hdr + ", reduce topk"); // search_job->AccumReduceCost(span); @@ -234,142 +229,75 @@ XSearchTask::Execute() { } Status -XSearchTask::ClusterResult(const std::vector& output_ids, const std::vector& output_distance, - uint64_t nq, uint64_t topk, scheduler::ResultSet& result_set) { - if (output_ids.size() < nq * topk || output_distance.size() < nq * topk) { - std::string msg = "Invalid id array size: " + std::to_string(output_ids.size()) + " distance array size: " + - std::to_string(output_distance.size()); - ENGINE_LOG_ERROR << msg; - return Status(DB_ERROR, msg); - } +XSearchTask::TopkResult(const std::vector &input_ids, + const std::vector &input_distance, + uint64_t input_k, + uint64_t nq, + uint64_t topk, + bool ascending, + scheduler::ResultSet &result) { + scheduler::ResultSet result_buf; - result_set.clear(); - result_set.resize(nq); - - std::function reduce_worker = [&](size_t from_index, size_t to_index) { - for (auto i = from_index; i < to_index; i++) { - scheduler::Id2DistanceMap id_distance; - id_distance.reserve(topk); - for (auto k = 0; k < topk; k++) { - uint64_t index = i * topk + k; - if (output_ids[index] < 0) { - continue; + if (result.empty()) { + result_buf.resize(nq, scheduler::Id2DistVec(input_k, scheduler::IdDistPair(-1, 0.0))); + for (auto i = 0; i < nq; ++i) { + auto& result_buf_i = result_buf[i]; + uint64_t input_k_multi_i = input_k * i; + for (auto k = 0; k < input_k; ++k) { + uint64_t idx = input_k_multi_i + k; + auto& result_buf_item = result_buf_i[k]; + result_buf_item.first = input_ids[idx]; + result_buf_item.second = input_distance[idx]; + } + } + } else { + size_t tar_size = result[0].size(); + uint64_t output_k = std::min(topk, input_k + tar_size); + result_buf.resize(nq, scheduler::Id2DistVec(output_k, scheduler::IdDistPair(-1, 0.0))); + for (auto i = 0; i < nq; ++i) { + size_t buf_k = 0, src_k = 0, tar_k = 0; + uint64_t src_idx; + auto& result_i = result[i]; + auto& result_buf_i = result_buf[i]; + uint64_t input_k_multi_i = input_k * i; + while (buf_k < output_k && src_k < input_k && tar_k < tar_size) { + src_idx = input_k_multi_i + src_k; + auto& result_buf_item = result_buf_i[buf_k]; + auto& result_item = result_i[tar_k]; + if ((ascending && input_distance[src_idx] < result_item.second) || + (!ascending && input_distance[src_idx] > result_item.second)) { + result_buf_item.first = input_ids[src_idx]; + result_buf_item.second = input_distance[src_idx]; + src_k++; + } else { + result_buf_item = result_item; + tar_k++; } - id_distance.push_back(std::make_pair(output_ids[index], output_distance[index])); + buf_k++; } - result_set[i] = id_distance; - } - }; - // if (NeedParallelReduce(nq, topk)) { - // ParallelReduce(reduce_worker, nq); - // } else { - reduce_worker(0, nq); - // } - - return Status::OK(); -} - -Status -XSearchTask::MergeResult(scheduler::Id2DistanceMap& distance_src, scheduler::Id2DistanceMap& distance_target, - uint64_t topk, bool ascending) { - // Note: the score_src and score_target are already arranged by score in ascending order - if (distance_src.empty()) { - ENGINE_LOG_WARNING << "Empty distance source array"; - return Status::OK(); - } - - std::unique_lock lock(merge_mutex_); - if (distance_target.empty()) { - distance_target.swap(distance_src); - return Status::OK(); - } - - size_t src_count = distance_src.size(); - size_t target_count = distance_target.size(); - scheduler::Id2DistanceMap distance_merged; - distance_merged.reserve(topk); - size_t src_index = 0, target_index = 0; - while (true) { - // all score_src items are merged, if score_merged.size() still less than topk - // move items from score_target to score_merged until score_merged.size() equal topk - if (src_index >= src_count) { - for (size_t i = target_index; i < target_count && distance_merged.size() < topk; ++i) { - distance_merged.push_back(distance_target[i]); + if (buf_k < topk) { + if (src_k < input_k) { + while (buf_k < output_k && src_k < input_k) { + src_idx = input_k_multi_i + src_k; + auto& result_buf_item = result_buf_i[buf_k]; + result_buf_item.first = input_ids[src_idx]; + result_buf_item.second = input_distance[src_idx]; + src_k++; + buf_k++; + } + } else { + while (buf_k < output_k && tar_k < tar_size) { + result_buf_i[buf_k] = result_i[tar_k]; + tar_k++; + buf_k++; + } + } } - break; - } - - // all score_target items are merged, if score_merged.size() still less than topk - // move items from score_src to score_merged until score_merged.size() equal topk - if (target_index >= target_count) { - for (size_t i = src_index; i < src_count && distance_merged.size() < topk; ++i) { - distance_merged.push_back(distance_src[i]); - } - break; - } - - // compare score, - // if ascending = true, put smallest score to score_merged one by one - // else, put largest score to score_merged one by one - auto& src_pair = distance_src[src_index]; - auto& target_pair = distance_target[target_index]; - if (ascending) { - if (src_pair.second > target_pair.second) { - distance_merged.push_back(target_pair); - target_index++; - } else { - distance_merged.push_back(src_pair); - src_index++; - } - } else { - if (src_pair.second < target_pair.second) { - distance_merged.push_back(target_pair); - target_index++; - } else { - distance_merged.push_back(src_pair); - src_index++; - } - } - - // score_merged.size() already equal topk - if (distance_merged.size() >= topk) { - break; } } - distance_target.swap(distance_merged); - - return Status::OK(); -} - -Status -XSearchTask::TopkResult(scheduler::ResultSet& result_src, uint64_t topk, bool ascending, - scheduler::ResultSet& result_target) { - if (result_target.empty()) { - result_target.swap(result_src); - return Status::OK(); - } - - if (result_src.size() != result_target.size()) { - std::string msg = "Invalid result set size"; - ENGINE_LOG_ERROR << msg; - return Status(DB_ERROR, msg); - } - - std::function ReduceWorker = [&](size_t from_index, size_t to_index) { - for (size_t i = from_index; i < to_index; i++) { - scheduler::Id2DistanceMap& score_src = result_src[i]; - scheduler::Id2DistanceMap& score_target = result_target[i]; - XSearchTask::MergeResult(score_src, score_target, topk, ascending); - } - }; - - // if (NeedParallelReduce(result_src.size(), topk)) { - // ParallelReduce(ReduceWorker, result_src.size()); - // } else { - ReduceWorker(0, result_src.size()); - // } + result.swap(result_buf); return Status::OK(); } diff --git a/cpp/src/scheduler/task/SearchTask.h b/cpp/src/scheduler/task/SearchTask.h index 3143799715..92d7235c6b 100644 --- a/cpp/src/scheduler/task/SearchTask.h +++ b/cpp/src/scheduler/task/SearchTask.h @@ -26,10 +26,10 @@ namespace milvus { namespace scheduler { -// TODO(wxy): rewrite +// TODO(wxyu): rewrite class XSearchTask : public Task { public: - explicit XSearchTask(TableFileSchemaPtr file); + explicit XSearchTask(TableFileSchemaPtr file, TaskLabelPtr label); void Load(LoadType type, uint8_t device_id) override; @@ -39,15 +39,13 @@ class XSearchTask : public Task { public: static Status - ClusterResult(const std::vector& output_ids, const std::vector& output_distence, uint64_t nq, - uint64_t topk, scheduler::ResultSet& result_set); - - static Status - MergeResult(scheduler::Id2DistanceMap& distance_src, scheduler::Id2DistanceMap& distance_target, uint64_t topk, - bool ascending); - - static Status - TopkResult(scheduler::ResultSet& result_src, uint64_t topk, bool ascending, scheduler::ResultSet& result_target); + TopkResult(const std::vector &input_ids, + const std::vector &input_distance, + uint64_t input_k, + uint64_t nq, + uint64_t topk, + bool ascending, + scheduler::ResultSet &result); public: TableFileSchemaPtr file_; diff --git a/cpp/src/scheduler/task/Task.h b/cpp/src/scheduler/task/Task.h index bb4b23ec1e..b77481c94d 100644 --- a/cpp/src/scheduler/task/Task.h +++ b/cpp/src/scheduler/task/Task.h @@ -24,6 +24,7 @@ #include #include +#include namespace milvus { namespace scheduler { @@ -49,7 +50,7 @@ using TaskPtr = std::shared_ptr; // TODO: re-design class Task { public: - explicit Task(TaskType type) : type_(type) { + explicit Task(TaskType type, TaskLabelPtr label) : type_(type), label_(std::move(label)) { } /* diff --git a/cpp/src/scheduler/task/TestTask.cpp b/cpp/src/scheduler/task/TestTask.cpp index 76e814a628..3ec3a8ab19 100644 --- a/cpp/src/scheduler/task/TestTask.cpp +++ b/cpp/src/scheduler/task/TestTask.cpp @@ -18,10 +18,12 @@ #include "scheduler/task/TestTask.h" #include "cache/GpuCacheMgr.h" +#include + namespace milvus { namespace scheduler { -TestTask::TestTask(TableFileSchemaPtr& file) : XSearchTask(file) { +TestTask::TestTask(TableFileSchemaPtr& file, TaskLabelPtr label) : XSearchTask(file, std::move(label)) { } void diff --git a/cpp/src/scheduler/task/TestTask.h b/cpp/src/scheduler/task/TestTask.h index 3ad9cb16e1..99b48a8afe 100644 --- a/cpp/src/scheduler/task/TestTask.h +++ b/cpp/src/scheduler/task/TestTask.h @@ -24,7 +24,7 @@ namespace scheduler { class TestTask : public XSearchTask { public: - explicit TestTask(TableFileSchemaPtr& file); + explicit TestTask(TableFileSchemaPtr& file, TaskLabelPtr label); public: void diff --git a/cpp/src/utils/Error.h b/cpp/src/utils/Error.h index 81403947c8..9cba18ef41 100644 --- a/cpp/src/utils/Error.h +++ b/cpp/src/utils/Error.h @@ -26,7 +26,7 @@ namespace milvus { using ErrorCode = int32_t; constexpr ErrorCode SERVER_SUCCESS = 0; -constexpr ErrorCode SERVER_ERROR_CODE_BASE = 0x30000; +constexpr ErrorCode SERVER_ERROR_CODE_BASE = 30000; constexpr ErrorCode ToServerErrorCode(const ErrorCode error_code) { @@ -34,7 +34,7 @@ ToServerErrorCode(const ErrorCode error_code) { } constexpr ErrorCode DB_SUCCESS = 0; -constexpr ErrorCode DB_ERROR_CODE_BASE = 0x40000; +constexpr ErrorCode DB_ERROR_CODE_BASE = 40000; constexpr ErrorCode ToDbErrorCode(const ErrorCode error_code) { @@ -42,7 +42,7 @@ ToDbErrorCode(const ErrorCode error_code) { } constexpr ErrorCode KNOWHERE_SUCCESS = 0; -constexpr ErrorCode KNOWHERE_ERROR_CODE_BASE = 0x50000; +constexpr ErrorCode KNOWHERE_ERROR_CODE_BASE = 50000; constexpr ErrorCode ToKnowhereErrorCode(const ErrorCode error_code) { diff --git a/cpp/src/wrapper/ConfAdapterMgr.h b/cpp/src/wrapper/ConfAdapterMgr.h index a88e090760..8d5fa22877 100644 --- a/cpp/src/wrapper/ConfAdapterMgr.h +++ b/cpp/src/wrapper/ConfAdapterMgr.h @@ -20,6 +20,7 @@ #include "ConfAdapter.h" #include "VecIndex.h" +#include #include #include diff --git a/cpp/unittest/db/test_search.cpp b/cpp/unittest/db/test_search.cpp index 12fc8e277a..0b13af0c51 100644 --- a/cpp/unittest/db/test_search.cpp +++ b/cpp/unittest/db/test_search.cpp @@ -22,13 +22,10 @@ #include "scheduler/task/SearchTask.h" #include "utils/TimeRecorder.h" +using namespace milvus::scheduler; + namespace { -namespace ms = milvus; - -static constexpr uint64_t NQ = 15; -static constexpr uint64_t TOP_K = 64; - void BuildResult(uint64_t nq, uint64_t topk, @@ -48,76 +45,36 @@ BuildResult(uint64_t nq, } } -void -CheckResult(const ms::scheduler::Id2DistanceMap &src_1, - const ms::scheduler::Id2DistanceMap &src_2, - const ms::scheduler::Id2DistanceMap &target, - bool ascending) { - for (uint64_t i = 0; i < target.size() - 1; i++) { +void CheckTopkResult(const std::vector &input_ids_1, + const std::vector &input_distance_1, + const std::vector &input_ids_2, + const std::vector &input_distance_2, + uint64_t nq, + uint64_t topk, + bool ascending, + const ResultSet& result) { + ASSERT_EQ(result.size(), nq); + ASSERT_EQ(input_ids_1.size(), input_distance_1.size()); + ASSERT_EQ(input_ids_2.size(), input_distance_2.size()); + + uint64_t input_k1 = input_ids_1.size() / nq; + uint64_t input_k2 = input_ids_2.size() / nq; + + for (int64_t i = 0; i < nq; i++) { + std::vector src_vec(input_distance_1.begin()+i*input_k1, input_distance_1.begin()+(i+1)*input_k1); + src_vec.insert(src_vec.end(), input_distance_2.begin()+i*input_k2, input_distance_2.begin()+(i+1)*input_k2); if (ascending) { - ASSERT_LE(target[i].second, target[i + 1].second); + std::sort(src_vec.begin(), src_vec.end()); } else { - ASSERT_GE(target[i].second, target[i + 1].second); - } - } - - using ID2DistMap = std::map; - ID2DistMap src_map_1, src_map_2; - for (const auto &pair : src_1) { - src_map_1.insert(pair); - } - for (const auto &pair : src_2) { - src_map_2.insert(pair); - } - - for (const auto &pair : target) { - ASSERT_TRUE(src_map_1.find(pair.first) != src_map_1.end() || src_map_2.find(pair.first) != src_map_2.end()); - - float dist = src_map_1.find(pair.first) != src_map_1.end() ? src_map_1[pair.first] : src_map_2[pair.first]; - ASSERT_LT(fabs(pair.second - dist), std::numeric_limits::epsilon()); - } -} - -void -CheckCluster(const std::vector &target_ids, - const std::vector &target_distence, - const ms::scheduler::ResultSet &src_result, - int64_t nq, - int64_t topk) { - ASSERT_EQ(src_result.size(), nq); - for (int64_t i = 0; i < nq; i++) { - auto &res = src_result[i]; - ASSERT_EQ(res.size(), topk); - - if (res.empty()) { - continue; + std::sort(src_vec.begin(), src_vec.end(), std::greater()); } - ASSERT_EQ(res[0].first, target_ids[i * topk]); - ASSERT_EQ(res[topk - 1].first, target_ids[i * topk + topk - 1]); - } -} - -void -CheckTopkResult(const ms::scheduler::ResultSet &src_result, - bool ascending, - int64_t nq, - int64_t topk) { - ASSERT_EQ(src_result.size(), nq); - for (int64_t i = 0; i < nq; i++) { - auto &res = src_result[i]; - ASSERT_EQ(res.size(), topk); - - if (res.empty()) { - continue; - } - - for (int64_t k = 0; k < topk - 1; k++) { - if (ascending) { - ASSERT_LE(res[k].second, res[k + 1].second); - } else { - ASSERT_GE(res[k].second, res[k + 1].second); + uint64_t n = std::min(topk, input_k1+input_k2); + for (uint64_t j = 0; j < n; j++) { + if (src_vec[j] != result[i][j].second) { + std::cout << src_vec[j] << " " << result[i][j].second << std::endl; } + ASSERT_TRUE(src_vec[j] == result[i][j].second); } } } @@ -125,179 +82,117 @@ CheckTopkResult(const ms::scheduler::ResultSet &src_result, } // namespace TEST(DBSearchTest, TOPK_TEST) { + uint64_t NQ = 15; + uint64_t TOP_K = 64; + bool ascending; + std::vector ids1, ids2; + std::vector dist1, dist2; + ResultSet result; + milvus::Status status; + + /* test1, id1/dist1 valid, id2/dist2 empty */ + ascending = true; + BuildResult(NQ, TOP_K, ascending, ids1, dist1); + status = XSearchTask::TopkResult(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + + /* test2, id1/dist1 valid, id2/dist2 valid */ + BuildResult(NQ, TOP_K, ascending, ids2, dist2); + status = XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + + /* test3, id1/dist1 small topk */ + ids1.clear(); + dist1.clear(); + result.clear(); + BuildResult(NQ, TOP_K/2, ascending, ids1, dist1); + status = XSearchTask::TopkResult(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + status = XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + + /* test4, id1/dist1 small topk, id2/dist2 small topk */ + ids2.clear(); + dist2.clear(); + result.clear(); + BuildResult(NQ, TOP_K/3, ascending, ids2, dist2); + status = XSearchTask::TopkResult(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + status = XSearchTask::TopkResult(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + +///////////////////////////////////////////////////////////////////////////////////////// + ascending = false; + ids1.clear(); + dist1.clear(); + ids2.clear(); + dist2.clear(); + result.clear(); + + /* test1, id1/dist1 valid, id2/dist2 empty */ + BuildResult(NQ, TOP_K, ascending, ids1, dist1); + status = XSearchTask::TopkResult(ids1, dist1, TOP_K, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + + /* test2, id1/dist1 valid, id2/dist2 valid */ + BuildResult(NQ, TOP_K, ascending, ids2, dist2); + status = XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + + /* test3, id1/dist1 small topk */ + ids1.clear(); + dist1.clear(); + result.clear(); + BuildResult(NQ, TOP_K/2, ascending, ids1, dist1); + status = XSearchTask::TopkResult(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + status = XSearchTask::TopkResult(ids2, dist2, TOP_K, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); + + /* test4, id1/dist1 small topk, id2/dist2 small topk */ + ids2.clear(); + dist2.clear(); + result.clear(); + BuildResult(NQ, TOP_K/3, ascending, ids2, dist2); + status = XSearchTask::TopkResult(ids1, dist1, TOP_K/2, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + status = XSearchTask::TopkResult(ids2, dist2, TOP_K/3, NQ, TOP_K, ascending, result); + ASSERT_TRUE(status.ok()); + CheckTopkResult(ids1, dist1, ids2, dist2, NQ, TOP_K, ascending, result); +} + +TEST(DBSearchTest, REDUCE_PERF_TEST) { + int32_t nq = 100; + int32_t top_k = 1000; + int32_t index_file_num = 478; /* sift1B dataset, index files num */ bool ascending = true; - std::vector target_ids; - std::vector target_distence; - ms::scheduler::ResultSet src_result; - auto status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result); - ASSERT_FALSE(status.ok()); - ASSERT_TRUE(src_result.empty()); + std::vector input_ids; + std::vector input_distance; + ResultSet final_result; + milvus::Status status; - BuildResult(NQ, TOP_K, ascending, target_ids, target_distence); - status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, NQ, TOP_K, src_result); - ASSERT_TRUE(status.ok()); - ASSERT_EQ(src_result.size(), NQ); + double span, reduce_cost = 0.0; + milvus::TimeRecorder rc(""); - ms::scheduler::ResultSet target_result; - status = ms::scheduler::XSearchTask::TopkResult(target_result, TOP_K, ascending, target_result); - ASSERT_TRUE(status.ok()); + for (int32_t i = 0; i < index_file_num; i++) { + BuildResult(nq, top_k, ascending, input_ids, input_distance); - status = ms::scheduler::XSearchTask::TopkResult(target_result, TOP_K, ascending, src_result); - ASSERT_FALSE(status.ok()); + rc.RecordSection("do search for context: " + std::to_string(i)); - status = ms::scheduler::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result); - ASSERT_TRUE(status.ok()); - ASSERT_TRUE(src_result.empty()); - ASSERT_EQ(target_result.size(), NQ); + // pick up topk result + status = XSearchTask::TopkResult(input_ids, input_distance, top_k, nq, top_k, ascending, final_result); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(final_result.size(), nq); - std::vector src_ids; - std::vector src_distence; - uint64_t wrong_topk = TOP_K - 10; - BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence); - - status = ms::scheduler::XSearchTask::ClusterResult(src_ids, src_distence, NQ, wrong_topk, src_result); - ASSERT_TRUE(status.ok()); - - status = ms::scheduler::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result); - ASSERT_TRUE(status.ok()); - for (uint64_t i = 0; i < NQ; i++) { - ASSERT_EQ(target_result[i].size(), TOP_K); - } - - wrong_topk = TOP_K + 10; - BuildResult(NQ, wrong_topk, ascending, src_ids, src_distence); - - status = ms::scheduler::XSearchTask::TopkResult(src_result, TOP_K, ascending, target_result); - ASSERT_TRUE(status.ok()); - for (uint64_t i = 0; i < NQ; i++) { - ASSERT_EQ(target_result[i].size(), TOP_K); + span = rc.RecordSection("reduce topk for context: " + std::to_string(i)); + reduce_cost += span; } -} - -TEST(DBSearchTest, MERGE_TEST) { - bool ascending = true; - std::vector target_ids; - std::vector target_distence; - std::vector src_ids; - std::vector src_distence; - ms::scheduler::ResultSet src_result, target_result; - - uint64_t src_count = 5, target_count = 8; - BuildResult(1, src_count, ascending, src_ids, src_distence); - BuildResult(1, target_count, ascending, target_ids, target_distence); - auto status = ms::scheduler::XSearchTask::ClusterResult(src_ids, src_distence, 1, src_count, src_result); - ASSERT_TRUE(status.ok()); - status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, 1, target_count, target_result); - ASSERT_TRUE(status.ok()); - - { - ms::scheduler::Id2DistanceMap src = src_result[0]; - ms::scheduler::Id2DistanceMap target = target_result[0]; - status = ms::scheduler::XSearchTask::MergeResult(src, target, 10, ascending); - ASSERT_TRUE(status.ok()); - ASSERT_EQ(target.size(), 10); - CheckResult(src_result[0], target_result[0], target, ascending); - } - - { - ms::scheduler::Id2DistanceMap src = src_result[0]; - ms::scheduler::Id2DistanceMap target; - status = ms::scheduler::XSearchTask::MergeResult(src, target, 10, ascending); - ASSERT_TRUE(status.ok()); - ASSERT_EQ(target.size(), src_count); - ASSERT_TRUE(src.empty()); - CheckResult(src_result[0], target_result[0], target, ascending); - } - - { - ms::scheduler::Id2DistanceMap src = src_result[0]; - ms::scheduler::Id2DistanceMap target = target_result[0]; - status = ms::scheduler::XSearchTask::MergeResult(src, target, 30, ascending); - ASSERT_TRUE(status.ok()); - ASSERT_EQ(target.size(), src_count + target_count); - CheckResult(src_result[0], target_result[0], target, ascending); - } - - { - ms::scheduler::Id2DistanceMap target = src_result[0]; - ms::scheduler::Id2DistanceMap src = target_result[0]; - status = ms::scheduler::XSearchTask::MergeResult(src, target, 30, ascending); - ASSERT_TRUE(status.ok()); - ASSERT_EQ(target.size(), src_count + target_count); - CheckResult(src_result[0], target_result[0], target, ascending); - } -} - -TEST(DBSearchTest, PARALLEL_CLUSTER_TEST) { - bool ascending = true; - std::vector target_ids; - std::vector target_distence; - ms::scheduler::ResultSet src_result; - - auto DoCluster = [&](int64_t nq, int64_t topk) { - ms::TimeRecorder rc("DoCluster"); - src_result.clear(); - BuildResult(nq, topk, ascending, target_ids, target_distence); - rc.RecordSection("build id/dietance map"); - - auto status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result); - ASSERT_TRUE(status.ok()); - ASSERT_EQ(src_result.size(), nq); - - rc.RecordSection("cluster result"); - - CheckCluster(target_ids, target_distence, src_result, nq, topk); - rc.RecordSection("check result"); - }; - - DoCluster(10000, 1000); - DoCluster(333, 999); - DoCluster(1, 1000); - DoCluster(1, 1); - DoCluster(7, 0); - DoCluster(9999, 1); - DoCluster(10001, 1); - DoCluster(58273, 1234); -} - -TEST(DBSearchTest, PARALLEL_TOPK_TEST) { - std::vector target_ids; - std::vector target_distence; - ms::scheduler::ResultSet src_result; - - std::vector insufficient_ids; - std::vector insufficient_distence; - ms::scheduler::ResultSet insufficient_result; - - auto DoTopk = [&](int64_t nq, int64_t topk, int64_t insufficient_topk, bool ascending) { - src_result.clear(); - insufficient_result.clear(); - - ms::TimeRecorder rc("DoCluster"); - - BuildResult(nq, topk, ascending, target_ids, target_distence); - auto status = ms::scheduler::XSearchTask::ClusterResult(target_ids, target_distence, nq, topk, src_result); - rc.RecordSection("cluster result"); - - BuildResult(nq, insufficient_topk, ascending, insufficient_ids, insufficient_distence); - status = ms::scheduler::XSearchTask::ClusterResult(target_ids, - target_distence, - nq, - insufficient_topk, - insufficient_result); - rc.RecordSection("cluster result"); - - ms::scheduler::XSearchTask::TopkResult(insufficient_result, topk, ascending, src_result); - ASSERT_TRUE(status.ok()); - rc.RecordSection("topk"); - - CheckTopkResult(src_result, ascending, nq, topk); - rc.RecordSection("check result"); - }; - - DoTopk(5, 10, 4, false); - DoTopk(20005, 998, 123, true); -// DoTopk(9987, 12, 10, false); -// DoTopk(77777, 1000, 1, false); -// DoTopk(5432, 8899, 8899, true); + std::cout << "total reduce time: " << reduce_cost/1000 << " ms" << std::endl; } diff --git a/cpp/unittest/db/utils.cpp b/cpp/unittest/db/utils.cpp index 61c75a0933..c5874be694 100644 --- a/cpp/unittest/db/utils.cpp +++ b/cpp/unittest/db/utils.cpp @@ -123,6 +123,7 @@ DBTest::TearDown() { ms::scheduler::JobMgrInst::GetInstance()->Stop(); ms::scheduler::SchedInst::GetInstance()->Stop(); ms::scheduler::ResMgrInst::GetInstance()->Stop(); + ms::scheduler::ResMgrInst::GetInstance()->Clear(); BaseTest::TearDown(); diff --git a/cpp/unittest/scheduler/task_test.cpp b/cpp/unittest/scheduler/task_test.cpp index ce91a62a52..07e85c723c 100644 --- a/cpp/unittest/scheduler/task_test.cpp +++ b/cpp/unittest/scheduler/task_test.cpp @@ -24,7 +24,7 @@ namespace milvus { namespace scheduler { TEST(TaskTest, INVALID_INDEX) { - auto search_task = std::make_shared(nullptr); + auto search_task = std::make_shared(nullptr, nullptr); search_task->Load(LoadType::TEST, 10); } diff --git a/cpp/unittest/scheduler/test_normal.cpp b/cpp/unittest/scheduler/test_normal.cpp index fb59b04214..1dbd93e044 100644 --- a/cpp/unittest/scheduler/test_normal.cpp +++ b/cpp/unittest/scheduler/test_normal.cpp @@ -54,7 +54,8 @@ TEST(NormalTest, INST_TEST) { ASSERT_FALSE(disks.empty()); if (auto observe = disks[0].lock()) { for (uint64_t i = 0; i < NUM_TASK; ++i) { - auto task = std::make_shared(dummy); + auto label = std::make_shared(); + auto task = std::make_shared(dummy, label); task->label() = std::make_shared(); tasks.push_back(task); observe->task_table().Put(task); diff --git a/cpp/unittest/scheduler/test_resource.cpp b/cpp/unittest/scheduler/test_resource.cpp index 7245761393..31fe425959 100644 --- a/cpp/unittest/scheduler/test_resource.cpp +++ b/cpp/unittest/scheduler/test_resource.cpp @@ -23,6 +23,7 @@ #include "scheduler/resource/TestResource.h" #include "scheduler/task/Task.h" #include "scheduler/task/TestTask.h" +#include "scheduler/tasklabel/DefaultLabel.h" #include "scheduler/ResourceFactory.h" #include @@ -185,7 +186,8 @@ TEST_F(ResourceAdvanceTest, DISK_RESOURCE_TEST) { std::vector> tasks; TableFileSchemaPtr dummy = nullptr; for (uint64_t i = 0; i < NUM; ++i) { - auto task = std::make_shared(dummy); + auto label = std::make_shared(); + auto task = std::make_shared(dummy, label); tasks.push_back(task); disk_resource_->task_table().Put(task); } @@ -210,7 +212,8 @@ TEST_F(ResourceAdvanceTest, CPU_RESOURCE_TEST) { std::vector> tasks; TableFileSchemaPtr dummy = nullptr; for (uint64_t i = 0; i < NUM; ++i) { - auto task = std::make_shared(dummy); + auto label = std::make_shared(); + auto task = std::make_shared(dummy, label); tasks.push_back(task); cpu_resource_->task_table().Put(task); } @@ -235,7 +238,8 @@ TEST_F(ResourceAdvanceTest, GPU_RESOURCE_TEST) { std::vector> tasks; TableFileSchemaPtr dummy = nullptr; for (uint64_t i = 0; i < NUM; ++i) { - auto task = std::make_shared(dummy); + auto label = std::make_shared(); + auto task = std::make_shared(dummy, label); tasks.push_back(task); gpu_resource_->task_table().Put(task); } @@ -260,7 +264,8 @@ TEST_F(ResourceAdvanceTest, TEST_RESOURCE_TEST) { std::vector> tasks; TableFileSchemaPtr dummy = nullptr; for (uint64_t i = 0; i < NUM; ++i) { - auto task = std::make_shared(dummy); + auto label = std::make_shared(); + auto task = std::make_shared(dummy, label); tasks.push_back(task); test_resource_->task_table().Put(task); } diff --git a/cpp/unittest/scheduler/test_resource_mgr.cpp b/cpp/unittest/scheduler/test_resource_mgr.cpp index 40633baa54..34e6b50c49 100644 --- a/cpp/unittest/scheduler/test_resource_mgr.cpp +++ b/cpp/unittest/scheduler/test_resource_mgr.cpp @@ -21,6 +21,7 @@ #include "scheduler/resource/DiskResource.h" #include "scheduler/resource/TestResource.h" #include "scheduler/task/TestTask.h" +#include "scheduler/tasklabel/DefaultLabel.h" #include "scheduler/ResourceMgr.h" #include @@ -184,7 +185,8 @@ TEST_F(ResourceMgrAdvanceTest, REGISTER_SUBSCRIBER) { }; mgr1_->RegisterSubscriber(callback); TableFileSchemaPtr dummy = nullptr; - disk_res->task_table().Put(std::make_shared(dummy)); + auto label = std::make_shared(); + disk_res->task_table().Put(std::make_shared(dummy, label)); sleep(1); ASSERT_TRUE(flag); } diff --git a/cpp/unittest/scheduler/test_scheduler.cpp b/cpp/unittest/scheduler/test_scheduler.cpp index 9666cc9812..1238f906d1 100644 --- a/cpp/unittest/scheduler/test_scheduler.cpp +++ b/cpp/unittest/scheduler/test_scheduler.cpp @@ -155,7 +155,8 @@ TEST_F(SchedulerTest, ON_LOAD_COMPLETED) { insert_dummy_index_into_gpu_cache(1); for (uint64_t i = 0; i < NUM; ++i) { - auto task = std::make_shared(dummy); + auto label = std::make_shared(); + auto task = std::make_shared(dummy, label); task->label() = std::make_shared(); tasks.push_back(task); cpu_resource_.lock()->task_table().Put(task); @@ -174,7 +175,8 @@ TEST_F(SchedulerTest, PUSH_TASK_TO_NEIGHBOUR_RANDOMLY_TEST) { tasks.clear(); for (uint64_t i = 0; i < NUM; ++i) { - auto task = std::make_shared(dummy1); + auto label = std::make_shared(); + auto task = std::make_shared(dummy1, label); task->label() = std::make_shared(); tasks.push_back(task); cpu_resource_.lock()->task_table().Put(task); @@ -242,7 +244,8 @@ TEST_F(SchedulerTest2, SPECIFIED_RESOURCE_TEST) { dummy->location_ = "location"; for (uint64_t i = 0; i < NUM; ++i) { - std::shared_ptr task = std::make_shared(dummy); + auto label = std::make_shared(); + std::shared_ptr task = std::make_shared(dummy, label); task->label() = std::make_shared(disk_); tasks.push_back(task); disk_.lock()->task_table().Put(task); diff --git a/cpp/unittest/scheduler/test_tasktable.cpp b/cpp/unittest/scheduler/test_tasktable.cpp index 3cfbb5a27f..271826614d 100644 --- a/cpp/unittest/scheduler/test_tasktable.cpp +++ b/cpp/unittest/scheduler/test_tasktable.cpp @@ -18,6 +18,7 @@ #include "scheduler/TaskTable.h" #include "scheduler/task/TestTask.h" +#include "scheduler/tasklabel/DefaultLabel.h" #include namespace { @@ -172,8 +173,9 @@ class TaskTableBaseTest : public ::testing::Test { SetUp() override { ms::TableFileSchemaPtr dummy = nullptr; invalid_task_ = nullptr; - task1_ = std::make_shared(dummy); - task2_ = std::make_shared(dummy); + auto label = std::make_shared(); + task1_ = std::make_shared(dummy, label); + task2_ = std::make_shared(dummy, label); } ms::TaskPtr invalid_task_; @@ -340,7 +342,8 @@ class TaskTableAdvanceTest : public ::testing::Test { SetUp() override { ms::TableFileSchemaPtr dummy = nullptr; for (uint64_t i = 0; i < 8; ++i) { - auto task = std::make_shared(dummy); + auto label = std::make_shared(); + auto task = std::make_shared(dummy, label); table1_.Put(task); }