diff --git a/internal/core/src/query/visitors/ExecExprVisitor.cpp b/internal/core/src/query/visitors/ExecExprVisitor.cpp index 14790e86d8..0d76f0da30 100644 --- a/internal/core/src/query/visitors/ExecExprVisitor.cpp +++ b/internal/core/src/query/visitors/ExecExprVisitor.cpp @@ -21,6 +21,7 @@ #include #include +#include "arrow/type_fwd.h" #include "common/Json.h" #include "common/Types.h" #include "exceptions/EasyAssert.h" @@ -429,108 +430,83 @@ ExecExprVisitor::ExecUnaryRangeVisitorDispatcherJson(UnaryRangeExpr& expr_raw) auto& nested_path = expr.column_.nested_path; auto field_id = expr.column_.field_id; auto index_func = [=](Index* index) { return TargetBitmap{}; }; + using GetType = + std::conditional_t, + std::string_view, + ExprValueType>; + +#define UnaryRangeJSONCompare(cmp) \ + do { \ + auto x = json.template at(nested_path); \ + if (x.error()) { \ + if constexpr (std::is_same_v) { \ + auto x = json.template at(nested_path); \ + return !x.error() && (cmp); \ + } \ + return false; \ + } \ + return (cmp); \ + } while (false) + +#define UnaryRangeJSONCompareNotEqual(cmp) \ + do { \ + auto x = json.template at(nested_path); \ + if (x.error()) { \ + if constexpr (std::is_same_v) { \ + auto x = json.template at(nested_path); \ + return x.error() || (cmp); \ + } \ + return true; \ + } \ + return (cmp); \ + } while (false) + switch (op) { case OpType::Equal: { auto elem_func = [val, nested_path](const milvus::Json& json) { - using GetType = std::conditional_t< - std::is_same_v, - std::string_view, - ExprValueType>; - auto x = json.template at(nested_path); - if (x.error()) { - return false; - } - return ExprValueType(x.value()) == val; + UnaryRangeJSONCompare(x.value() == val); }; return ExecRangeVisitorImpl( field_id, index_func, elem_func); } case OpType::NotEqual: { auto elem_func = [val, nested_path](const milvus::Json& json) { - using GetType = std::conditional_t< - std::is_same_v, - std::string_view, - ExprValueType>; - auto x = json.template at(nested_path); - if (x.error()) { - return true; - } - return ExprValueType(x.value()) != val; + UnaryRangeJSONCompareNotEqual(x.value() != val); }; return ExecRangeVisitorImpl( field_id, index_func, elem_func); } case OpType::GreaterEqual: { auto elem_func = [val, nested_path](const milvus::Json& json) { - using GetType = std::conditional_t< - std::is_same_v, - std::string_view, - ExprValueType>; - auto x = json.template at(nested_path); - if (x.error()) { - return false; - } - return ExprValueType(x.value()) >= val; + UnaryRangeJSONCompare(x.value() >= val); }; return ExecRangeVisitorImpl( field_id, index_func, elem_func); } case OpType::GreaterThan: { auto elem_func = [val, nested_path](const milvus::Json& json) { - using GetType = std::conditional_t< - std::is_same_v, - std::string_view, - ExprValueType>; - auto x = json.template at(nested_path); - if (x.error()) { - return false; - } - return ExprValueType(x.value()) > val; + UnaryRangeJSONCompare(x.value() > val); }; return ExecRangeVisitorImpl( field_id, index_func, elem_func); } case OpType::LessEqual: { auto elem_func = [val, nested_path](const milvus::Json& json) { - using GetType = std::conditional_t< - std::is_same_v, - std::string_view, - ExprValueType>; - auto x = json.template at(nested_path); - if (x.error()) { - return false; - } - return ExprValueType(x.value()) <= val; + UnaryRangeJSONCompare(x.value() <= val); }; return ExecRangeVisitorImpl( field_id, index_func, elem_func); } case OpType::LessThan: { auto elem_func = [val, nested_path](const milvus::Json& json) { - using GetType = std::conditional_t< - std::is_same_v, - std::string_view, - ExprValueType>; - auto x = json.template at(nested_path); - if (x.error()) { - return false; - } - return ExprValueType(x.value()) < val; + UnaryRangeJSONCompare(x.value() < val); }; return ExecRangeVisitorImpl( field_id, index_func, elem_func); } case OpType::PrefixMatch: { auto elem_func = [val, op, nested_path](const milvus::Json& json) { - using GetType = std::conditional_t< - std::is_same_v, - std::string_view, - ExprValueType>; - auto x = json.template at(nested_path); - if (x.error()) { - return false; - } - return Match(ExprValueType(x.value()), val, op); + UnaryRangeJSONCompare(Match(ExprValueType(x.value()), val, op)); }; return ExecRangeVisitorImpl( field_id, index_func, elem_func); @@ -718,6 +694,32 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson( auto val = expr.value_; auto& nested_path = expr.column_.nested_path; +#define BinaryArithRangeJSONCompare(cmp) \ + do { \ + auto x = json.template at(nested_path); \ + if (x.error()) { \ + if constexpr (std::is_same_v) { \ + auto x = json.template at(nested_path); \ + return !x.error() && (cmp); \ + } \ + return false; \ + } \ + return (cmp); \ + } while (false) + +#define BinaryArithRangeJSONCompareNotEqual(cmp) \ + do { \ + auto x = json.template at(nested_path); \ + if (x.error()) { \ + if constexpr (std::is_same_v) { \ + auto x = json.template at(nested_path); \ + return x.error() || (cmp); \ + } \ + return true; \ + } \ + return (cmp); \ + } while (false) + switch (op) { case OpType::Equal: { switch (arith_op) { @@ -727,9 +729,8 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson( return false; }; auto elem_func = [&](const milvus::Json& json) { - auto x = json.template at(nested_path); - return !x.error() && - ((x.value() + right_operand) == val); + BinaryArithRangeJSONCompare(x.value() + right_operand == + val); }; return ExecDataRangeVisitorImpl( expr.column_.field_id, index_func, elem_func); @@ -740,9 +741,8 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson( return false; }; auto elem_func = [&](const milvus::Json& json) { - auto x = json.template at(nested_path); - return !x.error() && - ((x.value() - right_operand) == val); + BinaryArithRangeJSONCompare(x.value() - right_operand == + val); }; return ExecDataRangeVisitorImpl( expr.column_.field_id, index_func, elem_func); @@ -753,9 +753,8 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson( return false; }; auto elem_func = [&](const milvus::Json& json) { - auto x = json.template at(nested_path); - return !x.error() && - ((x.value() * right_operand) == val); + BinaryArithRangeJSONCompare(x.value() * right_operand == + val); }; return ExecDataRangeVisitorImpl( expr.column_.field_id, index_func, elem_func); @@ -766,9 +765,8 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson( return false; }; auto elem_func = [&](const milvus::Json& json) { - auto x = json.template at(nested_path); - return !x.error() && - ((x.value() / right_operand) == val); + BinaryArithRangeJSONCompare(x.value() / right_operand == + val); }; return ExecDataRangeVisitorImpl( expr.column_.field_id, index_func, elem_func); @@ -779,10 +777,9 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson( return false; }; auto elem_func = [&](const milvus::Json& json) { - auto x = json.template at(nested_path); - return !x.error() && - (static_cast( - fmod(x.value(), right_operand)) == val); + BinaryArithRangeJSONCompare( + static_cast( + fmod(x.value(), right_operand)) == val); }; return ExecDataRangeVisitorImpl( expr.column_.field_id, index_func, elem_func); @@ -800,9 +797,8 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson( return false; }; auto elem_func = [&](const milvus::Json& json) { - auto x = json.template at(nested_path); - return x.error() || - ((x.value() + right_operand) != val); + BinaryArithRangeJSONCompareNotEqual( + x.value() + right_operand != val); }; return ExecDataRangeVisitorImpl( expr.column_.field_id, index_func, elem_func); @@ -813,9 +809,8 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson( return false; }; auto elem_func = [&](const milvus::Json& json) { - auto x = json.template at(nested_path); - return x.error() || - ((x.value() - right_operand) != val); + BinaryArithRangeJSONCompareNotEqual( + x.value() - right_operand != val); }; return ExecDataRangeVisitorImpl( expr.column_.field_id, index_func, elem_func); @@ -826,9 +821,8 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson( return false; }; auto elem_func = [&](const milvus::Json& json) { - auto x = json.template at(nested_path); - return x.error() || - ((x.value() * right_operand) != val); + BinaryArithRangeJSONCompareNotEqual( + x.value() * right_operand != val); }; return ExecDataRangeVisitorImpl( expr.column_.field_id, index_func, elem_func); @@ -839,9 +833,8 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson( return false; }; auto elem_func = [&](const milvus::Json& json) { - auto x = json.template at(nested_path); - return x.error() || - ((x.value() / right_operand) != val); + BinaryArithRangeJSONCompareNotEqual( + x.value() / right_operand != val); }; return ExecDataRangeVisitorImpl( expr.column_.field_id, index_func, elem_func); @@ -852,10 +845,9 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson( return false; }; auto elem_func = [&](const milvus::Json& json) { - auto x = json.template at(nested_path); - return x.error() || - (static_cast( - fmod(x.value(), right_operand)) != val); + BinaryArithRangeJSONCompareNotEqual( + static_cast( + fmod(x.value(), right_operand)) != val); }; return ExecDataRangeVisitorImpl( expr.column_.field_id, index_func, elem_func); @@ -869,7 +861,7 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson( PanicInfo("unsupported range node with arithmetic operation"); } } -} +} // namespace milvus::query #pragma clang diagnostic push #pragma ide diagnostic ignored "Simplify" @@ -931,44 +923,44 @@ ExecExprVisitor::ExecBinaryRangeVisitorDispatcherJson(BinaryRangeExpr& expr_raw) // no json index now auto index_func = [=](Index* index) { return TargetBitmap{}; }; +#define BinaryRangeJSONCompare(cmp) \ + do { \ + auto x = json.template at(expr.column_.nested_path); \ + if (x.error()) { \ + if constexpr (std::is_same_v) { \ + auto x = json.template at(expr.column_.nested_path); \ + if (!x.error()) { \ + auto value = x.value(); \ + return (cmp); \ + } \ + } \ + return false; \ + } \ + auto value = x.value(); \ + return (cmp); \ + } while (false) + if (lower_inclusive && upper_inclusive) { auto elem_func = [&](const milvus::Json& json) { - auto x = json.template at(expr.column_.nested_path); - auto value = x.value(); - return !x.error() && (val1 <= value && value <= val2); + BinaryRangeJSONCompare(val1 <= value && value <= val2); }; return ExecRangeVisitorImpl( expr.column_.field_id, index_func, elem_func); } else if (lower_inclusive && !upper_inclusive) { auto elem_func = [&](const milvus::Json& json) { - auto x = json.template at(nested_path); - if (x.error()) { - return false; - } - auto value = x.value(); - return val1 <= value && value < val2; + BinaryRangeJSONCompare(val1 <= value && value < val2); }; return ExecRangeVisitorImpl( expr.column_.field_id, index_func, elem_func); } else if (!lower_inclusive && upper_inclusive) { auto elem_func = [&](const milvus::Json& json) { - auto x = json.template at(nested_path); - if (x.error()) { - return false; - } - auto value = x.value(); - return !x.error() && (val1 < value && value <= val2); + BinaryRangeJSONCompare(val1 < value && value <= val2); }; return ExecRangeVisitorImpl( expr.column_.field_id, index_func, elem_func); } else { auto elem_func = [&](const milvus::Json& json) { - auto x = json.template at(nested_path); - if (x.error()) { - return false; - } - auto value = x.value(); - return !x.error() && (val1 < value && value < val2); + BinaryRangeJSONCompare(val1 < value && value < val2); }; return ExecRangeVisitorImpl( expr.column_.field_id, index_func, elem_func); diff --git a/internal/core/unittest/test_expr.cpp b/internal/core/unittest/test_expr.cpp index 3090ebc0a2..9145578137 100644 --- a/internal/core/unittest/test_expr.cpp +++ b/internal/core/unittest/test_expr.cpp @@ -369,6 +369,10 @@ TEST(Expr, TestBinaryRangeJSON) { {true, true, 20, 30, {"int"}}, {false, true, 30, 40, {"int"}}, {false, false, 40, 50, {"int"}}, + {true, false, 10, 20, {"double"}}, + {true, true, 20, 30, {"double"}}, + {false, true, 30, 40, {"double"}}, + {false, false, 40, 50, {"double"}}, }; auto schema = std::make_shared(); @@ -422,13 +426,23 @@ TEST(Expr, TestBinaryRangeJSON) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .template at(testcase.nested_path) - .value(); - auto ref = check(val); - ASSERT_EQ(ans, ref) - << val << testcase.lower_inclusive << testcase.lower - << testcase.upper_inclusive << testcase.upper; + if (testcase.nested_path[0] == "int") { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at(testcase.nested_path) + .value(); + auto ref = check(val); + ASSERT_EQ(ans, ref) + << val << testcase.lower_inclusive << testcase.lower + << testcase.upper_inclusive << testcase.upper; + } else { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at(testcase.nested_path) + .value(); + auto ref = check(val); + ASSERT_EQ(ans, ref) + << val << testcase.lower_inclusive << testcase.lower + << testcase.upper_inclusive << testcase.upper; + } } } } @@ -504,6 +518,10 @@ TEST(Expr, TestUnaryRangeJson) { {20, {"int"}}, {30, {"int"}}, {40, {"int"}}, + {10, {"double"}}, + {20, {"double"}}, + {30, {"double"}}, + {40, {"double"}}, }; auto schema = std::make_shared(); @@ -585,11 +603,21 @@ TEST(Expr, TestUnaryRangeJson) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .template at(testcase.nested_path) - .value(); - auto ref = f(val); - ASSERT_EQ(ans, ref); + if (testcase.nested_path[0] == "int") { + auto val = + milvus::Json(simdjson::padded_string(json_col[i])) + .template at(testcase.nested_path) + .value(); + auto ref = f(val); + ASSERT_EQ(ans, ref); + } else { + auto val = + milvus::Json(simdjson::padded_string(json_col[i])) + .template at(testcase.nested_path) + .value(); + auto ref = f(val); + ASSERT_EQ(ans, ref); + } } } } @@ -1738,6 +1766,10 @@ TEST(Expr, TestBinaryArithOpEvalRangeJSON) { {20, 30, OpType::Equal, {"int"}}, {30, 40, OpType::NotEqual, {"int"}}, {40, 50, OpType::NotEqual, {"int"}}, + {10, 20, OpType::Equal, {"double"}}, + {20, 30, OpType::Equal, {"double"}}, + {30, 40, OpType::NotEqual, {"double"}}, + {40, 50, OpType::NotEqual, {"double"}}, }; auto schema = std::make_shared(); @@ -1788,11 +1820,19 @@ TEST(Expr, TestBinaryArithOpEvalRangeJSON) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - auto val = milvus::Json(simdjson::padded_string(json_col[i])) - .template at(testcase.nested_path) - .value(); - auto ref = check(val); - ASSERT_EQ(ans, ref) << testcase.value << " " << val; + if (testcase.nested_path[0] == "int") { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at(testcase.nested_path) + .value(); + auto ref = check(val); + ASSERT_EQ(ans, ref) << testcase.value << " " << val; + } else { + auto val = milvus::Json(simdjson::padded_string(json_col[i])) + .template at(testcase.nested_path) + .value(); + auto ref = check(val); + ASSERT_EQ(ans, ref) << testcase.value << " " << val; + } } } } @@ -1812,6 +1852,10 @@ TEST(Expr, TestBinaryArithOpEvalRangeJSONFloat) { {20, 30, OpType::Equal, {"double"}}, {30, 40, OpType::NotEqual, {"double"}}, {40, 50, OpType::NotEqual, {"double"}}, + {10, 20, OpType::Equal, {"int"}}, + {20, 30, OpType::Equal, {"int"}}, + {30, 40, OpType::NotEqual, {"int"}}, + {40, 50, OpType::NotEqual, {"int"}}, }; auto schema = std::make_shared();