diff --git a/cpp/CHANGELOG.md b/cpp/CHANGELOG.md index e4c6ce04f6..731ff30630 100644 --- a/cpp/CHANGELOG.md +++ b/cpp/CHANGELOG.md @@ -12,7 +12,8 @@ Please mark all change in change log and use the ticket from JIRA. - MS-57 - Implement index load/search pipeline - MS-56 - Add version information when server is started -- Ms-64 - Different table can have different index type +- MS-64 - Different table can have different index type +- MS-52 - Return search score ## Task diff --git a/cpp/src/db/DBImpl.cpp b/cpp/src/db/DBImpl.cpp index dcc397345e..6211c688fb 100644 --- a/cpp/src/db/DBImpl.cpp +++ b/cpp/src/db/DBImpl.cpp @@ -240,7 +240,7 @@ Status DBImpl::QuerySync(const std::string& table_id, size_t k, size_t nq, int inner_k = dis.size() < k ? dis.size() : k; for (int i = 0; i < inner_k; ++i) { - res.emplace_back(nns[output_ids[i]]); // mapping + res.emplace_back(std::make_pair(nns[output_ids[i]], output_distence[i])); // mapping } results.push_back(res); // append to result list res.clear(); @@ -267,6 +267,8 @@ Status DBImpl::QuerySync(const std::string& table_id, size_t k, size_t nq, Status DBImpl::QueryAsync(const std::string& table_id, size_t k, size_t nq, const float* vectors, const meta::DatesT& dates, QueryResults& results) { + + //step 1: get files to search meta::DatePartionedTableFilesSchema files; auto status = pMeta_->FilesToSearch(table_id, dates, files); if (!status.ok()) { return status; } @@ -282,18 +284,15 @@ Status DBImpl::QueryAsync(const std::string& table_id, size_t k, size_t nq, } } + //step 2: put search task to scheduler SearchScheduler& scheduler = SearchScheduler::GetInstance(); scheduler.ScheduleSearchTask(context); context->WaitResult(); + + //step 3: construct results auto& context_result = context->GetResult(); - for(auto& topk_result : context_result) { - QueryResult ids; - for(auto& pair : topk_result) { - ids.push_back(pair.second); - } - results.emplace_back(ids); - } + results.swap(context_result); return Status::OK(); } diff --git a/cpp/src/db/ExecutionEngine.h b/cpp/src/db/ExecutionEngine.h index fe1acd913d..ad4355786f 100644 --- a/cpp/src/db/ExecutionEngine.h +++ b/cpp/src/db/ExecutionEngine.h @@ -32,6 +32,8 @@ public: virtual size_t Size() const = 0; + virtual size_t Dimension() const = 0; + virtual size_t PhysicalSize() const = 0; virtual Status Serialize() = 0; diff --git a/cpp/src/db/FaissExecutionEngine.cpp b/cpp/src/db/FaissExecutionEngine.cpp index a338ddf5cb..b25a3150ed 100644 --- a/cpp/src/db/FaissExecutionEngine.cpp +++ b/cpp/src/db/FaissExecutionEngine.cpp @@ -54,6 +54,10 @@ size_t FaissExecutionEngine::Size() const { return (size_t)(Count() * pIndex_->d)*sizeof(float); } +size_t FaissExecutionEngine::Dimension() const { + return pIndex_->d; +} + size_t FaissExecutionEngine::PhysicalSize() const { return (size_t)(Count() * pIndex_->d)*sizeof(float); } diff --git a/cpp/src/db/FaissExecutionEngine.h b/cpp/src/db/FaissExecutionEngine.h index d1981502b9..e41fe06456 100644 --- a/cpp/src/db/FaissExecutionEngine.h +++ b/cpp/src/db/FaissExecutionEngine.h @@ -38,6 +38,8 @@ public: size_t Size() const override; + size_t Dimension() const override; + size_t PhysicalSize() const override; Status Serialize() override; diff --git a/cpp/src/db/Types.h b/cpp/src/db/Types.h index f9a432fd94..73ecc81fa8 100644 --- a/cpp/src/db/Types.h +++ b/cpp/src/db/Types.h @@ -15,7 +15,7 @@ typedef long IDNumber; typedef IDNumber* IDNumberPtr; typedef std::vector IDNumbers; -typedef std::vector QueryResult; +typedef std::vector> QueryResult; typedef std::vector QueryResults; diff --git a/cpp/src/db/scheduler/SearchContext.h b/cpp/src/db/scheduler/SearchContext.h index ae7327fd68..b212ea34d9 100644 --- a/cpp/src/db/scheduler/SearchContext.h +++ b/cpp/src/db/scheduler/SearchContext.h @@ -31,8 +31,8 @@ public: using Id2IndexMap = std::unordered_map; const Id2IndexMap& GetIndexMap() const { return map_index_files_; } - using Score2IdMap = std::map; - using ResultSet = std::vector; + using Id2ScoreMap = std::vector>; + using ResultSet = std::vector; const ResultSet& GetResult() const { return result_; } ResultSet& GetResult() { return result_; } diff --git a/cpp/src/db/scheduler/SearchTaskQueue.cpp b/cpp/src/db/scheduler/SearchTaskQueue.cpp index 86478477d1..38db5fd7a7 100644 --- a/cpp/src/db/scheduler/SearchTaskQueue.cpp +++ b/cpp/src/db/scheduler/SearchTaskQueue.cpp @@ -19,12 +19,29 @@ void ClusterResult(const std::vector &output_ids, SearchContext::ResultSet &result_set) { result_set.clear(); for (auto i = 0; i < nq; i++) { - SearchContext::Score2IdMap score2id; + SearchContext::Id2ScoreMap id_score; for (auto k = 0; k < topk; k++) { uint64_t index = i * nq + k; - score2id.insert(std::make_pair(output_distence[index], output_ids[index])); + id_score.push_back(std::make_pair(output_ids[index], output_distence[index])); } - result_set.emplace_back(score2id); + result_set.emplace_back(id_score); + } +} + +void MergeResult(SearchContext::Id2ScoreMap &score_src, + SearchContext::Id2ScoreMap &score_target, + uint64_t topk) { + for (auto& pair_src : score_src) { + for (auto iter = score_target.begin(); iter != score_target.end(); ++iter) { + if(pair_src.second > iter->second) { + score_target.insert(iter, pair_src); + } + } + } + + //remove unused items + while (score_target.size() > topk) { + score_target.pop_back(); } } @@ -42,18 +59,39 @@ void TopkResult(SearchContext::ResultSet &result_src, } for (size_t i = 0; i < result_src.size(); i++) { - SearchContext::Score2IdMap &score2id_src = result_src[i]; - SearchContext::Score2IdMap &score2id_target = result_target[i]; - for (auto iter = score2id_src.begin(); iter != score2id_src.end(); ++iter) { - score2id_target.insert(std::make_pair(iter->first, iter->second)); - } - - //remove unused items - while (score2id_target.size() > topk) { - score2id_target.erase(score2id_target.rbegin()->first); - } + SearchContext::Id2ScoreMap &score_src = result_src[i]; + SearchContext::Id2ScoreMap &score_target = result_target[i]; + MergeResult(score_src, score_target, topk); } } + +void CalcScore(uint64_t vector_count, + const float *vectors_data, + uint64_t dimension, + const SearchContext::ResultSet &result_src, + SearchContext::ResultSet &result_target) { + result_target.clear(); + if(result_src.empty()){ + return; + } + + int vec_index = 0; + for(auto& result : result_src) { + const float * vec_data = vectors_data + vec_index*dimension; + double vec_len = 0; + for(uint64_t i = 0; i < dimension; i++) { + vec_len += vec_data[i]*vec_data[i]; + } + vec_index++; + + SearchContext::Id2ScoreMap score_array; + for(auto& pair : result) { + score_array.push_back(std::make_pair(pair.first, (1 - pair.second/vec_len)*100.0)); + } + result_target.emplace_back(score_array); + } +} + } @@ -78,10 +116,12 @@ bool SearchTask::DoSearch() { std::vector output_ids; std::vector output_distence; for(auto& context : search_contexts_) { + //step 1: allocate memory auto inner_k = index_engine_->Count() < context->topk() ? index_engine_->Count() : context->topk(); output_ids.resize(inner_k*context->nq()); output_distence.resize(inner_k*context->nq()); + //step 2: search try { index_engine_->Search(context->nq(), context->vectors(), inner_k, output_distence.data(), output_ids.data()); @@ -93,11 +133,21 @@ bool SearchTask::DoSearch() { rc.Record("do search"); + //step 3: cluster result SearchContext::ResultSet result_set; ClusterResult(output_ids, output_distence, context->nq(), inner_k, result_set); rc.Record("cluster result"); + + //step 4: pick up topk result TopkResult(result_set, inner_k, context->GetResult()); rc.Record("reduce topk"); + + //step 5: calculate score between 0 ~ 100 + CalcScore(context->nq(), context->vectors(), index_engine_->Dimension(), context->GetResult(), result_set); + context->GetResult().swap(result_set); + rc.Record("calculate score"); + + //step 6: notify to send result to client context->IndexSearchDone(index_id_); } diff --git a/cpp/src/server/MegasearchTask.cpp b/cpp/src/server/MegasearchTask.cpp index 2980deb6fa..7c78b10046 100644 --- a/cpp/src/server/MegasearchTask.cpp +++ b/cpp/src/server/MegasearchTask.cpp @@ -400,9 +400,10 @@ ServerError SearchVectorTask::OnExecute() { const auto& record = record_array_[i]; thrift::TopKQueryResult thrift_topk_result; - for(auto id : result) { + for(auto& pair : result) { thrift::QueryResult thrift_result; - thrift_result.__set_id(id); + thrift_result.__set_id(pair.first); + thrift_result.__set_score(pair.second); thrift_topk_result.query_result_arrays.emplace_back(thrift_result); } diff --git a/cpp/unittest/db/db_tests.cpp b/cpp/unittest/db/db_tests.cpp index 459acf9ab7..c903a7b957 100644 --- a/cpp/unittest/db/db_tests.cpp +++ b/cpp/unittest/db/db_tests.cpp @@ -164,11 +164,11 @@ TEST_F(DBTest, DB_TEST) { ASSERT_STATS(stat); for (auto k=0; k