From 11cb43ac2ed0fb55ed81e1b7fdf58a6818921296 Mon Sep 17 00:00:00 2001 From: groot Date: Thu, 25 Apr 2019 09:38:27 +0800 Subject: [PATCH 1/2] add scheduler Former-commit-id: aead7396cc627fc680408188d3182fd098b5271d --- cpp/src/server/VecServiceHandler.cpp | 106 ++++-------- cpp/src/server/VecServiceScheduler.cpp | 107 +++++++++++- cpp/src/server/VecServiceScheduler.h | 56 ++++++ cpp/src/server/VecServiceTask.cpp | 229 +++++++++++++++++++++++++ cpp/src/server/VecServiceTask.h | 116 +++++++++++++ 5 files changed, 537 insertions(+), 77 deletions(-) create mode 100644 cpp/src/server/VecServiceTask.cpp create mode 100644 cpp/src/server/VecServiceTask.h diff --git a/cpp/src/server/VecServiceHandler.cpp b/cpp/src/server/VecServiceHandler.cpp index c7892976c2..798f0de0e1 100644 --- a/cpp/src/server/VecServiceHandler.cpp +++ b/cpp/src/server/VecServiceHandler.cpp @@ -4,6 +4,7 @@ * Proprietary and confidential. ******************************************************************************/ #include "VecServiceHandler.h" +#include "VecServiceTask.h" #include "ServerConfig.h" #include "VecIdMapper.h" #include "utils/Log.h" @@ -34,19 +35,11 @@ VecServiceHandler::add_group(const VecGroup &group) { SERVER_LOG_TRACE << "group.id = " << group.id << ", group.dimension = " << group.dimension << ", group.index_type = " << group.index_type; - try { - engine::meta::GroupSchema group_info; - group_info.dimension = (size_t)group.dimension; - group_info.group_id = group.id; - engine::Status stat = db_->add_group(group_info); - if(!stat.ok()) { - SERVER_LOG_ERROR << "Engine failed: " << stat.ToString(); - } + BaseTaskPtr task_ptr = AddGroupTask::Create(group.dimension, group.id); + VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance(); + scheduler.ExecuteTask(task_ptr); - SERVER_LOG_INFO << "add_group() finished"; - } catch (std::exception& ex) { - SERVER_LOG_ERROR << ex.what(); - } + SERVER_LOG_INFO << "add_group() finished"; } void @@ -54,21 +47,12 @@ VecServiceHandler::get_group(VecGroup &_return, const std::string &group_id) { SERVER_LOG_INFO << "get_group() called"; SERVER_LOG_TRACE << "group_id = " << group_id; - try { - engine::meta::GroupSchema group_info; - group_info.group_id = group_id; - engine::Status stat = db_->get_group(group_info); - if(!stat.ok()) { - SERVER_LOG_ERROR << "Engine failed: " << stat.ToString(); - } else { - _return.id = group_info.group_id; - _return.dimension = (int32_t)group_info.dimension; - } + _return.id = group_id; + BaseTaskPtr task_ptr = GetGroupTask::Create(group_id, _return.dimension); + VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance(); + scheduler.ExecuteTask(task_ptr); - SERVER_LOG_INFO << "get_group() finished"; - } catch (std::exception& ex) { - SERVER_LOG_ERROR << ex.what(); - } + SERVER_LOG_INFO << "get_group() finished"; } void @@ -76,12 +60,11 @@ VecServiceHandler::del_group(const std::string &group_id) { SERVER_LOG_INFO << "del_group() called"; SERVER_LOG_TRACE << "group_id = " << group_id; - try { + BaseTaskPtr task_ptr = DeleteGroupTask::Create(group_id); + VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance(); + scheduler.ExecuteTask(task_ptr); - SERVER_LOG_INFO << "del_group() not implemented"; - } catch (std::exception& ex) { - SERVER_LOG_ERROR << ex.what(); - } + SERVER_LOG_INFO << "del_group() not implemented"; } @@ -90,25 +73,13 @@ VecServiceHandler::add_vector(const std::string &group_id, const VecTensor &tens SERVER_LOG_INFO << "add_vector() called"; SERVER_LOG_TRACE << "group_id = " << group_id << ", vector size = " << tensor.tensor.size(); - try { - engine::IDNumbers vector_ids; - std::vector vec_f(tensor.tensor.begin(), tensor.tensor.end()); - engine::Status stat = db_->add_vectors(group_id, 1, vec_f.data(), vector_ids); - if(!stat.ok()) { - SERVER_LOG_ERROR << "Engine failed: " << stat.ToString(); - } else { - if(vector_ids.size() != 1) { - SERVER_LOG_ERROR << "Vector ID not returned"; - } else { - std::string nid = group_id + "_" + std::to_string(vector_ids[0]); - IVecIdMapper::GetInstance()->Put(nid, tensor.uid); - } - } + VecTensorList tensor_list; + tensor_list.tensor_list.push_back(tensor); + BaseTaskPtr task_ptr = AddVectorTask::Create(group_id, tensor_list); + VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance(); + scheduler.ExecuteTask(task_ptr); - SERVER_LOG_INFO << "add_vector() finished"; - } catch (std::exception& ex) { - SERVER_LOG_ERROR << ex.what(); - } + SERVER_LOG_INFO << "add_vector() finished"; } void @@ -118,32 +89,11 @@ VecServiceHandler::add_vector_batch(const std::string &group_id, SERVER_LOG_TRACE << "group_id = " << group_id << ", vector list size = " << tensor_list.tensor_list.size(); - try { - std::vector vec_f; - for(const VecTensor& tensor : tensor_list.tensor_list) { - vec_f.insert(vec_f.begin(), tensor.tensor.begin(), tensor.tensor.end()); - } + BaseTaskPtr task_ptr = AddVectorTask::Create(group_id, tensor_list); + VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance(); + scheduler.ExecuteTask(task_ptr); - engine::IDNumbers vector_ids; - engine::Status stat = db_->add_vectors(group_id, tensor_list.tensor_list.size(), vec_f.data(), vector_ids); - if(!stat.ok()) { - SERVER_LOG_ERROR << "Engine failed: " << stat.ToString(); - } else { - if(vector_ids.size() != tensor_list.tensor_list.size()) { - SERVER_LOG_ERROR << "Vector ID not returned"; - } else { - std::string nid_prefix = group_id + "_"; - for(size_t i = 0; i < vector_ids.size(); i++) { - std::string nid = nid_prefix + std::to_string(vector_ids[i]); - IVecIdMapper::GetInstance()->Put(nid, tensor_list.tensor_list[i].uid); - } - } - } - - SERVER_LOG_INFO << "add_vector_batch() finished"; - } catch (std::exception& ex) { - SERVER_LOG_ERROR << ex.what(); - } + SERVER_LOG_INFO << "add_vector_batch() finished"; } @@ -177,10 +127,12 @@ VecServiceHandler::search_vector(VecSearchResult &_return, } } - SERVER_LOG_INFO << "search_vector() finished"; + } catch (std::exception& ex) { SERVER_LOG_ERROR << ex.what(); } + + SERVER_LOG_INFO << "search_vector() finished"; } void @@ -220,10 +172,12 @@ VecServiceHandler::search_vector_batch(VecSearchResultList &_return, } } - SERVER_LOG_INFO << "search_vector_batch() finished"; + } catch (std::exception& ex) { SERVER_LOG_ERROR << ex.what(); } + + SERVER_LOG_INFO << "search_vector_batch() finished"; } VecServiceHandler::~VecServiceHandler() { diff --git a/cpp/src/server/VecServiceScheduler.cpp b/cpp/src/server/VecServiceScheduler.cpp index b4dd506397..5acadc08a1 100644 --- a/cpp/src/server/VecServiceScheduler.cpp +++ b/cpp/src/server/VecServiceScheduler.cpp @@ -4,17 +4,122 @@ * Proprietary and confidential. ******************************************************************************/ #include "VecServiceScheduler.h" +#include "utils/Log.h" namespace zilliz { namespace vecwise { namespace server { -VecServiceScheduler::VecServiceScheduler() { +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +BaseTask::BaseTask(const std::string& task_group) + : task_group_(task_group), + done_(false), + error_code_(SERVER_SUCCESS) { } +BaseTask::~BaseTask() { + WaitToFinish(); +} + +ServerError BaseTask::Execute() { + error_code_ = OnExecute(); + done_ = true; + return error_code_; +} + +ServerError BaseTask::WaitToFinish() { + std::unique_lock lock(finish_mtx_); + finish_cond_.wait(lock, [this] { return done_; }); + + return error_code_; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +VecServiceScheduler::VecServiceScheduler() +: stopped_(false) { + Start(); +} + VecServiceScheduler::~VecServiceScheduler() { + Stop(); +} +void VecServiceScheduler::Start() { + if(!stopped_) { + return; + } + + stopped_ = false; +} + +void VecServiceScheduler::Stop() { + { + std::lock_guard lock(queue_mtx_); + for(auto iter : task_groups_) { + if(iter.second != nullptr) { + iter.second->Put(nullptr); + } + } + } + + for(auto iter : execute_threads_) { + if(iter == nullptr) + continue; + + iter->join(); + } + stopped_ = true; +} + +ServerError VecServiceScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) { + if(task_ptr == nullptr) { + return SERVER_NULL_POINTER; + } + + return PutTaskToQueue(task_ptr); +} + +namespace { + void TakeTaskToExecute(TaskQueuePtr task_queue) { + if(task_queue == nullptr) { + return; + } + + while(true) { + BaseTaskPtr task = task_queue->Take(); + if (task == nullptr) { + break;//stop the thread + } + + try { + ServerError err = task->Execute(); + if(err != SERVER_SUCCESS) { + SERVER_LOG_ERROR << "Task failed with code: " << err; + } + } catch (std::exception& ex) { + SERVER_LOG_ERROR << "Task failed to execute: " << ex.what(); + } + } + } +} + +ServerError VecServiceScheduler::PutTaskToQueue(const BaseTaskPtr& task_ptr) { + std::lock_guard lock(queue_mtx_); + + std::string group_name = task_ptr->TaskGroup(); + if(task_groups_.count(group_name) > 0) { + task_groups_[group_name]->Put(task_ptr); + } else { + TaskQueuePtr queue = std::make_shared(); + queue->Put(task_ptr); + task_groups_.insert(std::make_pair(group_name, queue)); + + //start a thread + ThreadPtr thread = std::make_shared(&TakeTaskToExecute, queue); + execute_threads_.push_back(thread); + SERVER_LOG_INFO << "Create new thread for task group: " << group_name; + } } } diff --git a/cpp/src/server/VecServiceScheduler.h b/cpp/src/server/VecServiceScheduler.h index e5819aab23..bbbcd151ac 100644 --- a/cpp/src/server/VecServiceScheduler.h +++ b/cpp/src/server/VecServiceScheduler.h @@ -5,15 +5,71 @@ ******************************************************************************/ #pragma once +#include "utils/BlockingQueue.h" + +#include +#include +#include + namespace zilliz { namespace vecwise { namespace server { +class BaseTask { +protected: + BaseTask(const std::string& task_group); + virtual ~BaseTask(); + +public: + ServerError Execute(); + ServerError WaitToFinish(); + + std::string TaskGroup() const { return task_group_; } + + ServerError ErrorCode() const { return error_code_; } +protected: + virtual ServerError OnExecute() = 0; + +protected: + mutable std::mutex finish_mtx_; + std::condition_variable finish_cond_; + + std::string task_group_; + bool done_; + ServerError error_code_; +}; + +using BaseTaskPtr = std::shared_ptr; +using TaskQueue = BlockingQueue; +using TaskQueuePtr = std::shared_ptr; +using ThreadPtr = std::shared_ptr; + class VecServiceScheduler { public: + static VecServiceScheduler& GetInstance() { + static VecServiceScheduler scheduler; + return scheduler; + } + + void Start(); + void Stop(); + + ServerError ExecuteTask(const BaseTaskPtr& task_ptr); + +protected: VecServiceScheduler(); virtual ~VecServiceScheduler(); + ServerError PutTaskToQueue(const BaseTaskPtr& task_ptr); + +private: + mutable std::mutex queue_mtx_; + + std::map task_groups_; + + std::vector execute_threads_; + + bool stopped_; }; diff --git a/cpp/src/server/VecServiceTask.cpp b/cpp/src/server/VecServiceTask.cpp new file mode 100644 index 0000000000..5e80132fbd --- /dev/null +++ b/cpp/src/server/VecServiceTask.cpp @@ -0,0 +1,229 @@ +/******************************************************************************* + * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved + * Unauthorized copying of this file, via any medium is strictly prohibited. + * Proprietary and confidential. + ******************************************************************************/ +#include "VecServiceTask.h" +#include "ServerConfig.h" +#include "VecIdMapper.h" +#include "utils/CommonUtil.h" +#include "utils/Log.h" +#include "db/DB.h" +#include "db/Env.h" + +namespace zilliz { +namespace vecwise { +namespace server { + +static const std::string NORMAL_TASK_GROUP = "normal"; + +namespace { + class DBWrapper { + public: + DBWrapper() { + zilliz::vecwise::engine::Options opt; + ConfigNode& config = ServerConfig::GetInstance().GetConfig(CONFIG_SERVER); + opt.meta.backend_uri = config.GetValue(CONFIG_SERVER_DB_URL); + std::string db_path = config.GetValue(CONFIG_SERVER_DB_PATH); + opt.meta.path = db_path + "/db"; + + CommonUtil::CreateDirectory(opt.meta.path); + + zilliz::vecwise::engine::DB::Open(opt, &db_); + if(db_ == nullptr) { + SERVER_LOG_ERROR << "Failed to open db"; + throw ServerException(SERVER_NULL_POINTER, "Failed to open db"); + } + } + + zilliz::vecwise::engine::DB* DB() { return db_; } + + private: + zilliz::vecwise::engine::DB* db_ = nullptr; + }; + + zilliz::vecwise::engine::DB* DB() { + static DBWrapper db_wrapper; + return db_wrapper.DB(); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +AddGroupTask::AddGroupTask(int32_t dimension, + const std::string& group_id) +: BaseTask(NORMAL_TASK_GROUP), + dimension_(dimension), + group_id_(group_id) { + +} + +BaseTaskPtr AddGroupTask::Create(int32_t dimension, + const std::string& group_id) { + return std::shared_ptr(new AddGroupTask(dimension,group_id)); +} + +ServerError AddGroupTask::OnExecute() { + try { + engine::meta::GroupSchema group_info; + group_info.dimension = (size_t)dimension_; + group_info.group_id = group_id_; + engine::Status stat = DB()->add_group(group_info); + if(!stat.ok()) { + SERVER_LOG_ERROR << "Engine failed: " << stat.ToString(); + } + + } catch (std::exception& ex) { + SERVER_LOG_ERROR << ex.what(); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +GetGroupTask::GetGroupTask(const std::string& group_id, int32_t& dimension) + : BaseTask(NORMAL_TASK_GROUP), + group_id_(group_id), + dimension_(dimension) { + +} + +BaseTaskPtr GetGroupTask::Create(const std::string& group_id, int32_t& dimension) { + return std::shared_ptr(new GetGroupTask(group_id, dimension)); +} + +ServerError GetGroupTask::OnExecute() { + try { + dimension_ = 0; + + engine::meta::GroupSchema group_info; + group_info.group_id = group_id_; + engine::Status stat = DB()->get_group(group_info); + if(!stat.ok()) { + SERVER_LOG_ERROR << "Engine failed: " << stat.ToString(); + } else { + dimension_ = (int32_t)group_info.dimension; + } + + } catch (std::exception& ex) { + SERVER_LOG_ERROR << ex.what(); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +DeleteGroupTask::DeleteGroupTask(const std::string& group_id) + : BaseTask(NORMAL_TASK_GROUP), + group_id_(group_id) { + +} + +BaseTaskPtr DeleteGroupTask::Create(const std::string& group_id) { + return std::shared_ptr(new DeleteGroupTask(group_id)); +} + +ServerError DeleteGroupTask::OnExecute() { + try { + + + } catch (std::exception& ex) { + SERVER_LOG_ERROR << ex.what(); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +AddVectorTask::AddVectorTask(const std::string& group_id, + const VecTensorList &tensor_list) + : BaseTask(NORMAL_TASK_GROUP), + group_id_(group_id), + tensor_list_(tensor_list) { + +} + +BaseTaskPtr AddVectorTask::Create(const std::string& group_id, + const VecTensorList &tensor_list) { + return std::shared_ptr(new AddVectorTask(group_id, tensor_list)); +} + +ServerError AddVectorTask::OnExecute() { + try { + std::vector vec_f; + for(const VecTensor& tensor : tensor_list_.tensor_list) { + vec_f.insert(vec_f.begin(), tensor.tensor.begin(), tensor.tensor.end()); + } + + engine::IDNumbers vector_ids; + engine::Status stat = DB()->add_vectors(group_id_, tensor_list_.tensor_list.size(), vec_f.data(), vector_ids); + if(!stat.ok()) { + SERVER_LOG_ERROR << "Engine failed: " << stat.ToString(); + } else { + if(vector_ids.size() != tensor_list_.tensor_list.size()) { + SERVER_LOG_ERROR << "Vector ID not returned"; + } else { + std::string nid_prefix = group_id_ + "_"; + for(size_t i = 0; i < vector_ids.size(); i++) { + std::string nid = nid_prefix + std::to_string(vector_ids[i]); + IVecIdMapper::GetInstance()->Put(nid, tensor_list_.tensor_list[i].uid); + } + } + } + + } catch (std::exception& ex) { + SERVER_LOG_ERROR << ex.what(); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +SearchVectorTask::SearchVectorTask(VecSearchResultList& result, + const std::string& group_id, + const int64_t top_k, + const VecTensorList& tensor_list, + const VecTimeRangeList& time_range_list) + : BaseTask(NORMAL_TASK_GROUP), + result_(result), + group_id_(group_id), + top_k_(top_k), + tensor_list_(tensor_list), + time_range_list_(time_range_list) { + +} + +BaseTaskPtr SearchVectorTask::Create(VecSearchResultList& result, + const std::string& group_id, + const int64_t top_k, + const VecTensorList& tensor_list, + const VecTimeRangeList& time_range_list) { + return std::shared_ptr(new SearchVectorTask(result, group_id, top_k, tensor_list, time_range_list)); +} + +ServerError SearchVectorTask::OnExecute() { + try { + std::vector vec_f; + for(const VecTensor& tensor : tensor_list_.tensor_list) { + vec_f.insert(vec_f.begin(), tensor.tensor.begin(), tensor.tensor.end()); + } + + engine::QueryResults results; + engine::Status stat = DB()->search(group_id_, (size_t)top_k_, tensor_list_.tensor_list.size(), vec_f.data(), results); + if(!stat.ok()) { + SERVER_LOG_ERROR << "Engine failed: " << stat.ToString(); + } else { + for(engine::QueryResult& res : results){ + VecSearchResult v_res; + std::string nid_prefix = group_id_ + "_"; + for(auto id : results[0]) { + std::string sid; + std::string nid = nid_prefix + std::to_string(id); + IVecIdMapper::GetInstance()->Get(nid, sid); + v_res.id_list.push_back(sid); + v_res.distance_list.push_back(0.0);//TODO: return distance + } + + result_.result_list.push_back(v_res); + } + } + + } catch (std::exception& ex) { + SERVER_LOG_ERROR << ex.what(); + } +} + +} +} +} diff --git a/cpp/src/server/VecServiceTask.h b/cpp/src/server/VecServiceTask.h new file mode 100644 index 0000000000..cf1b60bebe --- /dev/null +++ b/cpp/src/server/VecServiceTask.h @@ -0,0 +1,116 @@ +/******************************************************************************* + * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved + * Unauthorized copying of this file, via any medium is strictly prohibited. + * Proprietary and confidential. + ******************************************************************************/ +#pragma once + +#include "VecServiceScheduler.h" +#include "utils/Error.h" +#include "db/Types.h" + +#include "thrift/gen-cpp/VectorService_types.h" + +#include +#include + +namespace zilliz { +namespace vecwise { +namespace server { + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +class AddGroupTask : public BaseTask { +public: + static BaseTaskPtr Create(int32_t dimension, + const std::string& group_id); + +protected: + AddGroupTask(int32_t dimension, + const std::string& group_id); + + ServerError OnExecute() override; + +private: + int32_t dimension_; + std::string group_id_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +class GetGroupTask : public BaseTask { +public: + static BaseTaskPtr Create(const std::string& group_id, int32_t& dimension); + +protected: + GetGroupTask(const std::string& group_id, int32_t& dimension); + + ServerError OnExecute() override; + + +private: + std::string group_id_; + int32_t& dimension_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +class DeleteGroupTask : public BaseTask { +public: + static BaseTaskPtr Create(const std::string& group_id); + +protected: + DeleteGroupTask(const std::string& group_id); + + ServerError OnExecute() override; + + +private: + std::string group_id_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +class AddVectorTask : public BaseTask { +public: + static BaseTaskPtr Create(const std::string& group_id, + const VecTensorList &tensor_list); + +protected: + AddVectorTask(const std::string& group_id, + const VecTensorList &tensor_list); + + ServerError OnExecute() override; + + +private: + std::string group_id_; + const VecTensorList& tensor_list_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +class SearchVectorTask : public BaseTask { +public: + static BaseTaskPtr Create(VecSearchResultList& result, + const std::string& group_id, + const int64_t top_k, + const VecTensorList& tensor_list, + const VecTimeRangeList& time_range_list); + +protected: + SearchVectorTask(VecSearchResultList& result, + const std::string& group_id, + const int64_t top_k, + const VecTensorList& tensor_list, + const VecTimeRangeList& time_range_list); + + ServerError OnExecute() override; + + +private: + VecSearchResultList& result_; + std::string group_id_; + int64_t top_k_; + const VecTensorList& tensor_list_; + const VecTimeRangeList& time_range_list_; +}; + +} +} +} \ No newline at end of file From b03aafdc417e49020e5f6c235ec160b92d04a511 Mon Sep 17 00:00:00 2001 From: groot Date: Thu, 25 Apr 2019 12:17:56 +0800 Subject: [PATCH 2/2] implement scheduler Former-commit-id: 1be5a738138a626ddb4a7412e798c74debbc4c3a --- cpp/build.sh | 3 + cpp/src/server/VecServiceHandler.cpp | 74 +++++------------ cpp/src/server/VecServiceScheduler.cpp | 18 ++++- cpp/src/server/VecServiceScheduler.h | 3 + cpp/src/server/VecServiceTask.cpp | 107 +++++++++++++++++++------ cpp/src/server/VecServiceTask.h | 36 ++++++--- cpp/test_client/src/ClientApp.cpp | 43 +++++++--- 7 files changed, 186 insertions(+), 98 deletions(-) diff --git a/cpp/build.sh b/cpp/build.sh index e10545fa30..6353e59e2d 100755 --- a/cpp/build.sh +++ b/cpp/build.sh @@ -40,8 +40,11 @@ rm -rf ./cmake_build mkdir cmake_build cd cmake_build +CUDA_COMPILER=/usr/local/cuda/bin/nvcc + CMAKE_CMD="cmake -DBUILD_UNIT_TEST=${BUILD_UNITTEST} \ -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ +-DCMAKE_CUDA_COMPILER=${CUDA_COMPILER} \ $@ ../" echo ${CMAKE_CMD} diff --git a/cpp/src/server/VecServiceHandler.cpp b/cpp/src/server/VecServiceHandler.cpp index 798f0de0e1..d2b56ee753 100644 --- a/cpp/src/server/VecServiceHandler.cpp +++ b/cpp/src/server/VecServiceHandler.cpp @@ -9,6 +9,7 @@ #include "VecIdMapper.h" #include "utils/Log.h" #include "utils/CommonUtil.h" +#include "utils/TimeRecorder.h" #include "db/DB.h" #include "db/Env.h" @@ -73,9 +74,7 @@ VecServiceHandler::add_vector(const std::string &group_id, const VecTensor &tens SERVER_LOG_INFO << "add_vector() called"; SERVER_LOG_TRACE << "group_id = " << group_id << ", vector size = " << tensor.tensor.size(); - VecTensorList tensor_list; - tensor_list.tensor_list.push_back(tensor); - BaseTaskPtr task_ptr = AddVectorTask::Create(group_id, tensor_list); + BaseTaskPtr task_ptr = AddSingleVectorTask::Create(group_id, tensor); VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance(); scheduler.ExecuteTask(task_ptr); @@ -88,10 +87,11 @@ VecServiceHandler::add_vector_batch(const std::string &group_id, SERVER_LOG_INFO << "add_vector_batch() called"; SERVER_LOG_TRACE << "group_id = " << group_id << ", vector list size = " << tensor_list.tensor_list.size(); - - BaseTaskPtr task_ptr = AddVectorTask::Create(group_id, tensor_list); + TimeRecorder rc("Add VECTOR BATCH"); + BaseTaskPtr task_ptr = AddBatchVectorTask::Create(group_id, tensor_list); VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance(); scheduler.ExecuteTask(task_ptr); + rc.Elapse("DONE!"); SERVER_LOG_INFO << "add_vector_batch() finished"; } @@ -108,28 +108,17 @@ VecServiceHandler::search_vector(VecSearchResult &_return, << ", vector size = " << tensor.tensor.size() << ", time range list size = " << time_range_list.range_list.size(); - try { - engine::QueryResults results; - std::vector vec_f(tensor.tensor.begin(), tensor.tensor.end()); - engine::Status stat = db_->search(group_id, (size_t)top_k, 1, vec_f.data(), results); - if(!stat.ok()) { - SERVER_LOG_ERROR << "Engine failed: " << stat.ToString(); - } else { - if(!results.empty()) { - std::string nid_prefix = group_id + "_"; - for(auto id : results[0]) { - std::string sid; - std::string nid = nid_prefix + std::to_string(id); - IVecIdMapper::GetInstance()->Get(nid, sid); - _return.id_list.push_back(sid); - _return.distance_list.push_back(0.0);//TODO: return distance - } - } - } + VecTensorList tensor_list; + tensor_list.tensor_list.push_back(tensor); + VecSearchResultList result; + BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, tensor_list, time_range_list, result); + VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance(); + scheduler.ExecuteTask(task_ptr); - - } catch (std::exception& ex) { - SERVER_LOG_ERROR << ex.what(); + if(!result.result_list.empty()) { + _return = result.result_list[0]; + } else { + SERVER_LOG_ERROR << "No search result returned"; } SERVER_LOG_INFO << "search_vector() finished"; @@ -146,36 +135,9 @@ VecServiceHandler::search_vector_batch(VecSearchResultList &_return, << ", vector list size = " << tensor_list.tensor_list.size() << ", time range list size = " << time_range_list.range_list.size(); - try { - std::vector vec_f; - for(const VecTensor& tensor : tensor_list.tensor_list) { - vec_f.insert(vec_f.begin(), tensor.tensor.begin(), tensor.tensor.end()); - } - - engine::QueryResults results; - engine::Status stat = db_->search(group_id, (size_t)top_k, tensor_list.tensor_list.size(), vec_f.data(), results); - if(!stat.ok()) { - SERVER_LOG_ERROR << "Engine failed: " << stat.ToString(); - } else { - for(engine::QueryResult& res : results){ - VecSearchResult v_res; - std::string nid_prefix = group_id + "_"; - for(auto id : results[0]) { - std::string sid; - std::string nid = nid_prefix + std::to_string(id); - IVecIdMapper::GetInstance()->Get(nid, sid); - v_res.id_list.push_back(sid); - v_res.distance_list.push_back(0.0);//TODO: return distance - } - - _return.result_list.push_back(v_res); - } - } - - - } catch (std::exception& ex) { - SERVER_LOG_ERROR << ex.what(); - } + BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, tensor_list, time_range_list, _return); + VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance(); + scheduler.ExecuteTask(task_ptr); SERVER_LOG_INFO << "search_vector_batch() finished"; } diff --git a/cpp/src/server/VecServiceScheduler.cpp b/cpp/src/server/VecServiceScheduler.cpp index 5acadc08a1..5826ff5be2 100644 --- a/cpp/src/server/VecServiceScheduler.cpp +++ b/cpp/src/server/VecServiceScheduler.cpp @@ -25,6 +25,7 @@ BaseTask::~BaseTask() { ServerError BaseTask::Execute() { error_code_ = OnExecute(); done_ = true; + finish_cond_.notify_all(); return error_code_; } @@ -72,7 +73,7 @@ void VecServiceScheduler::Stop() { stopped_ = true; } -ServerError VecServiceScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) { +ServerError VecServiceScheduler::PushTask(const BaseTaskPtr& task_ptr) { if(task_ptr == nullptr) { return SERVER_NULL_POINTER; } @@ -80,6 +81,19 @@ ServerError VecServiceScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) { return PutTaskToQueue(task_ptr); } +ServerError VecServiceScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) { + if(task_ptr == nullptr) { + return SERVER_NULL_POINTER; + } + + ServerError err = PutTaskToQueue(task_ptr); + if(err != SERVER_SUCCESS) { + return err; + } + + return task_ptr->WaitToFinish(); +} + namespace { void TakeTaskToExecute(TaskQueuePtr task_queue) { if(task_queue == nullptr) { @@ -120,6 +134,8 @@ ServerError VecServiceScheduler::PutTaskToQueue(const BaseTaskPtr& task_ptr) { execute_threads_.push_back(thread); SERVER_LOG_INFO << "Create new thread for task group: " << group_name; } + + return SERVER_SUCCESS; } } diff --git a/cpp/src/server/VecServiceScheduler.h b/cpp/src/server/VecServiceScheduler.h index bbbcd151ac..d3fec4cc80 100644 --- a/cpp/src/server/VecServiceScheduler.h +++ b/cpp/src/server/VecServiceScheduler.h @@ -54,6 +54,9 @@ public: void Start(); void Stop(); + //async + ServerError PushTask(const BaseTaskPtr& task_ptr); + //sync ServerError ExecuteTask(const BaseTaskPtr& task_ptr); protected: diff --git a/cpp/src/server/VecServiceTask.cpp b/cpp/src/server/VecServiceTask.cpp index 5e80132fbd..7254a69aa6 100644 --- a/cpp/src/server/VecServiceTask.cpp +++ b/cpp/src/server/VecServiceTask.cpp @@ -8,6 +8,7 @@ #include "VecIdMapper.h" #include "utils/CommonUtil.h" #include "utils/Log.h" +#include "utils/TimeRecorder.h" #include "db/DB.h" #include "db/Env.h" @@ -16,6 +17,7 @@ namespace vecwise { namespace server { static const std::string NORMAL_TASK_GROUP = "normal"; +static const std::string SEARCH_TASK_GROUP = "search"; namespace { class DBWrapper { @@ -128,7 +130,44 @@ ServerError DeleteGroupTask::OnExecute() { } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -AddVectorTask::AddVectorTask(const std::string& group_id, +AddSingleVectorTask::AddSingleVectorTask(const std::string& group_id, + const VecTensor &tensor) + : BaseTask(NORMAL_TASK_GROUP), + group_id_(group_id), + tensor_(tensor) { + +} + +BaseTaskPtr AddSingleVectorTask::Create(const std::string& group_id, + const VecTensor &tensor) { + return std::shared_ptr(new AddSingleVectorTask(group_id, tensor)); +} + +ServerError AddSingleVectorTask::OnExecute() { + try { + engine::IDNumbers vector_ids; + std::vector vec_f(tensor_.tensor.begin(), tensor_.tensor.end()); + engine::Status stat = DB()->add_vectors(group_id_, 1, vec_f.data(), vector_ids); + if(!stat.ok()) { + SERVER_LOG_ERROR << "Engine failed: " << stat.ToString(); + } else { + if(vector_ids.empty()) { + SERVER_LOG_ERROR << "Vector ID not returned"; + } else { + std::string nid = group_id_ + "_" + std::to_string(vector_ids[0]); + IVecIdMapper::GetInstance()->Put(nid, tensor_.uid); + SERVER_LOG_TRACE << "nid = " << vector_ids[0] << ", sid = " << tensor_.uid; + } + } + + } catch (std::exception& ex) { + SERVER_LOG_ERROR << ex.what(); + } +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +AddBatchVectorTask::AddBatchVectorTask(const std::string& group_id, const VecTensorList &tensor_list) : BaseTask(NORMAL_TASK_GROUP), group_id_(group_id), @@ -136,31 +175,50 @@ AddVectorTask::AddVectorTask(const std::string& group_id, } -BaseTaskPtr AddVectorTask::Create(const std::string& group_id, +BaseTaskPtr AddBatchVectorTask::Create(const std::string& group_id, const VecTensorList &tensor_list) { - return std::shared_ptr(new AddVectorTask(group_id, tensor_list)); + return std::shared_ptr(new AddBatchVectorTask(group_id, tensor_list)); } -ServerError AddVectorTask::OnExecute() { +ServerError AddBatchVectorTask::OnExecute() { try { - std::vector vec_f; - for(const VecTensor& tensor : tensor_list_.tensor_list) { - vec_f.insert(vec_f.begin(), tensor.tensor.begin(), tensor.tensor.end()); + TimeRecorder rc("Add vector batch"); + + engine::meta::GroupSchema group_info; + group_info.group_id = group_id_; + engine::Status stat = DB()->get_group(group_info); + if(!stat.ok()) { + SERVER_LOG_ERROR << "Engine failed: " << stat.ToString(); + return SERVER_UNEXPECTED_ERROR; } + std::vector vec_f; + vec_f.reserve(tensor_list_.tensor_list.size()*group_info.dimension*4); + for(const VecTensor& tensor : tensor_list_.tensor_list) { + if(tensor.tensor.size() != group_info.dimension) { + SERVER_LOG_ERROR << "Invalid vector data size: " << tensor.tensor.size() + << " vs. group dimension:" << group_info.dimension; + return SERVER_UNEXPECTED_ERROR; + } + vec_f.insert(vec_f.begin(), tensor.tensor.begin(), tensor.tensor.end()); + } + rc.Record("prepare vectors data"); + engine::IDNumbers vector_ids; - engine::Status stat = DB()->add_vectors(group_id_, tensor_list_.tensor_list.size(), vec_f.data(), vector_ids); + stat = DB()->add_vectors(group_id_, tensor_list_.tensor_list.size(), vec_f.data(), vector_ids); + rc.Record("add vectors to engine"); if(!stat.ok()) { SERVER_LOG_ERROR << "Engine failed: " << stat.ToString(); } else { - if(vector_ids.size() != tensor_list_.tensor_list.size()) { + if(vector_ids.size() < tensor_list_.tensor_list.size()) { SERVER_LOG_ERROR << "Vector ID not returned"; } else { std::string nid_prefix = group_id_ + "_"; - for(size_t i = 0; i < vector_ids.size(); i++) { + for(size_t i = 0; i < tensor_list_.tensor_list.size(); i++) { std::string nid = nid_prefix + std::to_string(vector_ids[i]); IVecIdMapper::GetInstance()->Put(nid, tensor_list_.tensor_list[i].uid); } + rc.Record("build id mapping"); } } @@ -170,26 +228,26 @@ ServerError AddVectorTask::OnExecute() { } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -SearchVectorTask::SearchVectorTask(VecSearchResultList& result, - const std::string& group_id, +SearchVectorTask::SearchVectorTask(const std::string& group_id, const int64_t top_k, const VecTensorList& tensor_list, - const VecTimeRangeList& time_range_list) - : BaseTask(NORMAL_TASK_GROUP), - result_(result), - group_id_(group_id), - top_k_(top_k), - tensor_list_(tensor_list), - time_range_list_(time_range_list) { + const VecTimeRangeList& time_range_list, + VecSearchResultList& result) + : BaseTask(SEARCH_TASK_GROUP), + group_id_(group_id), + top_k_(top_k), + tensor_list_(tensor_list), + time_range_list_(time_range_list), + result_(result) { } -BaseTaskPtr SearchVectorTask::Create(VecSearchResultList& result, - const std::string& group_id, +BaseTaskPtr SearchVectorTask::Create(const std::string& group_id, const int64_t top_k, const VecTensorList& tensor_list, - const VecTimeRangeList& time_range_list) { - return std::shared_ptr(new SearchVectorTask(result, group_id, top_k, tensor_list, time_range_list)); + const VecTimeRangeList& time_range_list, + VecSearchResultList& result) { + return std::shared_ptr(new SearchVectorTask(group_id, top_k, tensor_list, time_range_list, result)); } ServerError SearchVectorTask::OnExecute() { @@ -213,6 +271,9 @@ ServerError SearchVectorTask::OnExecute() { IVecIdMapper::GetInstance()->Get(nid, sid); v_res.id_list.push_back(sid); v_res.distance_list.push_back(0.0);//TODO: return distance + + SERVER_LOG_TRACE << "nid = " << nid << ", string id = " << sid; + } result_.result_list.push_back(v_res); diff --git a/cpp/src/server/VecServiceTask.h b/cpp/src/server/VecServiceTask.h index cf1b60bebe..26be5f6b64 100644 --- a/cpp/src/server/VecServiceTask.h +++ b/cpp/src/server/VecServiceTask.h @@ -67,13 +67,31 @@ private: }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -class AddVectorTask : public BaseTask { +class AddSingleVectorTask : public BaseTask { +public: + static BaseTaskPtr Create(const std::string& group_id, + const VecTensor &tensor); + +protected: + AddSingleVectorTask(const std::string& group_id, + const VecTensor &tensor); + + ServerError OnExecute() override; + + +private: + std::string group_id_; + const VecTensor& tensor_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +class AddBatchVectorTask : public BaseTask { public: static BaseTaskPtr Create(const std::string& group_id, const VecTensorList &tensor_list); protected: - AddVectorTask(const std::string& group_id, + AddBatchVectorTask(const std::string& group_id, const VecTensorList &tensor_list); ServerError OnExecute() override; @@ -87,28 +105,28 @@ private: //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// class SearchVectorTask : public BaseTask { public: - static BaseTaskPtr Create(VecSearchResultList& result, - const std::string& group_id, + static BaseTaskPtr Create(const std::string& group_id, const int64_t top_k, const VecTensorList& tensor_list, - const VecTimeRangeList& time_range_list); + const VecTimeRangeList& time_range_list, + VecSearchResultList& result); protected: - SearchVectorTask(VecSearchResultList& result, - const std::string& group_id, + SearchVectorTask(const std::string& group_id, const int64_t top_k, const VecTensorList& tensor_list, - const VecTimeRangeList& time_range_list); + const VecTimeRangeList& time_range_list, + VecSearchResultList& result); ServerError OnExecute() override; private: - VecSearchResultList& result_; std::string group_id_; int64_t top_k_; const VecTensorList& tensor_list_; const VecTimeRangeList& time_range_list_; + VecSearchResultList& result_; }; } diff --git a/cpp/test_client/src/ClientApp.cpp b/cpp/test_client/src/ClientApp.cpp index f599ce714f..b951908efa 100644 --- a/cpp/test_client/src/ClientApp.cpp +++ b/cpp/test_client/src/ClientApp.cpp @@ -3,6 +3,7 @@ * Unauthorized copying of this file, via any medium is strictly prohibited. * Proprietary and confidential. ******************************************************************************/ +#include #include "ClientApp.h" #include "ClientSession.h" #include "server/ServerConfig.h" @@ -37,21 +38,44 @@ void ClientApp::Run(const std::string &config_file) { group.index_type = 0; session.interface()->add_group(group); - //add vectors - for(int64_t k = 0; k < 10000; k++) { - VecTensor tensor; - for(int32_t i = 0; i < dim; i++) { - tensor.tensor.push_back((double)(i + k)); + const int64_t count = 500; + //add vectors one by one + { + + server::TimeRecorder rc("Add " + std::to_string(count) + " vectors one by one"); + for (int64_t k = 0; k < count; k++) { + VecTensor tensor; + for (int32_t i = 0; i < dim; i++) { + tensor.tensor.push_back((double) (i + k)); + } + tensor.uid = "vec_" + std::to_string(k); + + session.interface()->add_vector(group.id, tensor); + + CLIENT_LOG_INFO << "add vector no." << k; } - tensor.uid = "vec_" + std::to_string(k); + rc.Elapse("done!"); + } - session.interface()->add_vector(group.id, tensor); - - CLIENT_LOG_INFO << "add vector no." << k; + //add vectors in one batch + { + server::TimeRecorder rc("Add " + std::to_string(count) + " vectors in one batch"); + VecTensorList vec_list; + for (int64_t k = 0; k < count; k++) { + VecTensor tensor; + for (int32_t i = 0; i < dim; i++) { + tensor.tensor.push_back((double) (i + k)); + } + tensor.uid = "vec_" + std::to_string(k); + vec_list.tensor_list.push_back(tensor); + } + session.interface()->add_vector_batch(group.id, vec_list); + rc.Elapse("done!"); } //search vector { + server::TimeRecorder rc("Search top_k"); VecTensor tensor; for (int32_t i = 0; i < dim; i++) { tensor.tensor.push_back((double) (i + 100)); @@ -65,6 +89,7 @@ void ClientApp::Run(const std::string &config_file) { for(auto id : res.id_list) { std::cout << id << std::endl; } + rc.Elapse("done!"); } } catch (std::exception& ex) {