fix: regex query can't handle text with newline (#32569)

issue: https://github.com/milvus-io/milvus/issues/32482

---------

Signed-off-by: longjiquan <jiquan.long@zilliz.com>
This commit is contained in:
Jiquan Long 2024-04-26 12:01:26 +08:00 committed by GitHub
parent 02ace25c68
commit ccce1e928a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 154 additions and 103 deletions

View File

@ -10,49 +10,54 @@
// or implied. See the License for the specific language governing permissions and limitations under the License // or implied. See the License for the specific language governing permissions and limitations under the License
#include <re2/re2.h> #include <re2/re2.h>
#include <regex>
#include "common/RegexQuery.h" #include "common/RegexQuery.h"
namespace milvus { namespace milvus {
std::string
ReplaceUnescapedChars(const std::string& input,
char src,
const std::string& replacement) {
std::string result;
bool escapeMode = false;
for (char c : input) { bool
if (escapeMode) { is_special(char c) {
result += '\\'; // initial special_bytes_bitmap only once.
result += c; static std::once_flag _initialized;
escapeMode = false; static std::string special_bytes(R"(\.+*?()|[]{}^$)");
static std::vector<bool> special_bytes_bitmap;
std::call_once(_initialized, []() -> void {
special_bytes_bitmap.resize(256);
for (char b : special_bytes) {
special_bytes_bitmap[b + 128] = true;
}
});
return special_bytes_bitmap[c + 128];
}
std::string
translate_pattern_match_to_regex(const std::string& pattern) {
std::string r;
r.reserve(2 * pattern.size());
bool escape_mode = false;
for (char c : pattern) {
if (escape_mode) {
if (is_special(c)) {
r += '\\';
}
r += c;
escape_mode = false;
} else { } else {
if (c == '\\') { if (c == '\\') {
escapeMode = true; escape_mode = true;
} else if (c == src) { } else if (c == '%') {
result += replacement; r += "[\\s\\S]*";
} else if (c == '_') {
r += "[\\s\\S]";
} else { } else {
result += c; if (is_special(c)) {
r += '\\';
}
r += c;
} }
} }
} }
return r;
return result;
}
std::string
TranslatePatternMatchToRegex(const std::string& pattern) {
std::string regex_pattern;
#if 0
regex_pattern = R"([\.\*\+\?\|\(\)\[\]\{\}\\])";
#else
regex_pattern = R"([\.\*\+\?\|\(\)\[\]\{\}])";
#endif
std::string regex =
std::regex_replace(pattern, std::regex(regex_pattern), R"(\$&)");
regex = ReplaceUnescapedChars(regex, '%', ".*");
regex = ReplaceUnescapedChars(regex, '_', ".");
return regex;
} }
} // namespace milvus } // namespace milvus

View File

@ -13,17 +13,17 @@
#include <string> #include <string>
#include <regex> #include <regex>
#include <boost/regex.hpp>
#include <utility>
#include "common/EasyAssert.h" #include "common/EasyAssert.h"
namespace milvus { namespace milvus {
std::string bool
ReplaceUnescapedChars(const std::string& input, is_special(char c);
char src,
const std::string& replacement);
std::string std::string
TranslatePatternMatchToRegex(const std::string& pattern); translate_pattern_match_to_regex(const std::string& pattern);
struct PatternMatchTranslator { struct PatternMatchTranslator {
template <typename T> template <typename T>
@ -37,28 +37,40 @@ struct PatternMatchTranslator {
template <> template <>
inline std::string inline std::string
PatternMatchTranslator::operator()<std::string>(const std::string& pattern) { PatternMatchTranslator::operator()<std::string>(const std::string& pattern) {
return TranslatePatternMatchToRegex(pattern); return translate_pattern_match_to_regex(pattern);
} }
struct RegexMatcher { struct RegexMatcher {
template <typename T> template <typename T>
inline bool inline bool
operator()(const std::regex& reg, const T& operand) { operator()(const T& operand) {
return false; return false;
} }
explicit RegexMatcher(const std::string& pattern) {
r_ = boost::regex(pattern);
}
private:
// avoid to construct the regex everytime.
boost::regex r_;
}; };
template <> template <>
inline bool inline bool
RegexMatcher::operator()<std::string>(const std::regex& reg, RegexMatcher::operator()(const std::string& operand) {
const std::string& operand) { // corner case:
return std::regex_match(operand, reg); // . don't match \n, but .* match \n.
// For example,
// boost::regex_match("Hello\n", boost::regex("Hello.")) returns false
// but
// boost::regex_match("Hello\n", boost::regex("Hello.*")) returns true
return boost::regex_match(operand, r_);
} }
template <> template <>
inline bool inline bool
RegexMatcher::operator()<std::string_view>(const std::regex& reg, RegexMatcher::operator()(const std::string_view& operand) {
const std::string_view& operand) { return boost::regex_match(operand.begin(), operand.end(), r_);
return std::regex_match(operand.begin(), operand.end(), reg);
} }
} // namespace milvus } // namespace milvus

View File

@ -335,15 +335,14 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() {
} }
case proto::plan::Match: { case proto::plan::Match: {
PatternMatchTranslator translator; PatternMatchTranslator translator;
RegexMatcher matcher;
auto regex_pattern = translator(val); auto regex_pattern = translator(val);
std::regex reg(regex_pattern); RegexMatcher matcher(regex_pattern);
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
if constexpr (std::is_same_v<GetType, proto::plan::Array>) { if constexpr (std::is_same_v<GetType, proto::plan::Array>) {
res[i] = false; res[i] = false;
} else { } else {
UnaryRangeJSONCompare( UnaryRangeJSONCompare(
matcher(reg, ExprValueType(x.value()))); matcher(ExprValueType(x.value())));
} }
} }
break; break;

View File

@ -43,21 +43,11 @@ struct UnaryElementFuncForMatch {
size_t size, size_t size,
IndexInnerType val, IndexInnerType val,
TargetBitmapView res) { TargetBitmapView res) {
if constexpr (std::is_same_v<T, std::string_view>) { PatternMatchTranslator translator;
// translate the pattern match in advance, which avoid computing it every loop. auto regex_pattern = translator(val);
std::regex reg(TranslatePatternMatchToRegex(val)); RegexMatcher matcher(regex_pattern);
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
res[i] = res[i] = matcher(src[i]);
std::regex_match(std::begin(src[i]), std::end(src[i]), reg);
}
} else if constexpr (std::is_same_v<T, std::string>) {
// translate the pattern match in advance, which avoid computing it every loop.
std::regex reg(TranslatePatternMatchToRegex(val));
for (int i = 0; i < size; ++i) {
res[i] = std::regex_match(src[i], reg);
}
} else {
PanicInfo(Unsupported, "regex query is only supported on string");
} }
} }
}; };
@ -216,9 +206,12 @@ struct UnaryIndexFuncForMatch {
!std::is_same_v<T, std::string>) { !std::is_same_v<T, std::string>) {
PanicInfo(Unsupported, "regex query is only supported on string"); PanicInfo(Unsupported, "regex query is only supported on string");
} else { } else {
auto reg = TranslatePatternMatchToRegex(val); PatternMatchTranslator translator;
auto regex_pattern = translator(val);
RegexMatcher matcher(regex_pattern);
if (index->SupportRegexQuery()) { if (index->SupportRegexQuery()) {
return index->RegexQuery(reg); return index->RegexQuery(regex_pattern);
} }
if (!index->HasRawData()) { if (!index->HasRawData()) {
PanicInfo(Unsupported, PanicInfo(Unsupported,
@ -228,11 +221,10 @@ struct UnaryIndexFuncForMatch {
// retrieve raw data to do brute force query, may be very slow. // retrieve raw data to do brute force query, may be very slow.
auto cnt = index->Count(); auto cnt = index->Count();
std::regex r(reg);
TargetBitmap res(cnt); TargetBitmap res(cnt);
for (int64_t i = 0; i < cnt; i++) { for (int64_t i = 0; i < cnt; i++) {
auto raw = index->Reverse_Lookup(i); auto raw = index->Reverse_Lookup(i);
res[i] = std::regex_match(raw, r); res[i] = matcher(raw);
} }
return res; return res;
} }

View File

@ -54,11 +54,11 @@ class GrowingSegmentRegexQueryTest : public ::testing::Test {
schema = GenTestSchema(); schema = GenTestSchema();
seg = CreateGrowingSegment(schema, empty_index_meta); seg = CreateGrowingSegment(schema, empty_index_meta);
raw_str = { raw_str = {
"b", "b\n",
"a", "a\n",
"aaa", "aaa\n",
"abbb", "abbb\n",
"abcabcabc", "abcabcabc\n",
}; };
raw_json = { raw_json = {
R"({"int":1})", R"({"int":1})",
@ -206,11 +206,11 @@ class SealedSegmentRegexQueryTest : public ::testing::Test {
schema = GenTestSchema(); schema = GenTestSchema();
seg = CreateSealedSegment(schema); seg = CreateSealedSegment(schema);
raw_str = { raw_str = {
"b", "b\n",
"a", "a\n",
"aaa", "aaa\n",
"abbb", "abbb\n",
"abcabcabc", "abcabcabc\n",
}; };
raw_json = { raw_json = {
R"({"int":1})", R"({"int":1})",

View File

@ -13,37 +13,60 @@
#include "common/RegexQuery.h" #include "common/RegexQuery.h"
TEST(IsSpecial, Demo) {
std::string special_bytes(R"(\.+*?()|[]{}^$)");
std::unordered_set<char> specials;
for (char b : special_bytes) {
specials.insert(b);
}
for (char c = std::numeric_limits<int8_t>::min();
c < std::numeric_limits<int8_t>::max();
c++) {
if (specials.find(c) != specials.end()) {
EXPECT_TRUE(milvus::is_special(c)) << c << static_cast<int>(c);
} else {
EXPECT_FALSE(milvus::is_special(c)) << c << static_cast<int>(c);
}
}
}
TEST(TranslatePatternMatchToRegexTest, SimplePatternWithPercent) { TEST(TranslatePatternMatchToRegexTest, SimplePatternWithPercent) {
std::string pattern = "abc%"; std::string pattern = "abc%";
std::string result = milvus::TranslatePatternMatchToRegex(pattern); std::string result = milvus::translate_pattern_match_to_regex(pattern);
EXPECT_EQ(result, "abc.*"); EXPECT_EQ(result, "abc[\\s\\S]*");
} }
TEST(TranslatePatternMatchToRegexTest, PatternWithUnderscore) { TEST(TranslatePatternMatchToRegexTest, PatternWithUnderscore) {
std::string pattern = "a_c"; std::string pattern = "a_c";
std::string result = milvus::TranslatePatternMatchToRegex(pattern); std::string result = milvus::translate_pattern_match_to_regex(pattern);
EXPECT_EQ(result, "a.c"); EXPECT_EQ(result, "a[\\s\\S]c");
} }
TEST(TranslatePatternMatchToRegexTest, PatternWithSpecialCharacters) { TEST(TranslatePatternMatchToRegexTest, PatternWithSpecialCharacters) {
std::string pattern = "a\\%b\\_c"; std::string pattern = "a\\%b\\_c";
std::string result = milvus::TranslatePatternMatchToRegex(pattern); std::string result = milvus::translate_pattern_match_to_regex(pattern);
EXPECT_EQ(result, "a\\%b\\_c"); EXPECT_EQ(result, "a%b_c");
} }
TEST(TranslatePatternMatchToRegexTest, TEST(TranslatePatternMatchToRegexTest,
PatternWithMultiplePercentAndUnderscore) { PatternWithMultiplePercentAndUnderscore) {
std::string pattern = "%a_b%"; std::string pattern = "%a_b%";
std::string result = milvus::TranslatePatternMatchToRegex(pattern); std::string result = milvus::translate_pattern_match_to_regex(pattern);
EXPECT_EQ(result, ".*a.b.*"); EXPECT_EQ(result, "[\\s\\S]*a[\\s\\S]b[\\s\\S]*");
} }
TEST(TranslatePatternMatchToRegexTest, PatternWithRegexChar) { TEST(TranslatePatternMatchToRegexTest, PatternWithRegexChar) {
std::string pattern = "abc*def.ghi+"; std::string pattern = "abc*def.ghi+";
std::string result = milvus::TranslatePatternMatchToRegex(pattern); std::string result = milvus::translate_pattern_match_to_regex(pattern);
EXPECT_EQ(result, "abc\\*def\\.ghi\\+"); EXPECT_EQ(result, "abc\\*def\\.ghi\\+");
} }
TEST(TranslatePatternMatchToRegexTest, MixPattern) {
std::string pattern = R"(abc\+\def%ghi_[\\)";
std::string result = milvus::translate_pattern_match_to_regex(pattern);
EXPECT_EQ(result, R"(abc\+def[\s\S]*ghi[\s\S]\[\\)");
}
TEST(PatternMatchTranslatorTest, InvalidTypeTest) { TEST(PatternMatchTranslatorTest, InvalidTypeTest) {
using namespace milvus; using namespace milvus;
PatternMatchTranslator translator; PatternMatchTranslator translator;
@ -63,47 +86,67 @@ TEST(PatternMatchTranslatorTest, StringTypeTest) {
EXPECT_EQ(translator(pattern1), "abc"); EXPECT_EQ(translator(pattern1), "abc");
EXPECT_EQ(translator(pattern2), "xyz"); EXPECT_EQ(translator(pattern2), "xyz");
EXPECT_EQ(translator(pattern3), ".*a.b.*"); EXPECT_EQ(translator(pattern3), "[\\s\\S]*a[\\s\\S]b[\\s\\S]*");
} }
TEST(RegexMatcherTest, DefaultBehaviorTest) { TEST(RegexMatcherTest, DefaultBehaviorTest) {
using namespace milvus; using namespace milvus;
RegexMatcher matcher; std::string pattern("Hello.*");
std::regex pattern("Hello.*"); RegexMatcher matcher(pattern);
int operand1 = 123; int operand1 = 123;
double operand2 = 3.14; double operand2 = 3.14;
bool operand3 = true; bool operand3 = true;
EXPECT_FALSE(matcher(pattern, operand1)); EXPECT_FALSE(matcher(operand1));
EXPECT_FALSE(matcher(pattern, operand2)); EXPECT_FALSE(matcher(operand2));
EXPECT_FALSE(matcher(pattern, operand3)); EXPECT_FALSE(matcher(operand3));
} }
TEST(RegexMatcherTest, StringMatchTest) { TEST(RegexMatcherTest, StringMatchTest) {
using namespace milvus; using namespace milvus;
RegexMatcher matcher; std::string pattern("Hello.*");
std::regex pattern("Hello.*"); RegexMatcher matcher(pattern);
std::string str1 = "Hello, World!"; std::string str1 = "Hello, World!";
std::string str2 = "Hi there!"; std::string str2 = "Hi there!";
std::string str3 = "Hello, OpenAI!"; std::string str3 = "Hello, OpenAI!";
EXPECT_TRUE(matcher(pattern, str1)); EXPECT_TRUE(matcher(str1));
EXPECT_FALSE(matcher(pattern, str2)); EXPECT_FALSE(matcher(str2));
EXPECT_TRUE(matcher(pattern, str3)); EXPECT_TRUE(matcher(str3));
} }
TEST(RegexMatcherTest, StringViewMatchTest) { TEST(RegexMatcherTest, StringViewMatchTest) {
using namespace milvus; using namespace milvus;
RegexMatcher matcher; std::string pattern("Hello.*");
std::regex pattern("Hello.*"); RegexMatcher matcher(pattern);
std::string_view str1 = "Hello, World!"; std::string_view str1 = "Hello, World!";
std::string_view str2 = "Hi there!"; std::string_view str2 = "Hi there!";
std::string_view str3 = "Hello, OpenAI!"; std::string_view str3 = "Hello, OpenAI!";
EXPECT_TRUE(matcher(pattern, str1)); EXPECT_TRUE(matcher(str1));
EXPECT_FALSE(matcher(pattern, str2)); EXPECT_FALSE(matcher(str2));
EXPECT_TRUE(matcher(pattern, str3)); EXPECT_TRUE(matcher(str3));
}
TEST(RegexMatcherTest, NewLine) {
GTEST_SKIP() << "TODO: matching behavior on newline";
using namespace milvus;
std::string pattern("Hello.*");
RegexMatcher matcher(pattern);
EXPECT_FALSE(matcher(std::string("Hello\n")));
}
TEST(RegexMatcherTest, PatternMatchWithNewLine) {
using namespace milvus;
std::string pattern("Hello%");
PatternMatchTranslator translator;
auto rp = translator(pattern);
RegexMatcher matcher(rp);
EXPECT_TRUE(matcher(std::string("Hello\n")));
} }