milvus/internal/core/unittest/test_bf_sparse.cpp
Buqian Zheng 96cfae55a5
feat: [Sparse Float Vector] segcore to support sparse vector search and get raw vector by id (#30629)
This PR adds the ability to search/get sparse float vectors in segcore,
and added unit tests by modifying lots of existing tests into
parameterized ones.

https://github.com/milvus-io/milvus/issues/29419

Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
2024-03-12 09:16:30 -07:00

116 lines
3.8 KiB
C++

// Copyright (C) 2019-2024 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <gtest/gtest.h>
#include <random>
#include "common/Utils.h"
#include "query/SearchBruteForce.h"
#include "test_utils/Constants.h"
#include "test_utils/Distance.h"
#include "test_utils/DataGen.h"
using namespace milvus;
using namespace milvus::segcore;
using namespace milvus::query;
namespace {
std::vector<int>
Ref(const knowhere::sparse::SparseRow<float>* base,
const knowhere::sparse::SparseRow<float>& query,
int nb,
int topk,
const knowhere::MetricType& metric) {
std::vector<std::tuple<float, int>> res;
for (int i = 0; i < nb; i++) {
auto& row = base[i];
auto distance = row.dot(query);
res.emplace_back(-distance, i);
}
std::sort(res.begin(), res.end());
std::vector<int> offsets;
for (int i = 0; i < topk; i++) {
auto [distance, offset] = res[i];
if (distance == 0) {
distance = std::numeric_limits<float>::quiet_NaN();
offset = -1;
}
offsets.push_back(offset);
}
return offsets;
}
void
AssertMatch(const std::vector<int>& expected, const int64_t* actual) {
for (int i = 0; i < expected.size(); i++) {
ASSERT_EQ(expected[i], actual[i]);
}
}
bool
is_supported_sparse_float_metric(const std::string& metric) {
return milvus::IsMetricType(metric, knowhere::metric::IP);
}
} // namespace
class TestSparseFloatSearchBruteForce : public ::testing::Test {
public:
void
Run(int nb, int nq, int topk, const knowhere::MetricType& metric_type) {
auto bitset = std::make_shared<BitsetType>();
bitset->resize(nb);
auto bitset_view = BitsetView(*bitset);
auto base = milvus::segcore::GenerateRandomSparseFloatVector(nb);
auto query = milvus::segcore::GenerateRandomSparseFloatVector(nq);
SearchInfo search_info;
search_info.topk_ = topk;
search_info.metric_type_ = metric_type;
dataset::SearchDataset dataset{
metric_type, nq, topk, -1, kTestSparseDim, query.get()};
if (!is_supported_sparse_float_metric(metric_type)) {
ASSERT_ANY_THROW(BruteForceSearch(dataset,
base.get(),
nb,
search_info,
bitset_view,
DataType::VECTOR_SPARSE_FLOAT));
return;
}
auto result = BruteForceSearch(dataset,
base.get(),
nb,
search_info,
bitset_view,
DataType::VECTOR_SPARSE_FLOAT);
for (int i = 0; i < nq; i++) {
auto ref =
Ref(base.get(), *(query.get() + i), nb, topk, metric_type);
auto ans = result.get_seg_offsets() + i * topk;
AssertMatch(ref, ans);
}
}
};
TEST_F(TestSparseFloatSearchBruteForce, NotSupported) {
Run(100, 10, 5, "L2");
Run(100, 10, 5, "l2");
Run(100, 10, 5, "lxxx");
}
TEST_F(TestSparseFloatSearchBruteForce, IP) {
Run(100, 10, 5, "IP");
Run(100, 10, 5, "ip");
}