diff --git a/internal/core/src/common/RegexQuery.cpp b/internal/core/src/common/RegexQuery.cpp index 94b65e681b..9fe99022de 100644 --- a/internal/core/src/common/RegexQuery.cpp +++ b/internal/core/src/common/RegexQuery.cpp @@ -10,49 +10,54 @@ // or implied. See the License for the specific language governing permissions and limitations under the License #include -#include #include "common/RegexQuery.h" namespace milvus { -std::string -ReplaceUnescapedChars(const std::string& input, - char src, - const std::string& replacement) { - std::string result; - bool escapeMode = false; - for (char c : input) { - if (escapeMode) { - result += '\\'; - result += c; - escapeMode = false; +bool +is_special(char c) { + // initial special_bytes_bitmap only once. + static std::once_flag _initialized; + static std::string special_bytes(R"(\.+*?()|[]{}^$)"); + static std::vector special_bytes_bitmap; + std::call_once(_initialized, []() -> void { + special_bytes_bitmap.resize(256); + for (char b : special_bytes) { + special_bytes_bitmap[b + 128] = true; + } + }); + + return special_bytes_bitmap[c + 128]; +} + +std::string +translate_pattern_match_to_regex(const std::string& pattern) { + std::string r; + r.reserve(2 * pattern.size()); + bool escape_mode = false; + for (char c : pattern) { + if (escape_mode) { + if (is_special(c)) { + r += '\\'; + } + r += c; + escape_mode = false; } else { if (c == '\\') { - escapeMode = true; - } else if (c == src) { - result += replacement; + escape_mode = true; + } else if (c == '%') { + r += "[\\s\\S]*"; + } else if (c == '_') { + r += "[\\s\\S]"; } else { - result += c; + if (is_special(c)) { + r += '\\'; + } + r += c; } } } - - return result; -} - -std::string -TranslatePatternMatchToRegex(const std::string& pattern) { - std::string regex_pattern; -#if 0 - regex_pattern = R"([\.\*\+\?\|\(\)\[\]\{\}\\])"; -#else - regex_pattern = R"([\.\*\+\?\|\(\)\[\]\{\}])"; -#endif - std::string regex = - std::regex_replace(pattern, std::regex(regex_pattern), R"(\$&)"); - regex = ReplaceUnescapedChars(regex, '%', ".*"); - regex = ReplaceUnescapedChars(regex, '_', "."); - return regex; + return r; } } // namespace milvus diff --git a/internal/core/src/common/RegexQuery.h b/internal/core/src/common/RegexQuery.h index 47cdd67f15..4cfcde7e14 100644 --- a/internal/core/src/common/RegexQuery.h +++ b/internal/core/src/common/RegexQuery.h @@ -13,17 +13,17 @@ #include #include +#include +#include #include "common/EasyAssert.h" namespace milvus { -std::string -ReplaceUnescapedChars(const std::string& input, - char src, - const std::string& replacement); +bool +is_special(char c); std::string -TranslatePatternMatchToRegex(const std::string& pattern); +translate_pattern_match_to_regex(const std::string& pattern); struct PatternMatchTranslator { template @@ -37,28 +37,40 @@ struct PatternMatchTranslator { template <> inline std::string PatternMatchTranslator::operator()(const std::string& pattern) { - return TranslatePatternMatchToRegex(pattern); + return translate_pattern_match_to_regex(pattern); } struct RegexMatcher { template inline bool - operator()(const std::regex& reg, const T& operand) { + operator()(const T& operand) { return false; } + + explicit RegexMatcher(const std::string& pattern) { + r_ = boost::regex(pattern); + } + + private: + // avoid to construct the regex everytime. + boost::regex r_; }; template <> inline bool -RegexMatcher::operator()(const std::regex& reg, - const std::string& operand) { - return std::regex_match(operand, reg); +RegexMatcher::operator()(const std::string& operand) { + // corner case: + // . don't match \n, but .* match \n. + // For example, + // boost::regex_match("Hello\n", boost::regex("Hello.")) returns false + // but + // boost::regex_match("Hello\n", boost::regex("Hello.*")) returns true + return boost::regex_match(operand, r_); } template <> inline bool -RegexMatcher::operator()(const std::regex& reg, - const std::string_view& operand) { - return std::regex_match(operand.begin(), operand.end(), reg); +RegexMatcher::operator()(const std::string_view& operand) { + return boost::regex_match(operand.begin(), operand.end(), r_); } } // namespace milvus diff --git a/internal/core/src/exec/expression/UnaryExpr.cpp b/internal/core/src/exec/expression/UnaryExpr.cpp index 305fd1caef..f780ec487b 100644 --- a/internal/core/src/exec/expression/UnaryExpr.cpp +++ b/internal/core/src/exec/expression/UnaryExpr.cpp @@ -335,15 +335,14 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { } case proto::plan::Match: { PatternMatchTranslator translator; - RegexMatcher matcher; auto regex_pattern = translator(val); - std::regex reg(regex_pattern); + RegexMatcher matcher(regex_pattern); for (size_t i = 0; i < size; ++i) { if constexpr (std::is_same_v) { res[i] = false; } else { UnaryRangeJSONCompare( - matcher(reg, ExprValueType(x.value()))); + matcher(ExprValueType(x.value()))); } } break; diff --git a/internal/core/src/exec/expression/UnaryExpr.h b/internal/core/src/exec/expression/UnaryExpr.h index cd26018360..e6342eda86 100644 --- a/internal/core/src/exec/expression/UnaryExpr.h +++ b/internal/core/src/exec/expression/UnaryExpr.h @@ -43,21 +43,11 @@ struct UnaryElementFuncForMatch { size_t size, IndexInnerType val, TargetBitmapView res) { - if constexpr (std::is_same_v) { - // translate the pattern match in advance, which avoid computing it every loop. - std::regex reg(TranslatePatternMatchToRegex(val)); - for (int i = 0; i < size; ++i) { - res[i] = - std::regex_match(std::begin(src[i]), std::end(src[i]), reg); - } - } else if constexpr (std::is_same_v) { - // translate the pattern match in advance, which avoid computing it every loop. - std::regex reg(TranslatePatternMatchToRegex(val)); - for (int i = 0; i < size; ++i) { - res[i] = std::regex_match(src[i], reg); - } - } else { - PanicInfo(Unsupported, "regex query is only supported on string"); + PatternMatchTranslator translator; + auto regex_pattern = translator(val); + RegexMatcher matcher(regex_pattern); + for (int i = 0; i < size; ++i) { + res[i] = matcher(src[i]); } } }; @@ -216,9 +206,12 @@ struct UnaryIndexFuncForMatch { !std::is_same_v) { PanicInfo(Unsupported, "regex query is only supported on string"); } else { - auto reg = TranslatePatternMatchToRegex(val); + PatternMatchTranslator translator; + auto regex_pattern = translator(val); + RegexMatcher matcher(regex_pattern); + if (index->SupportRegexQuery()) { - return index->RegexQuery(reg); + return index->RegexQuery(regex_pattern); } if (!index->HasRawData()) { PanicInfo(Unsupported, @@ -228,11 +221,10 @@ struct UnaryIndexFuncForMatch { // retrieve raw data to do brute force query, may be very slow. auto cnt = index->Count(); - std::regex r(reg); TargetBitmap res(cnt); for (int64_t i = 0; i < cnt; i++) { auto raw = index->Reverse_Lookup(i); - res[i] = std::regex_match(raw, r); + res[i] = matcher(raw); } return res; } diff --git a/internal/core/unittest/test_regex_query.cpp b/internal/core/unittest/test_regex_query.cpp index 485eb13888..cfafe95084 100644 --- a/internal/core/unittest/test_regex_query.cpp +++ b/internal/core/unittest/test_regex_query.cpp @@ -54,11 +54,11 @@ class GrowingSegmentRegexQueryTest : public ::testing::Test { schema = GenTestSchema(); seg = CreateGrowingSegment(schema, empty_index_meta); raw_str = { - "b", - "a", - "aaa", - "abbb", - "abcabcabc", + "b\n", + "a\n", + "aaa\n", + "abbb\n", + "abcabcabc\n", }; raw_json = { R"({"int":1})", @@ -206,11 +206,11 @@ class SealedSegmentRegexQueryTest : public ::testing::Test { schema = GenTestSchema(); seg = CreateSealedSegment(schema); raw_str = { - "b", - "a", - "aaa", - "abbb", - "abcabcabc", + "b\n", + "a\n", + "aaa\n", + "abbb\n", + "abcabcabc\n", }; raw_json = { R"({"int":1})", diff --git a/internal/core/unittest/test_regex_query_util.cpp b/internal/core/unittest/test_regex_query_util.cpp index 0ba999fec9..0945ea685a 100644 --- a/internal/core/unittest/test_regex_query_util.cpp +++ b/internal/core/unittest/test_regex_query_util.cpp @@ -13,37 +13,60 @@ #include "common/RegexQuery.h" +TEST(IsSpecial, Demo) { + std::string special_bytes(R"(\.+*?()|[]{}^$)"); + std::unordered_set specials; + for (char b : special_bytes) { + specials.insert(b); + } + for (char c = std::numeric_limits::min(); + c < std::numeric_limits::max(); + c++) { + if (specials.find(c) != specials.end()) { + EXPECT_TRUE(milvus::is_special(c)) << c << static_cast(c); + } else { + EXPECT_FALSE(milvus::is_special(c)) << c << static_cast(c); + } + } +} + TEST(TranslatePatternMatchToRegexTest, SimplePatternWithPercent) { std::string pattern = "abc%"; - std::string result = milvus::TranslatePatternMatchToRegex(pattern); - EXPECT_EQ(result, "abc.*"); + std::string result = milvus::translate_pattern_match_to_regex(pattern); + EXPECT_EQ(result, "abc[\\s\\S]*"); } TEST(TranslatePatternMatchToRegexTest, PatternWithUnderscore) { std::string pattern = "a_c"; - std::string result = milvus::TranslatePatternMatchToRegex(pattern); - EXPECT_EQ(result, "a.c"); + std::string result = milvus::translate_pattern_match_to_regex(pattern); + EXPECT_EQ(result, "a[\\s\\S]c"); } TEST(TranslatePatternMatchToRegexTest, PatternWithSpecialCharacters) { std::string pattern = "a\\%b\\_c"; - std::string result = milvus::TranslatePatternMatchToRegex(pattern); - EXPECT_EQ(result, "a\\%b\\_c"); + std::string result = milvus::translate_pattern_match_to_regex(pattern); + EXPECT_EQ(result, "a%b_c"); } TEST(TranslatePatternMatchToRegexTest, PatternWithMultiplePercentAndUnderscore) { std::string pattern = "%a_b%"; - std::string result = milvus::TranslatePatternMatchToRegex(pattern); - EXPECT_EQ(result, ".*a.b.*"); + std::string result = milvus::translate_pattern_match_to_regex(pattern); + EXPECT_EQ(result, "[\\s\\S]*a[\\s\\S]b[\\s\\S]*"); } TEST(TranslatePatternMatchToRegexTest, PatternWithRegexChar) { std::string pattern = "abc*def.ghi+"; - std::string result = milvus::TranslatePatternMatchToRegex(pattern); + std::string result = milvus::translate_pattern_match_to_regex(pattern); EXPECT_EQ(result, "abc\\*def\\.ghi\\+"); } +TEST(TranslatePatternMatchToRegexTest, MixPattern) { + std::string pattern = R"(abc\+\def%ghi_[\\)"; + std::string result = milvus::translate_pattern_match_to_regex(pattern); + EXPECT_EQ(result, R"(abc\+def[\s\S]*ghi[\s\S]\[\\)"); +} + TEST(PatternMatchTranslatorTest, InvalidTypeTest) { using namespace milvus; PatternMatchTranslator translator; @@ -63,47 +86,67 @@ TEST(PatternMatchTranslatorTest, StringTypeTest) { EXPECT_EQ(translator(pattern1), "abc"); EXPECT_EQ(translator(pattern2), "xyz"); - EXPECT_EQ(translator(pattern3), ".*a.b.*"); + EXPECT_EQ(translator(pattern3), "[\\s\\S]*a[\\s\\S]b[\\s\\S]*"); } TEST(RegexMatcherTest, DefaultBehaviorTest) { using namespace milvus; - RegexMatcher matcher; - std::regex pattern("Hello.*"); + std::string pattern("Hello.*"); + RegexMatcher matcher(pattern); int operand1 = 123; double operand2 = 3.14; bool operand3 = true; - EXPECT_FALSE(matcher(pattern, operand1)); - EXPECT_FALSE(matcher(pattern, operand2)); - EXPECT_FALSE(matcher(pattern, operand3)); + EXPECT_FALSE(matcher(operand1)); + EXPECT_FALSE(matcher(operand2)); + EXPECT_FALSE(matcher(operand3)); } TEST(RegexMatcherTest, StringMatchTest) { using namespace milvus; - RegexMatcher matcher; - std::regex pattern("Hello.*"); + std::string pattern("Hello.*"); + RegexMatcher matcher(pattern); std::string str1 = "Hello, World!"; std::string str2 = "Hi there!"; std::string str3 = "Hello, OpenAI!"; - EXPECT_TRUE(matcher(pattern, str1)); - EXPECT_FALSE(matcher(pattern, str2)); - EXPECT_TRUE(matcher(pattern, str3)); + EXPECT_TRUE(matcher(str1)); + EXPECT_FALSE(matcher(str2)); + EXPECT_TRUE(matcher(str3)); } TEST(RegexMatcherTest, StringViewMatchTest) { using namespace milvus; - RegexMatcher matcher; - std::regex pattern("Hello.*"); + std::string pattern("Hello.*"); + RegexMatcher matcher(pattern); std::string_view str1 = "Hello, World!"; std::string_view str2 = "Hi there!"; std::string_view str3 = "Hello, OpenAI!"; - EXPECT_TRUE(matcher(pattern, str1)); - EXPECT_FALSE(matcher(pattern, str2)); - EXPECT_TRUE(matcher(pattern, str3)); + EXPECT_TRUE(matcher(str1)); + EXPECT_FALSE(matcher(str2)); + EXPECT_TRUE(matcher(str3)); +} + +TEST(RegexMatcherTest, NewLine) { + GTEST_SKIP() << "TODO: matching behavior on newline"; + + using namespace milvus; + std::string pattern("Hello.*"); + RegexMatcher matcher(pattern); + + EXPECT_FALSE(matcher(std::string("Hello\n"))); +} + +TEST(RegexMatcherTest, PatternMatchWithNewLine) { + using namespace milvus; + std::string pattern("Hello%"); + PatternMatchTranslator translator; + auto rp = translator(pattern); + RegexMatcher matcher(rp); + + EXPECT_TRUE(matcher(std::string("Hello\n"))); }