From 6954a5ba3e91bd39d42da4ff5e0483c091eb283f Mon Sep 17 00:00:00 2001 From: Jiquan Long Date: Fri, 1 Jul 2022 22:28:23 +0800 Subject: [PATCH] Fix search successfully with invalid metric type (#17977) Signed-off-by: longjiquan --- internal/core/src/query/SearchBruteForce.cpp | 5 +- internal/core/unittest/CMakeLists.txt | 1 + internal/core/unittest/test_bf.cpp | 139 +++++++++++++++++++ internal/core/unittest/test_utils/Distance.h | 35 +++++ 4 files changed, 179 insertions(+), 1 deletion(-) create mode 100644 internal/core/unittest/test_bf.cpp create mode 100644 internal/core/unittest/test_utils/Distance.h diff --git a/internal/core/src/query/SearchBruteForce.cpp b/internal/core/src/query/SearchBruteForce.cpp index 6f0e3cc6b0..446246ce9d 100644 --- a/internal/core/src/query/SearchBruteForce.cpp +++ b/internal/core/src/query/SearchBruteForce.cpp @@ -101,10 +101,13 @@ FloatSearchBruteForce(const dataset::SearchDataset& dataset, faiss::float_maxheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_seg_offsets(), sub_qr.get_distances()}; faiss::knn_L2sqr(query_data, chunk_data, dim, num_queries, size_per_chunk, &buf, nullptr, bitset); - } else { + } else if (metric_type == knowhere::metric::IP) { faiss::float_minheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_seg_offsets(), sub_qr.get_distances()}; faiss::knn_inner_product(query_data, chunk_data, dim, num_queries, size_per_chunk, &buf, bitset); + } else { + std::string msg = "search not support metric type: " + metric_type; + PanicInfo(msg); } sub_qr.round_values(); return sub_qr; diff --git a/internal/core/unittest/CMakeLists.txt b/internal/core/unittest/CMakeLists.txt index 21682877a9..d6fbd6b971 100644 --- a/internal/core/unittest/CMakeLists.txt +++ b/internal/core/unittest/CMakeLists.txt @@ -17,6 +17,7 @@ add_definitions(-DMILVUS_TEST_SEGCORE_YAML_PATH="${CMAKE_SOURCE_DIR}/unittest/te # TODO: better to use ls/find pattern set(MILVUS_TEST_FILES init_gtest.cpp + test_bf.cpp test_binary.cpp test_bitmap.cpp test_bool_index.cpp diff --git a/internal/core/unittest/test_bf.cpp b/internal/core/unittest/test_bf.cpp new file mode 100644 index 0000000000..d9a5c419b3 --- /dev/null +++ b/internal/core/unittest/test_bf.cpp @@ -0,0 +1,139 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include +#include + +#include "query/SearchBruteForce.h" +#include "test_utils/Distance.h" +#include "test_utils/DataGen.h" + +using namespace milvus; +using namespace milvus::segcore; +using namespace milvus::query; + +namespace { + +auto +GenFloatVecs(int dim, int n, const knowhere::MetricType& metric, int seed = 42) { + auto schema = std::make_shared(); + auto fvec = schema->AddDebugField("fvec", DataType::VECTOR_FLOAT, dim, metric); + auto dataset = DataGen(schema, n, seed); + return dataset.get_col(fvec); +} + +// (offset, distance) +std::vector> +Distances(const float* base, + const float* query, // one query. + int nb, + int dim, + const knowhere::MetricType& metric) { + if (metric == knowhere::metric::L2) { + std::vector> res; + for (int i = 0; i < nb; i++) { + res.emplace_back(i, L2(base + i * dim, query, dim)); + } + return res; + } else if (metric == knowhere::metric::IP) { + std::vector> res; + for (int i = 0; i < nb; i++) { + res.emplace_back(i, IP(base + i * dim, query, dim)); + } + return res; + } else { + PanicInfo("invalid metric type"); + } +} + +std::vector +GetOffsets(const std::vector>& tuples, int k) { + std::vector offsets; + for (int i = 0; i < k; i++) { + auto [offset, distance] = tuples[i]; + offsets.push_back(offset); + } + return offsets; +} + +// offsets +std::vector +Ref(const float* base, + const float* query, // one query. + int nb, + int dim, + int topk, + const knowhere::MetricType& metric) { + auto res = Distances(base, query, nb, dim, metric); + std::sort(res.begin(), res.end()); + if (metric == knowhere::metric::L2) { + } else if (metric == knowhere::metric::IP) { + std::reverse(res.begin(), res.end()); + } else { + PanicInfo("invalid metric type"); + } + return GetOffsets(res, topk); +} + +bool +AssertMatch(const std::vector& ref, const int64_t* ans) { + for (int i = 0; i < ref.size(); i++) { + if (ref[i] != ans[i]) { + return false; + } + } + return true; +} + +bool +is_supported_float_metric(const knowhere::MetricType& metric) { + return metric == knowhere::metric::L2 || metric == knowhere::metric::IP; +} + +} // namespace + +class TestFloatSearchBruteForce : public ::testing::Test { + public: + void + Run(int nb, int nq, int topk, int dim, const knowhere::MetricType& metric_type) { + auto bitset = std::make_shared(); + bitset->resize(nb); + auto bitset_view = BitsetView(*bitset); + + auto base = GenFloatVecs(dim, nb, metric_type); + auto query = GenFloatVecs(dim, nq, metric_type); + + dataset::SearchDataset dataset{metric_type, nq, topk, -1, dim, query.data()}; + if (!is_supported_float_metric(metric_type)) { + ASSERT_ANY_THROW(FloatSearchBruteForce(dataset, base.data(), nb, bitset_view)); + return; + } + auto result = FloatSearchBruteForce(dataset, base.data(), nb, bitset_view); + for (int i = 0; i < nq; i++) { + auto ref = Ref(base.data(), query.data() + i * dim, nb, dim, topk, metric_type); + auto ans = result.get_seg_offsets() + i * topk; + AssertMatch(ref, ans); + } + } +}; + +TEST_F(TestFloatSearchBruteForce, L2) { + Run(100, 10, 5, 128, knowhere::metric::L2); +} + +TEST_F(TestFloatSearchBruteForce, IP) { + Run(100, 10, 5, 128, knowhere::metric::IP); +} + +TEST_F(TestFloatSearchBruteForce, NotSupported) { + Run(100, 10, 5, 128, "aaaaaaaaaaaa"); +} diff --git a/internal/core/unittest/test_utils/Distance.h b/internal/core/unittest/test_utils/Distance.h new file mode 100644 index 0000000000..2e339c41cb --- /dev/null +++ b/internal/core/unittest/test_utils/Distance.h @@ -0,0 +1,35 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +namespace { +float +L2(const float* point_a, const float* point_b, int dim) { + float dis = 0; + for (auto i = 0; i < dim; i++) { + auto c_a = point_a[i]; + auto c_b = point_b[i]; + dis += pow(c_b - c_a, 2); + } + return dis; +} + +float +IP(const float* point_a, const float* point_b, int dim) { + float dis = 0; + for (auto i = 0; i < dim; i++) { + auto c_a = point_a[i]; + auto c_b = point_b[i]; + dis += c_a * c_b; + } + return dis; +} + +} // namespace