From c4244fdc6db402f3b0244174ae87779a0a18d50c Mon Sep 17 00:00:00 2001 From: "xj.lin" Date: Sat, 11 May 2019 16:52:16 +0800 Subject: [PATCH] fix search stack overflow Former-commit-id: 39801544686f061a63f4c3f1dec11565164a928d --- cpp/src/db/DBImpl.cpp | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/cpp/src/db/DBImpl.cpp b/cpp/src/db/DBImpl.cpp index f977222535..4218ffa953 100644 --- a/cpp/src/db/DBImpl.cpp +++ b/cpp/src/db/DBImpl.cpp @@ -107,7 +107,7 @@ Status DBImpl::search(const std::string& group_id, size_t k, size_t nq, using SearchResult = std::pair, std::vector>; std::vector batchresult(nq); // allocate nq cells. - auto cluster = [&](long *nns, float *dis) -> void { + auto cluster = [&](long *nns, float *dis, const int& k) -> void { for (int i = 0; i < nq; ++i) { auto f_begin = batchresult[i].first.cbegin(); auto s_begin = batchresult[i].second.cbegin(); @@ -134,8 +134,10 @@ Status DBImpl::search(const std::string& group_id, size_t k, size_t nq, search_set_size += file_size; LOG(DEBUG) << "Search file_type " << file.file_type << " Of Size: " << file_size << " M"; - index.Search(nq, vectors, k, output_distence, output_ids); - cluster(output_ids, output_distence); // cluster to each query + + int inner_k = index.Count() < k ? index.Count() : k; + index.Search(nq, vectors, inner_k, output_distence, output_ids); + cluster(output_ids, output_distence, inner_k); // cluster to each query memset(output_distence, 0, k * nq * sizeof(float)); memset(output_ids, 0, k * nq * sizeof(long)); } @@ -161,8 +163,11 @@ Status DBImpl::search(const std::string& group_id, size_t k, size_t nq, for (auto &result_pair : batchresult) { auto &dis = result_pair.second; auto &nns = result_pair.first; + topk_cpu(dis, k, output_distence, output_ids); - for (int i = 0; i < k; ++i) { + + int inner_k = dis.size() < k ? dis.size() : k; + for (int i = 0; i < inner_k; ++i) { res.emplace_back(nns[output_ids[i]]); // mapping } results.push_back(res); // append to result list