mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
Update NSG (#1744)
* enable IP and fix crash Signed-off-by: Nicky <nicky.xj.lin@gmail.com> * update. Signed-off-by: xiaojun.lin <xiaojun.lin@zilliz.com> * lint pass Signed-off-by: xiaojun.lin <xiaojun.lin@zilliz.com>
This commit is contained in:
parent
e865e9c893
commit
cda57cf77f
@ -13,6 +13,8 @@ Please mark all change in change log and use the issue from GitHub
|
||||
- \#1663 PQ index parameter 'm' validation
|
||||
- \#1686 API search_in_files cannot work correctly when vectors is stored in certain non-default partition
|
||||
- \#1689 Fix SQ8H search fail on SIFT-1B dataset
|
||||
- \#1667 Create index failed with type: rnsg if metric_type is IP
|
||||
- \#1708 NSG search crashed
|
||||
- \#1724 Remove unused unittests
|
||||
- \#1734 Opentracing for combined search request
|
||||
|
||||
|
||||
@ -211,7 +211,7 @@ NSGConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
|
||||
static int64_t MAX_OUT_DEGREE = 300;
|
||||
static int64_t MIN_CANDIDATE_POOL_SIZE = 50;
|
||||
static int64_t MAX_CANDIDATE_POOL_SIZE = 1000;
|
||||
static std::vector<std::string> METRICS{knowhere::Metric::L2};
|
||||
static std::vector<std::string> METRICS{knowhere::Metric::L2, knowhere::Metric::IP};
|
||||
|
||||
CheckStrByValues(knowhere::Metric::TYPE, METRICS);
|
||||
CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS);
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <fiu-local.h>
|
||||
#include <string>
|
||||
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/common/Timer.h"
|
||||
@ -139,7 +140,7 @@ NSG::Train(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
auto p_ids = dataset_ptr->Get<const int64_t*>(meta::IDS);
|
||||
|
||||
GETTENSOR(dataset_ptr)
|
||||
index_ = std::make_shared<impl::NsgIndex>(dim, rows);
|
||||
index_ = std::make_shared<impl::NsgIndex>(dim, rows, config[Metric::TYPE].get<std::string>());
|
||||
index_->SetKnnGraph(knng);
|
||||
index_->Build_with_ids(rows, (float*)p_data, (int64_t*)p_ids, b_params);
|
||||
}
|
||||
|
||||
@ -9,6 +9,8 @@
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include "knowhere/index/vector_index/impl/nsg/NSG.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
@ -20,7 +22,6 @@
|
||||
#include "knowhere/common/Exception.h"
|
||||
#include "knowhere/common/Log.h"
|
||||
#include "knowhere/common/Timer.h"
|
||||
#include "knowhere/index/vector_index/impl/nsg/NSG.h"
|
||||
#include "knowhere/index/vector_index/impl/nsg/NSGHelper.h"
|
||||
|
||||
namespace milvus {
|
||||
@ -31,14 +32,11 @@ unsigned int seed = 100;
|
||||
|
||||
NsgIndex::NsgIndex(const size_t& dimension, const size_t& n, std::string metric)
|
||||
: dimension(dimension), ntotal(n), metric_type(metric) {
|
||||
// switch (metric) {
|
||||
// case METRICTYPE::L2:
|
||||
// break;
|
||||
// case METRICTYPE::IP:
|
||||
// distance_ = new DistanceIP;
|
||||
// break;
|
||||
// }
|
||||
distance_ = new DistanceL2;
|
||||
if (metric == knowhere::Metric::L2) {
|
||||
distance_ = new DistanceL2;
|
||||
} else if (metric == knowhere::Metric::IP) {
|
||||
distance_ = new DistanceIP;
|
||||
}
|
||||
}
|
||||
|
||||
NsgIndex::~NsgIndex() {
|
||||
@ -697,139 +695,176 @@ NsgIndex::FindUnconnectedNode(boost::dynamic_bitset<>& has_linked, int64_t& root
|
||||
nsg[root].push_back(id);
|
||||
}
|
||||
|
||||
void
|
||||
NsgIndex::GetNeighbors(const float* query, node_t* I, float* D, SearchParams* params) {
|
||||
size_t buffer_size = params ? params->search_length : search_length;
|
||||
// void
|
||||
// NsgIndex::GetNeighbors(const float* query, node_t* I, float* D, SearchParams* params) {
|
||||
// size_t buffer_size = params ? params->search_length : search_length;
|
||||
|
||||
if (buffer_size > ntotal) {
|
||||
KNOWHERE_THROW_MSG("Search Error, search_length > ntotal");
|
||||
}
|
||||
// if (buffer_size > ntotal) {
|
||||
// KNOWHERE_THROW_MSG("Search Error, search_length > ntotal");
|
||||
// }
|
||||
|
||||
std::vector<Neighbor> resset(buffer_size);
|
||||
std::vector<node_t> init_ids(buffer_size);
|
||||
boost::dynamic_bitset<> has_calculated_dist{ntotal, 0};
|
||||
// std::vector<Neighbor> resset(buffer_size);
|
||||
// std::vector<node_t> init_ids(buffer_size);
|
||||
// boost::dynamic_bitset<> has_calculated_dist{ntotal, 0};
|
||||
|
||||
{
|
||||
/*
|
||||
* copy navigation-point neighbor, pick random node if less than buffer size
|
||||
*/
|
||||
size_t count = 0;
|
||||
// {
|
||||
// /*
|
||||
// * copy navigation-point neighbor, pick random node if less than buffer size
|
||||
// */
|
||||
// size_t count = 0;
|
||||
|
||||
// Get all neighbors
|
||||
for (size_t i = 0; i < init_ids.size() && i < nsg[navigation_point].size(); ++i) {
|
||||
init_ids[i] = nsg[navigation_point][i];
|
||||
has_calculated_dist[init_ids[i]] = true;
|
||||
++count;
|
||||
}
|
||||
while (count < buffer_size) {
|
||||
node_t id = rand_r(&seed) % ntotal;
|
||||
if (has_calculated_dist[id])
|
||||
continue; // duplicate id
|
||||
init_ids[count] = id;
|
||||
++count;
|
||||
has_calculated_dist[id] = true;
|
||||
}
|
||||
}
|
||||
// // Get all neighbors
|
||||
// for (size_t i = 0; i < init_ids.size() && i < nsg[navigation_point].size(); ++i) {
|
||||
// init_ids[i] = nsg[navigation_point][i];
|
||||
// has_calculated_dist[init_ids[i]] = true;
|
||||
// ++count;
|
||||
// }
|
||||
// while (count < buffer_size) {
|
||||
// node_t id = rand_r(&seed) % ntotal;
|
||||
// if (has_calculated_dist[id])
|
||||
// continue; // duplicate id
|
||||
// init_ids[count] = id;
|
||||
// ++count;
|
||||
// has_calculated_dist[id] = true;
|
||||
// }
|
||||
// }
|
||||
|
||||
{
|
||||
// init resset and sort by distance
|
||||
for (size_t i = 0; i < init_ids.size(); ++i) {
|
||||
node_t id = init_ids[i];
|
||||
// {
|
||||
// // init resset and sort by distance
|
||||
// for (size_t i = 0; i < init_ids.size(); ++i) {
|
||||
// node_t id = init_ids[i];
|
||||
|
||||
if (id >= static_cast<node_t>(ntotal)) {
|
||||
KNOWHERE_THROW_MSG("Search Error, id > ntotal");
|
||||
}
|
||||
// if (id >= static_cast<node_t>(ntotal)) {
|
||||
// KNOWHERE_THROW_MSG("Search Error, id > ntotal");
|
||||
// }
|
||||
|
||||
float dist = distance_->Compare(ori_data_ + id * dimension, query, dimension);
|
||||
resset[i] = Neighbor(id, dist, false);
|
||||
}
|
||||
std::sort(resset.begin(), resset.end()); // sort by distance
|
||||
// float dist = distance_->Compare(ori_data_ + id * dimension, query, dimension);
|
||||
// resset[i] = Neighbor(id, dist, false);
|
||||
// }
|
||||
// std::sort(resset.begin(), resset.end()); // sort by distance
|
||||
|
||||
// search nearest neighbor
|
||||
size_t cursor = 0;
|
||||
while (cursor < buffer_size) {
|
||||
size_t nearest_updated_pos = buffer_size;
|
||||
// // search nearest neighbor
|
||||
// size_t cursor = 0;
|
||||
// while (cursor < buffer_size) {
|
||||
// size_t nearest_updated_pos = buffer_size;
|
||||
|
||||
if (!resset[cursor].has_explored) {
|
||||
resset[cursor].has_explored = true;
|
||||
// if (!resset[cursor].has_explored) {
|
||||
// resset[cursor].has_explored = true;
|
||||
|
||||
node_t start_pos = resset[cursor].id;
|
||||
auto& wait_for_search_node_vec = nsg[start_pos];
|
||||
for (size_t i = 0; i < wait_for_search_node_vec.size(); ++i) {
|
||||
node_t id = wait_for_search_node_vec[i];
|
||||
if (has_calculated_dist[id])
|
||||
continue;
|
||||
has_calculated_dist[id] = true;
|
||||
// node_t start_pos = resset[cursor].id;
|
||||
// auto& wait_for_search_node_vec = nsg[start_pos];
|
||||
// for (size_t i = 0; i < wait_for_search_node_vec.size(); ++i) {
|
||||
// node_t id = wait_for_search_node_vec[i];
|
||||
// if (has_calculated_dist[id])
|
||||
// continue;
|
||||
// has_calculated_dist[id] = true;
|
||||
|
||||
float dist = distance_->Compare(query, ori_data_ + dimension * id, dimension);
|
||||
// float dist = distance_->Compare(query, ori_data_ + dimension * id, dimension);
|
||||
|
||||
if (dist >= resset[buffer_size - 1].distance)
|
||||
continue;
|
||||
// if (dist >= resset[buffer_size - 1].distance)
|
||||
// continue;
|
||||
|
||||
//// difference from other GetNeighbors
|
||||
Neighbor nn(id, dist, false);
|
||||
///////////////////////////////////////
|
||||
// //// difference from other GetNeighbors
|
||||
// Neighbor nn(id, dist, false);
|
||||
// ///////////////////////////////////////
|
||||
|
||||
size_t pos = InsertIntoPool(resset.data(), buffer_size, nn); // replace with a closer node
|
||||
if (pos < nearest_updated_pos)
|
||||
nearest_updated_pos = pos;
|
||||
// size_t pos = InsertIntoPool(resset.data(), buffer_size, nn); // replace with a closer node
|
||||
// if (pos < nearest_updated_pos)
|
||||
// nearest_updated_pos = pos;
|
||||
|
||||
//>> Debug code
|
||||
/////
|
||||
// std::cout << "pos: " << pos << ", nn: " << nn.id << ":" << nn.distance << ", nup: " <<
|
||||
// nearest_updated_pos << std::endl;
|
||||
/////
|
||||
// //>> Debug code
|
||||
// /////
|
||||
// // std::cout << "pos: " << pos << ", nn: " << nn.id << ":" << nn.distance << ", nup: " <<
|
||||
// // nearest_updated_pos << std::endl;
|
||||
// /////
|
||||
|
||||
// trick: avoid search query search_length < init_ids.size() ...
|
||||
if (buffer_size + 1 < resset.size())
|
||||
++buffer_size;
|
||||
}
|
||||
}
|
||||
if (cursor >= nearest_updated_pos) {
|
||||
cursor = nearest_updated_pos; // re-search from new pos
|
||||
} else {
|
||||
++cursor;
|
||||
}
|
||||
}
|
||||
}
|
||||
// // trick: avoid search query search_length < init_ids.size() ...
|
||||
// if (buffer_size + 1 < resset.size())
|
||||
// ++buffer_size;
|
||||
// }
|
||||
// }
|
||||
// if (cursor >= nearest_updated_pos) {
|
||||
// cursor = nearest_updated_pos; // re-search from new pos
|
||||
// } else {
|
||||
// ++cursor;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
if ((resset.size() - params->k) >= 0) {
|
||||
for (size_t i = 0; i < params->k; ++i) {
|
||||
I[i] = resset[i].id;
|
||||
D[i] = resset[i].distance;
|
||||
}
|
||||
} else {
|
||||
size_t i = 0;
|
||||
for (; i < resset.size(); ++i) {
|
||||
I[i] = resset[i].id;
|
||||
D[i] = resset[i].distance;
|
||||
}
|
||||
for (; i < params->k; ++i) {
|
||||
I[i] = -1;
|
||||
D[i] = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
// if ((resset.size() - params->k) >= 0) {
|
||||
// for (size_t i = 0; i < params->k; ++i) {
|
||||
// I[i] = resset[i].id;
|
||||
// D[i] = resset[i].distance;
|
||||
// }
|
||||
// } else {
|
||||
// size_t i = 0;
|
||||
// for (; i < resset.size(); ++i) {
|
||||
// I[i] = resset[i].id;
|
||||
// D[i] = resset[i].distance;
|
||||
// }
|
||||
// for (; i < params->k; ++i) {
|
||||
// I[i] = -1;
|
||||
// D[i] = -1;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// void
|
||||
// NsgIndex::Search(const float* query, const unsigned& nq, const unsigned& dim, const unsigned& k, float* dist,
|
||||
// int64_t* ids, SearchParams& params) {
|
||||
// // if (k >= 45) {
|
||||
// // params.search_length = k;
|
||||
// // }
|
||||
|
||||
// TimeRecorder rc("nsgsearch", 1);
|
||||
|
||||
// if (nq == 1) {
|
||||
// GetNeighbors(query, ids, dist, ¶ms);
|
||||
// } else {
|
||||
// #pragma omp parallel for
|
||||
// for (unsigned int i = 0; i < nq; ++i) {
|
||||
// const float* single_query = query + i * dim;
|
||||
// GetNeighbors(single_query, ids + i * k, dist + i * k, ¶ms);
|
||||
// }
|
||||
// }
|
||||
// rc.ElapseFromBegin("seach finish");
|
||||
// }
|
||||
|
||||
void
|
||||
NsgIndex::Search(const float* query, const unsigned& nq, const unsigned& dim, const unsigned& k, float* dist,
|
||||
int64_t* ids, SearchParams& params) {
|
||||
// if (k >= 45) {
|
||||
// params.search_length = k;
|
||||
// }
|
||||
|
||||
TimeRecorder rc("nsgsearch", 1);
|
||||
std::vector<std::vector<Neighbor>> resset(nq);
|
||||
|
||||
TimeRecorder rc("NsgIndex::search", 1);
|
||||
if (nq == 1) {
|
||||
GetNeighbors(query, ids, dist, ¶ms);
|
||||
GetNeighbors(query, resset[0], nsg, ¶ms);
|
||||
} else {
|
||||
#pragma omp parallel for
|
||||
for (unsigned int i = 0; i < nq; ++i) {
|
||||
const float* single_query = query + i * dim;
|
||||
GetNeighbors(single_query, ids + i * k, dist + i * k, ¶ms);
|
||||
GetNeighbors(single_query, resset[i], nsg, ¶ms);
|
||||
}
|
||||
}
|
||||
rc.ElapseFromBegin("seach finish");
|
||||
rc.RecordSection("search");
|
||||
for (unsigned int i = 0; i < nq; ++i) {
|
||||
int64_t var = resset[i].size() - k;
|
||||
if (var >= 0) {
|
||||
for (unsigned int j = 0; j < k; ++j) {
|
||||
ids[i * k + j] = ids_[resset[i][j].id];
|
||||
dist[i * k + j] = resset[i][j].distance;
|
||||
}
|
||||
} else {
|
||||
for (unsigned int j = 0; j < resset[i].size(); ++j) {
|
||||
ids[i * k + j] = ids_[resset[i][j].id];
|
||||
dist[i * k + j] = resset[i][j].distance;
|
||||
}
|
||||
for (unsigned int j = resset[i].size(); j < k; ++j) {
|
||||
ids[i * k + j] = -1;
|
||||
dist[i * k + j] = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
rc.RecordSection("merge");
|
||||
}
|
||||
|
||||
void
|
||||
|
||||
@ -11,16 +11,16 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <boost/dynamic_bitset.hpp>
|
||||
#include <cstddef>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <boost/dynamic_bitset.hpp>
|
||||
|
||||
#include "Distance.h"
|
||||
#include "Neighbor.h"
|
||||
#include "knowhere/common/Config.h"
|
||||
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace knowhere {
|
||||
@ -65,7 +65,7 @@ class NsgIndex {
|
||||
size_t out_degree;
|
||||
|
||||
public:
|
||||
explicit NsgIndex(const size_t& dimension, const size_t& n, std::string metric = "L2");
|
||||
explicit NsgIndex(const size_t& dimension, const size_t& n, std::string metric = knowhere::Metric::L2);
|
||||
|
||||
NsgIndex() = default;
|
||||
|
||||
@ -111,9 +111,9 @@ class NsgIndex {
|
||||
void
|
||||
GetNeighbors(const float* query, std::vector<Neighbor>& resset, Graph& graph, SearchParams* param = nullptr);
|
||||
|
||||
// used by search
|
||||
void
|
||||
GetNeighbors(const float* query, node_t* I, float* D, SearchParams* params);
|
||||
// only for search
|
||||
// void
|
||||
// GetNeighbors(const float* query, node_t* I, float* D, SearchParams* params);
|
||||
|
||||
void
|
||||
Link();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user