From 759a951e4bb4f86b78d46e2c76744fdef05f759f Mon Sep 17 00:00:00 2001 From: "xiaojun.lin" Date: Mon, 25 Nov 2019 15:23:50 +0800 Subject: [PATCH] fix --- core/build.sh | 2 +- .../knowhere/index/vector_index/nsg/NSG.cpp | 60 +++++++++++-------- 2 files changed, 36 insertions(+), 26 deletions(-) diff --git a/core/build.sh b/core/build.sh index 3afb5d1b37..badcc5a032 100755 --- a/core/build.sh +++ b/core/build.sh @@ -56,7 +56,7 @@ while getopts "p:d:t:f:ulrcgjhxzme" arg; do USE_JFROG_CACHE="ON" ;; x) - CUSTOMIZATION="OFF" # force use ori faiss + CUSTOMIZATION="ON" ;; g) GPU_VERSION="ON" diff --git a/core/src/index/knowhere/knowhere/index/vector_index/nsg/NSG.cpp b/core/src/index/knowhere/knowhere/index/vector_index/nsg/NSG.cpp index bdf538c204..002e160562 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/nsg/NSG.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/nsg/NSG.cpp @@ -718,45 +718,55 @@ NsgIndex::Search(const float* query, const unsigned& nq, const unsigned& dim, co int64_t* ids, SearchParams& params) { std::vector> resset(nq); + params.search_length = k; TimeRecorder rc("search"); - if (nq == 1) { + // TODO(linxj): when to use openmp + if (nq <= 4) { GetNeighbors(query, resset[0], nsg, ¶ms); } else { -//#pragma omp parallel for schedule(dynamic, 50) #pragma omp parallel for for (unsigned int i = 0; i < nq; ++i) { - // TODO(linxj): when to use openmp auto single_query = query + i * dim; GetNeighbors(single_query, resset[i], nsg, ¶ms); } } - rc.ElapseFromBegin("cost"); - + rc.RecordSection("cost"); for (unsigned int i = 0; i < nq; ++i) { - for (unsigned int j = 0; j < k; ++j) { - // ids[i * k + j] = resset[i][j].id; - - // Fix(linxj): bug, reset[i][j] out of range - ids[i * k + j] = ids_[resset[i][j].id]; - dist[i * k + j] = resset[i][j].distance; + int64_t var = resset[i].size() - k; + if (var >= 0) { + for (unsigned int j = 0; j < k; ++j) { + ids[i * k + j] = ids_[resset[i][j].id]; + dist[i * k + j] = resset[i][j].distance; + } + } + else { + for (unsigned int j = 0; j < resset[i].size(); ++j) { + ids[i * k + j] = ids_[resset[i][j].id]; + dist[i * k + j] = resset[i][j].distance; + } + for (unsigned int j = resset[i].size(); j < k; ++j) { + ids[i * k + j] = -1; + dist[i * k + j] = -1; + } } } + rc.RecordSection("merge"); - //>> Debug: test single insert - // int x_0 = resset[0].size(); - // for (int l = 0; l < resset[0].size(); ++l) { - // resset[0].pop_back(); - //} - // resset.clear(); +//>> Debug: test single insert +// int x_0 = resset[0].size(); +// for (int l = 0; l < resset[0].size(); ++l) { +// resset[0].pop_back(); +//} +// resset.clear(); - // ProfilerStart("xx.prof"); - // std::vector resset; - // GetNeighbors(query, resset, nsg, ¶ms); - // for (int i = 0; i < k; ++i) { - // ids[i] = resset[i].id; - // dist[i] = resset[i].distance; - //} - // ProfilerStop(); +// ProfilerStart("xx.prof"); +// std::vector resset; +// GetNeighbors(query, resset, nsg, ¶ms); +// for (int i = 0; i < k; ++i) { +// ids[i] = resset[i].id; +// dist[i] = resset[i].distance; +//} +// ProfilerStop(); } void