diff --git a/cpp/CHANGELOG.md b/cpp/CHANGELOG.md index 67a17013f4..6dc98f20ac 100644 --- a/cpp/CHANGELOG.md +++ b/cpp/CHANGELOG.md @@ -19,6 +19,9 @@ Please mark all change in change log and use the ticket from JIRA. - MS-436 - Delete vectors failed if index created with index_type: IVF_FLAT/IVF_SQ8 - MS-450 - server hang after run stop_server.sh - MS-449 - Add vectors twice success, once with ids, the other no ids +- MS-461 - Mysql meta unittest failed +- MS-462 - Run milvus server twices, should display error +- MS-463 - Search timeout ## Improvement - MS-327 - Clean code for milvus @@ -74,6 +77,9 @@ Please mark all change in change log and use the ticket from JIRA. - MS-442 - Merge Knowhere - MS-445 - Rename CopyCompleted to LoadCompleted - MS-451 - Update server_config.template file, set GPU compute default +- MS-455 - Distribute tasks by minimal cost in scheduler +- MS-460 - Put transport speed as weight when choosing neighbour to execute task +- MS-459 - Add cache for pick function in tasktable ## New Feature - MS-343 - Implement ResourceMgr diff --git a/cpp/conf/server_config.template b/cpp/conf/server_config.template index 218dceed7a..8cee387765 100644 --- a/cpp/conf/server_config.template +++ b/cpp/conf/server_config.template @@ -33,7 +33,9 @@ cache_config: insert_cache_immediately: false # insert data will be load into cache immediately for hot query gpu_cache_capacity: 5 # how many memory are used as cache in gpu, unit: GB, RANGE: 0 ~ less than total memory gpu_cache_free_percent: 0.85 # old data will be erased from cache when cache is full, this value specify how much memory should be kept, range: greater than zero ~ 1.0 - gpu_ids: 0,1 # gpu id + gpu_ids: # gpu id + - 0 + - 1 engine_config: use_blas_threshold: 20 @@ -81,9 +83,13 @@ resource_config: # enable_executor: true # connection list, length: 0~N - # format: -${resource_name}===${resource_name} + # format: -${resource_name}===${resource_name} connections: - - ssda===cpu - - cpu===gpu0 + io: + speed: 500 + endpoint: ssda===cpu + pcie: + speed: 11000 + endpoint: cpu===gpu0 # - cpu===gtx1660 diff --git a/cpp/src/CMakeLists.txt b/cpp/src/CMakeLists.txt index 9f25cae8ef..a2ac3002b3 100644 --- a/cpp/src/CMakeLists.txt +++ b/cpp/src/CMakeLists.txt @@ -63,7 +63,7 @@ set(grpc_service_files grpc/gen-milvus/milvus.pb.cc grpc/gen-status/status.grpc.pb.cc grpc/gen-status/status.pb.cc - ) + scheduler/Utils.h) set(db_files ${CMAKE_CURRENT_SOURCE_DIR}/main.cpp diff --git a/cpp/src/cache/GpuCacheMgr.cpp b/cpp/src/cache/GpuCacheMgr.cpp index 4aa5626348..ef2f307c30 100644 --- a/cpp/src/cache/GpuCacheMgr.cpp +++ b/cpp/src/cache/GpuCacheMgr.cpp @@ -21,17 +21,14 @@ namespace { std::vector load() { server::ConfigNode& config = server::ServerConfig::GetInstance().GetConfig(server::CONFIG_CACHE); - std::string gpu_ids_str = config.GetValue(server::CONFIG_GPU_IDS, "0,1"); + auto conf_gpu_ids = config.GetSequence(server::CONFIG_GPU_IDS); std::vector gpu_ids; - std::stringstream ss(gpu_ids_str); - for (int i; ss >> i;) { - gpu_ids.push_back(i); - if (ss.peek() == ',') { - ss.ignore(); - } + for (auto gpu_id : conf_gpu_ids) { + gpu_ids.push_back(std::atoi(gpu_id.c_str())); } + return gpu_ids; } } diff --git a/cpp/src/core/src/knowhere/index/vector_index/cloner.cpp b/cpp/src/core/src/knowhere/index/vector_index/cloner.cpp index 7fd0df6664..4229bf0a88 100644 --- a/cpp/src/core/src/knowhere/index/vector_index/cloner.cpp +++ b/cpp/src/core/src/knowhere/index/vector_index/cloner.cpp @@ -29,7 +29,6 @@ VectorIndexPtr CopyCpuToGpu(const VectorIndexPtr &index, const int64_t &device_i if (auto cpu_index = std::dynamic_pointer_cast(index)) { return cpu_index->CopyCpuToGpu(device_id, config); - //KNOWHERE_THROW_MSG("IVFSQ not support tranfer to gpu"); } else if (auto cpu_index = std::dynamic_pointer_cast(index)) { KNOWHERE_THROW_MSG("IVFPQ not support tranfer to gpu"); } else if (auto cpu_index = std::dynamic_pointer_cast(index)) { diff --git a/cpp/src/core/src/knowhere/index/vector_index/gpu_ivf.cpp b/cpp/src/core/src/knowhere/index/vector_index/gpu_ivf.cpp index c06c58c40c..2771dd3533 100644 --- a/cpp/src/core/src/knowhere/index/vector_index/gpu_ivf.cpp +++ b/cpp/src/core/src/knowhere/index/vector_index/gpu_ivf.cpp @@ -31,10 +31,11 @@ IndexModelPtr GPUIVF::Train(const DatasetPtr &dataset, const Config &config) { GETTENSOR(dataset) - // TODO(linxj): use device_id auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_device); ResScope rs(gpu_device, res); - faiss::gpu::GpuIndexIVFFlat device_index(res.get(), dim, nlist, metric_type); + faiss::gpu::GpuIndexIVFFlatConfig idx_config; + idx_config.device = gpu_device; + faiss::gpu::GpuIndexIVFFlat device_index(res.get(), dim, nlist, metric_type, idx_config); device_index.train(rows, (float *) p_data); std::shared_ptr host_index = nullptr; diff --git a/cpp/src/db/DBImpl.cpp b/cpp/src/db/DBImpl.cpp index b744899d56..64dcc7275e 100644 --- a/cpp/src/db/DBImpl.cpp +++ b/cpp/src/db/DBImpl.cpp @@ -58,14 +58,14 @@ Status DBImpl::Start() { return Status::OK(); } + shutting_down_.store(false, std::memory_order_release); + //for distribute version, some nodes are read only if (options_.mode != Options::MODE::READ_ONLY) { ENGINE_LOG_TRACE << "StartTimerTasks"; bg_timer_thread_ = std::thread(&DBImpl::BackgroundTimerTask, this); } - shutting_down_.store(false, std::memory_order_release); - return Status::OK(); } @@ -163,7 +163,7 @@ Status DBImpl::PreloadTable(const std::string &table_id) { //step 1: load index engine->Load(true); } catch (std::exception &ex) { - std::string msg = "load to cache exception" + std::string(ex.what()); + std::string msg = "Pre-load table encounter exception: " + std::string(ex.what()); ENGINE_LOG_ERROR << msg; return Status::Error(msg); } @@ -198,8 +198,6 @@ Status DBImpl::InsertVectors(const std::string& table_id_, Status DBImpl::Query(const std::string &table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float *vectors, QueryResults &results) { - server::CollectQueryMetrics metrics(nq); - meta::DatesT dates = {utils::GetDate()}; Status result = Query(table_id, k, nq, nprobe, vectors, dates, results); @@ -208,7 +206,7 @@ Status DBImpl::Query(const std::string &table_id, uint64_t k, uint64_t nq, uint6 Status DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors, const meta::DatesT& dates, QueryResults& results) { - ENGINE_LOG_DEBUG << "Query by vectors " << table_id; + ENGINE_LOG_DEBUG << "Query by dates for table: " << table_id; //get all table files from table meta::DatePartionedTableFilesSchema files; @@ -232,7 +230,7 @@ Status DBImpl::Query(const std::string& table_id, uint64_t k, uint64_t nq, uint6 Status DBImpl::Query(const std::string& table_id, const std::vector& file_ids, uint64_t k, uint64_t nq, uint64_t nprobe, const float* vectors, const meta::DatesT& dates, QueryResults& results) { - ENGINE_LOG_DEBUG << "Query by file ids"; + ENGINE_LOG_DEBUG << "Query by file ids for table: " << table_id; //get specified files std::vector ids; @@ -274,7 +272,7 @@ Status DBImpl::QueryAsync(const std::string& table_id, const meta::TableFilesSch server::TimeRecorder rc(""); //step 1: get files to search - ENGINE_LOG_DEBUG << "Engine query begin, index file count:" << files.size() << " date range count:" << dates.size(); + ENGINE_LOG_DEBUG << "Engine query begin, index file count: " << files.size() << " date range count: " << dates.size(); SearchContextPtr context = std::make_shared(k, nq, nprobe, vectors); for (auto &file : files) { TableFileSchemaPtr file_ptr = std::make_shared(file); @@ -300,11 +298,11 @@ Status DBImpl::QueryAsync(const std::string& table_id, const meta::TableFilesSch double search_percent = search_cost/total_cost; double reduce_percent = reduce_cost/total_cost; - ENGINE_LOG_DEBUG << "Engine load index totally cost:" << load_info << " percent: " << load_percent*100 << "%"; - ENGINE_LOG_DEBUG << "Engine search index totally cost:" << search_info << " percent: " << search_percent*100 << "%"; - ENGINE_LOG_DEBUG << "Engine reduce topk totally cost:" << reduce_info << " percent: " << reduce_percent*100 << "%"; + ENGINE_LOG_DEBUG << "Engine load index totally cost: " << load_info << " percent: " << load_percent*100 << "%"; + ENGINE_LOG_DEBUG << "Engine search index totally cost: " << search_info << " percent: " << search_percent*100 << "%"; + ENGINE_LOG_DEBUG << "Engine reduce topk totally cost: " << reduce_info << " percent: " << reduce_percent*100 << "%"; } else { - ENGINE_LOG_DEBUG << "Engine load cost:" << load_info + ENGINE_LOG_DEBUG << "Engine load cost: " << load_info << " search cost: " << search_info << " reduce cost: " << reduce_info; } @@ -413,7 +411,7 @@ void DBImpl::StartCompactionTask() { Status DBImpl::MergeFiles(const std::string& table_id, const meta::DateT& date, const meta::TableFilesSchema& files) { - ENGINE_LOG_DEBUG << "Merge files for table " << table_id; + ENGINE_LOG_DEBUG << "Merge files for table: " << table_id; //step 1: create table file meta::TableFileSchema table_file; @@ -453,7 +451,7 @@ Status DBImpl::MergeFiles(const std::string& table_id, const meta::DateT& date, index->Serialize(); } catch (std::exception& ex) { //typical error: out of disk space or permition denied - std::string msg = "Serialize merged index encounter exception" + std::string(ex.what()); + std::string msg = "Serialize merged index encounter exception: " + std::string(ex.what()); ENGINE_LOG_ERROR << msg; table_file.file_type_ = meta::TableFileSchema::TO_DELETE; @@ -508,7 +506,7 @@ Status DBImpl::BackgroundMergeFiles(const std::string& table_id) { MergeFiles(table_id, kv.first, kv.second); if (shutting_down_.load(std::memory_order_acquire)){ - ENGINE_LOG_DEBUG << "Server will shutdown, skip merge action for table " << table_id; + ENGINE_LOG_DEBUG << "Server will shutdown, skip merge action for table: " << table_id; break; } } @@ -574,7 +572,7 @@ Status DBImpl::CreateIndex(const std::string& table_id, const TableIndex& index) TableIndex old_index; auto status = DescribeIndex(table_id, old_index); if(!status.ok()) { - ENGINE_LOG_ERROR << "Failed to get table index info"; + ENGINE_LOG_ERROR << "Failed to get table index info for table: " << table_id; return status; } @@ -584,7 +582,7 @@ Status DBImpl::CreateIndex(const std::string& table_id, const TableIndex& index) status = meta_ptr_->UpdateTableIndexParam(table_id, index); if (!status.ok()) { - ENGINE_LOG_ERROR << "Failed to update table index info"; + ENGINE_LOG_ERROR << "Failed to update table index info for table: " << table_id; return status; } } @@ -632,7 +630,7 @@ Status DBImpl::DescribeIndex(const std::string& table_id, TableIndex& index) { } Status DBImpl::DropIndex(const std::string& table_id) { - ENGINE_LOG_DEBUG << "drop index for table: " << table_id; + ENGINE_LOG_DEBUG << "Drop index for table: " << table_id; return meta_ptr_->DropTableIndex(table_id); } @@ -647,16 +645,20 @@ Status DBImpl::BuildIndex(const meta::TableFileSchema& file) { try { //step 1: load index - to_index->Load(options_.insert_cache_immediately_); + Status status = to_index->Load(options_.insert_cache_immediately_); + if (!status.ok()) { + ENGINE_LOG_ERROR << "Failed to load index file: " << status.ToString(); + return status; + } //step 2: create table file meta::TableFileSchema table_file; table_file.table_id_ = file.table_id_; table_file.date_ = file.date_; table_file.file_type_ = meta::TableFileSchema::NEW_INDEX; //for multi-db-path, distribute index file averagely to each path - Status status = meta_ptr_->CreateTableFile(table_file); + status = meta_ptr_->CreateTableFile(table_file); if (!status.ok()) { - ENGINE_LOG_ERROR << "Failed to create table: " << status.ToString(); + ENGINE_LOG_ERROR << "Failed to create table file: " << status.ToString(); return status; } @@ -666,9 +668,17 @@ Status DBImpl::BuildIndex(const meta::TableFileSchema& file) { try { server::CollectBuildIndexMetrics metrics; index = to_index->BuildIndex(table_file.location_, (EngineType)table_file.engine_type_); + if (index == nullptr) { + table_file.file_type_ = meta::TableFileSchema::TO_DELETE; + status = meta_ptr_->UpdateTableFile(table_file); + ENGINE_LOG_DEBUG << "Failed to update file to index, mark file: " << table_file.file_id_ << " to to_delete"; + + return status; + } + } catch (std::exception& ex) { //typical error: out of gpu memory - std::string msg = "BuildIndex encounter exception" + std::string(ex.what()); + std::string msg = "BuildIndex encounter exception: " + std::string(ex.what()); ENGINE_LOG_ERROR << msg; table_file.file_type_ = meta::TableFileSchema::TO_DELETE; @@ -693,7 +703,7 @@ Status DBImpl::BuildIndex(const meta::TableFileSchema& file) { index->Serialize(); } catch (std::exception& ex) { //typical error: out of disk space or permition denied - std::string msg = "Serialize index encounter exception" + std::string(ex.what()); + std::string msg = "Serialize index encounter exception: " + std::string(ex.what()); ENGINE_LOG_ERROR << msg; table_file.file_type_ = meta::TableFileSchema::TO_DELETE; @@ -736,7 +746,7 @@ Status DBImpl::BuildIndex(const meta::TableFileSchema& file) { } } catch (std::exception& ex) { - std::string msg = "Build index encounter exception" + std::string(ex.what()); + std::string msg = "Build index encounter exception: " + std::string(ex.what()); ENGINE_LOG_ERROR << msg; return Status::Error(msg); } @@ -745,7 +755,7 @@ Status DBImpl::BuildIndex(const meta::TableFileSchema& file) { } void DBImpl::BackgroundBuildIndex() { - ENGINE_LOG_TRACE << " Background build index thread start"; + ENGINE_LOG_TRACE << "Background build index thread start"; std::unique_lock lock(build_index_mutex_); meta::TableFilesSchema to_index_files; @@ -764,7 +774,7 @@ void DBImpl::BackgroundBuildIndex() { } } - ENGINE_LOG_TRACE << " Background build index thread exit"; + ENGINE_LOG_TRACE << "Background build index thread exit"; } Status DBImpl::DropAll() { diff --git a/cpp/src/db/engine/ExecutionEngineImpl.cpp b/cpp/src/db/engine/ExecutionEngineImpl.cpp index c2b0b35220..098823c482 100644 --- a/cpp/src/db/engine/ExecutionEngineImpl.cpp +++ b/cpp/src/db/engine/ExecutionEngineImpl.cpp @@ -132,7 +132,9 @@ Status ExecutionEngineImpl::Load(bool to_cache) { server::CollectExecutionEngineMetrics metrics(physical_size); index_ = read_index(location_); if(index_ == nullptr) { - ENGINE_LOG_ERROR << "Failed to load index from " << location_; + std::string msg = "Failed to load index from " + location_; + ENGINE_LOG_ERROR << msg; + return Status::Error(msg); } else { ENGINE_LOG_DEBUG << "Disk io from: " << location_; } diff --git a/cpp/src/db/meta/MySQLMetaImpl.cpp b/cpp/src/db/meta/MySQLMetaImpl.cpp index beae3fa1e9..dc1f931c03 100644 --- a/cpp/src/db/meta/MySQLMetaImpl.cpp +++ b/cpp/src/db/meta/MySQLMetaImpl.cpp @@ -144,7 +144,7 @@ Status MySQLMetaImpl::Initialize() { "dimension SMALLINT NOT NULL, " << "created_on BIGINT NOT NULL, " << "flag BIGINT DEFAULT 0 NOT NULL, " << - "index_file_size INT DEFAULT 1024 NOT NULL, " << + "index_file_size BIGINT DEFAULT 1024 NOT NULL, " << "engine_type INT DEFAULT 1 NOT NULL, " << "nlist INT DEFAULT 16384 NOT NULL, " << "metric_type INT DEFAULT 1 NOT NULL);"; @@ -291,11 +291,16 @@ Status MySQLMetaImpl::CreateTable(TableSchema &table_schema) { std::string state = std::to_string(table_schema.state_); std::string dimension = std::to_string(table_schema.dimension_); std::string created_on = std::to_string(table_schema.created_on_); + std::string flag = std::to_string(table_schema.flag_); + std::string index_file_size = std::to_string(table_schema.index_file_size_); std::string engine_type = std::to_string(table_schema.engine_type_); + std::string nlist = std::to_string(table_schema.nlist_); + std::string metric_type = std::to_string(table_schema.metric_type_); createTableQuery << "INSERT INTO Tables VALUES" << "(" << id << ", " << quote << table_id << ", " << state << ", " << dimension << ", " << - created_on << ", " << engine_type << ");"; + created_on << ", " << flag << ", " << index_file_size << ", " << engine_type << ", " << + nlist << ", " << metric_type << ");"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::CreateTable: " << createTableQuery.str(); @@ -904,6 +909,7 @@ Status MySQLMetaImpl::CreateTableFile(TableFileSchema &file_schema) { std::string engine_type = std::to_string(file_schema.engine_type_); std::string file_id = file_schema.file_id_; std::string file_type = std::to_string(file_schema.file_type_); + std::string file_size = std::to_string(file_schema.file_size_); std::string row_count = std::to_string(file_schema.row_count_); std::string updated_time = std::to_string(file_schema.updated_time_); std::string created_on = std::to_string(file_schema.created_on_); @@ -920,8 +926,8 @@ Status MySQLMetaImpl::CreateTableFile(TableFileSchema &file_schema) { createTableFileQuery << "INSERT INTO TableFiles VALUES" << "(" << id << ", " << quote << table_id << ", " << engine_type << ", " << - quote << file_id << ", " << file_type << ", " << row_count << ", " << - updated_time << ", " << created_on << ", " << date << ");"; + quote << file_id << ", " << file_type << ", " << file_size << ", " << + row_count << ", " << updated_time << ", " << created_on << ", " << date << ");"; ENGINE_LOG_DEBUG << "MySQLMetaImpl::CreateTableFile: " << createTableFileQuery.str(); @@ -1170,7 +1176,7 @@ Status MySQLMetaImpl::FilesToMerge(const std::string &table_id, } Query filesToMergeQuery = connectionPtr->query(); - filesToMergeQuery << "SELECT id, table_id, file_id, file_type, file_size, row_count, date, engine_type, create_on " << + filesToMergeQuery << "SELECT id, table_id, file_id, file_type, file_size, row_count, date, engine_type, created_on " << "FROM TableFiles " << "WHERE table_id = " << quote << table_id << " AND " << "file_type = " << std::to_string(TableFileSchema::RAW) << " " << diff --git a/cpp/src/metrics/PrometheusMetrics.cpp b/cpp/src/metrics/PrometheusMetrics.cpp index 08f1fe7dc2..5f98acfc84 100644 --- a/cpp/src/metrics/PrometheusMetrics.cpp +++ b/cpp/src/metrics/PrometheusMetrics.cpp @@ -170,16 +170,12 @@ void PrometheusMetrics::CPUTemperature() { void PrometheusMetrics::GpuCacheUsageGaugeSet() { if(!startup_) return; server::ConfigNode& config = server::ServerConfig::GetInstance().GetConfig(server::CONFIG_CACHE); - std::string gpu_ids_str = config.GetValue(server::CONFIG_GPU_IDS, "0,1"); + auto conf_gpu_ids = config.GetSequence(server::CONFIG_GPU_IDS); std::vector gpu_ids; - std::stringstream ss(gpu_ids_str); - for (int i; ss >> i;) { - gpu_ids.push_back(i); - if (ss.peek() == ',') { - ss.ignore(); - } + for (auto gpu_id : conf_gpu_ids) { + gpu_ids.push_back(std::atoi(gpu_id.c_str())); } for(auto i = 0; i < gpu_ids.size(); ++i) { diff --git a/cpp/src/scheduler/Algorithm.cpp b/cpp/src/scheduler/Algorithm.cpp new file mode 100644 index 0000000000..b861151ddf --- /dev/null +++ b/cpp/src/scheduler/Algorithm.cpp @@ -0,0 +1,98 @@ +/******************************************************************************* + * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved + * Unauthorized copying of this file, via any medium is strictly prohibited. + * Proprietary and confidential. + ******************************************************************************/ + +#include "Algorithm.h" + +namespace zilliz { +namespace milvus { +namespace engine { + +constexpr uint64_t MAXINT = 99999; + +uint64_t +ShortestPath(const ResourcePtr &src, + const ResourcePtr &dest, + const ResourceMgrPtr &res_mgr, + std::vector &path) { + + std::vector> paths; + + uint64_t num_of_resources = res_mgr->GetAllResouces().size(); + std::unordered_map id_name_map; + std::unordered_map name_id_map; + for (uint64_t i = 0; i < num_of_resources; ++i) { + id_name_map.insert(std::make_pair(i, res_mgr->GetAllResouces().at(i)->Name())); + name_id_map.insert(std::make_pair(res_mgr->GetAllResouces().at(i)->Name(), i)); + } + + std::vector > dis_matrix; + dis_matrix.resize(num_of_resources); + for (uint64_t i = 0; i < num_of_resources; ++i) { + dis_matrix[i].resize(num_of_resources); + for (uint64_t j = 0; j < num_of_resources; ++j) { + dis_matrix[i][j] = MAXINT; + } + dis_matrix[i][i] = 0; + } + + std::vector vis(num_of_resources, false); + std::vector dis(num_of_resources, MAXINT); + for (auto &res : res_mgr->GetAllResouces()) { + + auto cur_node = std::static_pointer_cast(res); + auto cur_neighbours = cur_node->GetNeighbours(); + + for (auto &neighbour : cur_neighbours) { + auto neighbour_res = std::static_pointer_cast(neighbour.neighbour_node.lock()); + dis_matrix[name_id_map.at(res->Name())][name_id_map.at(neighbour_res->Name())] = + neighbour.connection.transport_cost(); + } + } + + for (uint64_t i = 0; i < num_of_resources; ++i) { + dis[i] = dis_matrix[name_id_map.at(src->Name())][i]; + } + + vis[name_id_map.at(src->Name())] = true; + std::vector parent(num_of_resources, -1); + + for (uint64_t i = 0; i < num_of_resources; ++i) { + uint64_t minn = MAXINT; + uint64_t temp = 0; + for (uint64_t j = 0; j < num_of_resources; ++j) { + if (!vis[j] && dis[j] < minn) { + minn = dis[j]; + temp = j; + } + } + vis[temp] = true; + + if (i == 0) { + parent[temp] = name_id_map.at(src->Name()); + } + + for (uint64_t j = 0; j < num_of_resources; ++j) { + if (!vis[j] && dis_matrix[temp][j] != MAXINT && dis_matrix[temp][j] + dis[temp] < dis[j]) { + dis[j] = dis_matrix[temp][j] + dis[temp]; + parent[j] = temp; + } + } + } + + int64_t parent_idx = parent[name_id_map.at(dest->Name())]; + if (parent_idx != -1) { + path.push_back(dest->Name()); + } + while (parent_idx != -1) { + path.push_back(id_name_map.at(parent_idx)); + parent_idx = parent[parent_idx]; + } + return dis[name_id_map.at(dest->Name())]; +} + +} +} +} \ No newline at end of file diff --git a/cpp/src/scheduler/Algorithm.h b/cpp/src/scheduler/Algorithm.h new file mode 100644 index 0000000000..05d9ad71d8 --- /dev/null +++ b/cpp/src/scheduler/Algorithm.h @@ -0,0 +1,25 @@ +/******************************************************************************* + * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved + * Unauthorized copying of this file, via any medium is strictly prohibited. + * Proprietary and confidential. + ******************************************************************************/ + +#include "resource/Resource.h" +#include "ResourceMgr.h" + +#include +#include + +namespace zilliz { +namespace milvus { +namespace engine { + +uint64_t +ShortestPath(const ResourcePtr &src, + const ResourcePtr &dest, + const ResourceMgrPtr &res_mgr, + std::vector& path); + +} +} +} \ No newline at end of file diff --git a/cpp/src/scheduler/Cost.cpp b/cpp/src/scheduler/Cost.cpp deleted file mode 100644 index 724a717d2f..0000000000 --- a/cpp/src/scheduler/Cost.cpp +++ /dev/null @@ -1,54 +0,0 @@ -/******************************************************************************* - * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved - * Unauthorized copying of this file, via any medium is strictly prohibited. - * Proprietary and confidential. - ******************************************************************************/ - -#include "Cost.h" - - -namespace zilliz { -namespace milvus { -namespace engine { - -std::vector -PickToMove(TaskTable &task_table, const CacheMgr &cache_mgr, uint64_t limit) { - std::vector indexes; - for (uint64_t i = 0, count = 0; i < task_table.Size() && count < limit; ++i) { - if (task_table[i]->state == TaskTableItemState::LOADED) { - indexes.push_back(i); - ++count; - } - } - return indexes; -} - - -std::vector -PickToLoad(TaskTable &task_table, uint64_t limit) { - std::vector indexes; - for (uint64_t i = 0, count = 0; i < task_table.Size() && count < limit; ++i) { - if (task_table[i]->state == TaskTableItemState::START) { - indexes.push_back(i); - ++count; - } - } - return indexes; -} - - -std::vector -PickToExecute(TaskTable &task_table, uint64_t limit) { - std::vector indexes; - for (uint64_t i = 0, count = 0; i < task_table.Size() && count < limit; ++i) { - if (task_table[i]->state == TaskTableItemState::LOADED) { - indexes.push_back(i); - ++count; - } - } - return indexes; -} - -} -} -} diff --git a/cpp/src/scheduler/Cost.h b/cpp/src/scheduler/Cost.h deleted file mode 100644 index 76f16d4d1d..0000000000 --- a/cpp/src/scheduler/Cost.h +++ /dev/null @@ -1,48 +0,0 @@ -/******************************************************************************* - * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved - * Unauthorized copying of this file, via any medium is strictly prohibited. - * Proprietary and confidential. - ******************************************************************************/ -#pragma once - -#include -#include "task/Task.h" -#include "TaskTable.h" -#include "CacheMgr.h" - - -namespace zilliz { -namespace milvus { -namespace engine { - -// TODO: Policy interface -// TODO: collect statistics - -/* - * select tasks to move; - * call from scheduler; - */ -std::vector -PickToMove(TaskTable &task_table, const CacheMgr &cache_mgr, uint64_t limit); - - -/* - * select task to load - * call from resource; - * I DONT SURE NEED THIS; - */ -std::vector -PickToLoad(TaskTable &task_table, uint64_t limit); - -/* - * select task to execute; - * call from resource; - * I DONT SURE NEED THIS; - */ -std::vector -PickToExecute(TaskTable &task_table, uint64_t limit); - - -} -} -} diff --git a/cpp/src/scheduler/ResourceMgr.cpp b/cpp/src/scheduler/ResourceMgr.cpp index 649f840827..65373164e3 100644 --- a/cpp/src/scheduler/ResourceMgr.cpp +++ b/cpp/src/scheduler/ResourceMgr.cpp @@ -28,6 +28,17 @@ ResourceMgr::GetNumOfComputeResource() { return count; } +std::vector +ResourceMgr::GetComputeResource() { + std::vector result; + for (auto &resource : resources_) { + if (resource->HasExecutor()) { + result.emplace_back(resource); + } + } + return result; +} + uint64_t ResourceMgr::GetNumGpuResource() const { uint64_t num = 0; @@ -49,6 +60,21 @@ ResourceMgr::GetResource(ResourceType type, uint64_t device_id) { return nullptr; } +ResourcePtr +ResourceMgr::GetResourceByName(std::string name) { + for (auto &resource : resources_) { + if (resource->Name() == name) { + return resource; + } + } + return nullptr; +} + +std::vector +ResourceMgr::GetAllResouces() { + return resources_; +} + ResourceWPtr ResourceMgr::Add(ResourcePtr &&resource) { ResourceWPtr ret(resource); diff --git a/cpp/src/scheduler/ResourceMgr.h b/cpp/src/scheduler/ResourceMgr.h index 5083aa1b53..da8f34f87e 100644 --- a/cpp/src/scheduler/ResourceMgr.h +++ b/cpp/src/scheduler/ResourceMgr.h @@ -41,12 +41,21 @@ public: ResourcePtr GetResource(ResourceType type, uint64_t device_id); + ResourcePtr + GetResourceByName(std::string name); + + std::vector + GetAllResouces(); + /* * Return account of resource which enable executor; */ uint64_t GetNumOfComputeResource(); + std::vector + GetComputeResource(); + /* * Add resource into Resource Management; * Generate functions on events; diff --git a/cpp/src/scheduler/SchedInst.cpp b/cpp/src/scheduler/SchedInst.cpp index 3ee8cbfdb6..43204f0946 100644 --- a/cpp/src/scheduler/SchedInst.cpp +++ b/cpp/src/scheduler/SchedInst.cpp @@ -43,14 +43,21 @@ StartSchedulerService() { knowhere::FaissGpuResourceMgr::GetInstance().InitResource(); - auto default_connection = Connection("default_connection", 500.0); - auto connections = config.GetSequence(server::CONFIG_RESOURCE_CONNECTIONS); +// auto default_connection = Connection("default_connection", 500.0); + auto connections = config.GetChild(server::CONFIG_RESOURCE_CONNECTIONS).GetChildren(); for (auto &conn : connections) { - std::string delimiter = "==="; - std::string left = conn.substr(0, conn.find(delimiter)); - std::string right = conn.substr(conn.find(delimiter) + 3, conn.length()); + auto &connect_name = conn.first; + auto &connect_conf = conn.second; + auto connect_speed = connect_conf.GetInt64Value(server::CONFIG_SPEED_CONNECTIONS); + auto connect_endpoint = connect_conf.GetValue(server::CONFIG_ENDPOINT_CONNECTIONS); - ResMgrInst::GetInstance()->Connect(left, right, default_connection); + std::string delimiter = "==="; + std::string left = connect_endpoint.substr(0, connect_endpoint.find(delimiter)); + std::string right = connect_endpoint.substr(connect_endpoint.find(delimiter) + 3, + connect_endpoint.length()); + + auto connection = Connection(connect_name, connect_speed); + ResMgrInst::GetInstance()->Connect(left, right, connection); } ResMgrInst::GetInstance()->Start(); diff --git a/cpp/src/scheduler/Scheduler.cpp b/cpp/src/scheduler/Scheduler.cpp index 03fa479df6..fa67eef489 100644 --- a/cpp/src/scheduler/Scheduler.cpp +++ b/cpp/src/scheduler/Scheduler.cpp @@ -5,9 +5,10 @@ ******************************************************************************/ #include +#include "event/LoadCompletedEvent.h" #include "Scheduler.h" -#include "Cost.h" #include "action/Action.h" +#include "Algorithm.h" namespace zilliz { @@ -137,6 +138,54 @@ Scheduler::OnLoadCompleted(const EventPtr &event) { } break; } + case TaskLabelType::SPECIFIED_RESOURCE: { + auto self = event->resource_.lock(); + auto task = load_completed_event->task_table_item_->task; + + // if this resource is disk, assign it to smallest cost resource + if (self->Type() == ResourceType::DISK) { + // step 1: calculate shortest path per resource, from disk to compute resource + auto compute_resources = res_mgr_.lock()->GetComputeResource(); + std::vector> paths; + std::vector transport_costs; + for (auto &res : compute_resources) { + std::vector path; + uint64_t transport_cost = ShortestPath(self, res, res_mgr_.lock(), path); + transport_costs.push_back(transport_cost); + paths.emplace_back(path); + } + + // step 2: select min cost, cost(resource) = avg_cost * task_to_do + transport_cost + uint64_t min_cost = std::numeric_limits::max(); + uint64_t min_cost_idx = 0; + for (uint64_t i = 0; i < compute_resources.size(); ++i) { + if (compute_resources[i]->TotalTasks() == 0) { + min_cost_idx = i; + break; + } + uint64_t cost = compute_resources[i]->TaskAvgCost() * compute_resources[i]->NumOfTaskToExec() + + transport_costs[i]; + if (min_cost > cost) { + min_cost = cost; + min_cost_idx = i; + } + } + + // step 3: set path in task + Path task_path(paths[min_cost_idx], paths[min_cost_idx].size() - 1); + task->path() = task_path; + } + + if(self->Name() == task->path().Last()) { + self->WakeupLoader(); + } else { + auto next_res_name = task->path().Next(); + auto next_res = res_mgr_.lock()->GetResourceByName(next_res_name); + load_completed_event->task_table_item_->Move(); + next_res->task_table().Put(task); + } + break; + } case TaskLabelType::BROADCAST: { Action::PushTaskToAllNeighbour(load_completed_event->task_table_item_->task, resource); break; diff --git a/cpp/src/scheduler/TaskTable.cpp b/cpp/src/scheduler/TaskTable.cpp index 2d309d591c..56d31e299a 100644 --- a/cpp/src/scheduler/TaskTable.cpp +++ b/cpp/src/scheduler/TaskTable.cpp @@ -53,6 +53,11 @@ ToString(const TaskTimestamp ×tamp) { return ss.str(); } +bool +TaskTableItem::IsFinish() { + return state == TaskTableItemState::MOVED || state == TaskTableItemState::EXECUTED; +} + bool TaskTableItem::Load() { std::unique_lock lock(mutex); @@ -133,6 +138,38 @@ TaskTableItem::Dump() { return ss.str(); } +std::vector +TaskTable::PickToLoad(uint64_t limit) { + std::vector indexes; + bool cross = false; + for (uint64_t i = last_finish_, count = 0; i < table_.size() && count < limit; ++i) { + if (not cross && table_[i]->IsFinish()) { + last_finish_ = i; + } else if (table_[i]->state == TaskTableItemState::START) { + cross = true; + indexes.push_back(i); + ++count; + } + } + return indexes; +} + +std::vector +TaskTable::PickToExecute(uint64_t limit) { + std::vector indexes; + bool cross = false; + for (uint64_t i = last_finish_, count = 0; i < table_.size() && count < limit; ++i) { + if (not cross && table_[i]->IsFinish()) { + last_finish_ = i; + } else if (table_[i]->state == TaskTableItemState::LOADED) { + cross = true; + indexes.push_back(i); + ++count; + } + } + return indexes; +} + void TaskTable::Put(TaskPtr task) { std::lock_guard lock(id_mutex_); diff --git a/cpp/src/scheduler/TaskTable.h b/cpp/src/scheduler/TaskTable.h index 886259957c..ee6d3b56cb 100644 --- a/cpp/src/scheduler/TaskTable.h +++ b/cpp/src/scheduler/TaskTable.h @@ -54,6 +54,9 @@ struct TaskTableItem { uint8_t priority; // just a number, meaningless; + bool + IsFinish(); + bool Load(); @@ -141,6 +144,13 @@ public: std::deque::iterator begin() { return table_.begin(); } std::deque::iterator end() { return table_.end(); } +public: + std::vector + PickToLoad(uint64_t limit); + + std::vector + PickToExecute(uint64_t limit); + public: /******** Action ********/ @@ -182,7 +192,7 @@ public: * Called by executor; */ inline bool - Executed(uint64_t index){ + Executed(uint64_t index) { return table_[index]->Executed(); } @@ -193,7 +203,7 @@ public: */ inline bool - Move(uint64_t index){ + Move(uint64_t index) { return table_[index]->Move(); } @@ -203,7 +213,7 @@ public: * Called by scheduler; */ inline bool - Moved(uint64_t index){ + Moved(uint64_t index) { return table_[index]->Moved(); } @@ -220,6 +230,9 @@ private: mutable std::mutex id_mutex_; std::deque table_; std::function subscriber_ = nullptr; + + // cache last finish avoid Pick task from begin always + uint64_t last_finish_ = 0; }; diff --git a/cpp/src/scheduler/Utils.cpp b/cpp/src/scheduler/Utils.cpp new file mode 100644 index 0000000000..074c035e8e --- /dev/null +++ b/cpp/src/scheduler/Utils.cpp @@ -0,0 +1,25 @@ +/******************************************************************************* + * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved + * Unauthorized copying of this file, via any medium is strictly prohibited. + * Proprietary and confidential. + ******************************************************************************/ + +#include +#include "Utils.h" + +namespace zilliz { +namespace milvus { +namespace engine { + +uint64_t +get_current_timestamp() +{ + std::chrono::time_point now = std::chrono::system_clock::now(); + auto duration = now.time_since_epoch(); + auto millis = std::chrono::duration_cast(duration).count(); + return millis; +} + +} +} +} \ No newline at end of file diff --git a/cpp/src/scheduler/Utils.h b/cpp/src/scheduler/Utils.h new file mode 100644 index 0000000000..7a5bf1874d --- /dev/null +++ b/cpp/src/scheduler/Utils.h @@ -0,0 +1,18 @@ +/******************************************************************************* + * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved + * Unauthorized copying of this file, via any medium is strictly prohibited. + * Proprietary and confidential. + ******************************************************************************/ +#include + + +namespace zilliz { +namespace milvus { +namespace engine { + +uint64_t +get_current_timestamp(); + +} +} +} \ No newline at end of file diff --git a/cpp/src/scheduler/action/PushTaskToNeighbour.cpp b/cpp/src/scheduler/action/PushTaskToNeighbour.cpp index 1939cbc127..200f6214fe 100644 --- a/cpp/src/scheduler/action/PushTaskToNeighbour.cpp +++ b/cpp/src/scheduler/action/PushTaskToNeighbour.cpp @@ -28,17 +28,48 @@ get_neighbours(const ResourcePtr &self) { return neighbours; } +std::vector> +get_neighbours_with_connetion(const ResourcePtr &self) { + std::vector> neighbours; + for (auto &neighbour_node : self->GetNeighbours()) { + auto node = neighbour_node.neighbour_node.lock(); + if (not node) continue; + + auto resource = std::static_pointer_cast(node); +// if (not resource->HasExecutor()) continue; + Connection conn = neighbour_node.connection; + neighbours.emplace_back(std::make_pair(resource, conn)); + } + return neighbours; +} + void Action::PushTaskToNeighbourRandomly(const TaskPtr &task, const ResourcePtr &self) { - auto neighbours = get_neighbours(self); + auto neighbours = get_neighbours_with_connetion(self); if (not neighbours.empty()) { + std::vector speeds; + uint64_t total_speed = 0; + for (auto &neighbour : neighbours) { + uint64_t speed = neighbour.second.speed(); + speeds.emplace_back(speed); + total_speed += speed; + } + std::random_device rd; std::mt19937 mt(rd()); - std::uniform_int_distribution dist(0, neighbours.size() - 1); + std::uniform_int_distribution dist(0, total_speed); + uint64_t index = 0; + int64_t rd_speed = dist(mt); + for (uint64_t i = 0; i < speeds.size(); ++i) { + rd_speed -= speeds[i]; + if (rd_speed <= 0) { + neighbours[i].first->task_table().Put(task); + return; + } + } - neighbours[dist(mt)]->task_table().Put(task); } else { //TODO: process } diff --git a/cpp/src/scheduler/resource/Connection.h b/cpp/src/scheduler/resource/Connection.h index 0f1088e7fe..83c9cc529c 100644 --- a/cpp/src/scheduler/resource/Connection.h +++ b/cpp/src/scheduler/resource/Connection.h @@ -19,15 +19,20 @@ public: : name_(std::move(name)), speed_(speed) {} const std::string & - get_name() const { + name() const { return name_; } - const double - get_speed() const { + uint64_t + speed() const { return speed_; } + uint64_t + transport_cost() { + return 1024 / speed_; + } + public: std::string Dump() const { @@ -38,7 +43,7 @@ public: private: std::string name_; - double speed_; + uint64_t speed_; }; diff --git a/cpp/src/scheduler/resource/Resource.cpp b/cpp/src/scheduler/resource/Resource.cpp index 75c3b4d784..b4a6cb5b66 100644 --- a/cpp/src/scheduler/resource/Resource.cpp +++ b/cpp/src/scheduler/resource/Resource.cpp @@ -4,6 +4,7 @@ * Proprietary and confidential. ******************************************************************************/ #include +#include "../Utils.h" #include "Resource.h" @@ -80,7 +81,7 @@ void Resource::WakeupExecutor() { } TaskTableItemPtr Resource::pick_task_load() { - auto indexes = PickToLoad(task_table_, 10); + auto indexes = task_table_.PickToLoad(10); for (auto index : indexes) { // try to set one task loading, then return if (task_table_.Load(index)) @@ -91,7 +92,7 @@ TaskTableItemPtr Resource::pick_task_load() { } TaskTableItemPtr Resource::pick_task_execute() { - auto indexes = PickToExecute(task_table_, 3); + auto indexes = task_table_.PickToExecute(3); for (auto index : indexes) { // try to set one task executing, then return if (task_table_.Execute(index)) @@ -138,7 +139,13 @@ void Resource::executor_function() { if (task_item == nullptr) { break; } + + auto start = get_current_timestamp(); Process(task_item->task); + auto finish = get_current_timestamp(); + ++total_task_; + total_cost_ += finish - start; + task_item->Executed(); if (subscriber_) { auto event = std::make_shared(shared_from_this(), task_item); diff --git a/cpp/src/scheduler/resource/Resource.h b/cpp/src/scheduler/resource/Resource.h index dcd5fb5d8f..9169a67cf9 100644 --- a/cpp/src/scheduler/resource/Resource.h +++ b/cpp/src/scheduler/resource/Resource.h @@ -19,7 +19,6 @@ #include "../event/TaskTableUpdatedEvent.h" #include "../TaskTable.h" #include "../task/Task.h" -#include "../Cost.h" #include "Connection.h" #include "Node.h" #include "RegisterHandler.h" @@ -44,7 +43,7 @@ enum class RegisterType { }; class Resource : public Node, public std::enable_shared_from_this { -public: + public: /* * Start loader and executor if enable; */ @@ -69,7 +68,7 @@ public: void WakeupExecutor(); -public: + public: template void Register_T(const RegisterType &type) { register_table_.emplace(type, [] { return std::make_shared(); }); @@ -110,6 +109,27 @@ public: return enable_executor_; } + // TODO: const + uint64_t + NumOfTaskToExec() { + uint64_t count = 0; + for (auto &task : task_table_) { + if (task->state == TaskTableItemState::LOADED) ++count; + } + return count; + } + + // TODO: need double ? + inline uint64_t + TaskAvgCost() const { + return total_cost_ / total_task_; + } + + inline uint64_t + TotalTasks() const { + return total_task_; + } + TaskTable & task_table(); @@ -120,7 +140,7 @@ public: friend std::ostream &operator<<(std::ostream &out, const Resource &resource); -protected: + protected: Resource(std::string name, ResourceType type, uint64_t device_id, @@ -142,7 +162,7 @@ protected: virtual void Process(TaskPtr task) = 0; -private: + private: /* * These function should move to cost.h ??? * COST.H ??? @@ -162,7 +182,7 @@ private: TaskTableItemPtr pick_task_execute(); -private: + private: /* * Only called by load thread; */ @@ -175,14 +195,17 @@ private: void executor_function(); -protected: + protected: uint64_t device_id_; std::string name_; -private: + private: ResourceType type_; TaskTable task_table_; + uint64_t total_cost_ = 0; + uint64_t total_task_ = 0; + std::map> register_table_; std::function subscriber_ = nullptr; diff --git a/cpp/src/scheduler/task/Path.h b/cpp/src/scheduler/task/Path.h new file mode 100644 index 0000000000..388a7b9c82 --- /dev/null +++ b/cpp/src/scheduler/task/Path.h @@ -0,0 +1,68 @@ +/******************************************************************************* + * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved + * Unauthorized copying of this file, via any medium is strictly prohibited. + * Proprietary and confidential. + ******************************************************************************/ +#pragma once + +#include +#include + + +namespace zilliz { +namespace milvus { +namespace engine { + +class Path { + public: + Path() = default; + + Path(std::vector& path, uint64_t index) : path_(path), index_(index) {} + + void + push_back(const std::string &str) { + path_.push_back(str); + } + + std::vector + Dump() { + return path_; + } + + std::string + Next() { + if (index_ > 0 && !path_.empty()) { + --index_; + return path_[index_]; + } else { + return nullptr; + } + + } + + std::string + Last() { + if (!path_.empty()) { + return path_[0]; + } else { + return nullptr; + } + } + + public: + std::string & + operator[](uint64_t index) { + return path_[index]; + } + + std::vector::iterator begin() { return path_.begin(); } + std::vector::iterator end() { return path_.end(); } + + public: + std::vector path_; + uint64_t index_ = 0; +}; + +} +} +} \ No newline at end of file diff --git a/cpp/src/scheduler/task/Task.h b/cpp/src/scheduler/task/Task.h index 31a1a88404..7431679e13 100644 --- a/cpp/src/scheduler/task/Task.h +++ b/cpp/src/scheduler/task/Task.h @@ -8,6 +8,7 @@ #include "db/scheduler/context/SearchContext.h" #include "db/scheduler/task/IScheduleTask.h" #include "scheduler/tasklabel/TaskLabel.h" +#include "Path.h" #include #include @@ -44,6 +45,14 @@ public: inline TaskType Type() const { return type_; } + /* + * Transport path; + */ + inline Path& + path() { + return task_path_; + } + /* * Getter and Setter; */ @@ -64,6 +73,7 @@ public: Clone() = 0; public: + Path task_path_; std::vector search_contexts_; ScheduleTaskPtr task_; TaskType type_; diff --git a/cpp/src/scheduler/tasklabel/SpecResLabel.h b/cpp/src/scheduler/tasklabel/SpecResLabel.h index 9f69f5752f..51468bf28b 100644 --- a/cpp/src/scheduler/tasklabel/SpecResLabel.h +++ b/cpp/src/scheduler/tasklabel/SpecResLabel.h @@ -22,24 +22,24 @@ namespace engine { class SpecResLabel : public TaskLabel { public: SpecResLabel(const ResourceWPtr &resource) - : TaskLabel(TaskLabelType::SPECIAL_RESOURCE), resource_(resource) {} + : TaskLabel(TaskLabelType::SPECIFIED_RESOURCE), resource_(resource) {} inline ResourceWPtr & - resource() const { + resource() { return resource_; } inline std::string & - resource_name() const { + resource_name() { return resource_name_; } private: ResourceWPtr resource_; std::string resource_name_; -} +}; -using SpecResLabelPtr = std::make_shared; +using SpecResLabelPtr = std::shared_ptr(); } } diff --git a/cpp/src/scheduler/tasklabel/TaskLabel.h b/cpp/src/scheduler/tasklabel/TaskLabel.h index 3f39b8ec12..84fd5ee77b 100644 --- a/cpp/src/scheduler/tasklabel/TaskLabel.h +++ b/cpp/src/scheduler/tasklabel/TaskLabel.h @@ -13,7 +13,7 @@ namespace engine { enum class TaskLabelType { DEFAULT, // means can be executed in any resource - SPECIAL_RESOURCE, // means must executing in special resource + SPECIFIED_RESOURCE, // means must executing in special resource BROADCAST, // means all enable-executor resource must execute task }; diff --git a/cpp/src/sdk/examples/grpcsimple/src/ClientTest.cpp b/cpp/src/sdk/examples/grpcsimple/src/ClientTest.cpp index 06b6e45fce..8198d5a232 100644 --- a/cpp/src/sdk/examples/grpcsimple/src/ClientTest.cpp +++ b/cpp/src/sdk/examples/grpcsimple/src/ClientTest.cpp @@ -18,175 +18,173 @@ using namespace milvus; //#define SET_VECTOR_IDS; namespace { - std::string GetTableName(); +std::string GetTableName(); - const std::string TABLE_NAME = GetTableName(); - constexpr int64_t TABLE_DIMENSION = 512; - constexpr int64_t TABLE_INDEX_FILE_SIZE = 768; - constexpr int64_t BATCH_ROW_COUNT = 1000000; - constexpr int64_t NQ = 100; - constexpr int64_t TOP_K = 10; - constexpr int64_t SEARCH_TARGET = 5000; //change this value, result is different - constexpr int64_t ADD_VECTOR_LOOP = 1; - constexpr int64_t SECONDS_EACH_HOUR = 3600; +const std::string TABLE_NAME = GetTableName(); +constexpr int64_t TABLE_DIMENSION = 512; +constexpr int64_t TABLE_INDEX_FILE_SIZE = 768; +constexpr int64_t BATCH_ROW_COUNT = 100000; +constexpr int64_t NQ = 100; +constexpr int64_t TOP_K = 10; +constexpr int64_t SEARCH_TARGET = 5000; //change this value, result is different +constexpr int64_t ADD_VECTOR_LOOP = 1; +constexpr int64_t SECONDS_EACH_HOUR = 3600; #define BLOCK_SPLITER std::cout << "===========================================" << std::endl; - void PrintTableSchema(const TableSchema& tb_schema) { - BLOCK_SPLITER - std::cout << "Table name: " << tb_schema.table_name << std::endl; - std::cout << "Table dimension: " << tb_schema.dimension << std::endl; - BLOCK_SPLITER - } +void PrintTableSchema(const TableSchema& tb_schema) { + BLOCK_SPLITER + std::cout << "Table name: " << tb_schema.table_name << std::endl; + std::cout << "Table dimension: " << tb_schema.dimension << std::endl; + BLOCK_SPLITER +} - void PrintSearchResult(const std::vector>& search_record_array, - const std::vector& topk_query_result_array) { - BLOCK_SPLITER - std::cout << "Returned result count: " << topk_query_result_array.size() << std::endl; +void PrintSearchResult(const std::vector>& search_record_array, + const std::vector& topk_query_result_array) { + BLOCK_SPLITER + std::cout << "Returned result count: " << topk_query_result_array.size() << std::endl; - int32_t index = 0; - for(auto& result : topk_query_result_array) { - auto search_id = search_record_array[index].first; - index++; - std::cout << "No." << std::to_string(index) << " vector " << std::to_string(search_id) - << " top " << std::to_string(result.query_result_arrays.size()) - << " search result:" << std::endl; - for(auto& item : result.query_result_arrays) { - std::cout << "\t" << std::to_string(item.id) << "\tdistance:" << std::to_string(item.distance); - std::cout << std::endl; - } - } - - BLOCK_SPLITER - } - - std::string CurrentTime() { - time_t tt; - time( &tt ); - tt = tt + 8*SECONDS_EACH_HOUR; - tm* t= gmtime( &tt ); - - std::string str = std::to_string(t->tm_year + 1900) + "_" + std::to_string(t->tm_mon + 1) - + "_" + std::to_string(t->tm_mday) + "_" + std::to_string(t->tm_hour) - + "_" + std::to_string(t->tm_min) + "_" + std::to_string(t->tm_sec); - - return str; - } - - std::string CurrentTmDate(int64_t offset_day = 0) { - time_t tt; - time( &tt ); - tt = tt + 8*SECONDS_EACH_HOUR; - tt = tt + 24*SECONDS_EACH_HOUR*offset_day; - tm* t= gmtime( &tt ); - - std::string str = std::to_string(t->tm_year + 1900) + "-" + std::to_string(t->tm_mon + 1) - + "-" + std::to_string(t->tm_mday); - - return str; - } - - std::string GetTableName() { - static std::string s_id(CurrentTime()); - return "tbl_" + s_id; - } - - TableSchema BuildTableSchema() { - TableSchema tb_schema; - tb_schema.table_name = TABLE_NAME; - tb_schema.dimension = TABLE_DIMENSION; - tb_schema.index_file_size = TABLE_INDEX_FILE_SIZE; - - return tb_schema; - } - - void BuildVectors(int64_t from, int64_t to, - std::vector& vector_record_array) { - if(to <= from){ - return; - } - - vector_record_array.clear(); - for (int64_t k = from; k < to; k++) { - RowRecord record; - record.data.resize(TABLE_DIMENSION); - for(int64_t i = 0; i < TABLE_DIMENSION; i++) { - record.data[i] = (float)(k%(i+1)); - } - - vector_record_array.emplace_back(record); + int32_t index = 0; + for(auto& result : topk_query_result_array) { + auto search_id = search_record_array[index].first; + index++; + std::cout << "No." << std::to_string(index) << " vector " << std::to_string(search_id) + << " top " << std::to_string(result.query_result_arrays.size()) + << " search result:" << std::endl; + for(auto& item : result.query_result_arrays) { + std::cout << "\t" << std::to_string(item.id) << "\tdistance:" << std::to_string(item.distance); + std::cout << std::endl; } } - void Sleep(int seconds) { - std::cout << "Waiting " << seconds << " seconds ..." << std::endl; - sleep(seconds); + BLOCK_SPLITER +} + +std::string CurrentTime() { + time_t tt; + time( &tt ); + tt = tt + 8*SECONDS_EACH_HOUR; + tm* t= gmtime( &tt ); + + std::string str = std::to_string(t->tm_year + 1900) + "_" + std::to_string(t->tm_mon + 1) + + "_" + std::to_string(t->tm_mday) + "_" + std::to_string(t->tm_hour) + + "_" + std::to_string(t->tm_min) + "_" + std::to_string(t->tm_sec); + + return str; +} + +std::string CurrentTmDate(int64_t offset_day = 0) { + time_t tt; + time( &tt ); + tt = tt + 8*SECONDS_EACH_HOUR; + tt = tt + 24*SECONDS_EACH_HOUR*offset_day; + tm* t= gmtime( &tt ); + + std::string str = std::to_string(t->tm_year + 1900) + "-" + std::to_string(t->tm_mon + 1) + + "-" + std::to_string(t->tm_mday); + + return str; +} + +std::string GetTableName() { + static std::string s_id(CurrentTime()); + return "tbl_" + s_id; +} + +TableSchema BuildTableSchema() { + TableSchema tb_schema; + tb_schema.table_name = TABLE_NAME; + tb_schema.dimension = TABLE_DIMENSION; + tb_schema.index_file_size = TABLE_INDEX_FILE_SIZE; + + return tb_schema; +} + +void BuildVectors(int64_t from, int64_t to, + std::vector& vector_record_array) { + if(to <= from){ + return; } - class TimeRecorder { - public: - explicit TimeRecorder(const std::string& title) - : title_(title) { - start_ = std::chrono::system_clock::now(); + vector_record_array.clear(); + for (int64_t k = from; k < to; k++) { + RowRecord record; + record.data.resize(TABLE_DIMENSION); + for(int64_t i = 0; i < TABLE_DIMENSION; i++) { + record.data[i] = (float)(k%(i+1)); } - ~TimeRecorder() { - std::chrono::system_clock::time_point end = std::chrono::system_clock::now(); - long span = (std::chrono::duration_cast (end - start_)).count(); - std::cout << title_ << " totally cost: " << span << " ms" << std::endl; - } + vector_record_array.emplace_back(record); + } +} - private: - std::string title_; - std::chrono::system_clock::time_point start_; - }; +void Sleep(int seconds) { + std::cout << "Waiting " << seconds << " seconds ..." << std::endl; + sleep(seconds); +} - void CheckResult(const std::vector>& search_record_array, - const std::vector& topk_query_result_array) { - BLOCK_SPLITER - int64_t index = 0; - for(auto& result : topk_query_result_array) { - auto result_id = result.query_result_arrays[0].id; - auto search_id = search_record_array[index++].first; - if(result_id != search_id) { - std::cout << "The top 1 result is wrong: " << result_id - << " vs. " << search_id << std::endl; - } else { - std::cout << "Check result sucessfully" << std::endl; - } - } - BLOCK_SPLITER +class TimeRecorder { + public: + explicit TimeRecorder(const std::string& title) + : title_(title) { + start_ = std::chrono::system_clock::now(); } - void DoSearch(std::shared_ptr conn, - const std::vector>& search_record_array, - const std::string& phase_name) { - std::vector query_range_array; - Range rg; - rg.start_value = CurrentTmDate(); - rg.end_value = CurrentTmDate(1); - query_range_array.emplace_back(rg); - - std::vector record_array; - for(auto& pair : search_record_array) { - record_array.push_back(pair.second); - } - - auto start = std::chrono::high_resolution_clock::now(); - for (auto i = 0; i < 5; ++i) { - std::vector topk_query_result_array; - { - TimeRecorder rc(phase_name); - Status stat = conn->Search(TABLE_NAME, record_array, query_range_array, TOP_K, 32, topk_query_result_array); - std::cout << "SearchVector function call status: " << stat.ToString() << std::endl; - } - } - auto finish = std::chrono::high_resolution_clock::now(); - std::cout << "SEARCHVECTOR COST: " << std::chrono::duration_cast>(finish - start).count() << "s\n"; - -// PrintSearchResult(search_record_array, topk_query_result_array); -// CheckResult(search_record_array, topk_query_result_array); + ~TimeRecorder() { + std::chrono::system_clock::time_point end = std::chrono::system_clock::now(); + long span = (std::chrono::duration_cast (end - start_)).count(); + std::cout << title_ << " totally cost: " << span << " ms" << std::endl; } + + private: + std::string title_; + std::chrono::system_clock::time_point start_; +}; + +void CheckResult(const std::vector>& search_record_array, + const std::vector& topk_query_result_array) { + BLOCK_SPLITER + int64_t index = 0; + for(auto& result : topk_query_result_array) { + auto result_id = result.query_result_arrays[0].id; + auto search_id = search_record_array[index++].first; + if(result_id != search_id) { + std::cout << "The top 1 result is wrong: " << result_id + << " vs. " << search_id << std::endl; + } else { + std::cout << "Check result sucessfully" << std::endl; + } + } + BLOCK_SPLITER +} + +void DoSearch(std::shared_ptr conn, + const std::vector>& search_record_array, + const std::string& phase_name) { + std::vector query_range_array; + Range rg; + rg.start_value = CurrentTmDate(); + rg.end_value = CurrentTmDate(1); + query_range_array.emplace_back(rg); + + std::vector record_array; + for(auto& pair : search_record_array) { + record_array.push_back(pair.second); + } + + auto start = std::chrono::high_resolution_clock::now(); + std::vector topk_query_result_array; + { + TimeRecorder rc(phase_name); + Status stat = conn->Search(TABLE_NAME, record_array, query_range_array, TOP_K, 32, topk_query_result_array); + std::cout << "SearchVector function call status: " << stat.ToString() << std::endl; + } + auto finish = std::chrono::high_resolution_clock::now(); + std::cout << "SEARCHVECTOR COST: " << std::chrono::duration_cast>(finish - start).count() << "s\n"; + + PrintSearchResult(search_record_array, topk_query_result_array); + CheckResult(search_record_array, topk_query_result_array); +} } void @@ -216,9 +214,9 @@ ClientTest::Test(const std::string& address, const std::string& port) { std::cout << "All tables: " << std::endl; for(auto& table : tables) { int64_t row_count = 0; -// conn->DropTable(table); - stat = conn->CountTable(table, row_count); - std::cout << "\t" << table << "(" << row_count << " rows)" << std::endl; + conn->DropTable(table); +// stat = conn->CountTable(table, row_count); +// std::cout << "\t" << table << "(" << row_count << " rows)" << std::endl; } } @@ -273,7 +271,7 @@ ClientTest::Test(const std::string& address, const std::string& port) { if(search_record_array.size() < NQ) { search_record_array.push_back( - std::make_pair(record_ids[SEARCH_TARGET], record_array[SEARCH_TARGET])); + std::make_pair(record_ids[SEARCH_TARGET], record_array[SEARCH_TARGET])); } } } @@ -345,4 +343,4 @@ ClientTest::Test(const std::string& address, const std::string& port) { std::string status = conn->ServerStatus(); std::cout << "Server status after disconnect: " << status << std::endl; } -} \ No newline at end of file +} diff --git a/cpp/src/server/ServerConfig.h b/cpp/src/server/ServerConfig.h index f8f6deea98..f60617ba26 100644 --- a/cpp/src/server/ServerConfig.h +++ b/cpp/src/server/ServerConfig.h @@ -45,9 +45,9 @@ static const char* CONFIG_METRIC_COLLECTOR = "collector"; static const char* CONFIG_PROMETHEUS = "prometheus_config"; static const char* CONFIG_METRIC_PROMETHEUS_PORT = "port"; -static const std::string CONFIG_ENGINE = "engine_config"; -static const std::string CONFIG_DCBT = "use_blas_threshold"; -static const std::string CONFIG_OMP_THREAD_NUM = "omp_thread_num"; +static const char* CONFIG_ENGINE = "engine_config"; +static const char* CONFIG_DCBT = "use_blas_threshold"; +static const char* CONFIG_OMP_THREAD_NUM = "omp_thread_num"; static const char* CONFIG_RESOURCE = "resource_config"; static const char* CONFIG_RESOURCES = "resources"; @@ -57,6 +57,8 @@ static const char* CONFIG_RESOURCE_DEVICE_ID = "device_id"; static const char* CONFIG_RESOURCE_ENABLE_LOADER = "enable_loader"; static const char* CONFIG_RESOURCE_ENABLE_EXECUTOR = "enable_executor"; static const char* CONFIG_RESOURCE_CONNECTIONS = "connections"; +static const char* CONFIG_SPEED_CONNECTIONS = "speed"; +static const char* CONFIG_ENDPOINT_CONNECTIONS = "endpoint"; class ServerConfig { diff --git a/cpp/src/server/grpc_impl/GrpcMilvusServer.cpp b/cpp/src/server/grpc_impl/GrpcMilvusServer.cpp index 737f3dab95..935f53b8a8 100644 --- a/cpp/src/server/grpc_impl/GrpcMilvusServer.cpp +++ b/cpp/src/server/grpc_impl/GrpcMilvusServer.cpp @@ -34,6 +34,17 @@ static std::unique_ptr<::grpc::Server> server; constexpr long MESSAGE_SIZE = -1; +class NoReusePortOption : public ::grpc::ServerBuilderOption { + public: + void UpdateArguments(::grpc::ChannelArguments* args) override { + args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0); + } + + void UpdatePlugins(std::vector>* + plugins) override {} +}; + + void GrpcMilvusServer::StartService() { if (server != nullptr) { @@ -52,6 +63,7 @@ GrpcMilvusServer::StartService() { std::string server_address(address + ":" + std::to_string(port)); ::grpc::ServerBuilder builder; + builder.SetOption(std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption)); builder.SetMaxReceiveMessageSize(MESSAGE_SIZE); //default 4 * 1024 * 1024 builder.SetMaxSendMessageSize(MESSAGE_SIZE); diff --git a/cpp/src/wrapper/knowhere/vec_impl.cpp b/cpp/src/wrapper/knowhere/vec_impl.cpp index 0b9855c639..0989178783 100644 --- a/cpp/src/wrapper/knowhere/vec_impl.cpp +++ b/cpp/src/wrapper/knowhere/vec_impl.cpp @@ -114,6 +114,7 @@ server::KnowhereError VecIndexImpl::Search(const long &nq, const float *xq, floa } zilliz::knowhere::BinarySet VecIndexImpl::Serialize() { + type = ConvertToCpuIndexType(type); return index_->Serialize(); } @@ -136,26 +137,23 @@ IndexType VecIndexImpl::GetType() { } VecIndexPtr VecIndexImpl::CopyToGpu(const int64_t &device_id, const Config &cfg) { - //if (auto new_type = GetGpuIndexType(type)) { - // auto device_index = index_->CopyToGpu(device_id); - // return std::make_shared(device_index, new_type); - //} - //return nullptr; - - // TODO(linxj): update type + // TODO(linxj): exception handle auto gpu_index = zilliz::knowhere::CopyCpuToGpu(index_, device_id, cfg); - auto new_index = std::make_shared(gpu_index, type); + auto new_index = std::make_shared(gpu_index, ConvertToGpuIndexType(type)); new_index->dim = dim; return new_index; } -// TODO(linxj): rename copytocpu => copygputocpu VecIndexPtr VecIndexImpl::CopyToCpu(const Config &cfg) { + // TODO(linxj): exception handle auto cpu_index = zilliz::knowhere::CopyGpuToCpu(index_, cfg); - return std::make_shared(cpu_index, type); + auto new_index = std::make_shared(cpu_index, ConvertToCpuIndexType(type)); + new_index->dim = dim; + return new_index; } VecIndexPtr VecIndexImpl::Clone() { + // TODO(linxj): exception handle auto clone_index = std::make_shared(index_->Clone(), type); clone_index->dim = dim; return clone_index; @@ -165,10 +163,8 @@ int64_t VecIndexImpl::GetDeviceId() { if (auto device_idx = std::dynamic_pointer_cast(index_)){ return device_idx->GetGpuDevice(); } - else { - return -1; // -1 == cpu - } - return 0; + // else + return -1; // -1 == cpu } float *BFIndex::GetRawVectors() { @@ -243,9 +239,10 @@ server::KnowhereError IVFMixIndex::BuildAll(const long &nb, if (auto device_index = std::dynamic_pointer_cast(index_)) { auto host_index = device_index->CopyGpuToCpu(Config()); index_ = host_index; - type = TransferToCpuIndexType(type); + type = ConvertToCpuIndexType(type); } else { WRAPPER_LOG_ERROR << "Build IVFMIXIndex Failed"; + return server::KNOWHERE_ERROR; } } catch (KnowhereException &e) { WRAPPER_LOG_ERROR << e.what(); @@ -261,7 +258,7 @@ server::KnowhereError IVFMixIndex::BuildAll(const long &nb, } server::KnowhereError IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) { - index_ = std::make_shared(); + //index_ = std::make_shared(); index_->Load(index_binary); dim = Dimension(); return server::KNOWHERE_SUCCESS; diff --git a/cpp/src/wrapper/knowhere/vec_index.cpp b/cpp/src/wrapper/knowhere/vec_index.cpp index 9ac0d8b3ad..0665ffc166 100644 --- a/cpp/src/wrapper/knowhere/vec_index.cpp +++ b/cpp/src/wrapper/knowhere/vec_index.cpp @@ -71,8 +71,9 @@ size_t FileIOWriter::operator()(void *ptr, size_t size) { } -VecIndexPtr GetVecIndexFactory(const IndexType &type) { +VecIndexPtr GetVecIndexFactory(const IndexType &type, const Config& cfg) { std::shared_ptr index; + auto gpu_device = cfg.get_with_default("gpu_id", 0); switch (type) { case IndexType::FAISS_IDMAP: { index = std::make_shared(); @@ -83,7 +84,8 @@ VecIndexPtr GetVecIndexFactory(const IndexType &type) { break; } case IndexType::FAISS_IVFFLAT_GPU: { - index = std::make_shared(0); + // TODO(linxj): 规范化参数 + index = std::make_shared(gpu_device); break; } case IndexType::FAISS_IVFFLAT_MIX: { @@ -95,7 +97,7 @@ VecIndexPtr GetVecIndexFactory(const IndexType &type) { break; } case IndexType::FAISS_IVFPQ_GPU: { - index = std::make_shared(0); + index = std::make_shared(gpu_device); break; } case IndexType::SPTAG_KDT_RNT_CPU: { @@ -103,15 +105,19 @@ VecIndexPtr GetVecIndexFactory(const IndexType &type) { break; } case IndexType::FAISS_IVFSQ8_MIX: { - index = std::make_shared(0); + index = std::make_shared(gpu_device); return std::make_shared(index, IndexType::FAISS_IVFSQ8_MIX); } - case IndexType::FAISS_IVFSQ8: { + case IndexType::FAISS_IVFSQ8_CPU: { index = std::make_shared(); break; } + case IndexType::FAISS_IVFSQ8_GPU: { + index = std::make_shared(gpu_device); + break; + } case IndexType::NSG_MIX: { // TODO(linxj): bug. - index = std::make_shared(0); + index = std::make_shared(gpu_device); break; } default: { @@ -229,20 +235,40 @@ void AutoGenParams(const IndexType &type, const long &size, zilliz::knowhere::Co } } -IndexType TransferToCpuIndexType(const IndexType &type) { +IndexType ConvertToCpuIndexType(const IndexType &type) { + // TODO(linxj): add IDMAP switch (type) { + case IndexType::FAISS_IVFFLAT_GPU: case IndexType::FAISS_IVFFLAT_MIX: { return IndexType::FAISS_IVFFLAT_CPU; } + case IndexType::FAISS_IVFSQ8_GPU: case IndexType::FAISS_IVFSQ8_MIX: { - return IndexType::FAISS_IVFSQ8; + return IndexType::FAISS_IVFSQ8_CPU; } default: { - return IndexType::INVALID; + return type; } } } +IndexType ConvertToGpuIndexType(const IndexType &type) { + switch (type) { + case IndexType::FAISS_IVFFLAT_MIX: + case IndexType::FAISS_IVFFLAT_CPU: { + return IndexType::FAISS_IVFFLAT_GPU; + } + case IndexType::FAISS_IVFSQ8_MIX: + case IndexType::FAISS_IVFSQ8_CPU: { + return IndexType::FAISS_IVFSQ8_GPU; + } + default: { + return type; + } + } +} + + } } } diff --git a/cpp/src/wrapper/knowhere/vec_index.h b/cpp/src/wrapper/knowhere/vec_index.h index 1c45ce89fc..c69106159a 100644 --- a/cpp/src/wrapper/knowhere/vec_index.h +++ b/cpp/src/wrapper/knowhere/vec_index.h @@ -32,7 +32,8 @@ enum class IndexType { FAISS_IVFPQ_GPU, SPTAG_KDT_RNT_CPU, FAISS_IVFSQ8_MIX, - FAISS_IVFSQ8, + FAISS_IVFSQ8_CPU, + FAISS_IVFSQ8_GPU, NSG_MIX, }; @@ -83,13 +84,14 @@ extern server::KnowhereError write_index(VecIndexPtr index, const std::string &l extern VecIndexPtr read_index(const std::string &location); -extern VecIndexPtr GetVecIndexFactory(const IndexType &type); +extern VecIndexPtr GetVecIndexFactory(const IndexType &type, const Config& cfg = Config()); extern VecIndexPtr LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::BinarySet &index_binary); extern void AutoGenParams(const IndexType& type, const long& size, Config& cfg); -extern IndexType TransferToCpuIndexType(const IndexType& type); +extern IndexType ConvertToCpuIndexType(const IndexType& type); +extern IndexType ConvertToGpuIndexType(const IndexType& type); } } diff --git a/cpp/unittest/db/db_tests.cpp b/cpp/unittest/db/db_tests.cpp index f37896df33..acf69a9aa1 100644 --- a/cpp/unittest/db/db_tests.cpp +++ b/cpp/unittest/db/db_tests.cpp @@ -141,7 +141,6 @@ TEST_F(DBTest, CONFIG_TEST) { TEST_F(DBTest, DB_TEST) { - db_->Open(GetOptions(), &db_); engine::meta::TableSchema table_info = BuildTableSchema(); engine::Status stat = db_->CreateTable(table_info); diff --git a/cpp/unittest/db/mysql_db_test.cpp b/cpp/unittest/db/mysql_db_test.cpp index 78adf9f0f5..cdad9b2275 100644 --- a/cpp/unittest/db/mysql_db_test.cpp +++ b/cpp/unittest/db/mysql_db_test.cpp @@ -46,11 +46,7 @@ namespace { } -TEST_F(DISABLED_MySQLDBTest, DB_TEST) { - - auto options = GetOptions(); - auto db_ = engine::DBFactory::Build(options); - +TEST_F(MySQLDBTest, DB_TEST) { engine::meta::TableSchema table_info = BuildTableSchema(); engine::Status stat = db_->CreateTable(table_info); @@ -115,6 +111,8 @@ TEST_F(DISABLED_MySQLDBTest, DB_TEST) { ASSERT_TRUE(count >= prev_count); std::this_thread::sleep_for(std::chrono::seconds(3)); } + + std::cout << "Search AAA done" << std::endl; }); int loop = INSERT_LOOP; @@ -131,18 +129,9 @@ TEST_F(DISABLED_MySQLDBTest, DB_TEST) { } search.join(); - - delete db_; - - auto dummyDB = engine::DBFactory::Build(options); - dummyDB->DropAll(); - delete dummyDB; }; -TEST_F(DISABLED_MySQLDBTest, SEARCH_TEST) { - auto options = GetOptions(); - auto db_ = engine::DBFactory::Build(options); - +TEST_F(MySQLDBTest, SEARCH_TEST) { engine::meta::TableSchema table_info = BuildTableSchema(); engine::Status stat = db_->CreateTable(table_info); @@ -192,22 +181,9 @@ TEST_F(DISABLED_MySQLDBTest, SEARCH_TEST) { engine::QueryResults results; stat = db_->Query(TABLE_NAME, k, nq, 10, xq.data(), results); ASSERT_STATS(stat); - - delete db_; - - auto dummyDB = engine::DBFactory::Build(options); - dummyDB->DropAll(); - delete dummyDB; - - // TODO(linxj): add groundTruth assert }; -TEST_F(DISABLED_MySQLDBTest, ARHIVE_DISK_CHECK) { - - auto options = GetOptions(); - options.meta.archive_conf = engine::ArchiveConf("delete", "disk:1"); - auto db_ = engine::DBFactory::Build(options); - +TEST_F(MySQLDBTest, ARHIVE_DISK_CHECK) { engine::meta::TableSchema table_info = BuildTableSchema(); engine::Status stat = db_->CreateTable(table_info); @@ -250,20 +226,9 @@ TEST_F(DISABLED_MySQLDBTest, ARHIVE_DISK_CHECK) { db_->Size(size); LOG(DEBUG) << "size=" << size; ASSERT_LE(size, 1 * engine::meta::G); - - delete db_; - - auto dummyDB = engine::DBFactory::Build(options); - dummyDB->DropAll(); - delete dummyDB; }; -TEST_F(DISABLED_MySQLDBTest, DELETE_TEST) { - - auto options = GetOptions(); - options.meta.archive_conf = engine::ArchiveConf("delete", "disk:1"); - auto db_ = engine::DBFactory::Build(options); - +TEST_F(MySQLDBTest, DELETE_TEST) { engine::meta::TableSchema table_info = BuildTableSchema(); engine::Status stat = db_->CreateTable(table_info); // std::cout << stat.ToString() << std::endl; @@ -301,10 +266,4 @@ TEST_F(DISABLED_MySQLDBTest, DELETE_TEST) { db_->HasTable(TABLE_NAME, has_table); ASSERT_FALSE(has_table); - - delete db_; - - auto dummyDB = engine::DBFactory::Build(options); - dummyDB->DropAll(); - delete dummyDB; }; diff --git a/cpp/unittest/db/mysql_meta_test.cpp b/cpp/unittest/db/mysql_meta_test.cpp index 2ad842a223..4960b20309 100644 --- a/cpp/unittest/db/mysql_meta_test.cpp +++ b/cpp/unittest/db/mysql_meta_test.cpp @@ -21,7 +21,7 @@ using namespace zilliz::milvus::engine; -TEST_F(DISABLED_MySQLTest, TABLE_TEST) { +TEST_F(MySQLTest, TABLE_TEST) { DBMetaOptions options; try { options = getDBMetaOptions(); @@ -53,7 +53,7 @@ TEST_F(DISABLED_MySQLTest, TABLE_TEST) { table.table_id_ = table_id; status = impl.CreateTable(table); - ASSERT_TRUE(status.ok()); + ASSERT_TRUE(status.IsAlreadyExist()); table.table_id_ = ""; status = impl.CreateTable(table); @@ -63,7 +63,7 @@ TEST_F(DISABLED_MySQLTest, TABLE_TEST) { ASSERT_TRUE(status.ok()); } -TEST_F(DISABLED_MySQLTest, TABLE_FILE_TEST) { +TEST_F(MySQLTest, TABLE_FILE_TEST) { DBMetaOptions options; try { options = getDBMetaOptions(); @@ -92,7 +92,7 @@ TEST_F(DISABLED_MySQLTest, TABLE_FILE_TEST) { meta::DatesT dates; dates.push_back(utils::GetDate()); status = impl.DropPartitionsByDates(table_file.table_id_, dates); - ASSERT_FALSE(status.ok()); + ASSERT_TRUE(status.ok()); uint64_t cnt = 0; status = impl.Count(table_id, cnt); @@ -139,7 +139,7 @@ TEST_F(DISABLED_MySQLTest, TABLE_FILE_TEST) { ASSERT_TRUE(status.ok()); } -TEST_F(DISABLED_MySQLTest, ARCHIVE_TEST_DAYS) { +TEST_F(MySQLTest, ARCHIVE_TEST_DAYS) { srand(time(0)); DBMetaOptions options; try { @@ -211,7 +211,7 @@ TEST_F(DISABLED_MySQLTest, ARCHIVE_TEST_DAYS) { ASSERT_TRUE(status.ok()); } -TEST_F(DISABLED_MySQLTest, ARCHIVE_TEST_DISK) { +TEST_F(MySQLTest, ARCHIVE_TEST_DISK) { DBMetaOptions options; try { options = getDBMetaOptions(); @@ -269,7 +269,7 @@ TEST_F(DISABLED_MySQLTest, ARCHIVE_TEST_DISK) { ASSERT_TRUE(status.ok()); } -TEST_F(DISABLED_MySQLTest, TABLE_FILES_TEST) { +TEST_F(MySQLTest, TABLE_FILES_TEST) { DBMetaOptions options; try { options = getDBMetaOptions(); diff --git a/cpp/unittest/db/utils.cpp b/cpp/unittest/db/utils.cpp index 9654245e40..bc9bf7dbb3 100644 --- a/cpp/unittest/db/utils.cpp +++ b/cpp/unittest/db/utils.cpp @@ -12,6 +12,7 @@ #include "utils.h" #include "db/Factories.h" #include "db/Options.h" +#include "server/ServerConfig.h" INITIALIZE_EASYLOGGINGPP @@ -60,6 +61,9 @@ engine::Options DBTest::GetOptions() { void DBTest::SetUp() { InitLog(); + server::ConfigNode& config = server::ServerConfig::GetInstance().GetConfig(server::CONFIG_CACHE); + config.AddSequenceItem(server::CONFIG_GPU_IDS, "0"); + auto res_mgr = engine::ResMgrInst::GetInstance(); res_mgr->Clear(); res_mgr->Add(engine::ResourceFactory::Create("disk", "DISK", 0, true, false)); @@ -103,7 +107,7 @@ void MetaTest::TearDown() { impl_->DropAll(); } -zilliz::milvus::engine::DBMetaOptions DISABLED_MySQLTest::getDBMetaOptions() { +zilliz::milvus::engine::DBMetaOptions MySQLTest::getDBMetaOptions() { // std::string path = "/tmp/milvus_test"; // engine::DBMetaOptions options = engine::DBMetaOptionsFactory::Build(path); zilliz::milvus::engine::DBMetaOptions options; @@ -111,17 +115,16 @@ zilliz::milvus::engine::DBMetaOptions DISABLED_MySQLTest::getDBMetaOptions() { options.backend_uri = DBTestEnvironment::getURI(); if(options.backend_uri.empty()) { -// throw std::exception(); options.backend_uri = "mysql://root:Fantast1c@192.168.1.194:3306/"; } return options; } -zilliz::milvus::engine::Options DISABLED_MySQLDBTest::GetOptions() { +zilliz::milvus::engine::Options MySQLDBTest::GetOptions() { auto options = engine::OptionsFactory::Build(); options.meta.path = "/tmp/milvus_test"; - options.meta.backend_uri = DBTestEnvironment::getURI(); + options.meta.backend_uri = "mysql://root:Fantast1c@192.168.1.194:3306/"; return options; } diff --git a/cpp/unittest/db/utils.h b/cpp/unittest/db/utils.h index 83f7abef5a..8432374ada 100644 --- a/cpp/unittest/db/utils.h +++ b/cpp/unittest/db/utils.h @@ -79,13 +79,13 @@ class MetaTest : public DBTest { virtual void TearDown() override; }; -class DISABLED_MySQLTest : public ::testing::Test { +class MySQLTest : public ::testing::Test { protected: // std::shared_ptr impl_; zilliz::milvus::engine::DBMetaOptions getDBMetaOptions(); }; -class DISABLED_MySQLDBTest : public ::testing::Test { +class MySQLDBTest : public DBTest { protected: zilliz::milvus::engine::Options GetOptions(); }; diff --git a/cpp/unittest/knowhere/knowhere_test.cpp b/cpp/unittest/knowhere/knowhere_test.cpp index 0b9e6717e1..e1eccb5c35 100644 --- a/cpp/unittest/knowhere/knowhere_test.cpp +++ b/cpp/unittest/knowhere/knowhere_test.cpp @@ -21,15 +21,15 @@ using ::testing::TestWithParam; using ::testing::Values; using ::testing::Combine; -constexpr int64_t DIM = 512; -constexpr int64_t NB = 1000000; +constexpr int64_t DIM = 128; +constexpr int64_t NB = 100000; +constexpr int64_t DEVICE_ID = 0; class KnowhereWrapperTest : public TestWithParam<::std::tuple> { protected: void SetUp() override { - zilliz::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(0); - zilliz::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(1); + zilliz::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICE_ID); std::string generator_type; std::tie(index_type, generator_type, dim, nb, nq, k, train_cfg, search_cfg) = GetParam(); @@ -90,29 +90,40 @@ class KnowhereWrapperTest INSTANTIATE_TEST_CASE_P(WrapperParam, KnowhereWrapperTest, Values( //["Index type", "Generator type", "dim", "nb", "nq", "k", "build config", "search config"] - //std::make_tuple(IndexType::FAISS_IVFFLAT_CPU, "Default", - // 64, 100000, 10, 10, - // Config::object{{"nlist", 100}, {"dim", 64}}, - // Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 10}} - //), - //std::make_tuple(IndexType::FAISS_IVFFLAT_GPU, "Default", - // 64, 10000, 10, 10, - // Config::object{{"nlist", 100}, {"dim", 64}}, - // Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 40}} - //), -// std::make_tuple(IndexType::FAISS_IVFFLAT_MIX, "Default", -// 64, 100000, 10, 10, -// Config::object{{"nlist", 1000}, {"dim", 64}, {"metric_type", "L2"}}, -// Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 5}} -// ), -// std::make_tuple(IndexType::FAISS_IDMAP, "Default", -// 64, 100000, 10, 10, -// Config::object{{"dim", 64}, {"metric_type", "L2"}}, -// Config::object{{"dim", 64}, {"k", 10}} -// ), + std::make_tuple(IndexType::FAISS_IVFFLAT_CPU, "Default", + 64, 100000, 10, 10, + Config::object{{"nlist", 100}, {"dim", 64}, {"metric_type", "L2"}}, + Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 10}} + ), + // to_gpu_test Failed + std::make_tuple(IndexType::FAISS_IVFFLAT_GPU, "Default", + DIM, NB, 10, 10, + Config::object{{"nlist", 100}, {"dim", DIM}, {"metric_type", "L2"}, {"gpu_id", DEVICE_ID}}, + Config::object{{"dim", DIM}, {"k", 10}, {"nprobe", 40}} + ), + std::make_tuple(IndexType::FAISS_IVFFLAT_MIX, "Default", + 64, 100000, 10, 10, + Config::object{{"nlist", 1000}, {"dim", 64}, {"metric_type", "L2"}}, + Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 5}} + ), + std::make_tuple(IndexType::FAISS_IDMAP, "Default", + 64, 100000, 10, 10, + Config::object{{"dim", 64}, {"metric_type", "L2"}}, + Config::object{{"dim", 64}, {"k", 10}} + ), + std::make_tuple(IndexType::FAISS_IVFSQ8_CPU, "Default", + DIM, NB, 10, 10, + Config::object{{"dim", DIM}, {"nlist", 1000}, {"nbits", 8}, {"metric_type", "L2"}, {"gpu_id", DEVICE_ID}}, + Config::object{{"dim", DIM}, {"k", 10}, {"nprobe", 5}} + ), + std::make_tuple(IndexType::FAISS_IVFSQ8_GPU, "Default", + DIM, NB, 10, 10, + Config::object{{"dim", DIM}, {"nlist", 1000}, {"nbits", 8}, {"metric_type", "L2"}, {"gpu_id", DEVICE_ID}}, + Config::object{{"dim", DIM}, {"k", 10}, {"nprobe", 5}} + ), std::make_tuple(IndexType::FAISS_IVFSQ8_MIX, "Default", DIM, NB, 10, 10, - Config::object{{"dim", DIM}, {"nlist", 1000}, {"nbits", 8}, {"metric_type", "L2"}}, + Config::object{{"dim", DIM}, {"nlist", 1000}, {"nbits", 8}, {"metric_type", "L2"}, {"gpu_id", DEVICE_ID}}, Config::object{{"dim", DIM}, {"k", 10}, {"nprobe", 5}} ) // std::make_tuple(IndexType::NSG_MIX, "Default", @@ -151,19 +162,30 @@ TEST_P(KnowhereWrapperTest, to_gpu_test) { index_->BuildAll(nb, xb.data(), ids.data(), train_cfg); index_->Search(nq, xq.data(), res_dis.data(), res_ids.data(), search_cfg); AssertResult(res_ids, res_dis); + { - index_->CopyToGpu(1); + auto dev_idx = index_->CopyToGpu(DEVICE_ID); + for (int i = 0; i < 10; ++i) { + dev_idx->Search(nq, xq.data(), res_dis.data(), res_ids.data(), search_cfg); + } + AssertResult(res_ids, res_dis); } - std::string file_location = "/tmp/whatever"; - write_index(index_, file_location); - auto new_index = read_index(file_location); + { + std::string file_location = "/tmp/test_gpu_file"; + write_index(index_, file_location); + auto new_index = read_index(file_location); - auto dev_idx = new_index->CopyToGpu(1); - for (int i = 0; i < 10000; ++i) { - dev_idx->Search(nq, xq.data(), res_dis.data(), res_ids.data(), search_cfg); + auto dev_idx = new_index->CopyToGpu(DEVICE_ID); + for (int i = 0; i < 10; ++i) { + dev_idx->Search(nq, xq.data(), res_dis.data(), res_ids.data(), search_cfg); + } + AssertResult(res_ids, res_dis); } - AssertResult(res_ids, res_dis); +} + +TEST_P(KnowhereWrapperTest, to_cpu_test) { + // dev } TEST_P(KnowhereWrapperTest, serialize) { @@ -194,7 +216,7 @@ TEST_P(KnowhereWrapperTest, serialize) { std::string file_location = "/tmp/whatever"; write_index(index_, file_location); auto new_index = read_index(file_location); - EXPECT_EQ(new_index->GetType(), index_type); + EXPECT_EQ(new_index->GetType(), ConvertToCpuIndexType(index_type)); EXPECT_EQ(new_index->Dimension(), index_->Dimension()); EXPECT_EQ(new_index->Count(), index_->Count()); diff --git a/cpp/unittest/scheduler/algorithm_test.cpp b/cpp/unittest/scheduler/algorithm_test.cpp new file mode 100644 index 0000000000..43e6c4eae2 --- /dev/null +++ b/cpp/unittest/scheduler/algorithm_test.cpp @@ -0,0 +1,99 @@ +/******************************************************************************* + * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved + * Unauthorized copying of this file, via any medium is strictly prohibited. + * Proprietary and confidential. + ******************************************************************************/ + +#include + +#include "scheduler/resource/Resource.h" +#include "scheduler/ResourceMgr.h" +#include "scheduler/resource/CpuResource.h" +#include "scheduler/ResourceFactory.h" +#include "scheduler/Algorithm.h" + +namespace zilliz { +namespace milvus { +namespace engine { + +class AlgorithmTest : public testing::Test { + protected: + void + SetUp() override { + ResourcePtr disk = ResourceFactory::Create("disk", "DISK", 0, true, false); + ResourcePtr cpu0 = ResourceFactory::Create("cpu0", "CPU", 0, true, true); + ResourcePtr cpu1 = ResourceFactory::Create("cpu1", "CPU", 1); + ResourcePtr cpu2 = ResourceFactory::Create("cpu2", "CPU", 2); + ResourcePtr gpu0 = ResourceFactory::Create("gpu0", "GPU", 0); + ResourcePtr gpu1 = ResourceFactory::Create("gpu1", "GPU", 1); + + res_mgr_ = std::make_shared(); + disk_ = res_mgr_->Add(std::move(disk)); + cpu_0_ = res_mgr_->Add(std::move(cpu0)); + cpu_1_ = res_mgr_->Add(std::move(cpu1)); + cpu_2_ = res_mgr_->Add(std::move(cpu2)); + gpu_0_ = res_mgr_->Add(std::move(gpu0)); + gpu_1_ = res_mgr_->Add(std::move(gpu1)); + auto IO = Connection("IO", 5.0); + auto PCIE = Connection("PCIE", 11.0); + res_mgr_->Connect("disk", "cpu0", IO); + res_mgr_->Connect("cpu0", "cpu1", IO); + res_mgr_->Connect("cpu1", "cpu2", IO); + res_mgr_->Connect("cpu0", "cpu2", IO); + res_mgr_->Connect("cpu1", "gpu0", PCIE); + res_mgr_->Connect("cpu2", "gpu1", PCIE); + } + + ResourceWPtr disk_; + ResourceWPtr cpu_0_; + ResourceWPtr cpu_1_; + ResourceWPtr cpu_2_; + ResourceWPtr gpu_0_; + ResourceWPtr gpu_1_; + ResourceMgrPtr res_mgr_; +}; + +TEST_F(AlgorithmTest, ShortestPath_test) { + std::vector sp; + uint64_t cost; + cost = ShortestPath(disk_.lock(), gpu_0_.lock(), res_mgr_, sp); + while (!sp.empty()) { + std::cout << sp[sp.size() - 1] << std::endl; + sp.pop_back(); + } + + std::cout << "************************************\n"; + cost = ShortestPath(cpu_0_.lock(), gpu_0_.lock(), res_mgr_, sp); + while (!sp.empty()) { + std::cout << sp[sp.size() - 1] << std::endl; + sp.pop_back(); + } + + std::cout << "************************************\n"; + cost = ShortestPath(disk_.lock(), disk_.lock(), res_mgr_, sp); + while (!sp.empty()) { + std::cout << sp[sp.size() - 1] << std::endl; + sp.pop_back(); + } + + std::cout << "************************************\n"; + cost = ShortestPath(cpu_0_.lock(), disk_.lock(), res_mgr_, sp); + while (!sp.empty()) { + std::cout << sp[sp.size() - 1] << std::endl; + sp.pop_back(); + } + + std::cout << "************************************\n"; + cost = ShortestPath(cpu_2_.lock(), gpu_0_.lock(), res_mgr_, sp); + while (!sp.empty()) { + std::cout << sp[sp.size() - 1] << std::endl; + sp.pop_back(); + } + + +} + + +} +} +} \ No newline at end of file diff --git a/cpp/unittest/scheduler/cost_test.cpp b/cpp/unittest/scheduler/cost_test.cpp deleted file mode 100644 index 1a625d786e..0000000000 --- a/cpp/unittest/scheduler/cost_test.cpp +++ /dev/null @@ -1,49 +0,0 @@ -#include "scheduler/TaskTable.h" -#include "scheduler/Cost.h" -#include -#include "scheduler/task/TestTask.h" - - -using namespace zilliz::milvus::engine; - -class CostTest : public ::testing::Test { -protected: - void - SetUp() override { - TableFileSchemaPtr dummy = nullptr; - for (uint64_t i = 0; i < 8; ++i) { - auto task = std::make_shared(dummy); - table_.Put(task); - } - table_.Get(0)->state = TaskTableItemState::INVALID; - table_.Get(1)->state = TaskTableItemState::START; - table_.Get(2)->state = TaskTableItemState::LOADING; - table_.Get(3)->state = TaskTableItemState::LOADED; - table_.Get(4)->state = TaskTableItemState::EXECUTING; - table_.Get(5)->state = TaskTableItemState::EXECUTED; - table_.Get(6)->state = TaskTableItemState::MOVING; - table_.Get(7)->state = TaskTableItemState::MOVED; - } - - - TaskTable table_; -}; - -TEST_F(CostTest, pick_to_move) { - CacheMgr cache; - auto indexes = PickToMove(table_, cache, 10); - ASSERT_EQ(indexes.size(), 1); - ASSERT_EQ(indexes[0], 3); -} - -TEST_F(CostTest, pick_to_load) { - auto indexes = PickToLoad(table_, 10); - ASSERT_EQ(indexes.size(), 1); - ASSERT_EQ(indexes[0], 1); -} - -TEST_F(CostTest, pick_to_executed) { - auto indexes = PickToExecute(table_, 10); - ASSERT_EQ(indexes.size(), 1); - ASSERT_EQ(indexes[0], 3); -} diff --git a/cpp/unittest/scheduler/resource_test.cpp b/cpp/unittest/scheduler/resource_test.cpp index fd6017fadd..d1e7114ccb 100644 --- a/cpp/unittest/scheduler/resource_test.cpp +++ b/cpp/unittest/scheduler/resource_test.cpp @@ -30,7 +30,7 @@ protected: resources_.push_back(gpu_resource_); auto subscriber = [&](EventPtr event) { - if (event->Type() == EventType::COPY_COMPLETED) { + if (event->Type() == EventType::LOAD_COMPLETED) { std::lock_guard lock(load_mutex_); ++load_count_; cv_.notify_one(); diff --git a/cpp/unittest/scheduler/scheduler_test.cpp b/cpp/unittest/scheduler/scheduler_test.cpp index 5335dc8de6..b7d2ba3be3 100644 --- a/cpp/unittest/scheduler/scheduler_test.cpp +++ b/cpp/unittest/scheduler/scheduler_test.cpp @@ -13,6 +13,7 @@ #include "scheduler/resource/Resource.h" #include "utils/Error.h" #include "wrapper/knowhere/vec_index.h" +#include "scheduler/tasklabel/SpecResLabel.h" namespace zilliz { namespace milvus { @@ -122,9 +123,6 @@ protected: ResourceMgrPtr res_mgr_; std::shared_ptr scheduler_; - uint64_t load_count_ = 0; - std::mutex load_mutex_; - std::condition_variable cv_; }; void @@ -155,6 +153,94 @@ TEST_F(SchedulerTest, OnCopyCompleted) { sleep(3); ASSERT_EQ(res_mgr_->GetResource(ResourceType::GPU, 1)->task_table().Size(), NUM); + +} + +TEST_F(SchedulerTest, PushTaskToNeighbourRandomlyTest) { + const uint64_t NUM = 10; + std::vector> tasks; + TableFileSchemaPtr dummy1 = std::make_shared(); + dummy1->location_ = "location"; + + tasks.clear(); + + for (uint64_t i = 0; i < NUM; ++i) { + auto task = std::make_shared(dummy1); + task->label() = std::make_shared(); + tasks.push_back(task); + cpu_resource_.lock()->task_table().Put(task); + } + + sleep(3); +// ASSERT_EQ(res_mgr_->GetResource(ResourceType::GPU, 1)->task_table().Size(), NUM); +} + +class SchedulerTest2 : public testing::Test { + protected: + void + SetUp() override { + ResourcePtr disk = ResourceFactory::Create("disk", "DISK", 0, true, false); + ResourcePtr cpu0 = ResourceFactory::Create("cpu0", "CPU", 0, true, false); + ResourcePtr cpu1 = ResourceFactory::Create("cpu1", "CPU", 1, true, false); + ResourcePtr cpu2 = ResourceFactory::Create("cpu2", "CPU", 2, true, false); + ResourcePtr gpu0 = ResourceFactory::Create("gpu0", "GPU", 0, true, true); + ResourcePtr gpu1 = ResourceFactory::Create("gpu1", "GPU", 1, true, true); + + res_mgr_ = std::make_shared(); + disk_ = res_mgr_->Add(std::move(disk)); + cpu_0_ = res_mgr_->Add(std::move(cpu0)); + cpu_1_ = res_mgr_->Add(std::move(cpu1)); + cpu_2_ = res_mgr_->Add(std::move(cpu2)); + gpu_0_ = res_mgr_->Add(std::move(gpu0)); + gpu_1_ = res_mgr_->Add(std::move(gpu1)); + auto IO = Connection("IO", 5.0); + auto PCIE1 = Connection("PCIE", 11.0); + auto PCIE2 = Connection("PCIE", 20.0); + res_mgr_->Connect("disk", "cpu0", IO); + res_mgr_->Connect("cpu0", "cpu1", IO); + res_mgr_->Connect("cpu1", "cpu2", IO); + res_mgr_->Connect("cpu0", "cpu2", IO); + res_mgr_->Connect("cpu1", "gpu0", PCIE1); + res_mgr_->Connect("cpu2", "gpu1", PCIE2); + + scheduler_ = std::make_shared(res_mgr_); + + res_mgr_->Start(); + scheduler_->Start(); + } + + void + TearDown() override { + scheduler_->Stop(); + res_mgr_->Stop(); + } + + ResourceWPtr disk_; + ResourceWPtr cpu_0_; + ResourceWPtr cpu_1_; + ResourceWPtr cpu_2_; + ResourceWPtr gpu_0_; + ResourceWPtr gpu_1_; + ResourceMgrPtr res_mgr_; + + std::shared_ptr scheduler_; +}; + + +TEST_F(SchedulerTest2, SpecifiedResourceTest) { + const uint64_t NUM = 10; + std::vector> tasks; + TableFileSchemaPtr dummy = std::make_shared(); + dummy->location_ = "location"; + + for (uint64_t i = 0; i < NUM; ++i) { + std::shared_ptr task = std::make_shared(dummy); + task->label() = std::make_shared(disk_); + tasks.push_back(task); + disk_.lock()->task_table().Put(task); + } + +// ASSERT_EQ(res_mgr_->GetResource(ResourceType::GPU, 1)->task_table().Size(), NUM); } }