improve rhnsw (#5059)

Signed-off-by: shengjun.li <shengjun.li@zilliz.com>
This commit is contained in:
shengjun.li 2021-04-27 14:17:52 +08:00 committed by GitHub
parent 5a89b2668e
commit a3e4339027
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 63 additions and 77 deletions

View File

@ -266,8 +266,6 @@ void IndexRHNSW::search (idx_t n, const float *x, idx_t k,
float * simi = distances + i * k;
dis->set_query(x + i * d);
maxheap_heapify (k, simi, idxi);
if (STATISTICS_LEVEL == 3)
hnsw.searchKnn(*dis, k, idxi, simi, query_stats[i], bitset);
else {
@ -275,8 +273,6 @@ void IndexRHNSW::search (idx_t n, const float *x, idx_t k,
hnsw.searchKnn(*dis, k, idxi, simi, dummy_stat, bitset);
}
maxheap_reorder (k, simi, idxi);
if (reconstruct_from_neighbors &&
reconstruct_from_neighbors->k_reorder != 0) {
int k_reorder = reconstruct_from_neighbors->k_reorder;
@ -634,8 +630,8 @@ IndexRHNSWSQ::IndexRHNSWSQ(int d, QuantizerType qtype, int M,
MetricType metric):
IndexRHNSW (new IndexScalarQuantizer (d, qtype, metric), M)
{
is_trained = false;
own_fields = true;
is_trained = false;
}
IndexRHNSWSQ::IndexRHNSWSQ() {}

View File

@ -10,6 +10,7 @@
#include <faiss/impl/RHNSW.h>
#include <string>
#include <vector>
#include <faiss/impl/AuxIndexStructures.h>
@ -71,63 +72,46 @@ void RHNSW::reset() {
int RHNSW::prepare_level_tab(size_t n, bool preset_levels)
{
size_t n0 = levels.size();
size_t n1 = n0 + n;
std::vector<int> level_stats(n);
if (preset_levels) {
FAISS_ASSERT (n0 + n == levels.size());
FAISS_ASSERT (n1 == levels.size());
} else {
FAISS_ASSERT (n0 == levels.size());
levels.resize(n1);
for (int i = 0; i < n; i++) {
int pt_level = random_level(level_constant);
levels.push_back(pt_level);
levels[n0 + i] = pt_level;
}
}
char *level0_links_new = (char*)malloc((n0 + n) * level0_link_size);
if (level0_links_new == nullptr) {
throw std::runtime_error("No enough memory 4 level0_links!");
}
memset(level0_links_new, 0, (n0 + n) * level0_link_size);
if (level0_links) {
memcpy(level0_links_new, level0_links, n0 * level0_link_size);
free(level0_links);
}
char *level0_links_new = (char *) realloc(level0_links, level0_link_size * n1);
if (level0_links_new == nullptr)
throw std::runtime_error("No enough memory 4 level0_links!");
level0_links = level0_links_new;
memset(level0_links + n0 * level0_link_size, 0, n * level0_link_size);
char **linkLists_new = (char **)malloc(sizeof(void*) * (n0 + n));
if (linkLists_new == nullptr) {
throw std::runtime_error("No enough memory 4 level0_links_new!");
}
if (linkLists) {
memcpy(linkLists_new, linkLists, n0 * sizeof(void*));
free(linkLists);
}
char **linkLists_new = (char **) realloc(linkLists, sizeof(void *) * n1);
if (linkLists_new == nullptr)
throw std::runtime_error("No enough memory 4 level0_links_new!");
linkLists = linkLists_new;
memset(linkLists + n0 * sizeof(void *), 0, n * sizeof(void *));
int debug_space = 0;
for (int i = 0; i < n; i++) {
int pt_level = levels[i + n0];
if (pt_level > max_level) max_level = pt_level;
if (pt_level) {
linkLists[n0 + i] = (char*) malloc(link_size * pt_level + 1);
linkLists[n0 + i] = (char*) malloc(link_size * pt_level);
if (linkLists[n0 + i] == nullptr) {
throw std::runtime_error("No enough memory 4 linkLists!");
}
memset(linkLists[n0 + i], 0, link_size * pt_level + 1);
memset(linkLists[n0 + i], 0, link_size * pt_level);
}
if (max_level >= level_stats.size()) {
level_stats.resize(max_level + 1);
}
level_stats[pt_level] ++;
}
// printf("level stats:\n");
// for (int i = 0; i <= max_level; ++ i)
// printf("level %d: %d points\n", i, level_stats[i]);
// printf("\n");
std::vector<std::mutex>(n0 + n).swap(link_list_locks);
if (visited_list_pool) delete visited_list_pool;
visited_list_pool = new VisitedListPool(1, n0 + n);
visited_list_pool = new VisitedListPool(1, n1);
return max_level;
}
@ -184,9 +168,8 @@ void RHNSW::addPoint(DistanceComputer& ptdis, int pt_level, int pt_id) {
if (lev > maxlevel_copy || lev < 0)
throw std::runtime_error("Level error");
std::priority_queue<Node, std::vector<Node>, CompareByFirst> top_candidates = search_layer(ptdis, pt_id, currObj, lev);
currObj = top_candidates.top().second;
make_connection(ptdis, pt_id, top_candidates, lev);
std::priority_queue<Node, std::vector<Node>, CompareByFirst> top_candidates = search_layer(ptdis, currObj, lev);
currObj = make_connection(ptdis, pt_id, top_candidates, lev);
}
} else {
entry_point = 0;
@ -202,7 +185,6 @@ void RHNSW::addPoint(DistanceComputer& ptdis, int pt_level, int pt_id) {
std::priority_queue<Node, std::vector<Node>, CompareByFirst>
RHNSW::search_layer(DistanceComputer& ptdis,
storage_idx_t pt_id,
storage_idx_t nearest,
int level) {
VisitedList *vl = visited_list_pool->getFreeVisitedList();
@ -300,7 +282,7 @@ RHNSW::search_base_layer(DistanceComputer& ptdis,
return top_candidates;
}
void
int
RHNSW::make_connection(DistanceComputer& ptdis,
storage_idx_t pt_id,
std::priority_queue<Node, std::vector<Node>, CompareByFirst> &cand,
@ -312,6 +294,8 @@ RHNSW::make_connection(DistanceComputer& ptdis,
if (selectedNeighborsNum > maxM)
throw std::runtime_error("Wrong size of candidates returned by prune_neighbors!");
int next_closest_entry_point = selectedNeighbors[0];
int *cur_link = get_neighbor_link(pt_id, level);
if (*cur_link)
throw std::runtime_error("The newly inserted element should have blank link");
@ -352,41 +336,43 @@ RHNSW::make_connection(DistanceComputer& ptdis,
}
free(selectedNeighbors);
return next_closest_entry_point;
}
void RHNSW::prune_neighbors(DistanceComputer& ptdis,
std::priority_queue<Node, std::vector<Node>, CompareByFirst> &cand,
const int maxM, int *ret, int &ret_len) {
if (cand.size() < maxM) {
while (!cand.empty()) {
ret[ret_len ++] = cand.top().second;
ret_len = cand.size();
for (int i = static_cast<int>(cand.size()) - 1; i >= 0; i--) {
ret[i] = cand.top().second;
cand.pop();
}
return;
}
std::priority_queue<Node> closest;
} else if (maxM > 0) {
ret_len = 0;
while (!cand.empty()) {
closest.emplace(-cand.top().first, cand.top().second);
cand.pop();
}
while (closest.size()) {
if (ret_len >= maxM)
break;
Node curr = closest.top();
float dist_to_query = -curr.first;
closest.pop();
bool good = true;
for (auto i = 0; i < ret_len; ++ i) {
float cur_dist = ptdis.symmetric_dis(curr.second, ret[i]);
if (cur_dist < dist_to_query) {
good = false;
break;
}
std::vector<Node> queue_closest;
queue_closest.resize(cand.size());
for (int i = static_cast<int>(cand.size()) - 1; i >= 0; i--) {
queue_closest[i] = cand.top();
cand.pop();
}
if (good) {
ret[ret_len ++] = curr.second;
for (auto &curr: queue_closest) {
bool good = true;
for (auto i = 0; i < ret_len; ++ i) {
float cur_dist = ptdis.symmetric_dis(curr.second, ret[i]);
if (cur_dist < curr.first) {
good = false;
break;
}
}
if (good) {
ret[ret_len++] = (curr.second);
if (ret_len >= M) {
break;
}
}
}
}
}
@ -424,13 +410,18 @@ void RHNSW::searchKnn(DistanceComputer& qdis, int k,
std::priority_queue<Node, std::vector<Node>, CompareByFirst> top_candidates = search_base_layer(qdis, ep, std::max(efSearch, k), dist, bitset);
while (top_candidates.size() > k)
top_candidates.pop();
int i = 0;
int rst_num = top_candidates.size();
int i = rst_num - 1;
while (!top_candidates.empty()) {
I[i] = top_candidates.top().second;
D[i] = top_candidates.top().first;
i ++;
i--;
top_candidates.pop();
}
for (;rst_num < k; rst_num++) {
I[rst_num] = -1;
D[rst_num] = 1.0/0.0;
}
}
size_t RHNSW::cal_size() {

View File

@ -222,7 +222,6 @@ struct RHNSW {
std::priority_queue<Node, std::vector<Node>, CompareByFirst>
search_layer (DistanceComputer& ptdis,
storage_idx_t pt_id,
storage_idx_t nearest,
int level);
@ -233,10 +232,10 @@ struct RHNSW {
float d_nearest,
const BitsetView bitset = nullptr) const;
void make_connection(DistanceComputer& ptdis,
storage_idx_t pt_id,
std::priority_queue<Node, std::vector<Node>, CompareByFirst> &cand,
int level);
int make_connection(DistanceComputer& ptdis,
storage_idx_t pt_id,
std::priority_queue<Node, std::vector<Node>, CompareByFirst> &cand,
int level);
void prune_neighbors(DistanceComputer& ptdis,
std::priority_queue<Node, std::vector<Node>, CompareByFirst> &cand,

View File

@ -477,8 +477,8 @@ static void read_RHNSW (RHNSW *rhnsw, IOReader *f) {
rhnsw->linkLists = (char**) malloc(ntotal * sizeof(void*));
for (auto i = 0; i < ntotal; ++ i) {
if (rhnsw->levels[i]) {
rhnsw->linkLists[i] = (char*)malloc(rhnsw->link_size * rhnsw->levels[i] + 1);
READANDCHECK( rhnsw->linkLists[i], rhnsw->link_size * rhnsw->levels[i] + 1);
rhnsw->linkLists[i] = (char*)malloc(rhnsw->link_size * rhnsw->levels[i]);
READANDCHECK( rhnsw->linkLists[i], rhnsw->link_size * rhnsw->levels[i]);
rhnsw->level_stats[rhnsw->levels[i]] ++;
} else {
rhnsw->level_stats[0] ++;

View File

@ -379,7 +379,7 @@ static void write_RHNSW (const RHNSW *rhnsw, IOWriter *f) {
WRITEANDCHECK (rhnsw->level0_links, rhnsw->level0_link_size * rhnsw->levels.size());
for (auto i = 0; i < rhnsw->levels.size(); ++ i) {
if (rhnsw->levels[i])
WRITEANDCHECK (rhnsw->linkLists[i], rhnsw->link_size * rhnsw->levels[i] + 1);
WRITEANDCHECK (rhnsw->linkLists[i], rhnsw->link_size * rhnsw->levels[i]);
}
}