mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 01:28:27 +08:00
fix: regex query can't handle text with newline (#32569)
issue: https://github.com/milvus-io/milvus/issues/32482 --------- Signed-off-by: longjiquan <jiquan.long@zilliz.com>
This commit is contained in:
parent
02ace25c68
commit
ccce1e928a
@ -10,49 +10,54 @@
|
|||||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||||
|
|
||||||
#include <re2/re2.h>
|
#include <re2/re2.h>
|
||||||
#include <regex>
|
|
||||||
|
|
||||||
#include "common/RegexQuery.h"
|
#include "common/RegexQuery.h"
|
||||||
|
|
||||||
namespace milvus {
|
namespace milvus {
|
||||||
std::string
|
|
||||||
ReplaceUnescapedChars(const std::string& input,
|
|
||||||
char src,
|
|
||||||
const std::string& replacement) {
|
|
||||||
std::string result;
|
|
||||||
bool escapeMode = false;
|
|
||||||
|
|
||||||
for (char c : input) {
|
bool
|
||||||
if (escapeMode) {
|
is_special(char c) {
|
||||||
result += '\\';
|
// initial special_bytes_bitmap only once.
|
||||||
result += c;
|
static std::once_flag _initialized;
|
||||||
escapeMode = false;
|
static std::string special_bytes(R"(\.+*?()|[]{}^$)");
|
||||||
|
static std::vector<bool> special_bytes_bitmap;
|
||||||
|
std::call_once(_initialized, []() -> void {
|
||||||
|
special_bytes_bitmap.resize(256);
|
||||||
|
for (char b : special_bytes) {
|
||||||
|
special_bytes_bitmap[b + 128] = true;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return special_bytes_bitmap[c + 128];
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string
|
||||||
|
translate_pattern_match_to_regex(const std::string& pattern) {
|
||||||
|
std::string r;
|
||||||
|
r.reserve(2 * pattern.size());
|
||||||
|
bool escape_mode = false;
|
||||||
|
for (char c : pattern) {
|
||||||
|
if (escape_mode) {
|
||||||
|
if (is_special(c)) {
|
||||||
|
r += '\\';
|
||||||
|
}
|
||||||
|
r += c;
|
||||||
|
escape_mode = false;
|
||||||
} else {
|
} else {
|
||||||
if (c == '\\') {
|
if (c == '\\') {
|
||||||
escapeMode = true;
|
escape_mode = true;
|
||||||
} else if (c == src) {
|
} else if (c == '%') {
|
||||||
result += replacement;
|
r += "[\\s\\S]*";
|
||||||
|
} else if (c == '_') {
|
||||||
|
r += "[\\s\\S]";
|
||||||
} else {
|
} else {
|
||||||
result += c;
|
if (is_special(c)) {
|
||||||
|
r += '\\';
|
||||||
|
}
|
||||||
|
r += c;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return r;
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string
|
|
||||||
TranslatePatternMatchToRegex(const std::string& pattern) {
|
|
||||||
std::string regex_pattern;
|
|
||||||
#if 0
|
|
||||||
regex_pattern = R"([\.\*\+\?\|\(\)\[\]\{\}\\])";
|
|
||||||
#else
|
|
||||||
regex_pattern = R"([\.\*\+\?\|\(\)\[\]\{\}])";
|
|
||||||
#endif
|
|
||||||
std::string regex =
|
|
||||||
std::regex_replace(pattern, std::regex(regex_pattern), R"(\$&)");
|
|
||||||
regex = ReplaceUnescapedChars(regex, '%', ".*");
|
|
||||||
regex = ReplaceUnescapedChars(regex, '_', ".");
|
|
||||||
return regex;
|
|
||||||
}
|
}
|
||||||
} // namespace milvus
|
} // namespace milvus
|
||||||
|
|||||||
@ -13,17 +13,17 @@
|
|||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <regex>
|
#include <regex>
|
||||||
|
#include <boost/regex.hpp>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include "common/EasyAssert.h"
|
#include "common/EasyAssert.h"
|
||||||
|
|
||||||
namespace milvus {
|
namespace milvus {
|
||||||
std::string
|
bool
|
||||||
ReplaceUnescapedChars(const std::string& input,
|
is_special(char c);
|
||||||
char src,
|
|
||||||
const std::string& replacement);
|
|
||||||
|
|
||||||
std::string
|
std::string
|
||||||
TranslatePatternMatchToRegex(const std::string& pattern);
|
translate_pattern_match_to_regex(const std::string& pattern);
|
||||||
|
|
||||||
struct PatternMatchTranslator {
|
struct PatternMatchTranslator {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -37,28 +37,40 @@ struct PatternMatchTranslator {
|
|||||||
template <>
|
template <>
|
||||||
inline std::string
|
inline std::string
|
||||||
PatternMatchTranslator::operator()<std::string>(const std::string& pattern) {
|
PatternMatchTranslator::operator()<std::string>(const std::string& pattern) {
|
||||||
return TranslatePatternMatchToRegex(pattern);
|
return translate_pattern_match_to_regex(pattern);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct RegexMatcher {
|
struct RegexMatcher {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline bool
|
inline bool
|
||||||
operator()(const std::regex& reg, const T& operand) {
|
operator()(const T& operand) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
explicit RegexMatcher(const std::string& pattern) {
|
||||||
|
r_ = boost::regex(pattern);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// avoid to construct the regex everytime.
|
||||||
|
boost::regex r_;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline bool
|
inline bool
|
||||||
RegexMatcher::operator()<std::string>(const std::regex& reg,
|
RegexMatcher::operator()(const std::string& operand) {
|
||||||
const std::string& operand) {
|
// corner case:
|
||||||
return std::regex_match(operand, reg);
|
// . don't match \n, but .* match \n.
|
||||||
|
// For example,
|
||||||
|
// boost::regex_match("Hello\n", boost::regex("Hello.")) returns false
|
||||||
|
// but
|
||||||
|
// boost::regex_match("Hello\n", boost::regex("Hello.*")) returns true
|
||||||
|
return boost::regex_match(operand, r_);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline bool
|
inline bool
|
||||||
RegexMatcher::operator()<std::string_view>(const std::regex& reg,
|
RegexMatcher::operator()(const std::string_view& operand) {
|
||||||
const std::string_view& operand) {
|
return boost::regex_match(operand.begin(), operand.end(), r_);
|
||||||
return std::regex_match(operand.begin(), operand.end(), reg);
|
|
||||||
}
|
}
|
||||||
} // namespace milvus
|
} // namespace milvus
|
||||||
|
|||||||
@ -335,15 +335,14 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() {
|
|||||||
}
|
}
|
||||||
case proto::plan::Match: {
|
case proto::plan::Match: {
|
||||||
PatternMatchTranslator translator;
|
PatternMatchTranslator translator;
|
||||||
RegexMatcher matcher;
|
|
||||||
auto regex_pattern = translator(val);
|
auto regex_pattern = translator(val);
|
||||||
std::regex reg(regex_pattern);
|
RegexMatcher matcher(regex_pattern);
|
||||||
for (size_t i = 0; i < size; ++i) {
|
for (size_t i = 0; i < size; ++i) {
|
||||||
if constexpr (std::is_same_v<GetType, proto::plan::Array>) {
|
if constexpr (std::is_same_v<GetType, proto::plan::Array>) {
|
||||||
res[i] = false;
|
res[i] = false;
|
||||||
} else {
|
} else {
|
||||||
UnaryRangeJSONCompare(
|
UnaryRangeJSONCompare(
|
||||||
matcher(reg, ExprValueType(x.value())));
|
matcher(ExprValueType(x.value())));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|||||||
@ -43,21 +43,11 @@ struct UnaryElementFuncForMatch {
|
|||||||
size_t size,
|
size_t size,
|
||||||
IndexInnerType val,
|
IndexInnerType val,
|
||||||
TargetBitmapView res) {
|
TargetBitmapView res) {
|
||||||
if constexpr (std::is_same_v<T, std::string_view>) {
|
PatternMatchTranslator translator;
|
||||||
// translate the pattern match in advance, which avoid computing it every loop.
|
auto regex_pattern = translator(val);
|
||||||
std::regex reg(TranslatePatternMatchToRegex(val));
|
RegexMatcher matcher(regex_pattern);
|
||||||
for (int i = 0; i < size; ++i) {
|
for (int i = 0; i < size; ++i) {
|
||||||
res[i] =
|
res[i] = matcher(src[i]);
|
||||||
std::regex_match(std::begin(src[i]), std::end(src[i]), reg);
|
|
||||||
}
|
|
||||||
} else if constexpr (std::is_same_v<T, std::string>) {
|
|
||||||
// translate the pattern match in advance, which avoid computing it every loop.
|
|
||||||
std::regex reg(TranslatePatternMatchToRegex(val));
|
|
||||||
for (int i = 0; i < size; ++i) {
|
|
||||||
res[i] = std::regex_match(src[i], reg);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
PanicInfo(Unsupported, "regex query is only supported on string");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -216,9 +206,12 @@ struct UnaryIndexFuncForMatch {
|
|||||||
!std::is_same_v<T, std::string>) {
|
!std::is_same_v<T, std::string>) {
|
||||||
PanicInfo(Unsupported, "regex query is only supported on string");
|
PanicInfo(Unsupported, "regex query is only supported on string");
|
||||||
} else {
|
} else {
|
||||||
auto reg = TranslatePatternMatchToRegex(val);
|
PatternMatchTranslator translator;
|
||||||
|
auto regex_pattern = translator(val);
|
||||||
|
RegexMatcher matcher(regex_pattern);
|
||||||
|
|
||||||
if (index->SupportRegexQuery()) {
|
if (index->SupportRegexQuery()) {
|
||||||
return index->RegexQuery(reg);
|
return index->RegexQuery(regex_pattern);
|
||||||
}
|
}
|
||||||
if (!index->HasRawData()) {
|
if (!index->HasRawData()) {
|
||||||
PanicInfo(Unsupported,
|
PanicInfo(Unsupported,
|
||||||
@ -228,11 +221,10 @@ struct UnaryIndexFuncForMatch {
|
|||||||
|
|
||||||
// retrieve raw data to do brute force query, may be very slow.
|
// retrieve raw data to do brute force query, may be very slow.
|
||||||
auto cnt = index->Count();
|
auto cnt = index->Count();
|
||||||
std::regex r(reg);
|
|
||||||
TargetBitmap res(cnt);
|
TargetBitmap res(cnt);
|
||||||
for (int64_t i = 0; i < cnt; i++) {
|
for (int64_t i = 0; i < cnt; i++) {
|
||||||
auto raw = index->Reverse_Lookup(i);
|
auto raw = index->Reverse_Lookup(i);
|
||||||
res[i] = std::regex_match(raw, r);
|
res[i] = matcher(raw);
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -54,11 +54,11 @@ class GrowingSegmentRegexQueryTest : public ::testing::Test {
|
|||||||
schema = GenTestSchema();
|
schema = GenTestSchema();
|
||||||
seg = CreateGrowingSegment(schema, empty_index_meta);
|
seg = CreateGrowingSegment(schema, empty_index_meta);
|
||||||
raw_str = {
|
raw_str = {
|
||||||
"b",
|
"b\n",
|
||||||
"a",
|
"a\n",
|
||||||
"aaa",
|
"aaa\n",
|
||||||
"abbb",
|
"abbb\n",
|
||||||
"abcabcabc",
|
"abcabcabc\n",
|
||||||
};
|
};
|
||||||
raw_json = {
|
raw_json = {
|
||||||
R"({"int":1})",
|
R"({"int":1})",
|
||||||
@ -206,11 +206,11 @@ class SealedSegmentRegexQueryTest : public ::testing::Test {
|
|||||||
schema = GenTestSchema();
|
schema = GenTestSchema();
|
||||||
seg = CreateSealedSegment(schema);
|
seg = CreateSealedSegment(schema);
|
||||||
raw_str = {
|
raw_str = {
|
||||||
"b",
|
"b\n",
|
||||||
"a",
|
"a\n",
|
||||||
"aaa",
|
"aaa\n",
|
||||||
"abbb",
|
"abbb\n",
|
||||||
"abcabcabc",
|
"abcabcabc\n",
|
||||||
};
|
};
|
||||||
raw_json = {
|
raw_json = {
|
||||||
R"({"int":1})",
|
R"({"int":1})",
|
||||||
|
|||||||
@ -13,37 +13,60 @@
|
|||||||
|
|
||||||
#include "common/RegexQuery.h"
|
#include "common/RegexQuery.h"
|
||||||
|
|
||||||
|
TEST(IsSpecial, Demo) {
|
||||||
|
std::string special_bytes(R"(\.+*?()|[]{}^$)");
|
||||||
|
std::unordered_set<char> specials;
|
||||||
|
for (char b : special_bytes) {
|
||||||
|
specials.insert(b);
|
||||||
|
}
|
||||||
|
for (char c = std::numeric_limits<int8_t>::min();
|
||||||
|
c < std::numeric_limits<int8_t>::max();
|
||||||
|
c++) {
|
||||||
|
if (specials.find(c) != specials.end()) {
|
||||||
|
EXPECT_TRUE(milvus::is_special(c)) << c << static_cast<int>(c);
|
||||||
|
} else {
|
||||||
|
EXPECT_FALSE(milvus::is_special(c)) << c << static_cast<int>(c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST(TranslatePatternMatchToRegexTest, SimplePatternWithPercent) {
|
TEST(TranslatePatternMatchToRegexTest, SimplePatternWithPercent) {
|
||||||
std::string pattern = "abc%";
|
std::string pattern = "abc%";
|
||||||
std::string result = milvus::TranslatePatternMatchToRegex(pattern);
|
std::string result = milvus::translate_pattern_match_to_regex(pattern);
|
||||||
EXPECT_EQ(result, "abc.*");
|
EXPECT_EQ(result, "abc[\\s\\S]*");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TranslatePatternMatchToRegexTest, PatternWithUnderscore) {
|
TEST(TranslatePatternMatchToRegexTest, PatternWithUnderscore) {
|
||||||
std::string pattern = "a_c";
|
std::string pattern = "a_c";
|
||||||
std::string result = milvus::TranslatePatternMatchToRegex(pattern);
|
std::string result = milvus::translate_pattern_match_to_regex(pattern);
|
||||||
EXPECT_EQ(result, "a.c");
|
EXPECT_EQ(result, "a[\\s\\S]c");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TranslatePatternMatchToRegexTest, PatternWithSpecialCharacters) {
|
TEST(TranslatePatternMatchToRegexTest, PatternWithSpecialCharacters) {
|
||||||
std::string pattern = "a\\%b\\_c";
|
std::string pattern = "a\\%b\\_c";
|
||||||
std::string result = milvus::TranslatePatternMatchToRegex(pattern);
|
std::string result = milvus::translate_pattern_match_to_regex(pattern);
|
||||||
EXPECT_EQ(result, "a\\%b\\_c");
|
EXPECT_EQ(result, "a%b_c");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TranslatePatternMatchToRegexTest,
|
TEST(TranslatePatternMatchToRegexTest,
|
||||||
PatternWithMultiplePercentAndUnderscore) {
|
PatternWithMultiplePercentAndUnderscore) {
|
||||||
std::string pattern = "%a_b%";
|
std::string pattern = "%a_b%";
|
||||||
std::string result = milvus::TranslatePatternMatchToRegex(pattern);
|
std::string result = milvus::translate_pattern_match_to_regex(pattern);
|
||||||
EXPECT_EQ(result, ".*a.b.*");
|
EXPECT_EQ(result, "[\\s\\S]*a[\\s\\S]b[\\s\\S]*");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TranslatePatternMatchToRegexTest, PatternWithRegexChar) {
|
TEST(TranslatePatternMatchToRegexTest, PatternWithRegexChar) {
|
||||||
std::string pattern = "abc*def.ghi+";
|
std::string pattern = "abc*def.ghi+";
|
||||||
std::string result = milvus::TranslatePatternMatchToRegex(pattern);
|
std::string result = milvus::translate_pattern_match_to_regex(pattern);
|
||||||
EXPECT_EQ(result, "abc\\*def\\.ghi\\+");
|
EXPECT_EQ(result, "abc\\*def\\.ghi\\+");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(TranslatePatternMatchToRegexTest, MixPattern) {
|
||||||
|
std::string pattern = R"(abc\+\def%ghi_[\\)";
|
||||||
|
std::string result = milvus::translate_pattern_match_to_regex(pattern);
|
||||||
|
EXPECT_EQ(result, R"(abc\+def[\s\S]*ghi[\s\S]\[\\)");
|
||||||
|
}
|
||||||
|
|
||||||
TEST(PatternMatchTranslatorTest, InvalidTypeTest) {
|
TEST(PatternMatchTranslatorTest, InvalidTypeTest) {
|
||||||
using namespace milvus;
|
using namespace milvus;
|
||||||
PatternMatchTranslator translator;
|
PatternMatchTranslator translator;
|
||||||
@ -63,47 +86,67 @@ TEST(PatternMatchTranslatorTest, StringTypeTest) {
|
|||||||
|
|
||||||
EXPECT_EQ(translator(pattern1), "abc");
|
EXPECT_EQ(translator(pattern1), "abc");
|
||||||
EXPECT_EQ(translator(pattern2), "xyz");
|
EXPECT_EQ(translator(pattern2), "xyz");
|
||||||
EXPECT_EQ(translator(pattern3), ".*a.b.*");
|
EXPECT_EQ(translator(pattern3), "[\\s\\S]*a[\\s\\S]b[\\s\\S]*");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(RegexMatcherTest, DefaultBehaviorTest) {
|
TEST(RegexMatcherTest, DefaultBehaviorTest) {
|
||||||
using namespace milvus;
|
using namespace milvus;
|
||||||
RegexMatcher matcher;
|
std::string pattern("Hello.*");
|
||||||
std::regex pattern("Hello.*");
|
RegexMatcher matcher(pattern);
|
||||||
|
|
||||||
int operand1 = 123;
|
int operand1 = 123;
|
||||||
double operand2 = 3.14;
|
double operand2 = 3.14;
|
||||||
bool operand3 = true;
|
bool operand3 = true;
|
||||||
|
|
||||||
EXPECT_FALSE(matcher(pattern, operand1));
|
EXPECT_FALSE(matcher(operand1));
|
||||||
EXPECT_FALSE(matcher(pattern, operand2));
|
EXPECT_FALSE(matcher(operand2));
|
||||||
EXPECT_FALSE(matcher(pattern, operand3));
|
EXPECT_FALSE(matcher(operand3));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(RegexMatcherTest, StringMatchTest) {
|
TEST(RegexMatcherTest, StringMatchTest) {
|
||||||
using namespace milvus;
|
using namespace milvus;
|
||||||
RegexMatcher matcher;
|
std::string pattern("Hello.*");
|
||||||
std::regex pattern("Hello.*");
|
RegexMatcher matcher(pattern);
|
||||||
|
|
||||||
std::string str1 = "Hello, World!";
|
std::string str1 = "Hello, World!";
|
||||||
std::string str2 = "Hi there!";
|
std::string str2 = "Hi there!";
|
||||||
std::string str3 = "Hello, OpenAI!";
|
std::string str3 = "Hello, OpenAI!";
|
||||||
|
|
||||||
EXPECT_TRUE(matcher(pattern, str1));
|
EXPECT_TRUE(matcher(str1));
|
||||||
EXPECT_FALSE(matcher(pattern, str2));
|
EXPECT_FALSE(matcher(str2));
|
||||||
EXPECT_TRUE(matcher(pattern, str3));
|
EXPECT_TRUE(matcher(str3));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(RegexMatcherTest, StringViewMatchTest) {
|
TEST(RegexMatcherTest, StringViewMatchTest) {
|
||||||
using namespace milvus;
|
using namespace milvus;
|
||||||
RegexMatcher matcher;
|
std::string pattern("Hello.*");
|
||||||
std::regex pattern("Hello.*");
|
RegexMatcher matcher(pattern);
|
||||||
|
|
||||||
std::string_view str1 = "Hello, World!";
|
std::string_view str1 = "Hello, World!";
|
||||||
std::string_view str2 = "Hi there!";
|
std::string_view str2 = "Hi there!";
|
||||||
std::string_view str3 = "Hello, OpenAI!";
|
std::string_view str3 = "Hello, OpenAI!";
|
||||||
|
|
||||||
EXPECT_TRUE(matcher(pattern, str1));
|
EXPECT_TRUE(matcher(str1));
|
||||||
EXPECT_FALSE(matcher(pattern, str2));
|
EXPECT_FALSE(matcher(str2));
|
||||||
EXPECT_TRUE(matcher(pattern, str3));
|
EXPECT_TRUE(matcher(str3));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RegexMatcherTest, NewLine) {
|
||||||
|
GTEST_SKIP() << "TODO: matching behavior on newline";
|
||||||
|
|
||||||
|
using namespace milvus;
|
||||||
|
std::string pattern("Hello.*");
|
||||||
|
RegexMatcher matcher(pattern);
|
||||||
|
|
||||||
|
EXPECT_FALSE(matcher(std::string("Hello\n")));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RegexMatcherTest, PatternMatchWithNewLine) {
|
||||||
|
using namespace milvus;
|
||||||
|
std::string pattern("Hello%");
|
||||||
|
PatternMatchTranslator translator;
|
||||||
|
auto rp = translator(pattern);
|
||||||
|
RegexMatcher matcher(rp);
|
||||||
|
|
||||||
|
EXPECT_TRUE(matcher(std::string("Hello\n")));
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user