From db91d85dbc0cab74c5ad1d55d2cd2725138909a7 Mon Sep 17 00:00:00 2001 From: Spade A <71589810+SpadeA-Tang@users.noreply.github.com> Date: Mon, 14 Jul 2025 20:34:50 +0800 Subject: [PATCH] feat: more types of matches for ngram (#43081) Ref https://github.com/milvus-io/milvus/issues/42053 This PR enable ngram to support more kinds of matches such as prefix and postfix match. --------- Signed-off-by: SpadeA --- .../core/src/exec/expression/UnaryExpr.cpp | 14 +- internal/core/src/exec/expression/UnaryExpr.h | 3 + internal/core/src/exec/expression/Utils.h | 117 ----------- .../core/src/index/NgramInvertedIndex.cpp | 179 +++++++++++++--- internal/core/src/index/NgramInvertedIndex.h | 16 +- .../src/segcore/ChunkedSegmentSealedImpl.cpp | 14 +- .../src/segcore/ChunkedSegmentSealedImpl.h | 6 +- .../tantivy-binding/include/tantivy-binding.h | 2 +- .../tantivy-binding/src/index_ngram_writer.rs | 8 +- .../tantivy-binding/src/index_reader.rs | 2 +- .../tantivy-binding/src/index_reader_c.rs | 4 +- .../core/thirdparty/tantivy/tantivy-wrapper.h | 8 +- internal/core/unittest/test_ngram_query.cpp | 194 +++++++++--------- 13 files changed, 301 insertions(+), 266 deletions(-) diff --git a/internal/core/src/exec/expression/UnaryExpr.cpp b/internal/core/src/exec/expression/UnaryExpr.cpp index 9e711cdfc6..197744f6c5 100644 --- a/internal/core/src/exec/expression/UnaryExpr.cpp +++ b/internal/core/src/exec/expression/UnaryExpr.cpp @@ -1484,8 +1484,7 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImpl(EvalCtx& context) { fmt::format("match query does not support iterative filter")); } return ExecTextMatch(); - } else if (expr_->op_type_ == proto::plan::OpType::InnerMatch && - !has_offset_input_ && CanUseNgramIndex(field_id_)) { + } else if (CanExecNgramMatch(expr_->op_type_)) { auto res = ExecNgramMatch(); // If nullopt is returned, it means the query cannot be // optimized by ngram index. Forward it to the normal path. @@ -1933,6 +1932,15 @@ PhyUnaryRangeFilterExpr::ExecTextMatch() { return res; }; +bool +PhyUnaryRangeFilterExpr::CanExecNgramMatch(proto::plan::OpType op_type) { + return (op_type == proto::plan::OpType::InnerMatch || + op_type == proto::plan::OpType::Match || + op_type == proto::plan::OpType::PrefixMatch || + op_type == proto::plan::OpType::PostfixMatch) && + !has_offset_input_ && CanUseNgramIndex(field_id_); +} + std::optional PhyUnaryRangeFilterExpr::ExecNgramMatch() { if (!arg_inited_) { @@ -1951,7 +1959,7 @@ PhyUnaryRangeFilterExpr::ExecNgramMatch() { AssertInfo(index != nullptr, "ngram index should not be null, field_id: {}", field_id_.get()); - auto res_opt = index->InnerMatchQuery(literal, this); + auto res_opt = index->ExecuteQuery(literal, expr_->op_type_, this); if (!res_opt.has_value()) { return std::nullopt; } diff --git a/internal/core/src/exec/expression/UnaryExpr.h b/internal/core/src/exec/expression/UnaryExpr.h index b2b81c5eed..4a5857e9ca 100644 --- a/internal/core/src/exec/expression/UnaryExpr.h +++ b/internal/core/src/exec/expression/UnaryExpr.h @@ -506,6 +506,9 @@ class PhyUnaryRangeFilterExpr : public SegmentExpr { VectorPtr ExecTextMatch(); + bool + CanExecNgramMatch(proto::plan::OpType op_type); + std::optional ExecNgramMatch(); diff --git a/internal/core/src/exec/expression/Utils.h b/internal/core/src/exec/expression/Utils.h index 3c2029c7d7..241047ec17 100644 --- a/internal/core/src/exec/expression/Utils.h +++ b/internal/core/src/exec/expression/Utils.h @@ -191,122 +191,5 @@ GetValueWithCastNumber(const milvus::proto::plan::GenericValue& value_proto) { } } -enum class MatchType { - ExactMatch, - PrefixMatch, - PostfixMatch, - // The different between InnerMatch and Match is that InnerMatch is used for - // %xxx% while Match could be %xxx%xxx% - InnerMatch, - Match -}; -struct ParsedResult { - std::string literal; - MatchType type; -}; - -// Not used now, but may be used in the future for other type of match for ngram index -inline std::optional -parse_ngram_pattern(const std::string& pattern) { - if (pattern.empty()) { - return std::nullopt; - } - - std::vector percent_indices; - bool was_escaped = false; - for (size_t i = 0; i < pattern.length(); ++i) { - char c = pattern[i]; - if (c == '%' && !was_escaped) { - percent_indices.push_back(i); - } else if (c == '_' && !was_escaped) { - // todo(SpadeA): now not support '_' - return std::nullopt; - } - was_escaped = (c == '\\' && !was_escaped); - } - - MatchType match_type; - size_t core_start = 0; - size_t core_length = 0; - size_t percent_count = percent_indices.size(); - - if (percent_count == 0) { - match_type = MatchType::ExactMatch; - core_start = 0; - core_length = pattern.length(); - } else if (percent_count == 1) { - if (pattern.length() == 1) { - return std::nullopt; - } - - size_t idx = percent_indices[0]; - // case: %xxx - if (idx == 0 && pattern.length() > 1) { - match_type = MatchType::PrefixMatch; - core_start = 1; - core_length = pattern.length() - 1; - } else if (idx == pattern.length() - 1 && pattern.length() > 1) { - // case: xxx% - match_type = MatchType::PostfixMatch; - core_start = 0; - core_length = pattern.length() - 1; - } else { - // case: xxx%xxx - match_type = MatchType::Match; - } - } else if (percent_count == 2) { - size_t idx1 = percent_indices[0]; - size_t idx2 = percent_indices[1]; - if (idx1 == 0 && idx2 == pattern.length() - 1 && pattern.length() > 2) { - // case: %xxx% - match_type = MatchType::InnerMatch; - core_start = 1; - core_length = pattern.length() - 2; - } else { - match_type = MatchType::Match; - } - } else { - match_type = MatchType::Match; - } - - if (match_type == MatchType::Match) { - // not supported now - return std::nullopt; - } - - // Extract the literal from the pattern - std::string_view core_pattern = - std::string_view(pattern).substr(core_start, core_length); - - std::string r; - r.reserve(2 * core_pattern.size()); - bool escape_mode = false; - for (char c : core_pattern) { - if (escape_mode) { - if (is_special(c)) { - // todo(SpadeA): may not be suitable for ngram? Not use ngram in this case for now. - return std::nullopt; - } - r += c; - escape_mode = false; - } else { - if (c == '\\') { - escape_mode = true; - } else if (c == '%') { - // should be unreachable - } else if (c == '_') { - // should be unreachable - return std::nullopt; - } else { - if (is_special(c)) { - r += '\\'; - } - r += c; - } - } - } - return std::optional{ParsedResult{std::move(r), match_type}}; -} - } // namespace exec } // namespace milvus \ No newline at end of file diff --git a/internal/core/src/index/NgramInvertedIndex.cpp b/internal/core/src/index/NgramInvertedIndex.cpp index 9eb741096e..ee3f9c7e09 100644 --- a/internal/core/src/index/NgramInvertedIndex.cpp +++ b/internal/core/src/index/NgramInvertedIndex.cpp @@ -106,47 +106,168 @@ NgramInvertedIndex::Load(milvus::tracer::TraceContext ctx, } std::optional -NgramInvertedIndex::InnerMatchQuery(const std::string& literal, - exec::SegmentExpr* segment) { +NgramInvertedIndex::ExecuteQuery(const std::string& literal, + proto::plan::OpType op_type, + exec::SegmentExpr* segment) { if (literal.length() < min_gram_) { return std::nullopt; } + switch (op_type) { + case proto::plan::OpType::InnerMatch: { + auto predicate = [&literal](const std::string_view& data) { + return data.find(literal) != std::string::npos; + }; + bool need_post_filter = literal.length() > max_gram_; + return ExecuteQueryWithPredicate( + literal, segment, predicate, need_post_filter); + } + case proto::plan::OpType::Match: + return MatchQuery(literal, segment); + case proto::plan::OpType::PrefixMatch: { + auto predicate = [&literal](const std::string_view& data) { + return data.length() >= literal.length() && + std::equal(literal.begin(), literal.end(), data.begin()); + }; + return ExecuteQueryWithPredicate(literal, segment, predicate, true); + } + case proto::plan::OpType::PostfixMatch: { + auto predicate = [&literal](const std::string_view& data) { + return data.length() >= literal.length() && + std::equal( + literal.rbegin(), literal.rend(), data.rbegin()); + }; + return ExecuteQueryWithPredicate(literal, segment, predicate, true); + } + default: + LOG_WARN("unsupported op type for ngram index: {}", op_type); + return std::nullopt; + } +} + +inline void +handle_batch(const std::string_view* data, + const int32_t* offsets, + const int size, + TargetBitmapView res, + std::function predicate) { + auto next_off_option = res.find_first(); + while (next_off_option.has_value()) { + auto next_off = next_off_option.value(); + if (next_off >= size) { + return; + } + if (!predicate(data[next_off])) { + res[next_off] = false; + } + next_off_option = res.find_next(next_off); + } +} + +std::optional +NgramInvertedIndex::ExecuteQueryWithPredicate( + const std::string& literal, + exec::SegmentExpr* segment, + std::function predicate, + bool need_post_filter) { TargetBitmap bitset{static_cast(Count())}; - wrapper_->inner_match_ngram(literal, min_gram_, max_gram_, &bitset); + wrapper_->ngram_match_query(literal, min_gram_, max_gram_, &bitset); - // Post filtering: if the literal length is larger than the max_gram - // we need to filter out the bitset - if (literal.length() > max_gram_) { - auto bitset_off = 0; - TargetBitmapView res(bitset); - TargetBitmap valid(res.size(), true); - TargetBitmapView valid_res(valid.data(), valid.size()); + TargetBitmapView res(bitset); + TargetBitmap valid(res.size(), true); + TargetBitmapView valid_res(valid.data(), valid.size()); - auto execute_sub_batch = [&literal](const std::string_view* data, - const bool* valid_data, - const int32_t* offsets, - const int size, - TargetBitmapView res, - TargetBitmapView valid_res) { - auto next_off_option = res.find_first(); - while (next_off_option.has_value()) { - auto next_off = next_off_option.value(); - if (next_off >= size) { - break; - } - if (data[next_off].find(literal) == std::string::npos) { - res[next_off] = false; - } - next_off_option = res.find_next(next_off); - } - }; + if (need_post_filter) { + auto execute_batch = + [&predicate]( + const std::string_view* data, + // `valid_data` is not used as the results returned by ngram_match_query are all valid + const bool* _valid_data, + const int32_t* offsets, + const int size, + TargetBitmapView res, + // the same with `valid_data` + TargetBitmapView _valid_res) { + handle_batch(data, offsets, size, res, predicate); + }; segment->ProcessAllDataChunk( - execute_sub_batch, std::nullptr_t{}, res, valid_res); + execute_batch, std::nullptr_t{}, res, valid_res); } return std::optional(std::move(bitset)); } +std::vector +split_by_wildcard(const std::string& literal) { + std::vector result; + std::string r; + r.reserve(literal.size()); + bool escape_mode = false; + for (char c : literal) { + if (escape_mode) { + r += c; + escape_mode = false; + } else { + if (c == '\\') { + // consider case "\\%", we should reserve % + escape_mode = true; + } else if (c == '%' || c == '_') { + if (r.length() > 0) { + result.push_back(r); + r.clear(); + } + } else { + r += c; + } + } + } + if (r.length() > 0) { + result.push_back(r); + } + return result; +} + +std::optional +NgramInvertedIndex::MatchQuery(const std::string& literal, + exec::SegmentExpr* segment) { + TargetBitmap bitset{static_cast(Count())}; + auto literals = split_by_wildcard(literal); + for (const auto& l : literals) { + if (l.length() < min_gram_) { + return std::nullopt; + } + wrapper_->ngram_match_query(l, min_gram_, max_gram_, &bitset); + } + + TargetBitmapView res(bitset); + TargetBitmap valid(res.size(), true); + TargetBitmapView valid_res(valid.data(), valid.size()); + + PatternMatchTranslator translator; + auto regex_pattern = translator(literal); + RegexMatcher matcher(regex_pattern); + + auto predicate = [&matcher](const std::string_view& data) { + return matcher(data); + }; + + auto execute_batch = + [&predicate]( + const std::string_view* data, + // `_valid_data` is not used as the results returned by ngram_match_query are all valid + const bool* _valid_data, + const int32_t* offsets, + const int size, + TargetBitmapView res, + // the same with `_valid_data` + TargetBitmapView _valid_res) { + handle_batch(data, offsets, size, res, predicate); + }; + segment->ProcessAllDataChunk( + execute_batch, std::nullptr_t{}, res, valid_res); + + return std::optional(std::move(bitset)); +} + } // namespace milvus::index diff --git a/internal/core/src/index/NgramInvertedIndex.h b/internal/core/src/index/NgramInvertedIndex.h index a569a6eb3b..33418157b1 100644 --- a/internal/core/src/index/NgramInvertedIndex.h +++ b/internal/core/src/index/NgramInvertedIndex.h @@ -36,7 +36,21 @@ class NgramInvertedIndex : public InvertedIndexTantivy { BuildWithFieldData(const std::vector& datas) override; std::optional - InnerMatchQuery(const std::string& literal, exec::SegmentExpr* segment); + ExecuteQuery(const std::string& literal, + proto::plan::OpType op_type, + exec::SegmentExpr* segment); + + private: + std::optional + ExecuteQueryWithPredicate( + const std::string& literal, + exec::SegmentExpr* segment, + std::function predicate, + bool need_post_filter); + + // Match is something like xxx%xxx%xxx, xxx%xxx, %xxx%xxx, xxx_x etc. + std::optional + MatchQuery(const std::string& literal, exec::SegmentExpr* segment); private: uintptr_t min_gram_{0}; diff --git a/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp b/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp index caea5c5cf4..18a86fc0d5 100644 --- a/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp +++ b/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp @@ -185,12 +185,10 @@ ChunkedSegmentSealedImpl::LoadScalarIndex(const LoadIndexInfo& info) { if (auto it = info.index_params.find(index::INDEX_TYPE); it != info.index_params.end() && it->second == index::NGRAM_INDEX_TYPE) { - ngram_indexings_[field_id] = - std::move(const_cast(info).cache_index); - } else { - scalar_indexings_[field_id] = - std::move(const_cast(info).cache_index); + ngram_fields_.insert(field_id); } + scalar_indexings_[field_id] = + std::move(const_cast(info).cache_index); LoadResourceRequest request = milvus::index::IndexFactory::GetInstance().ScalarIndexLoadResource( @@ -633,8 +631,8 @@ ChunkedSegmentSealedImpl::chunk_index_impl(FieldId field_id, PinWrapper ChunkedSegmentSealedImpl::GetNgramIndex(FieldId field_id) const { std::shared_lock lck(mutex_); - auto iter = ngram_indexings_.find(field_id); - if (iter == ngram_indexings_.end()) { + auto iter = scalar_indexings_.find(field_id); + if (iter == scalar_indexings_.end()) { return PinWrapper(nullptr); } auto slot = iter->second.get(); @@ -987,6 +985,7 @@ ChunkedSegmentSealedImpl::ChunkedSegmentSealedImpl( field_data_ready_bitset_(schema->size()), index_ready_bitset_(schema->size()), binlog_index_bitset_(schema->size()), + ngram_fields_(schema->size()), scalar_indexings_(schema->size()), insert_record_(*schema, MAX_ROW_COUNT), schema_(schema), @@ -1146,6 +1145,7 @@ ChunkedSegmentSealedImpl::ClearData() { index_has_raw_data_.clear(); system_ready_count_ = 0; num_rows_ = std::nullopt; + ngram_fields_.clear(); scalar_indexings_.clear(); vector_indexings_.clear(); insert_record_.clear(); diff --git a/internal/core/src/segcore/ChunkedSegmentSealedImpl.h b/internal/core/src/segcore/ChunkedSegmentSealedImpl.h index 8c08a20cdb..3e4f668eeb 100644 --- a/internal/core/src/segcore/ChunkedSegmentSealedImpl.h +++ b/internal/core/src/segcore/ChunkedSegmentSealedImpl.h @@ -120,7 +120,7 @@ class ChunkedSegmentSealedImpl : public SegmentSealed { bool HasNgramIndex(FieldId field_id) const override { std::shared_lock lck(mutex_); - return ngram_indexings_.find(field_id) != ngram_indexings_.end(); + return ngram_fields_.find(field_id) != ngram_fields_.end(); } PinWrapper @@ -432,8 +432,8 @@ class ChunkedSegmentSealedImpl : public SegmentSealed { // TODO: generate index for scalar std::optional num_rows_; - // ngram field index - std::unordered_map ngram_indexings_; + // fields that has ngram index + std::unordered_set ngram_fields_{}; // scalar field index std::unordered_map scalar_indexings_; diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/include/tantivy-binding.h b/internal/core/thirdparty/tantivy/tantivy-binding/include/tantivy-binding.h index 648cd61951..6767d6f851 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/include/tantivy-binding.h +++ b/internal/core/thirdparty/tantivy/tantivy-binding/include/tantivy-binding.h @@ -189,7 +189,7 @@ RustResult tantivy_term_query_keyword(void *ptr, const char *term, void *bitset) RustResult tantivy_term_query_keyword_i64(void *ptr, const char *term); -RustResult tantivy_inner_match_ngram(void *ptr, +RustResult tantivy_ngram_match_query(void *ptr, const char *literal, uintptr_t min_gram, uintptr_t max_gram, diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/index_ngram_writer.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/index_ngram_writer.rs index 2c93e474fb..529672b2d5 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/src/index_ngram_writer.rs +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/index_ngram_writer.rs @@ -109,7 +109,7 @@ mod tests { let reader = writer.create_reader(set_bitset).unwrap(); let mut res: Vec = vec![]; reader - .inner_match_ngram("ic", 2, 3, &mut res as *mut _ as *mut c_void) + .ngram_match_query("ic", 2, 3, &mut res as *mut _ as *mut c_void) .unwrap(); assert_eq!(res, vec![2, 4, 5]); } @@ -138,19 +138,19 @@ mod tests { let reader = writer.create_reader(set_bitset).unwrap(); let mut res: Vec = vec![]; reader - .inner_match_ngram("测试", 2, 3, &mut res as *mut _ as *mut c_void) + .ngram_match_query("测试", 2, 3, &mut res as *mut _ as *mut c_void) .unwrap(); assert_eq!(res, vec![0, 1, 2, 4]); let mut res: Vec = vec![]; reader - .inner_match_ngram("m测试", 2, 3, &mut res as *mut _ as *mut c_void) + .ngram_match_query("m测试", 2, 3, &mut res as *mut _ as *mut c_void) .unwrap(); assert_eq!(res, vec![0, 2]); let mut res: Vec = vec![]; reader - .inner_match_ngram("需要被测试", 2, 3, &mut res as *mut _ as *mut c_void) + .ngram_match_query("需要被测试", 2, 3, &mut res as *mut _ as *mut c_void) .unwrap(); assert_eq!(res, vec![4]); } diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/index_reader.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/index_reader.rs index b38a38423d..4ea6bbbea8 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/src/index_reader.rs +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/index_reader.rs @@ -300,7 +300,7 @@ impl IndexReaderWrapper { } // **Note**: literal length must be larger or equal to min_gram. - pub fn inner_match_ngram( + pub fn ngram_match_query( &self, literal: &str, min_gram: usize, diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/index_reader_c.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/index_reader_c.rs index f8c15bd4ab..9b8a30a171 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/src/index_reader_c.rs +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/index_reader_c.rs @@ -234,7 +234,7 @@ pub extern "C" fn tantivy_term_query_keyword_i64( } #[no_mangle] -pub extern "C" fn tantivy_inner_match_ngram( +pub extern "C" fn tantivy_ngram_match_query( ptr: *mut c_void, literal: *const c_char, min_gram: usize, @@ -247,7 +247,7 @@ pub extern "C" fn tantivy_inner_match_ngram( let now = std::time::Instant::now(); unsafe { (*real) - .inner_match_ngram(literal, min_gram, max_gram, bitset) + .ngram_match_query(literal, min_gram, max_gram, bitset) .into() } } diff --git a/internal/core/thirdparty/tantivy/tantivy-wrapper.h b/internal/core/thirdparty/tantivy/tantivy-wrapper.h index 17a2565268..71fe1dd4b3 100644 --- a/internal/core/thirdparty/tantivy/tantivy-wrapper.h +++ b/internal/core/thirdparty/tantivy/tantivy-wrapper.h @@ -936,19 +936,19 @@ struct TantivyIndexWrapper { } void - inner_match_ngram(const std::string& literal, + ngram_match_query(const std::string& literal, uintptr_t min_gram, uintptr_t max_gram, void* bitset) { - auto array = tantivy_inner_match_ngram( + auto array = tantivy_ngram_match_query( reader_, literal.c_str(), min_gram, max_gram, bitset); auto res = RustResultWrapper(array); AssertInfo(res.result_->success, - "TantivyIndexWrapper.inner_match_ngram: {}", + "TantivyIndexWrapper.ngram_match_query: {}", res.result_->error); AssertInfo( res.result_->value.tag == Value::Tag::None, - "TantivyIndexWrapper.inner_match_ngram: invalid result type"); + "TantivyIndexWrapper.ngram_match_query: invalid result type"); } // json query diff --git a/internal/core/unittest/test_ngram_query.cpp b/internal/core/unittest/test_ngram_query.cpp index 7d745da7c6..d35bccfab5 100644 --- a/internal/core/unittest/test_ngram_query.cpp +++ b/internal/core/unittest/test_ngram_query.cpp @@ -28,84 +28,6 @@ using namespace milvus::query; using namespace milvus::segcore; using namespace milvus::exec; -TEST(ConvertToNgramLiteralTest, EmptyString) { - auto result = parse_ngram_pattern(""); - ASSERT_FALSE(result.has_value()); -} - -TEST(ConvertToNgramLiteralTest, ExactMatchSimple) { - auto result = parse_ngram_pattern("abc"); - ASSERT_TRUE(result.has_value()); - EXPECT_EQ(result->literal, "abc"); - EXPECT_EQ(result->type, MatchType::ExactMatch); -} - -TEST(ConvertToNgramLiteralTest, ExactMatchWithEscapedPercent) { - auto result = parse_ngram_pattern("ab\\%cd"); - ASSERT_TRUE(result.has_value()); - EXPECT_EQ(result->literal, "ab%cd"); - EXPECT_EQ(result->type, MatchType::ExactMatch); -} - -TEST(ConvertToNgramLiteralTest, ExactMatchWithEscapedSpecialChar) { - auto result = parse_ngram_pattern("a.b"); - ASSERT_TRUE(result.has_value()); - EXPECT_EQ(result->literal, "a\\.b"); - EXPECT_EQ(result->type, MatchType::ExactMatch); -} - -TEST(ConvertToNgramLiteralTest, PrefixMatchSimple) { - auto result = parse_ngram_pattern("%abc"); - ASSERT_TRUE(result.has_value()); - EXPECT_EQ(result->literal, "abc"); - EXPECT_EQ(result->type, MatchType::PrefixMatch); -} - -TEST(ConvertToNgramLiteralTest, PostfixMatchSimple) { - auto result = parse_ngram_pattern("abc%"); - ASSERT_TRUE(result.has_value()); - EXPECT_EQ(result->literal, "abc"); - EXPECT_EQ(result->type, MatchType::PostfixMatch); -} - -TEST(ConvertToNgramLiteralTest, InnerMatchSimple) { - auto result = parse_ngram_pattern("%abc%"); - ASSERT_TRUE(result.has_value()); - EXPECT_EQ(result->literal, "abc"); - EXPECT_EQ(result->type, MatchType::InnerMatch); -} - -TEST(ConvertToNgramLiteralTest, MatchSinglePercentMiddle) { - auto result = parse_ngram_pattern("a%b"); - ASSERT_FALSE(result.has_value()); -} - -TEST(ConvertToNgramLiteralTest, MatchTypeReturnsNullopt) { - EXPECT_FALSE(parse_ngram_pattern("%").has_value()); - // %a%b (n=2, not %xxx%) -> Match -> nullopt - EXPECT_FALSE(parse_ngram_pattern("%a%b").has_value()); - // a%b%c (n=2, not %xxx%) -> Match -> nullopt - EXPECT_FALSE(parse_ngram_pattern("a%b%c").has_value()); - // %% (n=2, not %xxx% because length is not > 2) -> Match -> nullopt - EXPECT_FALSE(parse_ngram_pattern("%%").has_value()); - // %a%b%c% (n=3) -> Match -> nullopt - EXPECT_FALSE(parse_ngram_pattern("%a%b%c%").has_value()); -} - -TEST(ConvertToNgramLiteralTest, UnescapedUnderscoreReturnsNullopt) { - EXPECT_FALSE(parse_ngram_pattern("a_b").has_value()); - EXPECT_FALSE(parse_ngram_pattern("%a_b").has_value()); - EXPECT_FALSE(parse_ngram_pattern("a_b%").has_value()); - EXPECT_FALSE(parse_ngram_pattern("%a_b%").has_value()); -} - -TEST(ConvertToNgramLiteralTest, EscapedUnderscore) { - auto result = parse_ngram_pattern("a\\_b"); - ASSERT_TRUE(result.has_value()); - EXPECT_EQ(result->literal, "a_b"); - EXPECT_EQ(result->type, MatchType::ExactMatch); -} - auto generate_field_meta(int64_t collection_id = 1, int64_t partition_id = 2, @@ -153,7 +75,9 @@ generate_local_storage_config(const std::string& root_path) void test_ngram_with_data(const boost::container::vector& data, const std::string& literal, - const std::vector& expected_result) { + proto::plan::OpType op_type, + const std::vector& expected_result, + bool forward_to_br = false) { int64_t collection_id = 1; int64_t partition_id = 2; int64_t segment_id = 3; @@ -275,9 +199,15 @@ test_ngram_with_data(const boost::container::vector& data, 8192, 0); - auto bitset = index->InnerMatchQuery(literal, &segment_expr).value(); - for (size_t i = 0; i < nb; i++) { - ASSERT_EQ(bitset[i], expected_result[i]); + std::optional bitset_opt = + index->ExecuteQuery(literal, op_type, &segment_expr); + if (forward_to_br) { + ASSERT_TRUE(!bitset_opt.has_value()); + } else { + auto bitset = std::move(bitset_opt.value()); + for (size_t i = 0; i < nb; i++) { + ASSERT_EQ(bitset[i], expected_result[i]); + } } } @@ -318,8 +248,7 @@ test_ngram_with_data(const boost::container::vector& data, AppendIndexV2(trace, cload_index_info); UpdateSealedSegmentIndex(segment.get(), cload_index_info); - auto unary_range_expr = - test::GenUnaryRangeExpr(OpType::InnerMatch, literal); + auto unary_range_expr = test::GenUnaryRangeExpr(op_type, literal); auto column_info = test::GenColumnInfo( field_id.get(), proto::schema::DataType::VarChar, false, false); unary_range_expr->set_allocated_column_info(column_info); @@ -339,39 +268,116 @@ test_ngram_with_data(const boost::container::vector& data, TEST(NgramIndex, TestNgramWikiEpisode) { boost::container::vector data; - // not hit data.push_back( "'Indira Davelba Murillo Alvarado (Tegucigalpa, " "the youngest of eight siblings. She attended primary school at the " "Escuela 14 de Julio, and her secondary studies at the Instituto " "school called \"Indi del Bosque\", where she taught the children of " "Honduran women'"); - // hit data.push_back( "Richmond Green Secondary School is a public secondary school in " "Richmond Hill, Ontario, Canada."); - // hit data.push_back( "The Gymnasium in 2002 Gymnasium Philippinum or Philippinum High " "School is an almost 500-year-old secondary school in Marburg, Hesse, " "Germany."); - // hit data.push_back( "Sir Winston Churchill Secondary School is a Canadian secondary school " "located in St. Catharines, Ontario."); - // not hit data.push_back("Sir Winston Churchill Secondary School"); - std::vector expected_result{false, true, true, true, false}; + // within min-max_gram + { + // inner match + std::vector expected_result{true, true, true, true, true}; + test_ngram_with_data( + data, "ary", proto::plan::OpType::InnerMatch, expected_result); - test_ngram_with_data(data, "secondary school", expected_result); + expected_result = {false, true, false, true, true}; + test_ngram_with_data( + data, "y S", proto::plan::OpType::InnerMatch, expected_result); + + expected_result = {true, true, true, true, false}; + test_ngram_with_data( + data, "y s", proto::plan::OpType::InnerMatch, expected_result); + + // prefix + expected_result = {false, false, false, true, true}; + test_ngram_with_data( + data, "Sir", proto::plan::OpType::PrefixMatch, expected_result); + + // postfix + expected_result = {false, false, false, false, true}; + test_ngram_with_data( + data, "ool", proto::plan::OpType::PostfixMatch, expected_result); + + // match + expected_result = {true, false, false, false, false}; + test_ngram_with_data( + data, "%Alv%y s%", proto::plan::OpType::Match, expected_result); + } + + // exceeds max_gram + { + // inner match + std::vector expected_result{false, true, true, true, false}; + test_ngram_with_data(data, + "secondary school", + proto::plan::OpType::InnerMatch, + expected_result); + + // prefix + expected_result = {false, false, false, true, true}; + test_ngram_with_data(data, + "Sir Winston", + proto::plan::OpType::PrefixMatch, + expected_result); + + // postfix + expected_result = {false, false, true, false, false}; + test_ngram_with_data(data, + "Germany.", + proto::plan::OpType::PostfixMatch, + expected_result); + + // match + expected_result = {true, true, true, true, false}; + test_ngram_with_data(data, + "%secondary%school%", + proto::plan::OpType::Match, + expected_result); + } } -TEST(NgramIndex, TestNgramAllFalse) { +TEST(NgramIndex, TestNgramSimple) { boost::container::vector data(10000, "elementary school secondary"); // all can be hit by ngram tantivy but will be filterred out by the second phase - test_ngram_with_data( - data, "secondary school", std::vector(10000, false)); + test_ngram_with_data(data, + "secondary school", + proto::plan::OpType::InnerMatch, + std::vector(10000, false)); + + test_ngram_with_data(data, + "ele", + proto::plan::OpType::PrefixMatch, + std::vector(10000, true)); + + test_ngram_with_data(data, + "%ary%sec%", + proto::plan::OpType::Match, + std::vector(10000, true)); + + // should be forwarded to brute force + test_ngram_with_data(data, + "%ary%s%", + proto::plan::OpType::Match, + std::vector(10000, true), + true); + + test_ngram_with_data(data, + "ary", + proto::plan::OpType::PostfixMatch, + std::vector(10000, true)); }