From 59dab6cb846abb042c09ace33e693fe1109ab8cc Mon Sep 17 00:00:00 2001 From: Xiaohai Xu Date: Mon, 16 Mar 2020 21:32:05 +0800 Subject: [PATCH] #1653 IndexFlat performance improvement for NQ < thread_number (#1674) * Optimize index flat L2/IP for SSE Signed-off-by: sahuang * parallel optimization Signed-off-by: sahuang * fix threshold Signed-off-by: sahuang * add changelog Signed-off-by: sahuang * add changelog Signed-off-by: sahuang Co-authored-by: sahuang --- CHANGELOG.md | 3 +- .../thirdparty/faiss/utils/BinaryDistance.cpp | 2 +- .../thirdparty/faiss/utils/distances.cpp | 181 +++++++++++++----- .../index/thirdparty/faiss/utils/hamming.cpp | 4 +- 4 files changed, 143 insertions(+), 47 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1272ed2847..bc5ce9115a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,8 +19,9 @@ Please mark all change in change log and use the issue from GitHub - \#1546 Move Config.cpp to config directory - \#1547 Rename storage/file to storage/disk and rename classes - \#1548 Move store/Directory to storage/Operation and add FSHandler -- \#1649 Fix Milvus crash on old CPU - \#1619 Improve compact performance +- \#1649 Fix Milvus crash on old CPU +- \#1653 IndexFlat performance improvement for NQ < thread_number ## Task diff --git a/core/src/index/thirdparty/faiss/utils/BinaryDistance.cpp b/core/src/index/thirdparty/faiss/utils/BinaryDistance.cpp index 8421e6ec37..e89d6aa7ee 100644 --- a/core/src/index/thirdparty/faiss/utils/BinaryDistance.cpp +++ b/core/src/index/thirdparty/faiss/utils/BinaryDistance.cpp @@ -33,7 +33,7 @@ namespace faiss { if (init_heap) ha->heapify (); int thread_max_num = omp_get_max_threads(); - if (ha->nh < thread_max_num) { + if (ha->nh < 4) { // omp for n2 int all_hash_size = thread_max_num * k; float *value = new float[all_hash_size]; diff --git a/core/src/index/thirdparty/faiss/utils/distances.cpp b/core/src/index/thirdparty/faiss/utils/distances.cpp index b970bc1a89..68bc6d9b82 100644 --- a/core/src/index/thirdparty/faiss/utils/distances.cpp +++ b/core/src/index/thirdparty/faiss/utils/distances.cpp @@ -152,39 +152,84 @@ static void knn_inner_product_sse (const float * x, ConcurrentBitsetPtr bitset = nullptr) { size_t k = res->k; - size_t check_period = InterruptCallback::get_period_hint (ny * d); - check_period *= omp_get_max_threads(); - - for (size_t i0 = 0; i0 < nx; i0 += check_period) { - size_t i1 = std::min(i0 + check_period, nx); + size_t thread_max_num = omp_get_max_threads(); + if (nx < 4) { + // omp for ny + size_t all_hash_size = thread_max_num * k; + float *value = new float[all_hash_size]; + int64_t *labels = new int64_t[all_hash_size]; + for (size_t i = 0; i < nx; i++) { + // init hash + for (size_t i = 0; i < all_hash_size; i++) { + value[i] = -1.0 / 0.0; + } + const float *x_i = x + i * d; #pragma omp parallel for - for (size_t i = i0; i < i1; i++) { - const float * x_i = x + i * d; - const float * y_j = y; - - float * __restrict simi = res->get_val(i); - int64_t * __restrict idxi = res->get_ids (i); - - minheap_heapify (k, simi, idxi); - for (size_t j = 0; j < ny; j++) { - if(!bitset || !bitset->test(j)){ + if(!bitset || !bitset->test(j)) { + const float *y_j = y + j * d; float ip = fvec_inner_product (x_i, y_j, d); - if (ip > simi[0]) { - minheap_pop (k, simi, idxi); - minheap_push (k, simi, idxi, ip, j); + size_t thread_no = omp_get_thread_num(); + float * __restrict val_ = value + thread_no * k; + int64_t * __restrict ids_ = labels + thread_no * k; + if (ip > val_[0]) { + minheap_pop (k, val_, ids_); + minheap_push (k, val_, ids_, ip, j); } } - y_j += d; + } + + // merge hash + float * __restrict simi = res->get_val(i); + int64_t * __restrict idxi = res->get_ids (i); + minheap_heapify (k, simi, idxi); + for (size_t i = 0; i < all_hash_size; i++) { + if (value[i] > simi[0]) { + minheap_pop (k, simi, idxi); + minheap_push (k, simi, idxi, value[i], labels[i]); + } } minheap_reorder (k, simi, idxi); } - InterruptCallback::check (); - } + delete[] value; + delete[] labels; + } else { + size_t check_period = InterruptCallback::get_period_hint (ny * d); + check_period *= thread_max_num; + + for (size_t i0 = 0; i0 < nx; i0 += check_period) { + size_t i1 = std::min(i0 + check_period, nx); + +#pragma omp parallel for + for (size_t i = i0; i < i1; i++) { + const float * x_i = x + i * d; + const float * y_j = y; + + float * __restrict simi = res->get_val(i); + int64_t * __restrict idxi = res->get_ids (i); + + minheap_heapify (k, simi, idxi); + + for (size_t j = 0; j < ny; j++) { + if(!bitset || !bitset->test(j)){ + float ip = fvec_inner_product (x_i, y_j, d); + + if (ip > simi[0]) { + minheap_pop (k, simi, idxi); + minheap_push (k, simi, idxi, ip, j); + } + } + y_j += d; + } + minheap_reorder (k, simi, idxi); + } + InterruptCallback::check (); + } + } } static void knn_L2sqr_sse ( @@ -196,37 +241,87 @@ static void knn_L2sqr_sse ( { size_t k = res->k; - size_t check_period = InterruptCallback::get_period_hint (ny * d); - check_period *= omp_get_max_threads(); - - for (size_t i0 = 0; i0 < nx; i0 += check_period) { - size_t i1 = std::min(i0 + check_period, nx); + size_t thread_max_num = omp_get_max_threads(); + if (nx < 4) { + // omp for ny + size_t all_hash_size = thread_max_num * k; + float *value = new float[all_hash_size]; + int64_t *labels = new int64_t[all_hash_size]; + for (size_t i = 0; i < nx; i++) { + // init hash + for (size_t i = 0; i < all_hash_size; i++) { + value[i] = 1.0 / 0.0; + } + for (size_t i = 0; i < k; i++) { + labels[i] = -1; + } + const float *x_i = x + i * d; #pragma omp parallel for - for (size_t i = i0; i < i1; i++) { - const float * x_i = x + i * d; - const float * y_j = y; - size_t j; - float * simi = res->get_val(i); - int64_t * idxi = res->get_ids (i); - - maxheap_heapify (k, simi, idxi); - for (j = 0; j < ny; j++) { - if(!bitset || !bitset->test(j)){ + for (size_t j = 0; j < ny; j++) { + if(!bitset || !bitset->test(j)) { + const float *y_j = y + j * d; float disij = fvec_L2sqr (x_i, y_j, d); - if (disij < simi[0]) { - maxheap_pop (k, simi, idxi); - maxheap_push (k, simi, idxi, disij, j); + size_t thread_no = omp_get_thread_num(); + float * __restrict val_ = value + thread_no * k; + int64_t * __restrict ids_ = labels + thread_no * k; + if (disij < val_[0]) { + maxheap_pop (k, val_, ids_); + maxheap_push (k, val_, ids_, disij, j); } } - y_j += d; + } + + // merge hash + float * __restrict simi = res->get_val(i); + int64_t * __restrict idxi = res->get_ids (i); + memcpy(simi, value, k * sizeof(float)); + memcpy(idxi, labels, k * sizeof(int64_t)); + maxheap_heapify (k, simi, idxi, value, labels, k); + for (size_t i = k; i < all_hash_size; i++) { + if (value[i] < simi[0]) { + maxheap_pop (k, simi, idxi); + maxheap_push (k, simi, idxi, value[i], labels[i]); + } } maxheap_reorder (k, simi, idxi); } - InterruptCallback::check (); - } + delete[] value; + delete[] labels; + } else { + size_t check_period = InterruptCallback::get_period_hint (ny * d); + check_period *= thread_max_num; + + for (size_t i0 = 0; i0 < nx; i0 += check_period) { + size_t i1 = std::min(i0 + check_period, nx); + +#pragma omp parallel for + for (size_t i = i0; i < i1; i++) { + const float * x_i = x + i * d; + const float * y_j = y; + float * simi = res->get_val(i); + int64_t * idxi = res->get_ids (i); + + maxheap_heapify (k, simi, idxi); + + for (size_t j = 0; j < ny; j++) { + if(!bitset || !bitset->test(j)){ + float disij = fvec_L2sqr (x_i, y_j, d); + + if (disij < simi[0]) { + maxheap_pop (k, simi, idxi); + maxheap_push (k, simi, idxi, disij, j); + } + } + y_j += d; + } + maxheap_reorder (k, simi, idxi); + } + InterruptCallback::check (); + } + } } @@ -899,4 +994,4 @@ void pairwise_L2sqr (int64_t d, } -} // namespace faiss +} // namespace faiss \ No newline at end of file diff --git a/core/src/index/thirdparty/faiss/utils/hamming.cpp b/core/src/index/thirdparty/faiss/utils/hamming.cpp index cf49b0d3ef..6fb4c95d56 100644 --- a/core/src/index/thirdparty/faiss/utils/hamming.cpp +++ b/core/src/index/thirdparty/faiss/utils/hamming.cpp @@ -281,7 +281,7 @@ void hammings_knn_hc ( if (init_heap) ha->heapify (); int thread_max_num = omp_get_max_threads(); - if (ha->nh < thread_max_num) { + if (ha->nh < 4) { // omp for n2 int all_hash_size = thread_max_num * k; hamdis_t *value = new hamdis_t[all_hash_size]; @@ -432,7 +432,7 @@ void hammings_knn_hc_1 ( } int thread_max_num = omp_get_max_threads(); - if (ha->nh < thread_max_num) { + if (ha->nh < 4) { // omp for n2 int all_hash_size = thread_max_num * k; hamdis_t *value = new hamdis_t[all_hash_size];