From 3bf205d9a8db3c5430ac686deefb1fb19201dc77 Mon Sep 17 00:00:00 2001 From: FluorineDog Date: Thu, 7 Jan 2021 09:32:17 +0800 Subject: [PATCH] Fix inner product Signed-off-by: FluorineDog --- internal/core/src/common/FieldMeta.h | 7 ++ internal/core/src/common/Types.cpp | 19 +++-- internal/core/src/query/CMakeLists.txt | 1 + internal/core/src/query/Search.cpp | 48 +++++------ internal/core/src/query/SearchBruteForce.cpp | 35 +++++--- internal/core/src/query/SearchBruteForce.h | 21 ++--- internal/core/src/query/SearchOnIndex.cpp | 40 +++++++++ internal/core/src/query/SearchOnIndex.h | 27 ++++++ internal/core/src/query/helper.h | 35 ++++++++ internal/core/src/segcore/ConcurrentVector.h | 6 ++ internal/core/src/segcore/IndexingEntry.cpp | 4 +- internal/core/unittest/test_indexing.cpp | 7 +- internal/core/unittest/test_query.cpp | 88 +++++++++++++++++++- internal/core/unittest/test_utils/DataGen.h | 29 +++++-- 14 files changed, 294 insertions(+), 73 deletions(-) create mode 100644 internal/core/src/query/SearchOnIndex.cpp create mode 100644 internal/core/src/query/SearchOnIndex.h create mode 100644 internal/core/src/query/helper.h diff --git a/internal/core/src/common/FieldMeta.h b/internal/core/src/common/FieldMeta.h index 9486ed2524..00db216cf1 100644 --- a/internal/core/src/common/FieldMeta.h +++ b/internal/core/src/common/FieldMeta.h @@ -106,6 +106,13 @@ struct FieldMeta { return vector_info_->dim_; } + MetricType + get_metric_type() const { + Assert(is_vector()); + Assert(vector_info_.has_value()); + return vector_info_->metric_type_; + } + const std::string& get_name() const { return name_; diff --git a/internal/core/src/common/Types.cpp b/internal/core/src/common/Types.cpp index 17299590da..73b704aeee 100644 --- a/internal/core/src/common/Types.cpp +++ b/internal/core/src/common/Types.cpp @@ -20,24 +20,25 @@ namespace milvus { -using boost::algorithm::to_lower_copy; +using boost::algorithm::to_upper_copy; namespace Metric = knowhere::Metric; static const auto metric_bimap = [] { boost::bimap mapping; using pos = boost::bimap::value_type; - mapping.insert(pos(to_lower_copy(std::string(Metric::L2)), MetricType::METRIC_L2)); - mapping.insert(pos(to_lower_copy(std::string(Metric::IP)), MetricType::METRIC_INNER_PRODUCT)); - mapping.insert(pos(to_lower_copy(std::string(Metric::JACCARD)), MetricType::METRIC_Jaccard)); - mapping.insert(pos(to_lower_copy(std::string(Metric::TANIMOTO)), MetricType::METRIC_Tanimoto)); - mapping.insert(pos(to_lower_copy(std::string(Metric::HAMMING)), MetricType::METRIC_Hamming)); - mapping.insert(pos(to_lower_copy(std::string(Metric::SUBSTRUCTURE)), MetricType::METRIC_Substructure)); - mapping.insert(pos(to_lower_copy(std::string(Metric::SUPERSTRUCTURE)), MetricType::METRIC_Superstructure)); + mapping.insert(pos(std::string(Metric::L2), MetricType::METRIC_L2)); + mapping.insert(pos(std::string(Metric::IP), MetricType::METRIC_INNER_PRODUCT)); + mapping.insert(pos(std::string(Metric::JACCARD), MetricType::METRIC_Jaccard)); + mapping.insert(pos(std::string(Metric::TANIMOTO), MetricType::METRIC_Tanimoto)); + mapping.insert(pos(std::string(Metric::HAMMING), MetricType::METRIC_Hamming)); + mapping.insert(pos(std::string(Metric::SUBSTRUCTURE), MetricType::METRIC_Substructure)); + mapping.insert(pos(std::string(Metric::SUPERSTRUCTURE), MetricType::METRIC_Superstructure)); return mapping; }(); MetricType GetMetricType(const std::string& type_name) { - auto real_name = to_lower_copy(type_name); + // Assume Metric is all upper at knowhere + auto real_name = to_upper_copy(type_name); AssertInfo(metric_bimap.left.count(real_name), "metric type not found: (" + type_name + ")"); return metric_bimap.left.at(real_name); } diff --git a/internal/core/src/query/CMakeLists.txt b/internal/core/src/query/CMakeLists.txt index 7488270fee..d26b2a3a9d 100644 --- a/internal/core/src/query/CMakeLists.txt +++ b/internal/core/src/query/CMakeLists.txt @@ -12,6 +12,7 @@ set(MILVUS_QUERY_SRCS Plan.cpp Search.cpp SearchOnSealed.cpp + SearchOnIndex.cpp SearchBruteForce.cpp SubQueryResult.cpp ) diff --git a/internal/core/src/query/Search.cpp b/internal/core/src/query/Search.cpp index 35c9ec0c51..e27863d258 100644 --- a/internal/core/src/query/Search.cpp +++ b/internal/core/src/query/Search.cpp @@ -17,6 +17,7 @@ #include #include "utils/tools.h" #include "query/SearchBruteForce.h" +#include "query/SearchOnIndex.h" namespace milvus::query { @@ -65,11 +66,14 @@ FloatSearch(const segcore::SegmentSmallIndex& segment, auto dim = field.get_dim(); auto topK = info.topK_; auto total_count = topK * num_queries; + auto metric_type = GetMetricType(info.metric_type_); // TODO: optimize // step 3: small indexing search - std::vector final_uids(total_count, -1); - std::vector final_dis(total_count, std::numeric_limits::max()); + // std::vector final_uids(total_count, -1); + // std::vector final_dis(total_count, std::numeric_limits::max()); + SubQueryResult final_qr(num_queries, topK, metric_type); + dataset::FloatQueryDataset query_dataset{metric_type, num_queries, topK, dim, query_data}; auto max_indexed_id = indexing_record.get_finished_ack(); const auto& indexing_entry = indexing_record.get_vec_entry(vecfield_offset); @@ -77,20 +81,18 @@ FloatSearch(const segcore::SegmentSmallIndex& segment, // TODO: use sub_qr for (int chunk_id = 0; chunk_id < max_indexed_id; ++chunk_id) { + auto bitset = create_bitmap_view(bitmaps_opt, chunk_id); auto indexing = indexing_entry.get_vec_indexing(chunk_id); - auto dataset = knowhere::GenDataset(num_queries, dim, query_data); - auto bitmap_view = create_bitmap_view(bitmaps_opt, chunk_id); - auto ans = indexing->Query(dataset, search_conf, bitmap_view); - auto dis = ans->Get(milvus::knowhere::meta::DISTANCE); - auto uids = ans->Get(milvus::knowhere::meta::IDS); + auto sub_qr = SearchOnIndex(query_dataset, *indexing, search_conf, bitset); + // convert chunk uid to segment uid - for (int64_t i = 0; i < total_count; ++i) { - auto& x = uids[i]; + for (auto& x : sub_qr.mutable_labels()) { if (x != -1) { x += chunk_id * indexing_entry.get_chunk_size(); } } - segcore::merge_into(num_queries, topK, final_dis.data(), final_uids.data(), dis, uids); + + final_qr.merge(sub_qr); } using segcore::FloatVector; auto vec_ptr = record.get_entity(vecfield_offset); @@ -100,37 +102,28 @@ FloatSearch(const segcore::SegmentSmallIndex& segment, Assert(vec_chunk_size == indexing_entry.get_chunk_size()); auto max_chunk = upper_div(ins_barrier, vec_chunk_size); - // TODO: use sub_qr for (int chunk_id = max_indexed_id; chunk_id < max_chunk; ++chunk_id) { - std::vector buf_uids(total_count, -1); - std::vector buf_dis(total_count, std::numeric_limits::max()); + auto bitmap_view = create_bitmap_view(bitmaps_opt, chunk_id); - // should be not visitable - faiss::float_maxheap_array_t buf = {(size_t)num_queries, (size_t)topK, buf_uids.data(), buf_dis.data()}; auto& chunk = vec_ptr->get_chunk(chunk_id); auto element_begin = chunk_id * vec_chunk_size; auto element_end = std::min(ins_barrier, (chunk_id + 1) * vec_chunk_size); + auto chunk_size = element_end - element_begin; - auto nsize = element_end - element_begin; - - auto bitmap_view = create_bitmap_view(bitmaps_opt, chunk_id); - // TODO: make it wrapped - faiss::knn_L2sqr(query_data, chunk.data(), dim, num_queries, nsize, &buf, bitmap_view); - - Assert(buf_uids.size() == total_count); + auto sub_qr = FloatSearchBruteForce(query_dataset, chunk.data(), chunk_size, bitmap_view); // convert chunk uid to segment uid - for (auto& x : buf_uids) { + for (auto& x : sub_qr.mutable_labels()) { if (x != -1) { x += chunk_id * vec_chunk_size; } } - segcore::merge_into(num_queries, topK, final_dis.data(), final_uids.data(), buf_dis.data(), buf_uids.data()); + final_qr.merge(sub_qr); } - results.result_distances_ = std::move(final_dis); - results.internal_seg_offsets_ = std::move(final_uids); + results.result_distances_ = std::move(final_qr.mutable_values()); + results.internal_seg_offsets_ = std::move(final_qr.mutable_labels()); results.topK_ = topK; results.num_queries_ = num_queries; @@ -168,14 +161,13 @@ BinarySearch(const segcore::SegmentSmallIndex& segment, Assert(field.get_data_type() == DataType::VECTOR_BINARY); auto dim = field.get_dim(); - auto code_size = dim / 8; auto topK = info.topK_; auto total_count = topK * num_queries; // step 3: small indexing search // TODO: this is too intrusive // TODO: use QuerySubResult instead - query::dataset::BinaryQueryDataset query_dataset{metric_type, num_queries, topK, code_size, query_data}; + query::dataset::BinaryQueryDataset query_dataset{metric_type, num_queries, topK, dim, query_data}; using segcore::BinaryVector; auto vec_ptr = record.get_entity(vecfield_offset); diff --git a/internal/core/src/query/SearchBruteForce.cpp b/internal/core/src/query/SearchBruteForce.cpp index 4d46a1ee12..13ef843679 100644 --- a/internal/core/src/query/SearchBruteForce.cpp +++ b/internal/core/src/query/SearchBruteForce.cpp @@ -16,11 +16,13 @@ #include #include "SubQueryResult.h" +#include + namespace milvus::query { SubQueryResult BinarySearchBruteForceFast(MetricType metric_type, - int64_t code_size, + int64_t dim, const uint8_t* binary_chunk, int64_t chunk_size, int64_t topk, @@ -31,6 +33,7 @@ BinarySearchBruteForceFast(MetricType metric_type, float* result_distances = sub_result.get_values(); idx_t* result_labels = sub_result.get_labels(); + int64_t code_size = dim / 8; const idx_t block_size = chunk_size; bool use_heap = true; @@ -95,14 +98,26 @@ BinarySearchBruteForceFast(MetricType metric_type, return sub_result; } -void -FloatSearchBruteForceFast(MetricType metric_type, - const float* chunk_data, - int64_t chunk_size, - float* result_distances, - idx_t* result_labels, - const faiss::BitsetView& bitset) { - // TODO +SubQueryResult +FloatSearchBruteForce(const dataset::FloatQueryDataset& query_dataset, + const float* chunk_data, + int64_t chunk_size, + const faiss::BitsetView& bitset) { + auto metric_type = query_dataset.metric_type; + auto num_queries = query_dataset.num_queries; + auto topk = query_dataset.topk; + auto dim = query_dataset.dim; + SubQueryResult sub_qr(num_queries, topk, metric_type); + + if (metric_type == MetricType::METRIC_L2) { + faiss::float_maxheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_labels(), sub_qr.get_values()}; + faiss::knn_L2sqr(query_dataset.query_data, chunk_data, dim, num_queries, chunk_size, &buf, bitset); + return sub_qr; + } else { + faiss::float_minheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_labels(), sub_qr.get_values()}; + faiss::knn_inner_product(query_dataset.query_data, chunk_data, dim, num_queries, chunk_size, &buf, bitset); + return sub_qr; + } } SubQueryResult @@ -111,7 +126,7 @@ BinarySearchBruteForce(const dataset::BinaryQueryDataset& query_dataset, int64_t chunk_size, const faiss::BitsetView& bitset) { // TODO: refactor the internal function - return BinarySearchBruteForceFast(query_dataset.metric_type, query_dataset.code_size, binary_chunk, chunk_size, + return BinarySearchBruteForceFast(query_dataset.metric_type, query_dataset.dim, binary_chunk, chunk_size, query_dataset.topk, query_dataset.num_queries, query_dataset.query_data, bitset); } } // namespace milvus::query diff --git a/internal/core/src/query/SearchBruteForce.h b/internal/core/src/query/SearchBruteForce.h index d8114e19d6..c7e166fbb1 100644 --- a/internal/core/src/query/SearchBruteForce.h +++ b/internal/core/src/query/SearchBruteForce.h @@ -14,25 +14,20 @@ #include "segcore/ConcurrentVector.h" #include "common/Schema.h" #include "query/SubQueryResult.h" +#include "query/helper.h" namespace milvus::query { -using MetricType = faiss::MetricType; - -namespace dataset { -struct BinaryQueryDataset { - MetricType metric_type; - int64_t num_queries; - int64_t topk; - int64_t code_size; - const uint8_t* query_data; -}; - -} // namespace dataset SubQueryResult BinarySearchBruteForce(const dataset::BinaryQueryDataset& query_dataset, const uint8_t* binary_chunk, int64_t chunk_size, - const faiss::BitsetView& bitset = nullptr); + const faiss::BitsetView& bitset); + +SubQueryResult +FloatSearchBruteForce(const dataset::FloatQueryDataset& query_dataset, + const float* chunk_data, + int64_t chunk_size, + const faiss::BitsetView& bitset); } // namespace milvus::query diff --git a/internal/core/src/query/SearchOnIndex.cpp b/internal/core/src/query/SearchOnIndex.cpp new file mode 100644 index 0000000000..db3a2a49c1 --- /dev/null +++ b/internal/core/src/query/SearchOnIndex.cpp @@ -0,0 +1,40 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include "SearchOnIndex.h" +namespace milvus::query { +SubQueryResult +SearchOnIndex(const dataset::FloatQueryDataset& query_dataset, + const knowhere::VecIndex& indexing, + const knowhere::Config& search_conf, + const faiss::BitsetView& bitset) { + auto num_queries = query_dataset.num_queries; + auto topK = query_dataset.topk; + auto dim = query_dataset.dim; + auto metric_type = query_dataset.metric_type; + + auto dataset = knowhere::GenDataset(num_queries, dim, query_dataset.query_data); + + // NOTE: VecIndex Query API forget to add const qualifier + // NOTE: use const_cast as a workaround + auto& indexing_nonconst = const_cast(indexing); + auto ans = indexing_nonconst.Query(dataset, search_conf, bitset); + + auto dis = ans->Get(milvus::knowhere::meta::DISTANCE); + auto uids = ans->Get(milvus::knowhere::meta::IDS); + + SubQueryResult sub_qr(num_queries, topK, metric_type); + std::copy_n(dis, num_queries * topK, sub_qr.get_values()); + std::copy_n(uids, num_queries * topK, sub_qr.get_labels()); + return sub_qr; +} + +} // namespace milvus::query diff --git a/internal/core/src/query/SearchOnIndex.h b/internal/core/src/query/SearchOnIndex.h new file mode 100644 index 0000000000..d46c60c4ee --- /dev/null +++ b/internal/core/src/query/SearchOnIndex.h @@ -0,0 +1,27 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include "query/SubQueryResult.h" +#include "query/helper.h" +#include "knowhere/index/vector_index/VecIndex.h" +#include +#include "utils/Json.h" + +namespace milvus::query { +SubQueryResult +SearchOnIndex(const dataset::FloatQueryDataset& query_dataset, + const knowhere::VecIndex& indexing, + const knowhere::Config& search_conf, + const faiss::BitsetView& bitset); + +} // namespace milvus::query diff --git a/internal/core/src/query/helper.h b/internal/core/src/query/helper.h new file mode 100644 index 0000000000..6df3efe572 --- /dev/null +++ b/internal/core/src/query/helper.h @@ -0,0 +1,35 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once +#include "common/Types.h" + +namespace milvus::query { +namespace dataset { + +struct FloatQueryDataset { + MetricType metric_type; + int64_t num_queries; + int64_t topk; + int64_t dim; + const float* query_data; +}; + +struct BinaryQueryDataset { + MetricType metric_type; + int64_t num_queries; + int64_t topk; + int64_t dim; + const uint8_t* query_data; +}; + +} // namespace dataset +} // namespace milvus::query diff --git a/internal/core/src/segcore/ConcurrentVector.h b/internal/core/src/segcore/ConcurrentVector.h index 1293bb843b..8bc6e9cbac 100644 --- a/internal/core/src/segcore/ConcurrentVector.h +++ b/internal/core/src/segcore/ConcurrentVector.h @@ -22,6 +22,7 @@ #include "utils/EasyAssert.h" #include "utils/tools.h" #include +#include "common/Types.h" namespace milvus::segcore { @@ -213,10 +214,15 @@ class ConcurrentVector : public ConcurrentVectorImpl { class VectorTrait {}; class FloatVector : public VectorTrait { + public: using embedded_type = float; + static constexpr auto metric_type = DataType::VECTOR_FLOAT; }; + class BinaryVector : public VectorTrait { + public: using embedded_type = uint8_t; + static constexpr auto metric_type = DataType::VECTOR_BINARY; }; template <> diff --git a/internal/core/src/segcore/IndexingEntry.cpp b/internal/core/src/segcore/IndexingEntry.cpp index f6d0aeaf94..dea4541992 100644 --- a/internal/core/src/segcore/IndexingEntry.cpp +++ b/internal/core/src/segcore/IndexingEntry.cpp @@ -45,7 +45,7 @@ VecIndexingEntry::get_build_conf() const { return knowhere::Config{{knowhere::meta::DIM, field_meta_.get_dim()}, {knowhere::IndexParams::nlist, 100}, {knowhere::IndexParams::nprobe, 4}, - {knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + {knowhere::Metric::TYPE, MetricTypeToName(field_meta_.get_metric_type())}, {knowhere::meta::DEVICEID, 0}}; } @@ -55,7 +55,7 @@ VecIndexingEntry::get_search_conf(int top_K) const { {knowhere::meta::TOPK, top_K}, {knowhere::IndexParams::nlist, 100}, {knowhere::IndexParams::nprobe, 4}, - {knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + {knowhere::Metric::TYPE, MetricTypeToName(field_meta_.get_metric_type())}, {knowhere::meta::DEVICEID, 0}}; } diff --git a/internal/core/unittest/test_indexing.cpp b/internal/core/unittest/test_indexing.cpp index d5712c1ae6..f937e5012d 100644 --- a/internal/core/unittest/test_indexing.cpp +++ b/internal/core/unittest/test_indexing.cpp @@ -246,17 +246,16 @@ TEST(Indexing, BinaryBruteForce) { schema->AddField("age", DataType::INT64); auto dataset = DataGen(schema, N, 10); auto bin_vec = dataset.get_col(0); - auto line_sizeof = schema->operator[](0).get_sizeof(); - auto query_data = 1024 * line_sizeof + bin_vec.data(); + auto query_data = 1024 * dim / 8 + bin_vec.data(); query::dataset::BinaryQueryDataset query_dataset{ faiss::MetricType::METRIC_Jaccard, // num_queries, // topk, // - line_sizeof, // + dim, // query_data // }; - auto sub_result = query::BinarySearchBruteForce(query_dataset, bin_vec.data(), N); + auto sub_result = query::BinarySearchBruteForce(query_dataset, bin_vec.data(), N, nullptr); QueryResult qr; qr.num_queries_ = num_queries; diff --git a/internal/core/unittest/test_query.cpp b/internal/core/unittest/test_query.cpp index 121f5932ae..8a35806669 100644 --- a/internal/core/unittest/test_query.cpp +++ b/internal/core/unittest/test_query.cpp @@ -312,6 +312,51 @@ TEST(Query, ExecTerm) { // for(auto x: ) } +TEST(Query, ExecEmpty) { + using namespace milvus::query; + using namespace milvus::segcore; + auto schema = std::make_shared(); + schema->AddField("age", DataType::FLOAT); + schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2); + std::string dsl = R"({ + "bool": { + "must": [ + { + "vector": { + "fakevec": { + "metric_type": "L2", + "params": { + "nprobe": 10 + }, + "query": "$0", + "topk": 5 + } + } + } + ] + } + })"; + int64_t N = 1000 * 1000; + auto segment = CreateSegment(schema); + auto plan = CreatePlan(*schema, dsl); + auto num_queries = 5; + auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024); + auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + QueryResult qr; + Timestamp time = 1000000; + std::vector ph_group_arr = {ph_group.get()}; + segment->Search(plan.get(), ph_group_arr.data(), &time, 1, qr); + std::cout << QueryResultToJson(qr); + + for (auto i : qr.internal_seg_offsets_) { + ASSERT_EQ(i, -1); + } + + for (auto v : qr.result_distances_) { + ASSERT_EQ(v, std::numeric_limits::max()); + } +} + TEST(Query, ExecWithoutPredicate) { using namespace milvus::query; using namespace milvus::segcore; @@ -336,13 +381,13 @@ TEST(Query, ExecWithoutPredicate) { ] } })"; + auto plan = CreatePlan(*schema, dsl); int64_t N = 1000 * 1000; auto dataset = DataGen(schema, N); auto segment = CreateSegment(schema); segment->PreInsert(N); segment->Insert(0, N, dataset.row_ids_.data(), dataset.timestamps_.data(), dataset.raw_); - auto plan = CreatePlan(*schema, dsl); auto num_queries = 5; auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024); auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); @@ -397,6 +442,47 @@ TEST(Query, ExecWithoutPredicate) { ASSERT_EQ(json.dump(2), ref.dump(2)); } +TEST(Indexing, InnerProduct) { + int64_t N = 100000; + constexpr auto dim = 16; + constexpr auto topk = 10; + auto num_queries = 5; + auto schema = std::make_shared(); + std::string dsl = R"({ + "bool": { + "must": [ + { + "vector": { + "normalized": { + "metric_type": "IP", + "params": { + "nprobe": 10 + }, + "query": "$0", + "topk": 5 + } + } + } + ] + } + })"; + schema->AddField("normalized", DataType::VECTOR_FLOAT, dim, MetricType::METRIC_INNER_PRODUCT); + auto dataset = DataGen(schema, N); + auto segment = CreateSegment(schema); + auto plan = CreatePlan(*schema, dsl); + segment->PreInsert(N); + segment->Insert(0, N, dataset.row_ids_.data(), dataset.timestamps_.data(), dataset.raw_); + auto col = dataset.get_col(0); + + auto ph_group_raw = CreatePlaceholderGroupFromBlob(num_queries, 16, col.data()); + auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + std::vector ts{(Timestamp)N * 2}; + const auto* ptr = ph_group.get(); + QueryResult qr; + segment->Search(plan.get(), &ptr, ts.data(), 1, qr); + std::cout << QueryResultToJson(qr).dump(2); +} + TEST(Query, FillSegment) { namespace pb = milvus::proto; pb::schema::CollectionSchema proto; diff --git a/internal/core/unittest/test_utils/DataGen.h b/internal/core/unittest/test_utils/DataGen.h index 039274bed4..e26374ab5e 100644 --- a/internal/core/unittest/test_utils/DataGen.h +++ b/internal/core/unittest/test_utils/DataGen.h @@ -16,6 +16,9 @@ #include #include "segcore/SegmentBase.h" #include "Constants.h" +#include +using boost::algorithm::starts_with; + namespace milvus::segcore { struct GeneratedData { @@ -92,11 +95,25 @@ DataGen(SchemaPtr schema, int64_t N, uint64_t seed = 42) { switch (field.get_data_type()) { case engine::DataType::VECTOR_FLOAT: { auto dim = field.get_dim(); - vector data(dim * N); - for (auto& x : data) { - x = distr(er) + offset; + vector final; + bool is_ip = starts_with(field.get_name(), "normalized"); + for (int n = 0; n < N; ++n) { + vector data(dim); + float sum = 0; + for (auto& x : data) { + x = distr(er) + offset; + sum += x * x; + } + if (is_ip) { + sum = sqrt(sum); + for (auto& x : data) { + x /= sum; + } + } + + final.insert(final.end(), data.begin(), data.end()); } - insert_cols(data); + insert_cols(final); break; } case engine::DataType::VECTOR_BINARY: { @@ -111,9 +128,9 @@ DataGen(SchemaPtr schema, int64_t N, uint64_t seed = 42) { } case engine::DataType::INT64: { vector data(N); - int64_t index = 0; // begin with counter - if (field.get_name().rfind("counter", 0) == 0) { + if (starts_with(field.get_name(), "counter")) { + int64_t index = 0; for (auto& x : data) { x = index++; }