From e2f35954d4ecfcc392d47932e035e723fef74147 Mon Sep 17 00:00:00 2001 From: Jiquan Long Date: Wed, 28 Feb 2024 18:31:00 +0800 Subject: [PATCH] enhance: support pattern matching on json field (#30779) issue: https://github.com/milvus-io/milvus/issues/30714 --------- Signed-off-by: longjiquan --- internal/core/src/common/RegexQuery.h | 42 ++++++++ .../core/src/exec/expression/UnaryExpr.cpp | 15 +++ internal/core/unittest/test_regex_query.cpp | 95 ++++++++++++++++++- .../core/unittest/test_regex_query_util.cpp | 64 +++++++++++++ tests/integration/jsonexpr/json_expr_test.go | 26 +++-- 5 files changed, 233 insertions(+), 9 deletions(-) diff --git a/internal/core/src/common/RegexQuery.h b/internal/core/src/common/RegexQuery.h index a7d50956db..47cdd67f15 100644 --- a/internal/core/src/common/RegexQuery.h +++ b/internal/core/src/common/RegexQuery.h @@ -9,7 +9,12 @@ // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License +#pragma once + #include +#include + +#include "common/EasyAssert.h" namespace milvus { std::string @@ -19,4 +24,41 @@ ReplaceUnescapedChars(const std::string& input, std::string TranslatePatternMatchToRegex(const std::string& pattern); + +struct PatternMatchTranslator { + template + inline std::string + operator()(const T& pattern) { + PanicInfo(OpTypeInvalid, + "pattern matching is only supported on string type"); + } +}; + +template <> +inline std::string +PatternMatchTranslator::operator()(const std::string& pattern) { + return TranslatePatternMatchToRegex(pattern); +} + +struct RegexMatcher { + template + inline bool + operator()(const std::regex& reg, const T& operand) { + return false; + } +}; + +template <> +inline bool +RegexMatcher::operator()(const std::regex& reg, + const std::string& operand) { + return std::regex_match(operand, reg); +} + +template <> +inline bool +RegexMatcher::operator()(const std::regex& reg, + const std::string_view& operand) { + return std::regex_match(operand.begin(), operand.end(), reg); +} } // namespace milvus diff --git a/internal/core/src/exec/expression/UnaryExpr.cpp b/internal/core/src/exec/expression/UnaryExpr.cpp index 5effb174fa..e577bacdfb 100644 --- a/internal/core/src/exec/expression/UnaryExpr.cpp +++ b/internal/core/src/exec/expression/UnaryExpr.cpp @@ -333,6 +333,21 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { } break; } + case proto::plan::Match: { + PatternMatchTranslator translator; + RegexMatcher matcher; + auto regex_pattern = translator(val); + std::regex reg(regex_pattern); + for (size_t i = 0; i < size; ++i) { + if constexpr (std::is_same_v) { + res[i] = false; + } else { + UnaryRangeJSONCompare( + matcher(reg, ExprValueType(x.value()))); + } + } + break; + } default: PanicInfo( OpTypeInvalid, diff --git a/internal/core/unittest/test_regex_query.cpp b/internal/core/unittest/test_regex_query.cpp index 17e753f925..485eb13888 100644 --- a/internal/core/unittest/test_regex_query.cpp +++ b/internal/core/unittest/test_regex_query.cpp @@ -38,6 +38,7 @@ GenTestSchema() { auto schema = std::make_shared(); schema->AddDebugField("str", DataType::VARCHAR); schema->AddDebugField("another_str", DataType::VARCHAR); + schema->AddDebugField("json", DataType::JSON); schema->AddDebugField( "fvec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); auto pk = schema->AddDebugField("int64", DataType::INT64); @@ -59,6 +60,13 @@ class GrowingSegmentRegexQueryTest : public ::testing::Test { "abbb", "abcabcabc", }; + raw_json = { + R"({"int":1})", + R"({"float":1.0})", + R"({"str":"aaa"})", + R"({"str":"bbb"})", + R"({"str":"abcabcabc"})", + }; N = 5; uint64_t seed = 19190504; @@ -71,6 +79,16 @@ class GrowingSegmentRegexQueryTest : public ::testing::Test { for (int64_t i = 0; i < N; i++) { str_col->at(i) = raw_str[i]; } + + auto json_col = raw_data.raw_->mutable_fields_data() + ->at(2) + .mutable_scalars() + ->mutable_json_data() + ->mutable_data(); + for (int64_t i = 0; i < N; i++) { + json_col->at(i) = raw_json[i]; + } + seg->PreInsert(N); seg->Insert(0, N, @@ -88,6 +106,7 @@ class GrowingSegmentRegexQueryTest : public ::testing::Test { SegmentGrowingPtr seg; int64_t N; std::vector raw_str; + std::vector raw_json; }; TEST_F(GrowingSegmentRegexQueryTest, RegexQueryOnNonStringField) { @@ -141,6 +160,33 @@ TEST_F(GrowingSegmentRegexQueryTest, RegexQueryOnStringField) { ASSERT_TRUE(final[4]); } +TEST_F(GrowingSegmentRegexQueryTest, RegexQueryOnJsonField) { + std::string operand = "a%"; + const auto& str_meta = schema->operator[](FieldName("json")); + auto column_info = test::GenColumnInfo( + str_meta.get_id().get(), proto::schema::DataType::JSON, false, false); + column_info->add_nested_path("str"); + auto unary_range_expr = test::GenUnaryRangeExpr(OpType::Match, operand); + unary_range_expr->set_allocated_column_info(column_info); + auto expr = test::GenExpr(); + expr->set_allocated_unary_range_expr(unary_range_expr); + + auto parser = ProtoParser(*schema); + auto typed_expr = parser.ParseExprs(*expr); + auto parsed = + std::make_shared(DEFAULT_PLANNODE_ID, typed_expr); + + auto segpromote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*segpromote, MAX_TIMESTAMP); + BitsetType final; + visitor.ExecuteExprNode(parsed, segpromote, N, final); + ASSERT_FALSE(final[0]); + ASSERT_FALSE(final[1]); + ASSERT_TRUE(final[2]); + ASSERT_FALSE(final[3]); + ASSERT_TRUE(final[4]); +} + struct MockStringIndex : index::StringIndexSort { const bool HasRawData() const override { @@ -166,6 +212,13 @@ class SealedSegmentRegexQueryTest : public ::testing::Test { "abbb", "abcabcabc", }; + raw_json = { + R"({"int":1})", + R"({"float":1.0})", + R"({"str":"aaa"})", + R"({"str":"bbb"})", + R"({"str":"abcabcabc"})", + }; N = 5; uint64_t seed = 19190504; auto raw_data = DataGen(schema, N, seed); @@ -180,6 +233,16 @@ class SealedSegmentRegexQueryTest : public ::testing::Test { for (int64_t i = 0; i < N; i++) { str_col->at(i) = raw_str[i]; } + + auto json_col = raw_data.raw_->mutable_fields_data() + ->at(2) + .mutable_scalars() + ->mutable_json_data() + ->mutable_data(); + for (int64_t i = 0; i < N; i++) { + json_col->at(i) = raw_json[i]; + } + SealedLoadFieldData(raw_data, *seg); } @@ -251,6 +314,7 @@ class SealedSegmentRegexQueryTest : public ::testing::Test { int64_t N; std::vector raw_str; std::vector raw_int; + std::vector raw_json; }; TEST_F(SealedSegmentRegexQueryTest, BFRegexQueryOnNonStringField) { @@ -271,9 +335,7 @@ TEST_F(SealedSegmentRegexQueryTest, BFRegexQueryOnNonStringField) { auto segpromote = dynamic_cast(seg.get()); query::ExecPlanNodeVisitor visitor(*segpromote, MAX_TIMESTAMP); BitsetType final; - ASSERT_ANY_THROW( - - visitor.ExecuteExprNode(parsed, segpromote, N, final)); + ASSERT_ANY_THROW(visitor.ExecuteExprNode(parsed, segpromote, N, final)); } TEST_F(SealedSegmentRegexQueryTest, BFRegexQueryOnStringField) { @@ -304,6 +366,33 @@ TEST_F(SealedSegmentRegexQueryTest, BFRegexQueryOnStringField) { ASSERT_TRUE(final[4]); } +TEST_F(SealedSegmentRegexQueryTest, BFRegexQueryOnJsonField) { + std::string operand = "a%"; + const auto& str_meta = schema->operator[](FieldName("json")); + auto column_info = test::GenColumnInfo( + str_meta.get_id().get(), proto::schema::DataType::JSON, false, false); + column_info->add_nested_path("str"); + auto unary_range_expr = test::GenUnaryRangeExpr(OpType::Match, operand); + unary_range_expr->set_allocated_column_info(column_info); + auto expr = test::GenExpr(); + expr->set_allocated_unary_range_expr(unary_range_expr); + + auto parser = ProtoParser(*schema); + auto typed_expr = parser.ParseExprs(*expr); + auto parsed = + std::make_shared(DEFAULT_PLANNODE_ID, typed_expr); + + auto segpromote = dynamic_cast(seg.get()); + query::ExecPlanNodeVisitor visitor(*segpromote, MAX_TIMESTAMP); + BitsetType final; + visitor.ExecuteExprNode(parsed, segpromote, N, final); + ASSERT_FALSE(final[0]); + ASSERT_FALSE(final[1]); + ASSERT_TRUE(final[2]); + ASSERT_FALSE(final[3]); + ASSERT_TRUE(final[4]); +} + TEST_F(SealedSegmentRegexQueryTest, RegexQueryOnIndexedNonStringField) { int64_t operand = 120; const auto& int_meta = schema->operator[](FieldName("another_int64")); diff --git a/internal/core/unittest/test_regex_query_util.cpp b/internal/core/unittest/test_regex_query_util.cpp index f508b003c4..0ba999fec9 100644 --- a/internal/core/unittest/test_regex_query_util.cpp +++ b/internal/core/unittest/test_regex_query_util.cpp @@ -43,3 +43,67 @@ TEST(TranslatePatternMatchToRegexTest, PatternWithRegexChar) { std::string result = milvus::TranslatePatternMatchToRegex(pattern); EXPECT_EQ(result, "abc\\*def\\.ghi\\+"); } + +TEST(PatternMatchTranslatorTest, InvalidTypeTest) { + using namespace milvus; + PatternMatchTranslator translator; + + ASSERT_ANY_THROW(translator(123)); + ASSERT_ANY_THROW(translator(3.14)); + ASSERT_ANY_THROW(translator(true)); +} + +TEST(PatternMatchTranslatorTest, StringTypeTest) { + using namespace milvus; + PatternMatchTranslator translator; + + std::string pattern1 = "abc"; + std::string pattern2 = "xyz"; + std::string pattern3 = "%a_b%"; + + EXPECT_EQ(translator(pattern1), "abc"); + EXPECT_EQ(translator(pattern2), "xyz"); + EXPECT_EQ(translator(pattern3), ".*a.b.*"); +} + +TEST(RegexMatcherTest, DefaultBehaviorTest) { + using namespace milvus; + RegexMatcher matcher; + std::regex pattern("Hello.*"); + + int operand1 = 123; + double operand2 = 3.14; + bool operand3 = true; + + EXPECT_FALSE(matcher(pattern, operand1)); + EXPECT_FALSE(matcher(pattern, operand2)); + EXPECT_FALSE(matcher(pattern, operand3)); +} + +TEST(RegexMatcherTest, StringMatchTest) { + using namespace milvus; + RegexMatcher matcher; + std::regex pattern("Hello.*"); + + std::string str1 = "Hello, World!"; + std::string str2 = "Hi there!"; + std::string str3 = "Hello, OpenAI!"; + + EXPECT_TRUE(matcher(pattern, str1)); + EXPECT_FALSE(matcher(pattern, str2)); + EXPECT_TRUE(matcher(pattern, str3)); +} + +TEST(RegexMatcherTest, StringViewMatchTest) { + using namespace milvus; + RegexMatcher matcher; + std::regex pattern("Hello.*"); + + std::string_view str1 = "Hello, World!"; + std::string_view str2 = "Hi there!"; + std::string_view str3 = "Hello, OpenAI!"; + + EXPECT_TRUE(matcher(pattern, str1)); + EXPECT_FALSE(matcher(pattern, str2)); + EXPECT_TRUE(matcher(pattern, str3)); +} diff --git a/tests/integration/jsonexpr/json_expr_test.go b/tests/integration/jsonexpr/json_expr_test.go index 7e90829900..25d8dcb931 100644 --- a/tests/integration/jsonexpr/json_expr_test.go +++ b/tests/integration/jsonexpr/json_expr_test.go @@ -659,6 +659,26 @@ func (s *JSONExprSuite) checkSearch(collectionName, fieldName string, dim int) { s.doSearch(collectionName, []string{fieldName}, expr, dim, checkFunc) log.Info("like expression run successfully") + expr = `D like "%name-%"` + checkFunc = func(result *milvuspb.SearchResults) { + s.Equal(1, len(result.Results.FieldsData)) + s.Equal(fieldName, result.Results.FieldsData[0].GetFieldName()) + s.Equal(schemapb.DataType_JSON, result.Results.FieldsData[0].GetType()) + s.Equal(10, len(result.Results.FieldsData[0].GetScalars().GetJsonData().GetData())) + } + s.doSearch(collectionName, []string{fieldName}, expr, dim, checkFunc) + log.Info("like expression run successfully") + + expr = `D like "na%me"` + checkFunc = func(result *milvuspb.SearchResults) { + s.Equal(1, len(result.Results.FieldsData)) + s.Equal(fieldName, result.Results.FieldsData[0].GetFieldName()) + s.Equal(schemapb.DataType_JSON, result.Results.FieldsData[0].GetType()) + s.Equal(0, len(result.Results.FieldsData[0].GetScalars().GetJsonData().GetData())) + } + s.doSearch(collectionName, []string{fieldName}, expr, dim, checkFunc) + log.Info("like expression run successfully") + expr = `A in []` checkFunc = func(result *milvuspb.SearchResults) { for _, topk := range result.GetResults().GetTopks() { @@ -700,12 +720,6 @@ func (s *JSONExprSuite) checkSearch(collectionName, fieldName string, dim int) { expr = `A like abc` s.doSearchWithInvalidExpr(collectionName, []string{fieldName}, expr, dim) - expr = `D like "%name-%"` - s.doSearchWithInvalidExpr(collectionName, []string{fieldName}, expr, dim) - - expr = `D like "na%me"` - s.doSearchWithInvalidExpr(collectionName, []string{fieldName}, expr, dim) - expr = `1+5 <= A+1 < 5+10` s.doSearchWithInvalidExpr(collectionName, []string{fieldName}, expr, dim)