Fix failed to compare int value with double value (#24229)

Signed-off-by: yah01 <yah2er0ne@outlook.com>
This commit is contained in:
yah01 2023-05-19 12:57:23 +08:00 committed by GitHub
parent bd343550a5
commit c75e7a5d05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 170 additions and 134 deletions

View File

@ -21,6 +21,7 @@
#include <unordered_set>
#include <utility>
#include "arrow/type_fwd.h"
#include "common/Json.h"
#include "common/Types.h"
#include "exceptions/EasyAssert.h"
@ -429,108 +430,83 @@ ExecExprVisitor::ExecUnaryRangeVisitorDispatcherJson(UnaryRangeExpr& expr_raw)
auto& nested_path = expr.column_.nested_path;
auto field_id = expr.column_.field_id;
auto index_func = [=](Index* index) { return TargetBitmap{}; };
using GetType =
std::conditional_t<std::is_same_v<ExprValueType, std::string>,
std::string_view,
ExprValueType>;
#define UnaryRangeJSONCompare(cmp) \
do { \
auto x = json.template at<GetType>(nested_path); \
if (x.error()) { \
if constexpr (std::is_same_v<GetType, int64_t>) { \
auto x = json.template at<double>(nested_path); \
return !x.error() && (cmp); \
} \
return false; \
} \
return (cmp); \
} while (false)
#define UnaryRangeJSONCompareNotEqual(cmp) \
do { \
auto x = json.template at<GetType>(nested_path); \
if (x.error()) { \
if constexpr (std::is_same_v<GetType, int64_t>) { \
auto x = json.template at<double>(nested_path); \
return x.error() || (cmp); \
} \
return true; \
} \
return (cmp); \
} while (false)
switch (op) {
case OpType::Equal: {
auto elem_func = [val, nested_path](const milvus::Json& json) {
using GetType = std::conditional_t<
std::is_same_v<ExprValueType, std::string>,
std::string_view,
ExprValueType>;
auto x = json.template at<GetType>(nested_path);
if (x.error()) {
return false;
}
return ExprValueType(x.value()) == val;
UnaryRangeJSONCompare(x.value() == val);
};
return ExecRangeVisitorImpl<milvus::Json>(
field_id, index_func, elem_func);
}
case OpType::NotEqual: {
auto elem_func = [val, nested_path](const milvus::Json& json) {
using GetType = std::conditional_t<
std::is_same_v<ExprValueType, std::string>,
std::string_view,
ExprValueType>;
auto x = json.template at<GetType>(nested_path);
if (x.error()) {
return true;
}
return ExprValueType(x.value()) != val;
UnaryRangeJSONCompareNotEqual(x.value() != val);
};
return ExecRangeVisitorImpl<milvus::Json>(
field_id, index_func, elem_func);
}
case OpType::GreaterEqual: {
auto elem_func = [val, nested_path](const milvus::Json& json) {
using GetType = std::conditional_t<
std::is_same_v<ExprValueType, std::string>,
std::string_view,
ExprValueType>;
auto x = json.template at<GetType>(nested_path);
if (x.error()) {
return false;
}
return ExprValueType(x.value()) >= val;
UnaryRangeJSONCompare(x.value() >= val);
};
return ExecRangeVisitorImpl<milvus::Json>(
field_id, index_func, elem_func);
}
case OpType::GreaterThan: {
auto elem_func = [val, nested_path](const milvus::Json& json) {
using GetType = std::conditional_t<
std::is_same_v<ExprValueType, std::string>,
std::string_view,
ExprValueType>;
auto x = json.template at<GetType>(nested_path);
if (x.error()) {
return false;
}
return ExprValueType(x.value()) > val;
UnaryRangeJSONCompare(x.value() > val);
};
return ExecRangeVisitorImpl<milvus::Json>(
field_id, index_func, elem_func);
}
case OpType::LessEqual: {
auto elem_func = [val, nested_path](const milvus::Json& json) {
using GetType = std::conditional_t<
std::is_same_v<ExprValueType, std::string>,
std::string_view,
ExprValueType>;
auto x = json.template at<GetType>(nested_path);
if (x.error()) {
return false;
}
return ExprValueType(x.value()) <= val;
UnaryRangeJSONCompare(x.value() <= val);
};
return ExecRangeVisitorImpl<milvus::Json>(
field_id, index_func, elem_func);
}
case OpType::LessThan: {
auto elem_func = [val, nested_path](const milvus::Json& json) {
using GetType = std::conditional_t<
std::is_same_v<ExprValueType, std::string>,
std::string_view,
ExprValueType>;
auto x = json.template at<GetType>(nested_path);
if (x.error()) {
return false;
}
return ExprValueType(x.value()) < val;
UnaryRangeJSONCompare(x.value() < val);
};
return ExecRangeVisitorImpl<milvus::Json>(
field_id, index_func, elem_func);
}
case OpType::PrefixMatch: {
auto elem_func = [val, op, nested_path](const milvus::Json& json) {
using GetType = std::conditional_t<
std::is_same_v<ExprValueType, std::string>,
std::string_view,
ExprValueType>;
auto x = json.template at<GetType>(nested_path);
if (x.error()) {
return false;
}
return Match(ExprValueType(x.value()), val, op);
UnaryRangeJSONCompare(Match(ExprValueType(x.value()), val, op));
};
return ExecRangeVisitorImpl<milvus::Json>(
field_id, index_func, elem_func);
@ -718,6 +694,32 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson(
auto val = expr.value_;
auto& nested_path = expr.column_.nested_path;
#define BinaryArithRangeJSONCompare(cmp) \
do { \
auto x = json.template at<GetType>(nested_path); \
if (x.error()) { \
if constexpr (std::is_same_v<GetType, int64_t>) { \
auto x = json.template at<double>(nested_path); \
return !x.error() && (cmp); \
} \
return false; \
} \
return (cmp); \
} while (false)
#define BinaryArithRangeJSONCompareNotEqual(cmp) \
do { \
auto x = json.template at<GetType>(nested_path); \
if (x.error()) { \
if constexpr (std::is_same_v<GetType, int64_t>) { \
auto x = json.template at<double>(nested_path); \
return x.error() || (cmp); \
} \
return true; \
} \
return (cmp); \
} while (false)
switch (op) {
case OpType::Equal: {
switch (arith_op) {
@ -727,9 +729,8 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson(
return false;
};
auto elem_func = [&](const milvus::Json& json) {
auto x = json.template at<GetType>(nested_path);
return !x.error() &&
((x.value() + right_operand) == val);
BinaryArithRangeJSONCompare(x.value() + right_operand ==
val);
};
return ExecDataRangeVisitorImpl<milvus::Json>(
expr.column_.field_id, index_func, elem_func);
@ -740,9 +741,8 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson(
return false;
};
auto elem_func = [&](const milvus::Json& json) {
auto x = json.template at<GetType>(nested_path);
return !x.error() &&
((x.value() - right_operand) == val);
BinaryArithRangeJSONCompare(x.value() - right_operand ==
val);
};
return ExecDataRangeVisitorImpl<milvus::Json>(
expr.column_.field_id, index_func, elem_func);
@ -753,9 +753,8 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson(
return false;
};
auto elem_func = [&](const milvus::Json& json) {
auto x = json.template at<GetType>(nested_path);
return !x.error() &&
((x.value() * right_operand) == val);
BinaryArithRangeJSONCompare(x.value() * right_operand ==
val);
};
return ExecDataRangeVisitorImpl<milvus::Json>(
expr.column_.field_id, index_func, elem_func);
@ -766,9 +765,8 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson(
return false;
};
auto elem_func = [&](const milvus::Json& json) {
auto x = json.template at<GetType>(nested_path);
return !x.error() &&
((x.value() / right_operand) == val);
BinaryArithRangeJSONCompare(x.value() / right_operand ==
val);
};
return ExecDataRangeVisitorImpl<milvus::Json>(
expr.column_.field_id, index_func, elem_func);
@ -779,10 +777,9 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson(
return false;
};
auto elem_func = [&](const milvus::Json& json) {
auto x = json.template at<GetType>(nested_path);
return !x.error() &&
(static_cast<ExprValueType>(
fmod(x.value(), right_operand)) == val);
BinaryArithRangeJSONCompare(
static_cast<ExprValueType>(
fmod(x.value(), right_operand)) == val);
};
return ExecDataRangeVisitorImpl<milvus::Json>(
expr.column_.field_id, index_func, elem_func);
@ -800,9 +797,8 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson(
return false;
};
auto elem_func = [&](const milvus::Json& json) {
auto x = json.template at<GetType>(nested_path);
return x.error() ||
((x.value() + right_operand) != val);
BinaryArithRangeJSONCompareNotEqual(
x.value() + right_operand != val);
};
return ExecDataRangeVisitorImpl<milvus::Json>(
expr.column_.field_id, index_func, elem_func);
@ -813,9 +809,8 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson(
return false;
};
auto elem_func = [&](const milvus::Json& json) {
auto x = json.template at<GetType>(nested_path);
return x.error() ||
((x.value() - right_operand) != val);
BinaryArithRangeJSONCompareNotEqual(
x.value() - right_operand != val);
};
return ExecDataRangeVisitorImpl<milvus::Json>(
expr.column_.field_id, index_func, elem_func);
@ -826,9 +821,8 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson(
return false;
};
auto elem_func = [&](const milvus::Json& json) {
auto x = json.template at<GetType>(nested_path);
return x.error() ||
((x.value() * right_operand) != val);
BinaryArithRangeJSONCompareNotEqual(
x.value() * right_operand != val);
};
return ExecDataRangeVisitorImpl<milvus::Json>(
expr.column_.field_id, index_func, elem_func);
@ -839,9 +833,8 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson(
return false;
};
auto elem_func = [&](const milvus::Json& json) {
auto x = json.template at<GetType>(nested_path);
return x.error() ||
((x.value() / right_operand) != val);
BinaryArithRangeJSONCompareNotEqual(
x.value() / right_operand != val);
};
return ExecDataRangeVisitorImpl<milvus::Json>(
expr.column_.field_id, index_func, elem_func);
@ -852,10 +845,9 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson(
return false;
};
auto elem_func = [&](const milvus::Json& json) {
auto x = json.template at<GetType>(nested_path);
return x.error() ||
(static_cast<ExprValueType>(
fmod(x.value(), right_operand)) != val);
BinaryArithRangeJSONCompareNotEqual(
static_cast<ExprValueType>(
fmod(x.value(), right_operand)) != val);
};
return ExecDataRangeVisitorImpl<milvus::Json>(
expr.column_.field_id, index_func, elem_func);
@ -869,7 +861,7 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson(
PanicInfo("unsupported range node with arithmetic operation");
}
}
}
} // namespace milvus::query
#pragma clang diagnostic push
#pragma ide diagnostic ignored "Simplify"
@ -931,44 +923,44 @@ ExecExprVisitor::ExecBinaryRangeVisitorDispatcherJson(BinaryRangeExpr& expr_raw)
// no json index now
auto index_func = [=](Index* index) { return TargetBitmap{}; };
#define BinaryRangeJSONCompare(cmp) \
do { \
auto x = json.template at<GetType>(expr.column_.nested_path); \
if (x.error()) { \
if constexpr (std::is_same_v<GetType, int64_t>) { \
auto x = json.template at<double>(expr.column_.nested_path); \
if (!x.error()) { \
auto value = x.value(); \
return (cmp); \
} \
} \
return false; \
} \
auto value = x.value(); \
return (cmp); \
} while (false)
if (lower_inclusive && upper_inclusive) {
auto elem_func = [&](const milvus::Json& json) {
auto x = json.template at<GetType>(expr.column_.nested_path);
auto value = x.value();
return !x.error() && (val1 <= value && value <= val2);
BinaryRangeJSONCompare(val1 <= value && value <= val2);
};
return ExecRangeVisitorImpl<milvus::Json>(
expr.column_.field_id, index_func, elem_func);
} else if (lower_inclusive && !upper_inclusive) {
auto elem_func = [&](const milvus::Json& json) {
auto x = json.template at<GetType>(nested_path);
if (x.error()) {
return false;
}
auto value = x.value();
return val1 <= value && value < val2;
BinaryRangeJSONCompare(val1 <= value && value < val2);
};
return ExecRangeVisitorImpl<milvus::Json>(
expr.column_.field_id, index_func, elem_func);
} else if (!lower_inclusive && upper_inclusive) {
auto elem_func = [&](const milvus::Json& json) {
auto x = json.template at<GetType>(nested_path);
if (x.error()) {
return false;
}
auto value = x.value();
return !x.error() && (val1 < value && value <= val2);
BinaryRangeJSONCompare(val1 < value && value <= val2);
};
return ExecRangeVisitorImpl<milvus::Json>(
expr.column_.field_id, index_func, elem_func);
} else {
auto elem_func = [&](const milvus::Json& json) {
auto x = json.template at<GetType>(nested_path);
if (x.error()) {
return false;
}
auto value = x.value();
return !x.error() && (val1 < value && value < val2);
BinaryRangeJSONCompare(val1 < value && value < val2);
};
return ExecRangeVisitorImpl<milvus::Json>(
expr.column_.field_id, index_func, elem_func);

View File

@ -369,6 +369,10 @@ TEST(Expr, TestBinaryRangeJSON) {
{true, true, 20, 30, {"int"}},
{false, true, 30, 40, {"int"}},
{false, false, 40, 50, {"int"}},
{true, false, 10, 20, {"double"}},
{true, true, 20, 30, {"double"}},
{false, true, 30, 40, {"double"}},
{false, false, 40, 50, {"double"}},
};
auto schema = std::make_shared<Schema>();
@ -422,13 +426,23 @@ TEST(Expr, TestBinaryRangeJSON) {
for (int i = 0; i < N * num_iters; ++i) {
auto ans = final[i];
auto val = milvus::Json(simdjson::padded_string(json_col[i]))
.template at<int64_t>(testcase.nested_path)
.value();
auto ref = check(val);
ASSERT_EQ(ans, ref)
<< val << testcase.lower_inclusive << testcase.lower
<< testcase.upper_inclusive << testcase.upper;
if (testcase.nested_path[0] == "int") {
auto val = milvus::Json(simdjson::padded_string(json_col[i]))
.template at<int64_t>(testcase.nested_path)
.value();
auto ref = check(val);
ASSERT_EQ(ans, ref)
<< val << testcase.lower_inclusive << testcase.lower
<< testcase.upper_inclusive << testcase.upper;
} else {
auto val = milvus::Json(simdjson::padded_string(json_col[i]))
.template at<double>(testcase.nested_path)
.value();
auto ref = check(val);
ASSERT_EQ(ans, ref)
<< val << testcase.lower_inclusive << testcase.lower
<< testcase.upper_inclusive << testcase.upper;
}
}
}
}
@ -504,6 +518,10 @@ TEST(Expr, TestUnaryRangeJson) {
{20, {"int"}},
{30, {"int"}},
{40, {"int"}},
{10, {"double"}},
{20, {"double"}},
{30, {"double"}},
{40, {"double"}},
};
auto schema = std::make_shared<Schema>();
@ -585,11 +603,21 @@ TEST(Expr, TestUnaryRangeJson) {
for (int i = 0; i < N * num_iters; ++i) {
auto ans = final[i];
auto val = milvus::Json(simdjson::padded_string(json_col[i]))
.template at<int64_t>(testcase.nested_path)
.value();
auto ref = f(val);
ASSERT_EQ(ans, ref);
if (testcase.nested_path[0] == "int") {
auto val =
milvus::Json(simdjson::padded_string(json_col[i]))
.template at<int64_t>(testcase.nested_path)
.value();
auto ref = f(val);
ASSERT_EQ(ans, ref);
} else {
auto val =
milvus::Json(simdjson::padded_string(json_col[i]))
.template at<double>(testcase.nested_path)
.value();
auto ref = f(val);
ASSERT_EQ(ans, ref);
}
}
}
}
@ -1738,6 +1766,10 @@ TEST(Expr, TestBinaryArithOpEvalRangeJSON) {
{20, 30, OpType::Equal, {"int"}},
{30, 40, OpType::NotEqual, {"int"}},
{40, 50, OpType::NotEqual, {"int"}},
{10, 20, OpType::Equal, {"double"}},
{20, 30, OpType::Equal, {"double"}},
{30, 40, OpType::NotEqual, {"double"}},
{40, 50, OpType::NotEqual, {"double"}},
};
auto schema = std::make_shared<Schema>();
@ -1788,11 +1820,19 @@ TEST(Expr, TestBinaryArithOpEvalRangeJSON) {
for (int i = 0; i < N * num_iters; ++i) {
auto ans = final[i];
auto val = milvus::Json(simdjson::padded_string(json_col[i]))
.template at<int64_t>(testcase.nested_path)
.value();
auto ref = check(val);
ASSERT_EQ(ans, ref) << testcase.value << " " << val;
if (testcase.nested_path[0] == "int") {
auto val = milvus::Json(simdjson::padded_string(json_col[i]))
.template at<int64_t>(testcase.nested_path)
.value();
auto ref = check(val);
ASSERT_EQ(ans, ref) << testcase.value << " " << val;
} else {
auto val = milvus::Json(simdjson::padded_string(json_col[i]))
.template at<double>(testcase.nested_path)
.value();
auto ref = check(val);
ASSERT_EQ(ans, ref) << testcase.value << " " << val;
}
}
}
}
@ -1812,6 +1852,10 @@ TEST(Expr, TestBinaryArithOpEvalRangeJSONFloat) {
{20, 30, OpType::Equal, {"double"}},
{30, 40, OpType::NotEqual, {"double"}},
{40, 50, OpType::NotEqual, {"double"}},
{10, 20, OpType::Equal, {"int"}},
{20, 30, OpType::Equal, {"int"}},
{30, 40, OpType::NotEqual, {"int"}},
{40, 50, OpType::NotEqual, {"int"}},
};
auto schema = std::make_shared<Schema>();