mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
improve rhnsw (#5059)
Signed-off-by: shengjun.li <shengjun.li@zilliz.com>
This commit is contained in:
parent
5a89b2668e
commit
a3e4339027
@ -266,8 +266,6 @@ void IndexRHNSW::search (idx_t n, const float *x, idx_t k,
|
||||
float * simi = distances + i * k;
|
||||
dis->set_query(x + i * d);
|
||||
|
||||
maxheap_heapify (k, simi, idxi);
|
||||
|
||||
if (STATISTICS_LEVEL == 3)
|
||||
hnsw.searchKnn(*dis, k, idxi, simi, query_stats[i], bitset);
|
||||
else {
|
||||
@ -275,8 +273,6 @@ void IndexRHNSW::search (idx_t n, const float *x, idx_t k,
|
||||
hnsw.searchKnn(*dis, k, idxi, simi, dummy_stat, bitset);
|
||||
}
|
||||
|
||||
maxheap_reorder (k, simi, idxi);
|
||||
|
||||
if (reconstruct_from_neighbors &&
|
||||
reconstruct_from_neighbors->k_reorder != 0) {
|
||||
int k_reorder = reconstruct_from_neighbors->k_reorder;
|
||||
@ -634,8 +630,8 @@ IndexRHNSWSQ::IndexRHNSWSQ(int d, QuantizerType qtype, int M,
|
||||
MetricType metric):
|
||||
IndexRHNSW (new IndexScalarQuantizer (d, qtype, metric), M)
|
||||
{
|
||||
is_trained = false;
|
||||
own_fields = true;
|
||||
is_trained = false;
|
||||
}
|
||||
|
||||
IndexRHNSWSQ::IndexRHNSWSQ() {}
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
#include <faiss/impl/RHNSW.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <faiss/impl/AuxIndexStructures.h>
|
||||
|
||||
@ -71,63 +72,46 @@ void RHNSW::reset() {
|
||||
int RHNSW::prepare_level_tab(size_t n, bool preset_levels)
|
||||
{
|
||||
size_t n0 = levels.size();
|
||||
size_t n1 = n0 + n;
|
||||
|
||||
std::vector<int> level_stats(n);
|
||||
if (preset_levels) {
|
||||
FAISS_ASSERT (n0 + n == levels.size());
|
||||
FAISS_ASSERT (n1 == levels.size());
|
||||
} else {
|
||||
FAISS_ASSERT (n0 == levels.size());
|
||||
levels.resize(n1);
|
||||
for (int i = 0; i < n; i++) {
|
||||
int pt_level = random_level(level_constant);
|
||||
levels.push_back(pt_level);
|
||||
levels[n0 + i] = pt_level;
|
||||
}
|
||||
}
|
||||
|
||||
char *level0_links_new = (char*)malloc((n0 + n) * level0_link_size);
|
||||
if (level0_links_new == nullptr) {
|
||||
throw std::runtime_error("No enough memory 4 level0_links!");
|
||||
}
|
||||
memset(level0_links_new, 0, (n0 + n) * level0_link_size);
|
||||
if (level0_links) {
|
||||
memcpy(level0_links_new, level0_links, n0 * level0_link_size);
|
||||
free(level0_links);
|
||||
}
|
||||
char *level0_links_new = (char *) realloc(level0_links, level0_link_size * n1);
|
||||
if (level0_links_new == nullptr)
|
||||
throw std::runtime_error("No enough memory 4 level0_links!");
|
||||
level0_links = level0_links_new;
|
||||
memset(level0_links + n0 * level0_link_size, 0, n * level0_link_size);
|
||||
|
||||
char **linkLists_new = (char **)malloc(sizeof(void*) * (n0 + n));
|
||||
if (linkLists_new == nullptr) {
|
||||
throw std::runtime_error("No enough memory 4 level0_links_new!");
|
||||
}
|
||||
if (linkLists) {
|
||||
memcpy(linkLists_new, linkLists, n0 * sizeof(void*));
|
||||
free(linkLists);
|
||||
}
|
||||
char **linkLists_new = (char **) realloc(linkLists, sizeof(void *) * n1);
|
||||
if (linkLists_new == nullptr)
|
||||
throw std::runtime_error("No enough memory 4 level0_links_new!");
|
||||
linkLists = linkLists_new;
|
||||
memset(linkLists + n0 * sizeof(void *), 0, n * sizeof(void *));
|
||||
|
||||
int debug_space = 0;
|
||||
for (int i = 0; i < n; i++) {
|
||||
int pt_level = levels[i + n0];
|
||||
if (pt_level > max_level) max_level = pt_level;
|
||||
if (pt_level) {
|
||||
linkLists[n0 + i] = (char*) malloc(link_size * pt_level + 1);
|
||||
linkLists[n0 + i] = (char*) malloc(link_size * pt_level);
|
||||
if (linkLists[n0 + i] == nullptr) {
|
||||
throw std::runtime_error("No enough memory 4 linkLists!");
|
||||
}
|
||||
memset(linkLists[n0 + i], 0, link_size * pt_level + 1);
|
||||
memset(linkLists[n0 + i], 0, link_size * pt_level);
|
||||
}
|
||||
if (max_level >= level_stats.size()) {
|
||||
level_stats.resize(max_level + 1);
|
||||
}
|
||||
level_stats[pt_level] ++;
|
||||
}
|
||||
|
||||
// printf("level stats:\n");
|
||||
// for (int i = 0; i <= max_level; ++ i)
|
||||
// printf("level %d: %d points\n", i, level_stats[i]);
|
||||
// printf("\n");
|
||||
std::vector<std::mutex>(n0 + n).swap(link_list_locks);
|
||||
if (visited_list_pool) delete visited_list_pool;
|
||||
visited_list_pool = new VisitedListPool(1, n0 + n);
|
||||
visited_list_pool = new VisitedListPool(1, n1);
|
||||
|
||||
return max_level;
|
||||
}
|
||||
@ -184,9 +168,8 @@ void RHNSW::addPoint(DistanceComputer& ptdis, int pt_level, int pt_id) {
|
||||
if (lev > maxlevel_copy || lev < 0)
|
||||
throw std::runtime_error("Level error");
|
||||
|
||||
std::priority_queue<Node, std::vector<Node>, CompareByFirst> top_candidates = search_layer(ptdis, pt_id, currObj, lev);
|
||||
currObj = top_candidates.top().second;
|
||||
make_connection(ptdis, pt_id, top_candidates, lev);
|
||||
std::priority_queue<Node, std::vector<Node>, CompareByFirst> top_candidates = search_layer(ptdis, currObj, lev);
|
||||
currObj = make_connection(ptdis, pt_id, top_candidates, lev);
|
||||
}
|
||||
} else {
|
||||
entry_point = 0;
|
||||
@ -202,7 +185,6 @@ void RHNSW::addPoint(DistanceComputer& ptdis, int pt_level, int pt_id) {
|
||||
|
||||
std::priority_queue<Node, std::vector<Node>, CompareByFirst>
|
||||
RHNSW::search_layer(DistanceComputer& ptdis,
|
||||
storage_idx_t pt_id,
|
||||
storage_idx_t nearest,
|
||||
int level) {
|
||||
VisitedList *vl = visited_list_pool->getFreeVisitedList();
|
||||
@ -300,7 +282,7 @@ RHNSW::search_base_layer(DistanceComputer& ptdis,
|
||||
return top_candidates;
|
||||
}
|
||||
|
||||
void
|
||||
int
|
||||
RHNSW::make_connection(DistanceComputer& ptdis,
|
||||
storage_idx_t pt_id,
|
||||
std::priority_queue<Node, std::vector<Node>, CompareByFirst> &cand,
|
||||
@ -312,6 +294,8 @@ RHNSW::make_connection(DistanceComputer& ptdis,
|
||||
if (selectedNeighborsNum > maxM)
|
||||
throw std::runtime_error("Wrong size of candidates returned by prune_neighbors!");
|
||||
|
||||
int next_closest_entry_point = selectedNeighbors[0];
|
||||
|
||||
int *cur_link = get_neighbor_link(pt_id, level);
|
||||
if (*cur_link)
|
||||
throw std::runtime_error("The newly inserted element should have blank link");
|
||||
@ -352,41 +336,43 @@ RHNSW::make_connection(DistanceComputer& ptdis,
|
||||
}
|
||||
|
||||
free(selectedNeighbors);
|
||||
return next_closest_entry_point;
|
||||
}
|
||||
|
||||
void RHNSW::prune_neighbors(DistanceComputer& ptdis,
|
||||
std::priority_queue<Node, std::vector<Node>, CompareByFirst> &cand,
|
||||
const int maxM, int *ret, int &ret_len) {
|
||||
if (cand.size() < maxM) {
|
||||
while (!cand.empty()) {
|
||||
ret[ret_len ++] = cand.top().second;
|
||||
ret_len = cand.size();
|
||||
for (int i = static_cast<int>(cand.size()) - 1; i >= 0; i--) {
|
||||
ret[i] = cand.top().second;
|
||||
cand.pop();
|
||||
}
|
||||
return;
|
||||
}
|
||||
std::priority_queue<Node> closest;
|
||||
} else if (maxM > 0) {
|
||||
ret_len = 0;
|
||||
|
||||
while (!cand.empty()) {
|
||||
closest.emplace(-cand.top().first, cand.top().second);
|
||||
cand.pop();
|
||||
}
|
||||
|
||||
while (closest.size()) {
|
||||
if (ret_len >= maxM)
|
||||
break;
|
||||
Node curr = closest.top();
|
||||
float dist_to_query = -curr.first;
|
||||
closest.pop();
|
||||
bool good = true;
|
||||
for (auto i = 0; i < ret_len; ++ i) {
|
||||
float cur_dist = ptdis.symmetric_dis(curr.second, ret[i]);
|
||||
if (cur_dist < dist_to_query) {
|
||||
good = false;
|
||||
break;
|
||||
}
|
||||
std::vector<Node> queue_closest;
|
||||
queue_closest.resize(cand.size());
|
||||
for (int i = static_cast<int>(cand.size()) - 1; i >= 0; i--) {
|
||||
queue_closest[i] = cand.top();
|
||||
cand.pop();
|
||||
}
|
||||
if (good) {
|
||||
ret[ret_len ++] = curr.second;
|
||||
|
||||
for (auto &curr: queue_closest) {
|
||||
bool good = true;
|
||||
for (auto i = 0; i < ret_len; ++ i) {
|
||||
float cur_dist = ptdis.symmetric_dis(curr.second, ret[i]);
|
||||
if (cur_dist < curr.first) {
|
||||
good = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (good) {
|
||||
ret[ret_len++] = (curr.second);
|
||||
if (ret_len >= M) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -424,13 +410,18 @@ void RHNSW::searchKnn(DistanceComputer& qdis, int k,
|
||||
std::priority_queue<Node, std::vector<Node>, CompareByFirst> top_candidates = search_base_layer(qdis, ep, std::max(efSearch, k), dist, bitset);
|
||||
while (top_candidates.size() > k)
|
||||
top_candidates.pop();
|
||||
int i = 0;
|
||||
int rst_num = top_candidates.size();
|
||||
int i = rst_num - 1;
|
||||
while (!top_candidates.empty()) {
|
||||
I[i] = top_candidates.top().second;
|
||||
D[i] = top_candidates.top().first;
|
||||
i ++;
|
||||
i--;
|
||||
top_candidates.pop();
|
||||
}
|
||||
for (;rst_num < k; rst_num++) {
|
||||
I[rst_num] = -1;
|
||||
D[rst_num] = 1.0/0.0;
|
||||
}
|
||||
}
|
||||
|
||||
size_t RHNSW::cal_size() {
|
||||
|
||||
@ -222,7 +222,6 @@ struct RHNSW {
|
||||
|
||||
std::priority_queue<Node, std::vector<Node>, CompareByFirst>
|
||||
search_layer (DistanceComputer& ptdis,
|
||||
storage_idx_t pt_id,
|
||||
storage_idx_t nearest,
|
||||
int level);
|
||||
|
||||
@ -233,10 +232,10 @@ struct RHNSW {
|
||||
float d_nearest,
|
||||
const BitsetView bitset = nullptr) const;
|
||||
|
||||
void make_connection(DistanceComputer& ptdis,
|
||||
storage_idx_t pt_id,
|
||||
std::priority_queue<Node, std::vector<Node>, CompareByFirst> &cand,
|
||||
int level);
|
||||
int make_connection(DistanceComputer& ptdis,
|
||||
storage_idx_t pt_id,
|
||||
std::priority_queue<Node, std::vector<Node>, CompareByFirst> &cand,
|
||||
int level);
|
||||
|
||||
void prune_neighbors(DistanceComputer& ptdis,
|
||||
std::priority_queue<Node, std::vector<Node>, CompareByFirst> &cand,
|
||||
|
||||
@ -477,8 +477,8 @@ static void read_RHNSW (RHNSW *rhnsw, IOReader *f) {
|
||||
rhnsw->linkLists = (char**) malloc(ntotal * sizeof(void*));
|
||||
for (auto i = 0; i < ntotal; ++ i) {
|
||||
if (rhnsw->levels[i]) {
|
||||
rhnsw->linkLists[i] = (char*)malloc(rhnsw->link_size * rhnsw->levels[i] + 1);
|
||||
READANDCHECK( rhnsw->linkLists[i], rhnsw->link_size * rhnsw->levels[i] + 1);
|
||||
rhnsw->linkLists[i] = (char*)malloc(rhnsw->link_size * rhnsw->levels[i]);
|
||||
READANDCHECK( rhnsw->linkLists[i], rhnsw->link_size * rhnsw->levels[i]);
|
||||
rhnsw->level_stats[rhnsw->levels[i]] ++;
|
||||
} else {
|
||||
rhnsw->level_stats[0] ++;
|
||||
|
||||
@ -379,7 +379,7 @@ static void write_RHNSW (const RHNSW *rhnsw, IOWriter *f) {
|
||||
WRITEANDCHECK (rhnsw->level0_links, rhnsw->level0_link_size * rhnsw->levels.size());
|
||||
for (auto i = 0; i < rhnsw->levels.size(); ++ i) {
|
||||
if (rhnsw->levels[i])
|
||||
WRITEANDCHECK (rhnsw->linkLists[i], rhnsw->link_size * rhnsw->levels[i] + 1);
|
||||
WRITEANDCHECK (rhnsw->linkLists[i], rhnsw->link_size * rhnsw->levels[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user