From c6a6998ba75afb9251b718bed665a35f24bc419d Mon Sep 17 00:00:00 2001 From: Jiquan Long Date: Wed, 14 Jun 2023 11:44:38 +0800 Subject: [PATCH] Fix term expression on interger overflow case (#24867) Signed-off-by: longjiquan --- internal/core/src/query/PlanProto.cpp | 16 +++++++++++----- internal/core/src/query/Utils.h | 9 +++++++++ .../core/unittest/test_range_search_sort.cpp | 6 ++++-- internal/core/unittest/test_utils.cpp | 14 ++++++++++++++ 4 files changed, 38 insertions(+), 7 deletions(-) diff --git a/internal/core/src/query/PlanProto.cpp b/internal/core/src/query/PlanProto.cpp index 247699b9a0..bba0e1c735 100644 --- a/internal/core/src/query/PlanProto.cpp +++ b/internal/core/src/query/PlanProto.cpp @@ -22,6 +22,7 @@ #include "generated/ExtractInfoExprVisitor.h" #include "generated/ExtractInfoPlanNodeVisitor.h" #include "pb/plan.pb.h" +#include "query/Utils.h" namespace milvus::query { namespace planpb = milvus::proto::plan; @@ -33,25 +34,30 @@ ExtractTermExprImpl(FieldId field_id, const planpb::TermExpr& expr_proto) { static_assert(IsScalar); auto size = expr_proto.values_size(); - std::vector terms(size); + std::vector terms; + terms.reserve(size); auto val_case = proto::plan::GenericValue::ValCase::VAL_NOT_SET; for (int i = 0; i < size; ++i) { auto& value_proto = expr_proto.values(i); if constexpr (std::is_same_v) { Assert(value_proto.val_case() == planpb::GenericValue::kBoolVal); - terms[i] = static_cast(value_proto.bool_val()); + terms.push_back(static_cast(value_proto.bool_val())); val_case = proto::plan::GenericValue::ValCase::kBoolVal; } else if constexpr (std::is_integral_v) { Assert(value_proto.val_case() == planpb::GenericValue::kInt64Val); - terms[i] = static_cast(value_proto.int64_val()); + auto value = value_proto.int64_val(); + if (out_of_range(value)) { + continue; + } + terms.push_back(static_cast(value)); val_case = proto::plan::GenericValue::ValCase::kInt64Val; } else if constexpr (std::is_floating_point_v) { Assert(value_proto.val_case() == planpb::GenericValue::kFloatVal); - terms[i] = static_cast(value_proto.float_val()); + terms.push_back(static_cast(value_proto.float_val())); val_case = proto::plan::GenericValue::ValCase::kFloatVal; } else if constexpr (std::is_same_v) { Assert(value_proto.val_case() == planpb::GenericValue::kStringVal); - terms[i] = static_cast(value_proto.string_val()); + terms.push_back(static_cast(value_proto.string_val())); val_case = proto::plan::GenericValue::ValCase::kStringVal; } else { static_assert(always_false); diff --git a/internal/core/src/query/Utils.h b/internal/core/src/query/Utils.h index 5a3a4ba060..7063a7d11e 100644 --- a/internal/core/src/query/Utils.h +++ b/internal/core/src/query/Utils.h @@ -11,7 +11,9 @@ #pragma once +#include #include + #include "query/Expr.h" #include "common/Utils.h" @@ -50,4 +52,11 @@ Match(const std::string_view& str, PanicInfo("not supported"); } } + +template >> +inline bool +out_of_range(int64_t t) { + return t > std::numeric_limits::max() || + t < std::numeric_limits::min(); +} } // namespace milvus::query diff --git a/internal/core/unittest/test_range_search_sort.cpp b/internal/core/unittest/test_range_search_sort.cpp index c81221af80..f9d1fea36c 100644 --- a/internal/core/unittest/test_range_search_sort.cpp +++ b/internal/core/unittest/test_range_search_sort.cpp @@ -88,7 +88,8 @@ CheckRangeSearchSortResult(int64_t* p_id, auto dist = milvus::GetDatasetDistance(dataset); for (int i = 0; i < n; i++) { AssertInfo(id[i] == p_id[i], "id of range search result not same"); - AssertInfo(dist[i] == p_dist[i], "distance of range search result not same"); + AssertInfo(dist[i] == p_dist[i], + "distance of range search result not same"); } } @@ -166,7 +167,8 @@ INSTANTIATE_TEST_CASE_P(RangeSearchSortParameters, TEST_P(RangeSearchSortTest, CheckRangeSearchSort) { auto res = milvus::ReGenRangeSearchResult(dataset, TOPK, N, metric_type); - auto [p_id, p_dist] = RangeSearchSortResultBF(dataset, TOPK, N, metric_type); + auto [p_id, p_dist] = + RangeSearchSortResultBF(dataset, TOPK, N, metric_type); CheckRangeSearchSortResult(p_id, p_dist, res, N * TOPK); delete[] p_id; delete[] p_dist; diff --git a/internal/core/unittest/test_utils.cpp b/internal/core/unittest/test_utils.cpp index 6615eb170f..e82c1da3ce 100644 --- a/internal/core/unittest/test_utils.cpp +++ b/internal/core/unittest/test_utils.cpp @@ -109,3 +109,17 @@ TEST(Util, GetDeleteBitmap) { del_barrier, N, delete_record, insert_record, query_timestamp); ASSERT_EQ(res_bitmap->bitmap_ptr->count(), 0); } + +TEST(Util, OutOfRange) { + using milvus::query::out_of_range; + + ASSERT_FALSE(out_of_range( + static_cast(std::numeric_limits::max()) - 1)); + ASSERT_FALSE(out_of_range( + static_cast(std::numeric_limits::min()) + 1)); + + ASSERT_TRUE(out_of_range( + static_cast(std::numeric_limits::max()) + 1)); + ASSERT_TRUE(out_of_range( + static_cast(std::numeric_limits::min()) - 1)); +}