From 4c5ffc832c3dbdaa846e36fdbb33b438bce80cb9 Mon Sep 17 00:00:00 2001 From: Cai Yudong Date: Wed, 30 Nov 2022 14:53:15 +0800 Subject: [PATCH] Add param check for BruteForceSearch in segcore (#20838) Signed-off-by: yudong.cai Signed-off-by: yudong.cai --- internal/core/src/query/SearchBruteForce.cpp | 12 ++++++++++++ internal/core/src/query/SearchBruteForce.h | 5 +++++ internal/core/src/query/SearchOnGrowing.cpp | 1 + internal/core/src/query/SearchOnSealed.cpp | 4 +++- 4 files changed, 21 insertions(+), 1 deletion(-) diff --git a/internal/core/src/query/SearchBruteForce.cpp b/internal/core/src/query/SearchBruteForce.cpp index 374d02594b..df9fdca28e 100644 --- a/internal/core/src/query/SearchBruteForce.cpp +++ b/internal/core/src/query/SearchBruteForce.cpp @@ -19,6 +19,18 @@ namespace milvus::query { +void +CheckBruteForceSearchParam(const FieldMeta& field, const SearchInfo& search_info) { + auto data_type = field.get_data_type(); + auto& metric_type = search_info.metric_type_; + + AssertInfo(datatype_is_vector(data_type), "[BruteForceSearch] Data type isn't vector type"); + bool is_float_data_type = (data_type == DataType::VECTOR_FLOAT); + bool is_float_metric_type = + IsMetricType(metric_type, knowhere::metric::IP) || IsMetricType(metric_type, knowhere::metric::L2); + AssertInfo(is_float_data_type == is_float_metric_type, "[BruteForceSearch] Data type and metric type mis-match"); +} + SubSearchResult BruteForceSearch(const dataset::SearchDataset& dataset, const void* chunk_data_raw, diff --git a/internal/core/src/query/SearchBruteForce.h b/internal/core/src/query/SearchBruteForce.h index d4cbaaccfc..43e75d59c6 100644 --- a/internal/core/src/query/SearchBruteForce.h +++ b/internal/core/src/query/SearchBruteForce.h @@ -12,11 +12,16 @@ #pragma once #include "common/BitsetView.h" +#include "common/FieldMeta.h" +#include "common/QueryInfo.h" #include "query/SubSearchResult.h" #include "query/helper.h" namespace milvus::query { +void +CheckBruteForceSearchParam(const FieldMeta& field, const SearchInfo& search_info); + SubSearchResult BruteForceSearch(const dataset::SearchDataset& dataset, const void* chunk_data_raw, diff --git a/internal/core/src/query/SearchOnGrowing.cpp b/internal/core/src/query/SearchOnGrowing.cpp index b5f3520afe..d8a3218130 100644 --- a/internal/core/src/query/SearchOnGrowing.cpp +++ b/internal/core/src/query/SearchOnGrowing.cpp @@ -123,6 +123,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, auto size_per_chunk = element_end - element_begin; auto sub_view = bitset.subview(element_begin, size_per_chunk); + CheckBruteForceSearchParam(field, info); auto sub_qr = BruteForceSearch(search_dataset, chunk_data, size_per_chunk, sub_view); // convert chunk uid to segment uid diff --git a/internal/core/src/query/SearchOnSealed.cpp b/internal/core/src/query/SearchOnSealed.cpp index 5772cb30a1..2da53250bd 100644 --- a/internal/core/src/query/SearchOnSealed.cpp +++ b/internal/core/src/query/SearchOnSealed.cpp @@ -86,7 +86,9 @@ SearchOnSealed(const Schema& schema, auto vec_data = record.get_field_data_base(field_id); AssertInfo(vec_data->num_chunk() == 1, "num chunk not equal to 1 for sealed segment"); auto chunk_data = vec_data->get_chunk_data(0); - auto sub_qr = query::BruteForceSearch(dataset, chunk_data, row_count, bitset); + + CheckBruteForceSearchParam(field, search_info); + auto sub_qr = BruteForceSearch(dataset, chunk_data, row_count, bitset); result.distances_ = std::move(sub_qr.mutable_distances()); result.seg_offsets_ = std::move(sub_qr.mutable_seg_offsets());