diff --git a/core/src/index/thirdparty/faiss/utils/distances.cpp b/core/src/index/thirdparty/faiss/utils/distances.cpp index db7c222b2c..e97e873614 100644 --- a/core/src/index/thirdparty/faiss/utils/distances.cpp +++ b/core/src/index/thirdparty/faiss/utils/distances.cpp @@ -148,7 +148,7 @@ static void knn_inner_product_sse (const float * x, size_t k = res->k; size_t thread_max_num = omp_get_max_threads(); - if (ny > parallel_policy_threshold) { + if (ny > parallel_policy_threshold || (nx < thread_max_num / 2 && ny >= thread_max_num * 32)) { size_t block_x = std::min( get_L3_Size() / (d * sizeof(float) + thread_max_num * k * (sizeof(float) + sizeof(int64_t))), nx); @@ -173,24 +173,24 @@ static void knn_inner_product_sse (const float * x, if(!bitset || !bitset->test(j)) { size_t thread_no = omp_get_thread_num(); const float *y_j = y + j * d; - for (size_t i = x_from; i < x_to; i++) { - const float *x_i = x + i * d; + const float *x_i = x + x_from * d; + for (size_t i = 0; i < size; i++) { float disij = fvec_inner_product (x_i, y_j, d); - - float * val_ = value + thread_no * thread_heap_size + (i - x_from) * k; - int64_t * ids_ = labels + thread_no * thread_heap_size + (i - x_from) * k; + float * val_ = value + thread_no * thread_heap_size + i * k; + int64_t * ids_ = labels + thread_no * thread_heap_size + i * k; if (disij > val_[0]) { minheap_swap_top (k, val_, ids_, disij, j); } + x_i += d; } } } + // merge heap for (size_t t = 1; t < thread_max_num; t++) { - // merge heap - for (size_t i = x_from; i < x_to; i++) { - float * __restrict value_x = value + (i - x_from) * k; - int64_t * __restrict labels_x = labels + (i - x_from) * k; + for (size_t i = 0; i < size; i++) { + float * __restrict value_x = value + i * k; + int64_t * __restrict labels_x = labels + i * k; float *value_x_t = value_x + t * thread_heap_size; int64_t *labels_x_t = labels_x + t * thread_heap_size; for (size_t j = 0; j < k; j++) { @@ -201,14 +201,16 @@ static void knn_inner_product_sse (const float * x, } } - for (size_t i = x_from; i < x_to; i++) { - float * value_x = value + (i - x_from) * k; - int64_t * labels_x = labels + (i - x_from) * k; + // sort + for (size_t i = 0; i < size; i++) { + float * value_x = value + i * k; + int64_t * labels_x = labels + i * k; minheap_reorder (k, value_x, labels_x); } - memcpy(res->val+ x_from * k, value, thread_heap_size * sizeof(float)); - memcpy(res->ids+ x_from * k, labels, thread_heap_size * sizeof(int64_t)); + // copy result + memcpy(res->val + x_from * k, value, thread_heap_size * sizeof(float)); + memcpy(res->ids + x_from * k, labels, thread_heap_size * sizeof(int64_t)); } delete[] value; delete[] labels; @@ -255,7 +257,7 @@ static void knn_L2sqr_sse ( size_t k = res->k; size_t thread_max_num = omp_get_max_threads(); - if (ny > parallel_policy_threshold) { + if (ny > parallel_policy_threshold || (nx < thread_max_num / 2 && ny >= thread_max_num * 32)) { size_t block_x = std::min( get_L3_Size() / (d * sizeof(float) + thread_max_num * k * (sizeof(float) + sizeof(int64_t))), nx); @@ -280,24 +282,24 @@ static void knn_L2sqr_sse ( if(!bitset || !bitset->test(j)) { size_t thread_no = omp_get_thread_num(); const float *y_j = y + j * d; - for (size_t i = x_from; i < x_to; i++) { - const float *x_i = x + i * d; + const float *x_i = x + x_from * d; + for (size_t i = 0; i < size; i++) { float disij = fvec_L2sqr (x_i, y_j, d); - - float * val_ = value + thread_no * thread_heap_size + (i - x_from) * k; - int64_t * ids_ = labels + thread_no * thread_heap_size + (i - x_from) * k; + float * val_ = value + thread_no * thread_heap_size + i * k; + int64_t * ids_ = labels + thread_no * thread_heap_size + i * k; if (disij < val_[0]) { maxheap_swap_top (k, val_, ids_, disij, j); } + x_i += d; } } } + // merge heap for (size_t t = 1; t < thread_max_num; t++) { - // merge heap - for (size_t i = x_from; i < x_to; i++) { - float * __restrict value_x = value + (i - x_from) * k; - int64_t * __restrict labels_x = labels + (i - x_from) * k; + for (size_t i = 0; i < size; i++) { + float * __restrict value_x = value + i * k; + int64_t * __restrict labels_x = labels + i * k; float *value_x_t = value_x + t * thread_heap_size; int64_t *labels_x_t = labels_x + t * thread_heap_size; for (size_t j = 0; j < k; j++) { @@ -308,19 +310,20 @@ static void knn_L2sqr_sse ( } } - for (size_t i = x_from; i < x_to; i++) { - float * value_x = value + (i - x_from) * k; - int64_t * labels_x = labels + (i - x_from) * k; + // sort + for (size_t i = 0; i < size; i++) { + float * value_x = value + i * k; + int64_t * labels_x = labels + i * k; maxheap_reorder (k, value_x, labels_x); } - memcpy(res->val+ x_from * k, value, thread_heap_size * sizeof(float)); - memcpy(res->ids+ x_from * k, labels, thread_heap_size * sizeof(int64_t)); + // copy result + memcpy(res->val + x_from * k, value, thread_heap_size * sizeof(float)); + memcpy(res->ids + x_from * k, labels, thread_heap_size * sizeof(int64_t)); } delete[] value; delete[] labels; - } else { float * value = res->val;