diff --git a/internal/core/src/query/visitors/ExecExprVisitor.cpp b/internal/core/src/query/visitors/ExecExprVisitor.cpp index 2b987efb71..3b6887ad89 100644 --- a/internal/core/src/query/visitors/ExecExprVisitor.cpp +++ b/internal/core/src/query/visitors/ExecExprVisitor.cpp @@ -2067,10 +2067,21 @@ ExecExprVisitor::visit(AlwaysTrueExpr& expr) { bitset_opt_ = std::move(res); } +template bool -compareTwoJsonArray(simdjson::simdjson_result arr1, - const proto::plan::Array& arr2) { - if (arr2.array_size() != arr1.count_elements()) { +compareTwoJsonArray(T arr1, const proto::plan::Array& arr2) { + int json_array_length = 0; + if constexpr (std::is_same_v< + T, + simdjson::simdjson_result>) { + json_array_length = arr1.count_elements(); + } + if constexpr (std::is_same_v>>) { + json_array_length = arr1.size(); + } + if (arr2.array_size() != json_array_length) { return false; } int i = 0; @@ -2165,13 +2176,19 @@ ExecExprVisitor::ExecJsonContainsArray(JsonContainsExpr& expr_raw) if (array.error()) { return false; } - for (auto const& element : elements) { - for (auto&& it : array) { - auto val = it.get_array(); - if (val.error()) { - continue; - } - if (compareTwoJsonArray(val, element)) { + for (auto&& it : array) { + auto val = it.get_array(); + if (val.error()) { + continue; + } + std::vector> + json_array; + json_array.reserve(val.count_elements()); + for (auto&& e : val) { + json_array.emplace_back(e); + } + for (auto const& element : elements) { + if (compareTwoJsonArray(json_array, element)) { return true; } } @@ -2322,33 +2339,39 @@ ExecExprVisitor::ExecJsonContainsAllArray(JsonContainsExpr& expr_raw) elements_index.insert(i); i++; } - auto elem_func = - [&elements, &elements_index, &pointer](const milvus::Json& json) { - auto doc = json.doc(); - auto array = doc.at_pointer(pointer).get_array(); - if (array.error()) { - return false; + auto elem_func = [&elements, &elements_index, &pointer]( + const milvus::Json& json) { + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (array.error()) { + return false; + } + std::unordered_set tmp_elements_index(elements_index); + for (auto&& it : array) { + auto val = it.get_array(); + if (val.error()) { + continue; } - std::unordered_set tmp_elements_index(elements_index); - for (auto&& it : array) { - auto val = it.get_array(); - if (val.error()) { - continue; - } - int i = -1; - for (auto const& element : elements) { - i++; - if (compareTwoJsonArray(val, element)) { - tmp_elements_index.erase(i); - break; - } - } - if (tmp_elements_index.size() == 0) { - return true; + std::vector> + json_array; + json_array.reserve(val.count_elements()); + for (auto&& e : val) { + json_array.emplace_back(e); + } + for (auto index : tmp_elements_index) { + if (compareTwoJsonArray(json_array, elements[index])) { + tmp_elements_index.erase(index); + // TODO: construct array set. + // prevent expression json_contains_all(json_array, [[1,2], [3,4], [1,2]]) being unsuccessful + // break; } } - return tmp_elements_index.size() == 0; - }; + if (tmp_elements_index.size() == 0) { + return true; + } + } + return tmp_elements_index.size() == 0; + }; return ExecRangeVisitorImpl( expr.column_.field_id, index_func, elem_func); diff --git a/internal/core/unittest/test_expr.cpp b/internal/core/unittest/test_expr.cpp index 15fae9afef..1998d32afd 100644 --- a/internal/core/unittest/test_expr.cpp +++ b/internal/core/unittest/test_expr.cpp @@ -3997,6 +3997,164 @@ TEST(Expr, TestJsonContainsArray) { ASSERT_EQ(ans, check(res, i)); } } + + proto::plan::Array sub_arr1; + sub_arr1.set_same_type(true); + proto::plan::GenericValue int_val11; + int_val11.set_int64_val(int64_t(1)); + sub_arr1.add_array()->CopyFrom(int_val11); + + proto::plan::GenericValue int_val12; + int_val12.set_int64_val(int64_t(2)); + sub_arr1.add_array()->CopyFrom(int_val12); + + proto::plan::Array sub_arr2; + sub_arr2.set_same_type(true); + proto::plan::GenericValue int_val21; + int_val21.set_int64_val(int64_t(3)); + sub_arr2.add_array()->CopyFrom(int_val21); + + proto::plan::GenericValue int_val22; + int_val22.set_int64_val(int64_t(4)); + sub_arr2.add_array()->CopyFrom(int_val22); + std::vector> diff_testcases2{{{sub_arr1, sub_arr2}, {"array2"}}}; + + for (auto& testcase : diff_testcases2) { + auto check = [&](const std::vector& values, int i) { + return true; + }; + RetrievePlanNode plan; + auto pointer = milvus::Json::pointer(testcase.nested_path); + plan.predicate_ = + std::make_unique>( + ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), + testcase.term, + true, + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + proto::plan::GenericValue::ValCase::kArrayVal); + auto start = std::chrono::steady_clock::now(); + auto final = visitor.call_child(*plan.predicate_.value()); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + std::vector res; + ASSERT_EQ(ans, check(res, i)); + } + } + + for (auto& testcase : diff_testcases2) { + auto check = [&](const std::vector& values, int i) { + return true; + }; + RetrievePlanNode plan; + auto pointer = milvus::Json::pointer(testcase.nested_path); + plan.predicate_ = + std::make_unique>( + ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), + testcase.term, + true, + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + proto::plan::GenericValue::ValCase::kArrayVal); + auto start = std::chrono::steady_clock::now(); + auto final = visitor.call_child(*plan.predicate_.value()); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + std::vector res; + ASSERT_EQ(ans, check(res, i)); + } + } + + proto::plan::Array sub_arr3; + sub_arr3.set_same_type(true); + proto::plan::GenericValue int_val31; + int_val31.set_int64_val(int64_t(5)); + sub_arr3.add_array()->CopyFrom(int_val31); + + proto::plan::GenericValue int_val32; + int_val32.set_int64_val(int64_t(6)); + sub_arr3.add_array()->CopyFrom(int_val32); + + proto::plan::Array sub_arr4; + sub_arr4.set_same_type(true); + proto::plan::GenericValue int_val41; + int_val41.set_int64_val(int64_t(7)); + sub_arr4.add_array()->CopyFrom(int_val41); + + proto::plan::GenericValue int_val42; + int_val42.set_int64_val(int64_t(8)); + sub_arr4.add_array()->CopyFrom(int_val42); + std::vector> diff_testcases3{{{sub_arr3, sub_arr4}, {"array2"}}}; + + for (auto& testcase : diff_testcases2) { + auto check = [&](const std::vector& values, int i) { + return true; + }; + RetrievePlanNode plan; + auto pointer = milvus::Json::pointer(testcase.nested_path); + plan.predicate_ = + std::make_unique>( + ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), + testcase.term, + true, + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + proto::plan::GenericValue::ValCase::kArrayVal); + auto start = std::chrono::steady_clock::now(); + auto final = visitor.call_child(*plan.predicate_.value()); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + std::vector res; + ASSERT_EQ(ans, check(res, i)); + } + } + + for (auto& testcase : diff_testcases2) { + auto check = [&](const std::vector& values, int i) { + return true; + }; + RetrievePlanNode plan; + auto pointer = milvus::Json::pointer(testcase.nested_path); + plan.predicate_ = + std::make_unique>( + ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), + testcase.term, + true, + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + proto::plan::GenericValue::ValCase::kArrayVal); + auto start = std::chrono::steady_clock::now(); + auto final = visitor.call_child(*plan.predicate_.value()); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + std::vector res; + ASSERT_EQ(ans, check(res, i)); + } + } } TEST(Expr, TestJsonContainsDiffType) { diff --git a/internal/core/unittest/test_utils/DataGen.h b/internal/core/unittest/test_utils/DataGen.h index 349594fe39..df2c3bdd84 100644 --- a/internal/core/unittest/test_utils/DataGen.h +++ b/internal/core/unittest/test_utils/DataGen.h @@ -420,7 +420,8 @@ DataGenForJsonArray(SchemaPtr schema, R"(],"double":[)" + join(doubleVec, ",") + R"(],"string":[)" + join(stringVec, ",") + R"(],"bool": [)" + join(boolVec, ",") + - R"(],"array": [)" + join(arrayVec, ",") + "]}"; + R"(],"array": [)" + join(arrayVec, ",") + + R"(],"array2": [[1,2], [3,4]])" + "}"; //std::cout << str << std::endl; data[i] = str; }