fix: regex query can't handle text with newline (#32569)

issue: https://github.com/milvus-io/milvus/issues/32482

---------

Signed-off-by: longjiquan <jiquan.long@zilliz.com>
This commit is contained in:
Jiquan Long 2024-04-26 12:01:26 +08:00 committed by GitHub
parent 02ace25c68
commit ccce1e928a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 154 additions and 103 deletions

View File

@ -10,49 +10,54 @@
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <re2/re2.h>
#include <regex>
#include "common/RegexQuery.h"
namespace milvus {
std::string
ReplaceUnescapedChars(const std::string& input,
char src,
const std::string& replacement) {
std::string result;
bool escapeMode = false;
for (char c : input) {
if (escapeMode) {
result += '\\';
result += c;
escapeMode = false;
bool
is_special(char c) {
// initial special_bytes_bitmap only once.
static std::once_flag _initialized;
static std::string special_bytes(R"(\.+*?()|[]{}^$)");
static std::vector<bool> special_bytes_bitmap;
std::call_once(_initialized, []() -> void {
special_bytes_bitmap.resize(256);
for (char b : special_bytes) {
special_bytes_bitmap[b + 128] = true;
}
});
return special_bytes_bitmap[c + 128];
}
std::string
translate_pattern_match_to_regex(const std::string& pattern) {
std::string r;
r.reserve(2 * pattern.size());
bool escape_mode = false;
for (char c : pattern) {
if (escape_mode) {
if (is_special(c)) {
r += '\\';
}
r += c;
escape_mode = false;
} else {
if (c == '\\') {
escapeMode = true;
} else if (c == src) {
result += replacement;
escape_mode = true;
} else if (c == '%') {
r += "[\\s\\S]*";
} else if (c == '_') {
r += "[\\s\\S]";
} else {
result += c;
if (is_special(c)) {
r += '\\';
}
r += c;
}
}
}
return result;
}
std::string
TranslatePatternMatchToRegex(const std::string& pattern) {
std::string regex_pattern;
#if 0
regex_pattern = R"([\.\*\+\?\|\(\)\[\]\{\}\\])";
#else
regex_pattern = R"([\.\*\+\?\|\(\)\[\]\{\}])";
#endif
std::string regex =
std::regex_replace(pattern, std::regex(regex_pattern), R"(\$&)");
regex = ReplaceUnescapedChars(regex, '%', ".*");
regex = ReplaceUnescapedChars(regex, '_', ".");
return regex;
return r;
}
} // namespace milvus

View File

@ -13,17 +13,17 @@
#include <string>
#include <regex>
#include <boost/regex.hpp>
#include <utility>
#include "common/EasyAssert.h"
namespace milvus {
std::string
ReplaceUnescapedChars(const std::string& input,
char src,
const std::string& replacement);
bool
is_special(char c);
std::string
TranslatePatternMatchToRegex(const std::string& pattern);
translate_pattern_match_to_regex(const std::string& pattern);
struct PatternMatchTranslator {
template <typename T>
@ -37,28 +37,40 @@ struct PatternMatchTranslator {
template <>
inline std::string
PatternMatchTranslator::operator()<std::string>(const std::string& pattern) {
return TranslatePatternMatchToRegex(pattern);
return translate_pattern_match_to_regex(pattern);
}
struct RegexMatcher {
template <typename T>
inline bool
operator()(const std::regex& reg, const T& operand) {
operator()(const T& operand) {
return false;
}
explicit RegexMatcher(const std::string& pattern) {
r_ = boost::regex(pattern);
}
private:
// avoid to construct the regex everytime.
boost::regex r_;
};
template <>
inline bool
RegexMatcher::operator()<std::string>(const std::regex& reg,
const std::string& operand) {
return std::regex_match(operand, reg);
RegexMatcher::operator()(const std::string& operand) {
// corner case:
// . don't match \n, but .* match \n.
// For example,
// boost::regex_match("Hello\n", boost::regex("Hello.")) returns false
// but
// boost::regex_match("Hello\n", boost::regex("Hello.*")) returns true
return boost::regex_match(operand, r_);
}
template <>
inline bool
RegexMatcher::operator()<std::string_view>(const std::regex& reg,
const std::string_view& operand) {
return std::regex_match(operand.begin(), operand.end(), reg);
RegexMatcher::operator()(const std::string_view& operand) {
return boost::regex_match(operand.begin(), operand.end(), r_);
}
} // namespace milvus

View File

@ -335,15 +335,14 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() {
}
case proto::plan::Match: {
PatternMatchTranslator translator;
RegexMatcher matcher;
auto regex_pattern = translator(val);
std::regex reg(regex_pattern);
RegexMatcher matcher(regex_pattern);
for (size_t i = 0; i < size; ++i) {
if constexpr (std::is_same_v<GetType, proto::plan::Array>) {
res[i] = false;
} else {
UnaryRangeJSONCompare(
matcher(reg, ExprValueType(x.value())));
matcher(ExprValueType(x.value())));
}
}
break;

View File

@ -43,21 +43,11 @@ struct UnaryElementFuncForMatch {
size_t size,
IndexInnerType val,
TargetBitmapView res) {
if constexpr (std::is_same_v<T, std::string_view>) {
// translate the pattern match in advance, which avoid computing it every loop.
std::regex reg(TranslatePatternMatchToRegex(val));
PatternMatchTranslator translator;
auto regex_pattern = translator(val);
RegexMatcher matcher(regex_pattern);
for (int i = 0; i < size; ++i) {
res[i] =
std::regex_match(std::begin(src[i]), std::end(src[i]), reg);
}
} else if constexpr (std::is_same_v<T, std::string>) {
// translate the pattern match in advance, which avoid computing it every loop.
std::regex reg(TranslatePatternMatchToRegex(val));
for (int i = 0; i < size; ++i) {
res[i] = std::regex_match(src[i], reg);
}
} else {
PanicInfo(Unsupported, "regex query is only supported on string");
res[i] = matcher(src[i]);
}
}
};
@ -216,9 +206,12 @@ struct UnaryIndexFuncForMatch {
!std::is_same_v<T, std::string>) {
PanicInfo(Unsupported, "regex query is only supported on string");
} else {
auto reg = TranslatePatternMatchToRegex(val);
PatternMatchTranslator translator;
auto regex_pattern = translator(val);
RegexMatcher matcher(regex_pattern);
if (index->SupportRegexQuery()) {
return index->RegexQuery(reg);
return index->RegexQuery(regex_pattern);
}
if (!index->HasRawData()) {
PanicInfo(Unsupported,
@ -228,11 +221,10 @@ struct UnaryIndexFuncForMatch {
// retrieve raw data to do brute force query, may be very slow.
auto cnt = index->Count();
std::regex r(reg);
TargetBitmap res(cnt);
for (int64_t i = 0; i < cnt; i++) {
auto raw = index->Reverse_Lookup(i);
res[i] = std::regex_match(raw, r);
res[i] = matcher(raw);
}
return res;
}

View File

@ -54,11 +54,11 @@ class GrowingSegmentRegexQueryTest : public ::testing::Test {
schema = GenTestSchema();
seg = CreateGrowingSegment(schema, empty_index_meta);
raw_str = {
"b",
"a",
"aaa",
"abbb",
"abcabcabc",
"b\n",
"a\n",
"aaa\n",
"abbb\n",
"abcabcabc\n",
};
raw_json = {
R"({"int":1})",
@ -206,11 +206,11 @@ class SealedSegmentRegexQueryTest : public ::testing::Test {
schema = GenTestSchema();
seg = CreateSealedSegment(schema);
raw_str = {
"b",
"a",
"aaa",
"abbb",
"abcabcabc",
"b\n",
"a\n",
"aaa\n",
"abbb\n",
"abcabcabc\n",
};
raw_json = {
R"({"int":1})",

View File

@ -13,37 +13,60 @@
#include "common/RegexQuery.h"
TEST(IsSpecial, Demo) {
std::string special_bytes(R"(\.+*?()|[]{}^$)");
std::unordered_set<char> specials;
for (char b : special_bytes) {
specials.insert(b);
}
for (char c = std::numeric_limits<int8_t>::min();
c < std::numeric_limits<int8_t>::max();
c++) {
if (specials.find(c) != specials.end()) {
EXPECT_TRUE(milvus::is_special(c)) << c << static_cast<int>(c);
} else {
EXPECT_FALSE(milvus::is_special(c)) << c << static_cast<int>(c);
}
}
}
TEST(TranslatePatternMatchToRegexTest, SimplePatternWithPercent) {
std::string pattern = "abc%";
std::string result = milvus::TranslatePatternMatchToRegex(pattern);
EXPECT_EQ(result, "abc.*");
std::string result = milvus::translate_pattern_match_to_regex(pattern);
EXPECT_EQ(result, "abc[\\s\\S]*");
}
TEST(TranslatePatternMatchToRegexTest, PatternWithUnderscore) {
std::string pattern = "a_c";
std::string result = milvus::TranslatePatternMatchToRegex(pattern);
EXPECT_EQ(result, "a.c");
std::string result = milvus::translate_pattern_match_to_regex(pattern);
EXPECT_EQ(result, "a[\\s\\S]c");
}
TEST(TranslatePatternMatchToRegexTest, PatternWithSpecialCharacters) {
std::string pattern = "a\\%b\\_c";
std::string result = milvus::TranslatePatternMatchToRegex(pattern);
EXPECT_EQ(result, "a\\%b\\_c");
std::string result = milvus::translate_pattern_match_to_regex(pattern);
EXPECT_EQ(result, "a%b_c");
}
TEST(TranslatePatternMatchToRegexTest,
PatternWithMultiplePercentAndUnderscore) {
std::string pattern = "%a_b%";
std::string result = milvus::TranslatePatternMatchToRegex(pattern);
EXPECT_EQ(result, ".*a.b.*");
std::string result = milvus::translate_pattern_match_to_regex(pattern);
EXPECT_EQ(result, "[\\s\\S]*a[\\s\\S]b[\\s\\S]*");
}
TEST(TranslatePatternMatchToRegexTest, PatternWithRegexChar) {
std::string pattern = "abc*def.ghi+";
std::string result = milvus::TranslatePatternMatchToRegex(pattern);
std::string result = milvus::translate_pattern_match_to_regex(pattern);
EXPECT_EQ(result, "abc\\*def\\.ghi\\+");
}
TEST(TranslatePatternMatchToRegexTest, MixPattern) {
std::string pattern = R"(abc\+\def%ghi_[\\)";
std::string result = milvus::translate_pattern_match_to_regex(pattern);
EXPECT_EQ(result, R"(abc\+def[\s\S]*ghi[\s\S]\[\\)");
}
TEST(PatternMatchTranslatorTest, InvalidTypeTest) {
using namespace milvus;
PatternMatchTranslator translator;
@ -63,47 +86,67 @@ TEST(PatternMatchTranslatorTest, StringTypeTest) {
EXPECT_EQ(translator(pattern1), "abc");
EXPECT_EQ(translator(pattern2), "xyz");
EXPECT_EQ(translator(pattern3), ".*a.b.*");
EXPECT_EQ(translator(pattern3), "[\\s\\S]*a[\\s\\S]b[\\s\\S]*");
}
TEST(RegexMatcherTest, DefaultBehaviorTest) {
using namespace milvus;
RegexMatcher matcher;
std::regex pattern("Hello.*");
std::string pattern("Hello.*");
RegexMatcher matcher(pattern);
int operand1 = 123;
double operand2 = 3.14;
bool operand3 = true;
EXPECT_FALSE(matcher(pattern, operand1));
EXPECT_FALSE(matcher(pattern, operand2));
EXPECT_FALSE(matcher(pattern, operand3));
EXPECT_FALSE(matcher(operand1));
EXPECT_FALSE(matcher(operand2));
EXPECT_FALSE(matcher(operand3));
}
TEST(RegexMatcherTest, StringMatchTest) {
using namespace milvus;
RegexMatcher matcher;
std::regex pattern("Hello.*");
std::string pattern("Hello.*");
RegexMatcher matcher(pattern);
std::string str1 = "Hello, World!";
std::string str2 = "Hi there!";
std::string str3 = "Hello, OpenAI!";
EXPECT_TRUE(matcher(pattern, str1));
EXPECT_FALSE(matcher(pattern, str2));
EXPECT_TRUE(matcher(pattern, str3));
EXPECT_TRUE(matcher(str1));
EXPECT_FALSE(matcher(str2));
EXPECT_TRUE(matcher(str3));
}
TEST(RegexMatcherTest, StringViewMatchTest) {
using namespace milvus;
RegexMatcher matcher;
std::regex pattern("Hello.*");
std::string pattern("Hello.*");
RegexMatcher matcher(pattern);
std::string_view str1 = "Hello, World!";
std::string_view str2 = "Hi there!";
std::string_view str3 = "Hello, OpenAI!";
EXPECT_TRUE(matcher(pattern, str1));
EXPECT_FALSE(matcher(pattern, str2));
EXPECT_TRUE(matcher(pattern, str3));
EXPECT_TRUE(matcher(str1));
EXPECT_FALSE(matcher(str2));
EXPECT_TRUE(matcher(str3));
}
TEST(RegexMatcherTest, NewLine) {
GTEST_SKIP() << "TODO: matching behavior on newline";
using namespace milvus;
std::string pattern("Hello.*");
RegexMatcher matcher(pattern);
EXPECT_FALSE(matcher(std::string("Hello\n")));
}
TEST(RegexMatcherTest, PatternMatchWithNewLine) {
using namespace milvus;
std::string pattern("Hello%");
PatternMatchTranslator translator;
auto rp = translator(pattern);
RegexMatcher matcher(rp);
EXPECT_TRUE(matcher(std::string("Hello\n")));
}