diff --git a/cpp/src/server/VecServiceHandler.cpp b/cpp/src/server/VecServiceHandler.cpp index 58deb88079..79ffa56b9a 100644 --- a/cpp/src/server/VecServiceHandler.cpp +++ b/cpp/src/server/VecServiceHandler.cpp @@ -126,13 +126,13 @@ VecServiceHandler::search_vector(VecSearchResult &_return, const VecTimeRangeList &time_range_list) { TimeRecordWrapper rc("search_vector()"); SERVER_LOG_TRACE << "group_id = " << group_id << ", top_k = " << top_k - << ", vector size = " << tensor.tensor.size() + << ", vector dimension = " << tensor.tensor.size() << ", time range list size = " << time_range_list.range_list.size(); VecTensorList tensor_list; tensor_list.tensor_list.push_back(tensor); VecSearchResultList result; - BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, tensor_list, time_range_list, result); + BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, &tensor_list, time_range_list, result); VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance(); scheduler.ExecuteTask(task_ptr); @@ -154,7 +154,48 @@ VecServiceHandler::search_vector_batch(VecSearchResultList &_return, << ", vector list size = " << tensor_list.tensor_list.size() << ", time range list size = " << time_range_list.range_list.size(); - BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, tensor_list, time_range_list, _return); + BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, &tensor_list, time_range_list, _return); + VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance(); + scheduler.ExecuteTask(task_ptr); +} + +void +VecServiceHandler::search_binary_vector(VecSearchResult& _return, + const std::string& group_id, + const int64_t top_k, + const VecBinaryTensor& tensor, + const VecTimeRangeList& time_range_list) { + TimeRecordWrapper rc("search_binary_vector()"); + SERVER_LOG_TRACE << "group_id = " << group_id << ", top_k = " << top_k + << ", vector dimension = " << tensor.tensor.size() + << ", time range list size = " << time_range_list.range_list.size(); + + VecBinaryTensorList tensor_list; + tensor_list.tensor_list.push_back(tensor); + VecSearchResultList result; + BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, &tensor_list, time_range_list, result); + VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance(); + scheduler.ExecuteTask(task_ptr); + + if(!result.result_list.empty()) { + _return = result.result_list[0]; + } else { + SERVER_LOG_ERROR << "No search result returned"; + } +} + +void +VecServiceHandler::search_binary_vector_batch(VecSearchResultList& _return, + const std::string& group_id, + const int64_t top_k, + const VecBinaryTensorList& tensor_list, + const VecTimeRangeList& time_range_list) { + TimeRecordWrapper rc("search_binary_vector_batch()"); + SERVER_LOG_TRACE << "group_id = " << group_id << ", top_k = " << top_k + << ", vector list size = " << tensor_list.tensor_list.size() + << ", time range list size = " << time_range_list.range_list.size(); + + BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, &tensor_list, time_range_list, _return); VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance(); scheduler.ExecuteTask(task_ptr); } diff --git a/cpp/src/server/VecServiceHandler.h b/cpp/src/server/VecServiceHandler.h index c818f694a5..b7eea5c017 100644 --- a/cpp/src/server/VecServiceHandler.h +++ b/cpp/src/server/VecServiceHandler.h @@ -67,6 +67,9 @@ public: void search_vector_batch(VecSearchResultList& _return, const std::string& group_id, const int64_t top_k, const VecTensorList& tensor_list, const VecTimeRangeList& time_range_list); + void search_binary_vector(VecSearchResult& _return, const std::string& group_id, const int64_t top_k, const VecBinaryTensor& tensor, const VecTimeRangeList& time_range_list); + + void search_binary_vector_batch(VecSearchResultList& _return, const std::string& group_id, const int64_t top_k, const VecBinaryTensorList& tensor_list, const VecTimeRangeList& time_range_list); }; diff --git a/cpp/src/server/VecServiceTask.cpp b/cpp/src/server/VecServiceTask.cpp index f5211d6c9b..201451b94f 100644 --- a/cpp/src/server/VecServiceTask.cpp +++ b/cpp/src/server/VecServiceTask.cpp @@ -301,7 +301,7 @@ uint64_t AddBatchVectorTask::GetVecDimension(uint64_t index) const { if(index >= bin_tensor_list_->tensor_list.size()){ return 0; } - return (uint64_t) bin_tensor_list_->tensor_list[index].tensor.size(); + return (uint64_t) bin_tensor_list_->tensor_list[index].tensor.size()/8; } else { return 0; } @@ -341,8 +341,6 @@ std::string AddBatchVectorTask::GetVecID(uint64_t index) const { ServerError AddBatchVectorTask::OnExecute() { try { - TimeRecorder rc("Add vector batch"); - engine::meta::GroupSchema group_info; group_info.group_id = group_id_; engine::Status stat = DB()->get_group(group_info); @@ -351,6 +349,7 @@ ServerError AddBatchVectorTask::OnExecute() { return SERVER_UNEXPECTED_ERROR; } + TimeRecorder rc("Add vector batch"); uint64_t group_dim = group_info.dimension; uint64_t vec_count = GetVecListCount(); std::vector vec_f; @@ -402,13 +401,29 @@ ServerError AddBatchVectorTask::OnExecute() { //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// SearchVectorTask::SearchVectorTask(const std::string& group_id, const int64_t top_k, - const VecTensorList& tensor_list, + const VecTensorList* tensor_list, const VecTimeRangeList& time_range_list, VecSearchResultList& result) : BaseTask(DQL_TASK_GROUP), group_id_(group_id), top_k_(top_k), tensor_list_(tensor_list), + bin_tensor_list_(nullptr), + time_range_list_(time_range_list), + result_(result) { + +} + +SearchVectorTask::SearchVectorTask(const std::string& group_id, + const int64_t top_k, + const VecBinaryTensorList* bin_tensor_list, + const VecTimeRangeList& time_range_list, + VecSearchResultList& result) + : BaseTask(DQL_TASK_GROUP), + group_id_(group_id), + top_k_(top_k), + tensor_list_(nullptr), + bin_tensor_list_(bin_tensor_list), time_range_list_(time_range_list), result_(result) { @@ -416,21 +431,101 @@ SearchVectorTask::SearchVectorTask(const std::string& group_id, BaseTaskPtr SearchVectorTask::Create(const std::string& group_id, const int64_t top_k, - const VecTensorList& tensor_list, + const VecTensorList* tensor_list, const VecTimeRangeList& time_range_list, VecSearchResultList& result) { return std::shared_ptr(new SearchVectorTask(group_id, top_k, tensor_list, time_range_list, result)); } +BaseTaskPtr SearchVectorTask::Create(const std::string& group_id, + const int64_t top_k, + const VecBinaryTensorList* bin_tensor_list, + const VecTimeRangeList& time_range_list, + VecSearchResultList& result) { + return std::shared_ptr(new SearchVectorTask(group_id, top_k, bin_tensor_list, time_range_list, result)); +} + + +ServerError SearchVectorTask::GetTargetData(std::vector& data) const { + if(tensor_list_ && !tensor_list_->tensor_list.empty()) { + uint64_t count = tensor_list_->tensor_list.size(); + uint64_t dim = tensor_list_->tensor_list[0].tensor.size(); + data.resize(count*dim); + for(size_t i = 0; i < count; i++) { + if(tensor_list_->tensor_list[i].tensor.size() != dim) { + SERVER_LOG_ERROR << "Invalid vector dimension: " << tensor_list_->tensor_list[i].tensor.size(); + return SERVER_INVALID_ARGUMENT; + } + const double* d_p = tensor_list_->tensor_list[i].tensor.data(); + for(int64_t k = 0; k < dim; k++) { + data[i*dim + k] = (float)(d_p[k]); + } + } + } else if(bin_tensor_list_ && !bin_tensor_list_->tensor_list.empty()) { + uint64_t count = bin_tensor_list_->tensor_list.size(); + uint64_t dim = bin_tensor_list_->tensor_list[0].tensor.size()/8; + data.resize(count*dim); + for(size_t i = 0; i < count; i++) { + if(bin_tensor_list_->tensor_list[i].tensor.size()/8 != dim) { + SERVER_LOG_ERROR << "Invalid vector dimension: " << bin_tensor_list_->tensor_list[i].tensor.size()/8; + return SERVER_INVALID_ARGUMENT; + } + const double* d_p = (const double*)(bin_tensor_list_->tensor_list[i].tensor.data()); + for(int64_t k = 0; k < dim; k++) { + data[i*dim + k] = (float)(d_p[k]); + } + } + } + + return SERVER_SUCCESS; +} + +uint64_t SearchVectorTask::GetTargetDimension() const { + if(tensor_list_ && !tensor_list_->tensor_list.empty()) { + return tensor_list_->tensor_list[0].tensor.size(); + } else if(bin_tensor_list_ && !bin_tensor_list_->tensor_list.empty()) { + return bin_tensor_list_->tensor_list[0].tensor.size()/8; + } + + return 0; +} + +uint64_t SearchVectorTask::GetTargetCount() const { + if(tensor_list_) { + return tensor_list_->tensor_list.size(); + } else if(bin_tensor_list_) { + return bin_tensor_list_->tensor_list.size(); + } +} + ServerError SearchVectorTask::OnExecute() { try { - std::vector vec_f; - for(const VecTensor& tensor : tensor_list_.tensor_list) { - vec_f.insert(vec_f.begin(), tensor.tensor.begin(), tensor.tensor.end()); + engine::meta::GroupSchema group_info; + group_info.group_id = group_id_; + engine::Status stat = DB()->get_group(group_info); + if(!stat.ok()) { + SERVER_LOG_ERROR << "Engine failed: " << stat.ToString(); + return SERVER_UNEXPECTED_ERROR; } + uint64_t vec_dim = GetTargetDimension(); + if(vec_dim != group_info.dimension) { + SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim + << " vs. group dimension:" << group_info.dimension; + return SERVER_INVALID_ARGUMENT; + } + + TimeRecorder rc("Search vector"); + std::vector vec_f; + ServerError err = GetTargetData(vec_f); + if(err != SERVER_SUCCESS) { + return err; + } + + uint64_t vec_count = GetTargetCount(); + engine::QueryResults results; - engine::Status stat = DB()->search(group_id_, (size_t)top_k_, tensor_list_.tensor_list.size(), vec_f.data(), results); + stat = DB()->search(group_id_, (size_t)top_k_, vec_count, vec_f.data(), results); if(!stat.ok()) { SERVER_LOG_ERROR << "Engine failed: " << stat.ToString(); return SERVER_UNEXPECTED_ERROR; @@ -438,7 +533,7 @@ ServerError SearchVectorTask::OnExecute() { for(engine::QueryResult& res : results){ VecSearchResult v_res; std::string nid_prefix = group_id_ + "_"; - for(auto id : results[0]) { + for(auto id : res) { std::string sid; std::string nid = nid_prefix + std::to_string(id); IVecIdMapper::GetInstance()->Get(nid, sid); diff --git a/cpp/src/server/VecServiceTask.h b/cpp/src/server/VecServiceTask.h index e1cca36d05..b5d91e3aa7 100644 --- a/cpp/src/server/VecServiceTask.h +++ b/cpp/src/server/VecServiceTask.h @@ -129,24 +129,40 @@ class SearchVectorTask : public BaseTask { public: static BaseTaskPtr Create(const std::string& group_id, const int64_t top_k, - const VecTensorList& tensor_list, + const VecTensorList* tensor_list, + const VecTimeRangeList& time_range_list, + VecSearchResultList& result); + + static BaseTaskPtr Create(const std::string& group_id, + const int64_t top_k, + const VecBinaryTensorList* bin_tensor_list, const VecTimeRangeList& time_range_list, VecSearchResultList& result); protected: SearchVectorTask(const std::string& group_id, const int64_t top_k, - const VecTensorList& tensor_list, + const VecTensorList* tensor_list, const VecTimeRangeList& time_range_list, VecSearchResultList& result); - ServerError OnExecute() override; + SearchVectorTask(const std::string& group_id, + const int64_t top_k, + const VecBinaryTensorList* bin_tensor_list, + const VecTimeRangeList& time_range_list, + VecSearchResultList& result); + ServerError GetTargetData(std::vector& data) const; + uint64_t GetTargetDimension() const; + uint64_t GetTargetCount() const; + + ServerError OnExecute() override; private: std::string group_id_; int64_t top_k_; - const VecTensorList& tensor_list_; + const VecTensorList* tensor_list_; + const VecBinaryTensorList* bin_tensor_list_; const VecTimeRangeList& time_range_list_; VecSearchResultList& result_; }; diff --git a/cpp/test_client/src/ClientApp.cpp b/cpp/test_client/src/ClientApp.cpp index 58dbba110f..9d0230b510 100644 --- a/cpp/test_client/src/ClientApp.cpp +++ b/cpp/test_client/src/ClientApp.cpp @@ -57,7 +57,7 @@ void ClientApp::Run(const std::string &config_file) { session.interface()->add_group(group); //prepare data - const int64_t count = 100000; + const int64_t count = 10000; VecTensorList tensor_list; VecBinaryTensorList bin_tensor_list; for (int64_t k = 0; k < count; k++) { @@ -79,36 +79,36 @@ void ClientApp::Run(const std::string &config_file) { bin_tensor_list.tensor_list.emplace_back(bin_tensor); } -// //add vectors one by one -// { -// server::TimeRecorder rc("Add " + std::to_string(count) + " vectors one by one"); -// for (int64_t k = 0; k < count; k++) { -// session.interface()->add_vector(group.id, tensor_list.tensor_list[k]); -// if(k%1000 == 0) { -// CLIENT_LOG_INFO << "add normal vector no." << k; -// } -// } -// rc.Elapse("done!"); -// } -// -// //add vectors in one batch -// { -// server::TimeRecorder rc("Add " + std::to_string(count) + " vectors in one batch"); -// session.interface()->add_vector_batch(group.id, tensor_list); -// rc.Elapse("done!"); -// } -// -// //add binary vectors one by one -// { -// server::TimeRecorder rc("Add " + std::to_string(count) + " binary vectors one by one"); -// for (int64_t k = 0; k < count; k++) { -// session.interface()->add_binary_vector(group.id, bin_tensor_list.tensor_list[k]); -// if(k%1000 == 0) { -// CLIENT_LOG_INFO << "add binary vector no." << k; -// } -// } -// rc.Elapse("done!"); -// } + //add vectors one by one + { + server::TimeRecorder rc("Add " + std::to_string(count) + " vectors one by one"); + for (int64_t k = 0; k < count; k++) { + session.interface()->add_vector(group.id, tensor_list.tensor_list[k]); + if(k%1000 == 0) { + CLIENT_LOG_INFO << "add normal vector no." << k; + } + } + rc.Elapse("done!"); + } + + //add vectors in one batch + { + server::TimeRecorder rc("Add " + std::to_string(count) + " vectors in one batch"); + session.interface()->add_vector_batch(group.id, tensor_list); + rc.Elapse("done!"); + } + + //add binary vectors one by one + { + server::TimeRecorder rc("Add " + std::to_string(count) + " binary vectors one by one"); + for (int64_t k = 0; k < count; k++) { + session.interface()->add_binary_vector(group.id, bin_tensor_list.tensor_list[k]); + if(k%1000 == 0) { + CLIENT_LOG_INFO << "add binary vector no." << k; + } + } + rc.Elapse("done!"); + } //add binary vectors in one batch { @@ -134,11 +134,41 @@ void ClientApp::Run(const std::string &config_file) { std::cout << "Search result: " << std::endl; for(auto id : res.id_list) { - std::cout << id << std::endl; + std::cout << "\t" << id << std::endl; } rc.Elapse("done!"); } + //search binary vector + { + server::TimeRecorder rc("Search binary batch top_k"); + VecBinaryTensorList tensor_list; + for(int32_t k = 350; k < 360; k++) { + VecBinaryTensor bin_tensor; + bin_tensor.tensor.resize(dim * sizeof(double)); + double* d_p = new double[dim]; + for (int32_t i = 0; i < dim; i++) { + d_p[i] = (double)(i + k); + } + memcpy(const_cast(bin_tensor.tensor.data()), d_p, dim * sizeof(double)); + tensor_list.tensor_list.emplace_back(bin_tensor); + } + + VecSearchResultList res; + VecTimeRangeList range; + session.interface()->search_binary_vector_batch(res, group.id, 5, tensor_list, range); + + std::cout << "Search binary batch result: " << std::endl; + for(size_t i = 0 ; i < res.result_list.size(); i++) { + std::cout << "No " << i << ":" << std::endl; + for(auto id : res.result_list[i].id_list) { + std::cout << "\t" << id << std::endl; + } + } + + rc.Elapse("done!"); + } + } catch (std::exception& ex) { CLIENT_LOG_ERROR << "request encounter exception: " << ex.what(); }