milvus/internal/core/src/query/SearchBruteForce.cpp
xige-16 2ca53fa668 Fix msgstream deadlock when loadCollection
Signed-off-by: xige-16 <xi.ge@zilliz.com>
2021-02-07 15:47:10 +08:00

137 lines
6.2 KiB
C++

// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include "SearchBruteForce.h"
#include <vector>
#include <common/Types.h>
#include <boost/dynamic_bitset.hpp>
#include <queue>
#include "SubQueryResult.h"
#include <faiss/utils/distances.h>
namespace milvus::query {
SubQueryResult
BinarySearchBruteForceFast(MetricType metric_type,
int64_t dim,
const uint8_t* binary_chunk,
int64_t size_per_chunk,
int64_t topk,
int64_t num_queries,
const uint8_t* query_data,
const faiss::BitsetView& bitset) {
SubQueryResult sub_result(num_queries, topk, metric_type);
float* result_distances = sub_result.get_values();
idx_t* result_labels = sub_result.get_labels();
int64_t code_size = dim / 8;
const idx_t block_size = size_per_chunk;
bool use_heap = true;
if (metric_type == faiss::METRIC_Jaccard || metric_type == faiss::METRIC_Tanimoto) {
float* D = result_distances;
for (idx_t query_base_index = 0; query_base_index < num_queries; query_base_index += block_size) {
idx_t query_size = block_size;
if (query_base_index + block_size > num_queries) {
query_size = num_queries - query_base_index;
}
// We see the distances and labels as heaps.
faiss::float_maxheap_array_t res = {size_t(query_size), size_t(topk),
result_labels + query_base_index * topk, D + query_base_index * topk};
binary_distence_knn_hc(metric_type, &res, query_data + query_base_index * code_size, binary_chunk,
size_per_chunk, code_size,
/* ordered = */ true, bitset);
}
if (metric_type == faiss::METRIC_Tanimoto) {
for (int i = 0; i < topk * num_queries; i++) {
D[i] = -log2(1 - D[i]);
}
}
} else if (metric_type == faiss::METRIC_Substructure || metric_type == faiss::METRIC_Superstructure) {
float* D = result_distances;
for (idx_t s = 0; s < num_queries; s += block_size) {
idx_t nn = block_size;
if (s + block_size > num_queries) {
nn = num_queries - s;
}
// only match ids will be chosed, not to use heap
binary_distence_knn_mc(metric_type, query_data + s * code_size, binary_chunk, nn, size_per_chunk, topk,
code_size, D + s * topk, result_labels + s * topk, bitset);
}
} else if (metric_type == faiss::METRIC_Hamming) {
std::vector<int> int_distances(topk * num_queries);
for (idx_t s = 0; s < num_queries; s += block_size) {
idx_t nn = block_size;
if (s + block_size > num_queries) {
nn = num_queries - s;
}
if (use_heap) {
// We see the distances and labels as heaps.
faiss::int_maxheap_array_t res = {size_t(nn), size_t(topk), result_labels + s * topk,
int_distances.data() + s * topk};
hammings_knn_hc(&res, query_data + s * code_size, binary_chunk, size_per_chunk, code_size,
/* ordered = */ true, bitset);
} else {
hammings_knn_mc(query_data + s * code_size, binary_chunk, nn, size_per_chunk, topk, code_size,
int_distances.data() + s * topk, result_labels + s * topk, bitset);
}
}
for (int i = 0; i < num_queries; ++i) {
result_distances[i] = static_cast<float>(int_distances[i]);
}
} else {
PanicInfo("Unsupported metric type");
}
return sub_result;
}
SubQueryResult
FloatSearchBruteForce(const dataset::QueryDataset& query_dataset,
const void* chunk_data_raw,
int64_t size_per_chunk,
const faiss::BitsetView& bitset) {
auto metric_type = query_dataset.metric_type;
auto num_queries = query_dataset.num_queries;
auto topk = query_dataset.topk;
auto dim = query_dataset.dim;
SubQueryResult sub_qr(num_queries, topk, metric_type);
auto query_data = reinterpret_cast<const float*>(query_dataset.query_data);
auto chunk_data = reinterpret_cast<const float*>(chunk_data_raw);
if (metric_type == MetricType::METRIC_L2) {
faiss::float_maxheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_labels(), sub_qr.get_values()};
faiss::knn_L2sqr(query_data, chunk_data, dim, num_queries, size_per_chunk, &buf, bitset);
return sub_qr;
} else {
faiss::float_minheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_labels(), sub_qr.get_values()};
faiss::knn_inner_product(query_data, chunk_data, dim, num_queries, size_per_chunk, &buf, bitset);
return sub_qr;
}
}
SubQueryResult
BinarySearchBruteForce(const dataset::QueryDataset& query_dataset,
const void* chunk_data_raw,
int64_t size_per_chunk,
const faiss::BitsetView& bitset) {
// TODO: refactor the internal function
auto query_data = reinterpret_cast<const uint8_t*>(query_dataset.query_data);
auto chunk_data = reinterpret_cast<const uint8_t*>(chunk_data_raw);
return BinarySearchBruteForceFast(query_dataset.metric_type, query_dataset.dim, chunk_data, size_per_chunk,
query_dataset.topk, query_dataset.num_queries, query_data, bitset);
}
} // namespace milvus::query