Fix term expression on interger overflow case (#24867)

Signed-off-by: longjiquan <jiquan.long@zilliz.com>
This commit is contained in:
Jiquan Long 2023-06-14 11:44:38 +08:00 committed by GitHub
parent 893c3c0409
commit c6a6998ba7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 38 additions and 7 deletions

View File

@ -22,6 +22,7 @@
#include "generated/ExtractInfoExprVisitor.h"
#include "generated/ExtractInfoPlanNodeVisitor.h"
#include "pb/plan.pb.h"
#include "query/Utils.h"
namespace milvus::query {
namespace planpb = milvus::proto::plan;
@ -33,25 +34,30 @@ ExtractTermExprImpl(FieldId field_id,
const planpb::TermExpr& expr_proto) {
static_assert(IsScalar<T>);
auto size = expr_proto.values_size();
std::vector<T> terms(size);
std::vector<T> terms;
terms.reserve(size);
auto val_case = proto::plan::GenericValue::ValCase::VAL_NOT_SET;
for (int i = 0; i < size; ++i) {
auto& value_proto = expr_proto.values(i);
if constexpr (std::is_same_v<T, bool>) {
Assert(value_proto.val_case() == planpb::GenericValue::kBoolVal);
terms[i] = static_cast<T>(value_proto.bool_val());
terms.push_back(static_cast<T>(value_proto.bool_val()));
val_case = proto::plan::GenericValue::ValCase::kBoolVal;
} else if constexpr (std::is_integral_v<T>) {
Assert(value_proto.val_case() == planpb::GenericValue::kInt64Val);
terms[i] = static_cast<T>(value_proto.int64_val());
auto value = value_proto.int64_val();
if (out_of_range<T>(value)) {
continue;
}
terms.push_back(static_cast<T>(value));
val_case = proto::plan::GenericValue::ValCase::kInt64Val;
} else if constexpr (std::is_floating_point_v<T>) {
Assert(value_proto.val_case() == planpb::GenericValue::kFloatVal);
terms[i] = static_cast<T>(value_proto.float_val());
terms.push_back(static_cast<T>(value_proto.float_val()));
val_case = proto::plan::GenericValue::ValCase::kFloatVal;
} else if constexpr (std::is_same_v<T, std::string>) {
Assert(value_proto.val_case() == planpb::GenericValue::kStringVal);
terms[i] = static_cast<T>(value_proto.string_val());
terms.push_back(static_cast<T>(value_proto.string_val()));
val_case = proto::plan::GenericValue::ValCase::kStringVal;
} else {
static_assert(always_false<T>);

View File

@ -11,7 +11,9 @@
#pragma once
#include <limits>
#include <string>
#include "query/Expr.h"
#include "common/Utils.h"
@ -50,4 +52,11 @@ Match<std::string_view>(const std::string_view& str,
PanicInfo("not supported");
}
}
template <typename T, typename = std::enable_if_t<std::is_integral_v<T>>>
inline bool
out_of_range(int64_t t) {
return t > std::numeric_limits<T>::max() ||
t < std::numeric_limits<T>::min();
}
} // namespace milvus::query

View File

@ -88,7 +88,8 @@ CheckRangeSearchSortResult(int64_t* p_id,
auto dist = milvus::GetDatasetDistance(dataset);
for (int i = 0; i < n; i++) {
AssertInfo(id[i] == p_id[i], "id of range search result not same");
AssertInfo(dist[i] == p_dist[i], "distance of range search result not same");
AssertInfo(dist[i] == p_dist[i],
"distance of range search result not same");
}
}
@ -166,7 +167,8 @@ INSTANTIATE_TEST_CASE_P(RangeSearchSortParameters,
TEST_P(RangeSearchSortTest, CheckRangeSearchSort) {
auto res = milvus::ReGenRangeSearchResult(dataset, TOPK, N, metric_type);
auto [p_id, p_dist] = RangeSearchSortResultBF(dataset, TOPK, N, metric_type);
auto [p_id, p_dist] =
RangeSearchSortResultBF(dataset, TOPK, N, metric_type);
CheckRangeSearchSortResult(p_id, p_dist, res, N * TOPK);
delete[] p_id;
delete[] p_dist;

View File

@ -109,3 +109,17 @@ TEST(Util, GetDeleteBitmap) {
del_barrier, N, delete_record, insert_record, query_timestamp);
ASSERT_EQ(res_bitmap->bitmap_ptr->count(), 0);
}
TEST(Util, OutOfRange) {
using milvus::query::out_of_range;
ASSERT_FALSE(out_of_range<int32_t>(
static_cast<int64_t>(std::numeric_limits<int32_t>::max()) - 1));
ASSERT_FALSE(out_of_range<int32_t>(
static_cast<int64_t>(std::numeric_limits<int32_t>::min()) + 1));
ASSERT_TRUE(out_of_range<int32_t>(
static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1));
ASSERT_TRUE(out_of_range<int32_t>(
static_cast<int64_t>(std::numeric_limits<int32_t>::min()) - 1));
}