From 354f68a96feaa2a31a63f136eef47c9bf5eb3425 Mon Sep 17 00:00:00 2001 From: "xj.lin" Date: Sat, 11 May 2019 18:49:30 +0800 Subject: [PATCH] fix TopK bug Former-commit-id: 76e24617b83d49b2a4808a0cd7406edb849e767e --- cpp/src/db/DBImpl.cpp | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/cpp/src/db/DBImpl.cpp b/cpp/src/db/DBImpl.cpp index 4218ffa953..3df27ba063 100644 --- a/cpp/src/db/DBImpl.cpp +++ b/cpp/src/db/DBImpl.cpp @@ -147,15 +147,25 @@ Status DBImpl::search(const std::string& group_id, size_t k, size_t nq, const int &k, float *output_distence, long *output_ids) -> void { - std::map inverted_table; + std::map> inverted_table; for (int i = 0; i < input_data.size(); ++i) { - inverted_table[input_data[i]] = i; + if (inverted_table.count(input_data[i]) == 1) { + auto& ori_vec = inverted_table[input_data[i]]; + ori_vec.push_back(i); + } + else { + inverted_table[input_data[i]] = std::vector{i}; + } } int count = 0; - for (auto it = inverted_table.begin(); it != inverted_table.end() && count < k; ++it, ++count) { - output_distence[count] = it->first; - output_ids[count] = it->second; + for (auto &item : inverted_table){ + if (count == k) break; + for (auto &id : item.second){ + if (++count == k) break; + output_distence[count] = item.first; + output_ids[count] = id; + } } }; auto cluster_topk = [&]() -> void {