mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
enhance: Add support for minimum_should_match in text_match (parser, engine, client, and tests) (#44988)
### Is there an existing issue for this? - [x] I have searched the existing issues --- Please see: https://github.com/milvus-io/milvus/issues/44593 for the background This PR makes https://github.com/milvus-io/milvus/pull/44638 redundant, which can be closed. The PR comments for the original implementation suggested an alternative and a better approach, this new PR has that implementation. --- This PR - Adds an optional `minimum_should_match` argument to `text_match(...)` and wires it through the parser, planner/visitor, index bindings, and client-level tests/examples so full-text queries can require a minimum number of tokens to match. Motivation - Provide a way to require an expression to match a minimum number of tokens in lexical search. What changed - Parser / grammar - Added grammar rule and token: `MINIMUM_SHOULD_MATCH` and `textMatchOption` in `internal/parser/planparserv2/Plan.g4`. - Regenerated parser outputs: `internal/parser/planparserv2/generated/*` (parser, lexer, visitor, etc.) to support the new rule. - Planner / visitor - `parser_visitor.go`: parse and validate the `minimum_should_match` integer; propagate as an extra value on the `TextMatch` expression so downstream components receive it. - Added `VisitTextMatchOption` visitor method handling. - Client (Golang) - Added a unit test to verify `text_match(..., minimum_should_match=...)` appears in the generated DSL and is accepted by client code: `client/milvusclient/read_test.go` (new test coverage). - Added an integration-style test for the feature to the go-client testcase suite: `tests/go_client/testcases/full_text_search_test.go` (exercise min=1, min=3, large min). - Added an example demonstrating `text_match` usage: `client/milvusclient/read_example_test.go` (example name conforms to godoc mapping). - Engine / index - Updated C++ index interface: `TextMatchIndex::MatchQuery` - Added/updated unit tests for the index behavior: `internal/core/src/index/TextMatchIndexTest.cpp`. - Tantivy binding - Added `match_query_with_minimum` implementation and unit tests to `internal/core/thirdparty/tantivy/tantivy-binding/src/index_reader_text.rs` that construct boolean queries with minimum required clauses. Behavioral / compatibility notes - This adds an optional argument to `text_match` only; default behavior (no `minimum_should_match`) is unchanged. - Internal API change: `TextMatchIndex::MatchQuery` signature changed (internal component). Callers in the repo were updated accordingly. - Parser changes required regenerating ANTLR outputs Tests and verification - New/updated tests: - Go client unit test: `client/milvusclient/read_test.go` (mocked Search request asserts DSL contains `minimum_should_match=2`). - Go e2e-style test: `tests/go_client/testcases/full_text_search_test.go` (exercises min=1, 3 and a large min). - C++ unit tests for index behavior: `internal/core/src/index/TextMatchIndexTest.cpp`. - Rust binding unit tests for `match_query_with_minimum`. - Local verification commands to run: - Go client tests: `cd client && go test ./milvusclient -run ^$` (client package) - Go testcases: `cd tests/go_client && go test ./testcases -run TestTextMatchMinimumShouldMatch` (requires a running Milvus instance) - C++ unit tests / build: run core build/test per repo instructions (the change touches core index code). - Rust binding tests: `cd internal/core/thirdparty/tantivy/tantivy-binding && cargo test` (if developing locally). --------- Signed-off-by: Amit Kumar <amit.kumar@reddit.com> Co-authored-by: Amit Kumar <amit.kumar@reddit.com>
This commit is contained in:
parent
489288b5e3
commit
388d56fdc7
@ -403,3 +403,91 @@ func ExampleClient_HybridSearch() {
|
||||
log.Println("Scores: ", resultSet.Scores)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleClient_Search_textMatch() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
collectionName := "text_min_match"
|
||||
titleField := "title"
|
||||
textField := "document_text"
|
||||
titleSparse := "title_sparse_vector"
|
||||
textSparse := "text_sparse_vector"
|
||||
milvusAddr := "127.0.0.1:19530"
|
||||
|
||||
cli, err := milvusclient.New(ctx, &milvusclient.ClientConfig{
|
||||
Address: milvusAddr,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal("failed to connect to milvus server: ", err.Error())
|
||||
}
|
||||
defer cli.Close(ctx)
|
||||
|
||||
_ = cli.DropCollection(ctx, milvusclient.NewDropCollectionOption(collectionName))
|
||||
|
||||
schema := entity.NewSchema().
|
||||
WithField(entity.NewField().WithName("id").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true).WithIsAutoID(true)).
|
||||
WithField(entity.NewField().WithName(titleField).WithDataType(entity.FieldTypeVarChar).WithMaxLength(512).WithEnableAnalyzer(true).WithEnableMatch(true)).
|
||||
WithField(entity.NewField().WithName(textField).WithDataType(entity.FieldTypeVarChar).WithMaxLength(2048).WithEnableAnalyzer(true).WithEnableMatch(true)).
|
||||
WithField(entity.NewField().WithName(titleSparse).WithDataType(entity.FieldTypeSparseVector)).
|
||||
WithField(entity.NewField().WithName(textSparse).WithDataType(entity.FieldTypeSparseVector)).
|
||||
WithFunction(entity.NewFunction().WithName("title_bm25_func").WithType(entity.FunctionTypeBM25).WithInputFields(titleField).WithOutputFields(titleSparse)).
|
||||
WithFunction(entity.NewFunction().WithName("text_bm25_func").WithType(entity.FunctionTypeBM25).WithInputFields(textField).WithOutputFields(textSparse))
|
||||
|
||||
idxOpts := []milvusclient.CreateIndexOption{
|
||||
milvusclient.NewCreateIndexOption(collectionName, titleField, index.NewInvertedIndex()),
|
||||
milvusclient.NewCreateIndexOption(collectionName, textField, index.NewInvertedIndex()),
|
||||
milvusclient.NewCreateIndexOption(collectionName, titleSparse, index.NewSparseInvertedIndex(entity.BM25, 0.2)),
|
||||
milvusclient.NewCreateIndexOption(collectionName, textSparse, index.NewSparseInvertedIndex(entity.BM25, 0.2)),
|
||||
}
|
||||
|
||||
err = cli.CreateCollection(ctx, milvusclient.NewCreateCollectionOption(collectionName, schema).WithIndexOptions(idxOpts...))
|
||||
if err != nil {
|
||||
log.Fatal("failed to create collection: ", err.Error())
|
||||
}
|
||||
|
||||
_, err = cli.Insert(ctx, milvusclient.NewColumnBasedInsertOption(collectionName).
|
||||
WithVarcharColumn(titleField, []string{
|
||||
"History of AI",
|
||||
"Alan Turing Biography",
|
||||
"Machine Learning Overview",
|
||||
}).
|
||||
WithVarcharColumn(textField, []string{
|
||||
"Artificial intelligence was founded in 1956 by computer scientists.",
|
||||
"Alan Turing proposed early concepts of AI and machine learning.",
|
||||
"Machine learning is a subset of artificial intelligence.",
|
||||
}))
|
||||
if err != nil {
|
||||
log.Fatal("failed to insert data: ", err.Error())
|
||||
}
|
||||
|
||||
task, err := cli.LoadCollection(ctx, milvusclient.NewLoadCollectionOption(collectionName))
|
||||
if err != nil {
|
||||
log.Fatal("failed to load collection: ", err.Error())
|
||||
}
|
||||
_ = task.Await(ctx)
|
||||
|
||||
q := "artificial intelligence"
|
||||
expr := "text_match(" + titleField + ", \"" + q + "\", minimum_should_match=2) OR text_match(" + textField + ", \"" + q + "\", minimum_should_match=2)"
|
||||
|
||||
boost := entity.NewFunction().
|
||||
WithName("title_boost").
|
||||
WithType(entity.FunctionTypeRerank).
|
||||
WithParam("reranker", "boost").
|
||||
WithParam("filter", "text_match("+titleField+", \""+q+"\", minimum_should_match=2)").
|
||||
WithParam("weight", "2.0")
|
||||
|
||||
vectors := []entity.Vector{entity.Text(q)}
|
||||
rs, err := cli.Search(ctx, milvusclient.NewSearchOption(collectionName, 5, vectors).
|
||||
WithANNSField(textSparse).
|
||||
WithFilter(expr).
|
||||
WithOutputFields("id", titleField, textField).
|
||||
WithFunctionReranker(boost))
|
||||
if err != nil {
|
||||
log.Fatal("failed to search: ", err.Error())
|
||||
}
|
||||
|
||||
for _, r := range rs {
|
||||
_ = r.ResultCount
|
||||
}
|
||||
}
|
||||
|
||||
@ -191,6 +191,46 @@ func (s *ReadSuite) TestSearch() {
|
||||
})
|
||||
}
|
||||
|
||||
// TestSearch_TextMatch tests the text match search functionality.
|
||||
// It tests the minimum_should_match parameter in the expression.
|
||||
func (s *ReadSuite) TestSearch_TextMatch() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("min_should_match_in_expr", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
s.setupCache(collectionName, s.schema)
|
||||
|
||||
s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
|
||||
// ensure the expression contains minimum_should_match and both fields
|
||||
s.Contains(sr.GetDsl(), "minimum_should_match=2")
|
||||
s.Contains(sr.GetDsl(), "text_match(")
|
||||
return &milvuspb.SearchResults{
|
||||
Status: merr.Success(),
|
||||
Results: &schemapb.SearchResultData{
|
||||
NumQueries: 1,
|
||||
TopK: 1,
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
s.getInt64FieldData("ID", []int64{1}),
|
||||
},
|
||||
Ids: &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{1}}}},
|
||||
Scores: []float32{0.1},
|
||||
Topks: []int64{1},
|
||||
},
|
||||
}, nil
|
||||
}).Once()
|
||||
|
||||
q := "artificial intelligence"
|
||||
expr := "text_match(title, \"" + q + "\", minimum_should_match=2) OR text_match(document_text, \"" + q + "\", minimum_should_match=2)"
|
||||
vectors := []entity.Vector{entity.Text(q)}
|
||||
_, err := s.client.Search(ctx, NewSearchOption(collectionName, 5, vectors).
|
||||
WithANNSField("text_sparse_vector").
|
||||
WithFilter(expr).
|
||||
WithOutputFields("ID"))
|
||||
s.NoError(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ReadSuite) TestHybridSearch() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
4
go.mod
4
go.mod
@ -65,7 +65,7 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.23.0
|
||||
github.com/bits-and-blooms/bitset v1.10.0
|
||||
github.com/bytedance/mockey v1.2.14
|
||||
github.com/bytedance/sonic v1.13.2
|
||||
github.com/bytedance/sonic v1.14.0
|
||||
github.com/cenkalti/backoff/v4 v4.2.1
|
||||
github.com/cockroachdb/redact v1.1.3
|
||||
github.com/google/uuid v1.6.0
|
||||
@ -136,7 +136,7 @@ require (
|
||||
github.com/aws/smithy-go v1.22.2 // indirect
|
||||
github.com/benesch/cgosymbolizer v0.0.0-20190515212042-bec6fe6e597b // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/bytedance/sonic/loader v0.2.4 // indirect
|
||||
github.com/bytedance/sonic/loader v0.3.0 // indirect
|
||||
github.com/campoy/embedmd v1.0.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/cilium/ebpf v0.11.0 // indirect
|
||||
|
||||
8
go.sum
8
go.sum
@ -205,11 +205,11 @@ github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl
|
||||
github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s=
|
||||
github.com/bytedance/mockey v1.2.14 h1:KZaFgPdiUwW+jOWFieo3Lr7INM1P+6adO3hxZhDswY8=
|
||||
github.com/bytedance/mockey v1.2.14/go.mod h1:1BPHF9sol5R1ud/+0VEHGQq/+i2lN+GTsr3O2Q9IENY=
|
||||
github.com/bytedance/sonic v1.13.2 h1:8/H1FempDZqC4VqjptGo14QQlJx8VdZJegxs6wwfqpQ=
|
||||
github.com/bytedance/sonic v1.13.2/go.mod h1:o68xyaF9u2gvVBuGHPlUVCy+ZfmNNO5ETf1+KgkJhz4=
|
||||
github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ=
|
||||
github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA=
|
||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||
github.com/bytedance/sonic/loader v0.2.4 h1:ZWCw4stuXUsn1/+zQDqeE7JKP+QO47tz7QCNan80NzY=
|
||||
github.com/bytedance/sonic/loader v0.2.4/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI=
|
||||
github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA=
|
||||
github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI=
|
||||
github.com/campoy/embedmd v1.0.0 h1:V4kI2qTJJLf4J29RzI/MAt2c3Bl4dQSYPuflzwFH2hY=
|
||||
github.com/campoy/embedmd v1.0.0/go.mod h1:oxyr9RCiSXg0M3VJ3ks0UGfp98BpSSGr0kpiX3MzVl8=
|
||||
github.com/casbin/casbin/v2 v2.0.0/go.mod h1:YcPU1XXisHhLzuxH9coDNf2FbKpjGlbCg3n9yuLkIJQ=
|
||||
|
||||
@ -1797,10 +1797,18 @@ PhyUnaryRangeFilterExpr::ExecTextMatch() {
|
||||
}
|
||||
}
|
||||
|
||||
auto func = [op_type, slop](Index* index,
|
||||
const std::string& query) -> TargetBitmap {
|
||||
uint32_t min_should_match = 1; // default value
|
||||
if (op_type == proto::plan::OpType::TextMatch &&
|
||||
expr_->extra_values_.size() > 0) {
|
||||
// min_should_match is stored in the first extra value
|
||||
min_should_match = static_cast<uint32_t>(
|
||||
GetValueFromProto<int64_t>(expr_->extra_values_[0]));
|
||||
}
|
||||
|
||||
auto func = [op_type, slop, min_should_match](
|
||||
Index* index, const std::string& query) -> TargetBitmap {
|
||||
if (op_type == proto::plan::OpType::TextMatch) {
|
||||
return index->MatchQuery(query);
|
||||
return index->MatchQuery(query, min_should_match);
|
||||
} else if (op_type == proto::plan::OpType::PhraseMatch) {
|
||||
return index->PhraseMatchQuery(query, slop);
|
||||
} else {
|
||||
|
||||
@ -302,7 +302,8 @@ TextMatchIndex::RegisterTokenizer(const char* tokenizer_name,
|
||||
}
|
||||
|
||||
TargetBitmap
|
||||
TextMatchIndex::MatchQuery(const std::string& query) {
|
||||
TextMatchIndex::MatchQuery(const std::string& query,
|
||||
uint32_t min_should_match) {
|
||||
tracer::AutoSpan span("TextMatchIndex::MatchQuery", tracer::GetRootSpan());
|
||||
if (shouldTriggerCommit()) {
|
||||
Commit();
|
||||
@ -310,10 +311,10 @@ TextMatchIndex::MatchQuery(const std::string& query) {
|
||||
}
|
||||
|
||||
TargetBitmap bitset{static_cast<size_t>(Count())};
|
||||
// The count opeartion of tantivy may be get older cnt if the index is committed with new tantivy segment.
|
||||
// The count operation of tantivy may be get older cnt if the index is committed with new tantivy segment.
|
||||
// So we cannot use the count operation to get the total count for bitmap.
|
||||
// Just use the maximum offset of hits to get the total count for bitmap here.
|
||||
wrapper_->match_query(query, &bitset);
|
||||
wrapper_->match_query(query, min_should_match, &bitset);
|
||||
return bitset;
|
||||
}
|
||||
|
||||
@ -327,7 +328,7 @@ TextMatchIndex::PhraseMatchQuery(const std::string& query, uint32_t slop) {
|
||||
}
|
||||
|
||||
TargetBitmap bitset{static_cast<size_t>(Count())};
|
||||
// The count opeartion of tantivy may be get older cnt if the index is committed with new tantivy segment.
|
||||
// The count operation of tantivy may be get older cnt if the index is committed with new tantivy segment.
|
||||
// So we cannot use the count operation to get the total count for bitmap.
|
||||
// Just use the maximum offset of hits to get the total count for bitmap here.
|
||||
wrapper_->phrase_match_query(query, slop, &bitset);
|
||||
|
||||
@ -83,7 +83,7 @@ class TextMatchIndex : public InvertedIndexTantivy<std::string> {
|
||||
RegisterTokenizer(const char* tokenizer_name, const char* analyzer_params);
|
||||
|
||||
TargetBitmap
|
||||
MatchQuery(const std::string& query);
|
||||
MatchQuery(const std::string& query, uint32_t min_should_match);
|
||||
|
||||
TargetBitmap
|
||||
PhraseMatchQuery(const std::string& query, uint32_t slop);
|
||||
|
||||
@ -164,7 +164,7 @@ TEST(TextMatch, Index) {
|
||||
index->Reload();
|
||||
|
||||
{
|
||||
auto res = index->MatchQuery("football");
|
||||
auto res = index->MatchQuery("football", 1);
|
||||
ASSERT_EQ(res.size(), 3);
|
||||
ASSERT_TRUE(res[0]);
|
||||
ASSERT_FALSE(res[1]);
|
||||
@ -177,11 +177,16 @@ TEST(TextMatch, Index) {
|
||||
ASSERT_TRUE(res2[0]);
|
||||
ASSERT_FALSE(res2[1]);
|
||||
ASSERT_TRUE(res2[2]);
|
||||
res = index->MatchQuery("nothing");
|
||||
res = index->MatchQuery("nothing", 1);
|
||||
ASSERT_EQ(res.size(), 3);
|
||||
ASSERT_FALSE(res[0]);
|
||||
ASSERT_FALSE(res[1]);
|
||||
ASSERT_FALSE(res[2]);
|
||||
auto res3 = index->MatchQuery("football pingpang cricket", 2);
|
||||
ASSERT_EQ(res3.size(), 3);
|
||||
ASSERT_TRUE(res3[0]);
|
||||
ASSERT_FALSE(res3[1]);
|
||||
ASSERT_FALSE(res3[2]);
|
||||
}
|
||||
|
||||
{
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -31,15 +31,46 @@ impl IndexReaderWrapper {
|
||||
}
|
||||
let collector = DirectBitsetCollector {
|
||||
bitset_wrapper: BitsetWrapper::new(bitset, self.set_bitset),
|
||||
terms,
|
||||
terms: terms.clone(),
|
||||
};
|
||||
let query = BooleanQuery::new_multiterms_query(vec![]);
|
||||
let query = BooleanQuery::new_multiterms_query(terms);
|
||||
let searcher = self.reader.searcher();
|
||||
searcher
|
||||
.search(&query, &collector)
|
||||
.map_err(TantivyBindingError::TantivyError)
|
||||
}
|
||||
|
||||
pub(crate) fn match_query_with_minimum(
|
||||
&self,
|
||||
q: &str,
|
||||
min_should_match: usize,
|
||||
bitset: *mut c_void,
|
||||
) -> Result<()> {
|
||||
let mut tokenizer = self
|
||||
.index
|
||||
.tokenizer_for_field(self.field)
|
||||
.unwrap_or(standard_analyzer(vec![]))
|
||||
.clone();
|
||||
let mut token_stream = tokenizer.token_stream(q);
|
||||
let mut terms: Vec<Term> = Vec::new();
|
||||
while token_stream.advance() {
|
||||
let token = token_stream.token();
|
||||
terms.push(Term::from_field_text(self.field, &token.text));
|
||||
}
|
||||
use tantivy::query::{Occur, TermQuery};
|
||||
use tantivy::schema::IndexRecordOption;
|
||||
let mut subqueries: Vec<(Occur, Box<dyn tantivy::query::Query>)> = Vec::new();
|
||||
for term in terms.into_iter() {
|
||||
subqueries.push((
|
||||
Occur::Should,
|
||||
Box::new(TermQuery::new(term, IndexRecordOption::Basic)),
|
||||
));
|
||||
}
|
||||
let effective_min = std::cmp::max(1, min_should_match);
|
||||
let query = BooleanQuery::with_minimum_required_clauses(subqueries, effective_min);
|
||||
self.search(&query, bitset)
|
||||
}
|
||||
|
||||
// split the query string into multiple tokens using index's default tokenizer,
|
||||
// and then execute the disconjunction of term query.
|
||||
pub(crate) fn phrase_match_query(&self, q: &str, slop: u32, bitset: *mut c_void) -> Result<()> {
|
||||
@ -148,4 +179,58 @@ mod tests {
|
||||
.unwrap();
|
||||
assert_eq!(res, (0..100000).collect::<HashSet<u32>>());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_min_should_match_match_query() {
|
||||
let dir = tempfile::TempDir::new().unwrap();
|
||||
let mut writer = IndexWriterWrapper::create_text_writer(
|
||||
"text",
|
||||
dir.path().to_str().unwrap(),
|
||||
"default",
|
||||
"",
|
||||
1,
|
||||
50_000_000,
|
||||
false,
|
||||
TantivyIndexVersion::default_version(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// doc ids: 0..4
|
||||
writer.add("a b", Some(0)).unwrap();
|
||||
writer.add("a c", Some(1)).unwrap();
|
||||
writer.add("b c", Some(2)).unwrap();
|
||||
writer.add("c", Some(3)).unwrap();
|
||||
writer.add("a b c", Some(4)).unwrap();
|
||||
writer.commit().unwrap();
|
||||
|
||||
let reader = writer.create_reader(set_bitset).unwrap();
|
||||
|
||||
// min=1 behaves like union of tokens
|
||||
let mut res: HashSet<u32> = HashSet::new();
|
||||
reader
|
||||
.match_query_with_minimum("a b", 1, &mut res as *mut _ as *mut c_void)
|
||||
.unwrap();
|
||||
assert_eq!(res, vec![0, 1, 2, 4].into_iter().collect::<HashSet<u32>>());
|
||||
|
||||
// min=2 requires at least two tokens
|
||||
res.clear();
|
||||
reader
|
||||
.match_query_with_minimum("a b c", 2, &mut res as *mut _ as *mut c_void)
|
||||
.unwrap();
|
||||
assert_eq!(res, vec![0, 1, 2, 4].into_iter().collect::<HashSet<u32>>());
|
||||
|
||||
// min=3 requires all three tokens
|
||||
res.clear();
|
||||
reader
|
||||
.match_query_with_minimum("a b c", 3, &mut res as *mut _ as *mut c_void)
|
||||
.unwrap();
|
||||
assert_eq!(res, vec![4].into_iter().collect::<HashSet<u32>>());
|
||||
|
||||
// large min should yield empty
|
||||
res.clear();
|
||||
reader
|
||||
.match_query_with_minimum("a b c", 10, &mut res as *mut _ as *mut c_void)
|
||||
.unwrap();
|
||||
assert!(res.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
@ -11,11 +11,20 @@ use crate::{
|
||||
pub extern "C" fn tantivy_match_query(
|
||||
ptr: *mut c_void,
|
||||
query: *const c_char,
|
||||
min_should_match: usize,
|
||||
bitset: *mut c_void,
|
||||
) -> RustResult {
|
||||
let real = ptr as *mut IndexReaderWrapper;
|
||||
let query = cstr_to_str!(query);
|
||||
unsafe { (*real).match_query(query, bitset).into() }
|
||||
if min_should_match > 1 {
|
||||
unsafe {
|
||||
(*real)
|
||||
.match_query_with_minimum(query, min_should_match, bitset)
|
||||
.into()
|
||||
}
|
||||
} else {
|
||||
unsafe { (*real).match_query(query, bitset).into() }
|
||||
}
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
|
||||
@ -950,8 +950,11 @@ struct TantivyIndexWrapper {
|
||||
}
|
||||
|
||||
void
|
||||
match_query(const std::string& query, void* bitset) {
|
||||
auto array = tantivy_match_query(reader_, query.c_str(), bitset);
|
||||
match_query(const std::string& query,
|
||||
uintptr_t min_should_match,
|
||||
void* bitset) {
|
||||
auto array = tantivy_match_query(
|
||||
reader_, query.c_str(), min_should_match, bitset);
|
||||
auto res = RustResultWrapper(array);
|
||||
AssertInfo(res.result_->success,
|
||||
"TantivyIndexWrapper.match_query: {}",
|
||||
|
||||
@ -15,7 +15,7 @@ expr:
|
||||
| EmptyArray # EmptyArray
|
||||
| EXISTS expr # Exists
|
||||
| expr LIKE StringLiteral # Like
|
||||
| TEXTMATCH'('Identifier',' StringLiteral')' # TextMatch
|
||||
| TEXTMATCH'('Identifier',' StringLiteral (',' textMatchOption)? ')' # TextMatch
|
||||
| PHRASEMATCH'('Identifier',' StringLiteral (',' expr)? ')' # PhraseMatch
|
||||
| RANDOMSAMPLE'(' expr ')' # RandomSample
|
||||
| expr POW expr # Power
|
||||
@ -50,6 +50,9 @@ expr:
|
||||
| (Identifier | JSONIdentifier) ISNULL # IsNull
|
||||
| (Identifier | JSONIdentifier) ISNOTNULL # IsNotNull;
|
||||
|
||||
textMatchOption:
|
||||
MINIMUM_SHOULD_MATCH ASSIGN IntegerConstant;
|
||||
|
||||
// typeName: ty = (BOOL | INT8 | INT16 | INT32 | INT64 | FLOAT | DOUBLE);
|
||||
|
||||
// BOOL: 'bool';
|
||||
@ -76,6 +79,8 @@ PHRASEMATCH: 'phrase_match'|'PHRASE_MATCH';
|
||||
RANDOMSAMPLE: 'random_sample' | 'RANDOM_SAMPLE';
|
||||
INTERVAL: 'interval' | 'INTERVAL';
|
||||
ISO: 'iso' | 'ISO';
|
||||
MINIMUM_SHOULD_MATCH: 'minimum_should_match' | 'MINIMUM_SHOULD_MATCH';
|
||||
ASSIGN: '=';
|
||||
|
||||
ADD: '+';
|
||||
SUB: '-';
|
||||
|
||||
File diff suppressed because one or more lines are too long
@ -18,49 +18,51 @@ PHRASEMATCH=17
|
||||
RANDOMSAMPLE=18
|
||||
INTERVAL=19
|
||||
ISO=20
|
||||
ADD=21
|
||||
SUB=22
|
||||
MUL=23
|
||||
DIV=24
|
||||
MOD=25
|
||||
POW=26
|
||||
SHL=27
|
||||
SHR=28
|
||||
BAND=29
|
||||
BOR=30
|
||||
BXOR=31
|
||||
AND=32
|
||||
OR=33
|
||||
ISNULL=34
|
||||
ISNOTNULL=35
|
||||
BNOT=36
|
||||
NOT=37
|
||||
IN=38
|
||||
EmptyArray=39
|
||||
JSONContains=40
|
||||
JSONContainsAll=41
|
||||
JSONContainsAny=42
|
||||
ArrayContains=43
|
||||
ArrayContainsAll=44
|
||||
ArrayContainsAny=45
|
||||
ArrayLength=46
|
||||
STEuqals=47
|
||||
STTouches=48
|
||||
STOverlaps=49
|
||||
STCrosses=50
|
||||
STContains=51
|
||||
STIntersects=52
|
||||
STWithin=53
|
||||
STDWithin=54
|
||||
BooleanConstant=55
|
||||
IntegerConstant=56
|
||||
FloatingConstant=57
|
||||
Identifier=58
|
||||
Meta=59
|
||||
StringLiteral=60
|
||||
JSONIdentifier=61
|
||||
Whitespace=62
|
||||
Newline=63
|
||||
MINIMUM_SHOULD_MATCH=21
|
||||
ASSIGN=22
|
||||
ADD=23
|
||||
SUB=24
|
||||
MUL=25
|
||||
DIV=26
|
||||
MOD=27
|
||||
POW=28
|
||||
SHL=29
|
||||
SHR=30
|
||||
BAND=31
|
||||
BOR=32
|
||||
BXOR=33
|
||||
AND=34
|
||||
OR=35
|
||||
ISNULL=36
|
||||
ISNOTNULL=37
|
||||
BNOT=38
|
||||
NOT=39
|
||||
IN=40
|
||||
EmptyArray=41
|
||||
JSONContains=42
|
||||
JSONContainsAll=43
|
||||
JSONContainsAny=44
|
||||
ArrayContains=45
|
||||
ArrayContainsAll=46
|
||||
ArrayContainsAny=47
|
||||
ArrayLength=48
|
||||
STEuqals=49
|
||||
STTouches=50
|
||||
STOverlaps=51
|
||||
STCrosses=52
|
||||
STContains=53
|
||||
STIntersects=54
|
||||
STWithin=55
|
||||
STDWithin=56
|
||||
BooleanConstant=57
|
||||
IntegerConstant=58
|
||||
FloatingConstant=59
|
||||
Identifier=60
|
||||
Meta=61
|
||||
StringLiteral=62
|
||||
JSONIdentifier=63
|
||||
Whitespace=64
|
||||
Newline=65
|
||||
'('=1
|
||||
')'=2
|
||||
'['=3
|
||||
@ -74,16 +76,17 @@ Newline=63
|
||||
'>='=11
|
||||
'=='=12
|
||||
'!='=13
|
||||
'+'=21
|
||||
'-'=22
|
||||
'*'=23
|
||||
'/'=24
|
||||
'%'=25
|
||||
'**'=26
|
||||
'<<'=27
|
||||
'>>'=28
|
||||
'&'=29
|
||||
'|'=30
|
||||
'^'=31
|
||||
'~'=36
|
||||
'$meta'=59
|
||||
'='=22
|
||||
'+'=23
|
||||
'-'=24
|
||||
'*'=25
|
||||
'/'=26
|
||||
'%'=27
|
||||
'**'=28
|
||||
'<<'=29
|
||||
'>>'=30
|
||||
'&'=31
|
||||
'|'=32
|
||||
'^'=33
|
||||
'~'=38
|
||||
'$meta'=61
|
||||
|
||||
File diff suppressed because one or more lines are too long
@ -18,49 +18,51 @@ PHRASEMATCH=17
|
||||
RANDOMSAMPLE=18
|
||||
INTERVAL=19
|
||||
ISO=20
|
||||
ADD=21
|
||||
SUB=22
|
||||
MUL=23
|
||||
DIV=24
|
||||
MOD=25
|
||||
POW=26
|
||||
SHL=27
|
||||
SHR=28
|
||||
BAND=29
|
||||
BOR=30
|
||||
BXOR=31
|
||||
AND=32
|
||||
OR=33
|
||||
ISNULL=34
|
||||
ISNOTNULL=35
|
||||
BNOT=36
|
||||
NOT=37
|
||||
IN=38
|
||||
EmptyArray=39
|
||||
JSONContains=40
|
||||
JSONContainsAll=41
|
||||
JSONContainsAny=42
|
||||
ArrayContains=43
|
||||
ArrayContainsAll=44
|
||||
ArrayContainsAny=45
|
||||
ArrayLength=46
|
||||
STEuqals=47
|
||||
STTouches=48
|
||||
STOverlaps=49
|
||||
STCrosses=50
|
||||
STContains=51
|
||||
STIntersects=52
|
||||
STWithin=53
|
||||
STDWithin=54
|
||||
BooleanConstant=55
|
||||
IntegerConstant=56
|
||||
FloatingConstant=57
|
||||
Identifier=58
|
||||
Meta=59
|
||||
StringLiteral=60
|
||||
JSONIdentifier=61
|
||||
Whitespace=62
|
||||
Newline=63
|
||||
MINIMUM_SHOULD_MATCH=21
|
||||
ASSIGN=22
|
||||
ADD=23
|
||||
SUB=24
|
||||
MUL=25
|
||||
DIV=26
|
||||
MOD=27
|
||||
POW=28
|
||||
SHL=29
|
||||
SHR=30
|
||||
BAND=31
|
||||
BOR=32
|
||||
BXOR=33
|
||||
AND=34
|
||||
OR=35
|
||||
ISNULL=36
|
||||
ISNOTNULL=37
|
||||
BNOT=38
|
||||
NOT=39
|
||||
IN=40
|
||||
EmptyArray=41
|
||||
JSONContains=42
|
||||
JSONContainsAll=43
|
||||
JSONContainsAny=44
|
||||
ArrayContains=45
|
||||
ArrayContainsAll=46
|
||||
ArrayContainsAny=47
|
||||
ArrayLength=48
|
||||
STEuqals=49
|
||||
STTouches=50
|
||||
STOverlaps=51
|
||||
STCrosses=52
|
||||
STContains=53
|
||||
STIntersects=54
|
||||
STWithin=55
|
||||
STDWithin=56
|
||||
BooleanConstant=57
|
||||
IntegerConstant=58
|
||||
FloatingConstant=59
|
||||
Identifier=60
|
||||
Meta=61
|
||||
StringLiteral=62
|
||||
JSONIdentifier=63
|
||||
Whitespace=64
|
||||
Newline=65
|
||||
'('=1
|
||||
')'=2
|
||||
'['=3
|
||||
@ -74,16 +76,17 @@ Newline=63
|
||||
'>='=11
|
||||
'=='=12
|
||||
'!='=13
|
||||
'+'=21
|
||||
'-'=22
|
||||
'*'=23
|
||||
'/'=24
|
||||
'%'=25
|
||||
'**'=26
|
||||
'<<'=27
|
||||
'>>'=28
|
||||
'&'=29
|
||||
'|'=30
|
||||
'^'=31
|
||||
'~'=36
|
||||
'$meta'=59
|
||||
'='=22
|
||||
'+'=23
|
||||
'-'=24
|
||||
'*'=25
|
||||
'/'=26
|
||||
'%'=27
|
||||
'**'=28
|
||||
'<<'=29
|
||||
'>>'=30
|
||||
'&'=31
|
||||
'|'=32
|
||||
'^'=33
|
||||
'~'=38
|
||||
'$meta'=61
|
||||
|
||||
@ -194,3 +194,7 @@ func (v *BasePlanVisitor) VisitPower(ctx *PowerContext) interface{} {
|
||||
func (v *BasePlanVisitor) VisitSTOverlaps(ctx *STOverlapsContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
func (v *BasePlanVisitor) VisitTextMatchOption(ctx *TextMatchOptionContext) interface{} {
|
||||
return v.VisitChildren(ctx)
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -147,4 +147,7 @@ type PlanVisitor interface {
|
||||
|
||||
// Visit a parse tree produced by PlanParser#STOverlaps.
|
||||
VisitSTOverlaps(ctx *STOverlapsContext) interface{}
|
||||
|
||||
// Visit a parse tree produced by PlanParser#textMatchOption.
|
||||
VisitTextMatchOption(ctx *TextMatchOptionContext) interface{}
|
||||
}
|
||||
|
||||
@ -527,13 +527,28 @@ func (v *ParserVisitor) VisitTextMatch(ctx *parser.TextMatchContext) interface{}
|
||||
return err
|
||||
}
|
||||
|
||||
// Handle optional min_should_match parameter
|
||||
var extraValues []*planpb.GenericValue
|
||||
if ctx.TextMatchOption() != nil {
|
||||
minShouldMatchExpr := ctx.TextMatchOption().Accept(v)
|
||||
if err, ok := minShouldMatchExpr.(error); ok {
|
||||
return err
|
||||
}
|
||||
extraVal, err := validateAndExtractMinShouldMatch(minShouldMatchExpr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
extraValues = extraVal
|
||||
}
|
||||
|
||||
return &ExprWithType{
|
||||
expr: &planpb.Expr{
|
||||
Expr: &planpb.Expr_UnaryRangeExpr{
|
||||
UnaryRangeExpr: &planpb.UnaryRangeExpr{
|
||||
ColumnInfo: columnInfo,
|
||||
Op: planpb.OpType_TextMatch,
|
||||
Value: NewString(queryText),
|
||||
ColumnInfo: columnInfo,
|
||||
Op: planpb.OpType_TextMatch,
|
||||
Value: NewString(queryText),
|
||||
ExtraValues: extraValues,
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -541,6 +556,26 @@ func (v *ParserVisitor) VisitTextMatch(ctx *parser.TextMatchContext) interface{}
|
||||
}
|
||||
}
|
||||
|
||||
func (v *ParserVisitor) VisitTextMatchOption(ctx *parser.TextMatchOptionContext) interface{} {
|
||||
// Parse the integer constant for minimum_should_match
|
||||
integerConstant := ctx.IntegerConstant().GetText()
|
||||
value, err := strconv.ParseInt(integerConstant, 0, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid minimum_should_match value: %s", integerConstant)
|
||||
}
|
||||
|
||||
return &ExprWithType{
|
||||
expr: &planpb.Expr{
|
||||
Expr: &planpb.Expr_ValueExpr{
|
||||
ValueExpr: &planpb.ValueExpr{
|
||||
Value: NewInt(value),
|
||||
},
|
||||
},
|
||||
},
|
||||
dataType: schemapb.DataType_Int64,
|
||||
}
|
||||
}
|
||||
|
||||
func (v *ParserVisitor) VisitPhraseMatch(ctx *parser.PhraseMatchContext) interface{} {
|
||||
identifier := ctx.Identifier().GetText()
|
||||
column, err := v.translateIdentifier(identifier)
|
||||
@ -2060,3 +2095,21 @@ func reverseCompareOp(op planpb.OpType) planpb.OpType {
|
||||
return planpb.OpType_Invalid
|
||||
}
|
||||
}
|
||||
|
||||
func validateAndExtractMinShouldMatch(minShouldMatchExpr interface{}) ([]*planpb.GenericValue, error) {
|
||||
if minShouldMatchValue, ok := minShouldMatchExpr.(*ExprWithType); ok {
|
||||
valueExpr := getValueExpr(minShouldMatchValue)
|
||||
if valueExpr == nil || valueExpr.GetValue() == nil {
|
||||
return nil, fmt.Errorf("minimum_should_match should be a const integer expression")
|
||||
}
|
||||
minShouldMatch := valueExpr.GetValue().GetInt64Val()
|
||||
if minShouldMatch < 1 {
|
||||
return nil, fmt.Errorf("minimum_should_match should be >= 1, got %d", minShouldMatch)
|
||||
}
|
||||
if minShouldMatch > 1000 {
|
||||
return nil, fmt.Errorf("minimum_should_match should be <= 1000, got %d", minShouldMatch)
|
||||
}
|
||||
return []*planpb.GenericValue{NewInt(minShouldMatch)}, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@ -291,6 +291,272 @@ func TestExpr_TextMatch(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpr_TextMatch_MinShouldMatch(t *testing.T) {
|
||||
schema := newTestSchema(true)
|
||||
enableMatch(schema)
|
||||
helper, err := typeutil.CreateSchemaHelper(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for _, v := range []int64{1, 2, 1000} {
|
||||
expr := fmt.Sprintf(`text_match(VarCharField, "query", minimum_should_match=%d)`, v)
|
||||
plan, err := CreateSearchPlan(helper, expr, "FloatVectorField", &planpb.QueryInfo{
|
||||
Topk: 10,
|
||||
MetricType: "L2",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err, expr)
|
||||
assert.NotNil(t, plan)
|
||||
|
||||
predicates := plan.GetVectorAnns().GetPredicates()
|
||||
assert.NotNil(t, predicates)
|
||||
ure := predicates.GetUnaryRangeExpr()
|
||||
assert.NotNil(t, ure)
|
||||
assert.Equal(t, planpb.OpType_TextMatch, ure.GetOp())
|
||||
assert.Equal(t, "query", ure.GetValue().GetStringVal())
|
||||
extra := ure.GetExtraValues()
|
||||
assert.Equal(t, 1, len(extra))
|
||||
assert.Equal(t, v, extra[0].GetInt64Val())
|
||||
}
|
||||
|
||||
{
|
||||
expr := `text_match(VarCharField, "query", minimum_should_match=0)`
|
||||
_, err := CreateSearchPlan(helper, expr, "FloatVectorField", &planpb.QueryInfo{}, nil, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "minimum_should_match should be >= 1")
|
||||
}
|
||||
|
||||
{
|
||||
expr := `text_match(VarCharField, "query", minimum_should_match=1001)`
|
||||
_, err := CreateSearchPlan(helper, expr, "FloatVectorField", &planpb.QueryInfo{}, nil, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "minimum_should_match should be <= 1000")
|
||||
}
|
||||
|
||||
{
|
||||
expr := `text_match(VarCharField, "query", minimum_should_match=1.5)`
|
||||
_, err := CreateSearchPlan(helper, expr, "FloatVectorField", &planpb.QueryInfo{}, nil, nil)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
{
|
||||
expr := `text_match(VarCharField, "query", minimum_should_match={min})`
|
||||
_, err := CreateSearchPlan(helper, expr, "FloatVectorField", &planpb.QueryInfo{}, nil, nil)
|
||||
assert.Error(t, err)
|
||||
// grammar rejects placeholder before visitor; accept either parse error or visitor error
|
||||
errMsg := err.Error()
|
||||
assert.True(t, strings.Contains(errMsg, "mismatched input") || strings.Contains(errMsg, "minimum_should_match should be a const integer expression"), errMsg)
|
||||
}
|
||||
|
||||
{
|
||||
expr := `text_match(VarCharField, "query", minimum_should_match=9223372036854775808)`
|
||||
_, err := CreateSearchPlan(helper, expr, "FloatVectorField", &planpb.QueryInfo{}, nil, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid minimum_should_match value")
|
||||
}
|
||||
|
||||
{
|
||||
expr := `text_match(VarCharField, "\中国")`
|
||||
_, err := CreateSearchPlan(helper, expr, "FloatVectorField", &planpb.QueryInfo{}, nil, nil)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
{
|
||||
expr := `text_match(VarCharField, "query", minimum_should_match=9223372036854775808)`
|
||||
_, err := CreateSearchPlan(helper, expr, "FloatVectorField", &planpb.QueryInfo{}, nil, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid minimum_should_match value")
|
||||
}
|
||||
|
||||
{
|
||||
expr := `text_match(VarCharField, "\中国")`
|
||||
_, err := CreateSearchPlan(helper, expr, "FloatVectorField", &planpb.QueryInfo{}, nil, nil)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpr_TextMatch_MinShouldMatch_NilValue_Coverage(t *testing.T) {
|
||||
// This test is specifically to cover the error case in validateAndExtractMinShouldMatch
|
||||
// which handles the edge case where minShouldMatchExpr is an ExprWithType
|
||||
|
||||
// Test case 1: ExprWithType with a ColumnExpr
|
||||
// This will make getValueExpr return nil
|
||||
exprWithColumnExpr := &ExprWithType{
|
||||
expr: &planpb.Expr{
|
||||
Expr: &planpb.Expr_ColumnExpr{
|
||||
ColumnExpr: &planpb.ColumnExpr{
|
||||
Info: &planpb.ColumnInfo{
|
||||
FieldId: 100,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
dataType: schemapb.DataType_Int64,
|
||||
}
|
||||
|
||||
_, err := validateAndExtractMinShouldMatch(exprWithColumnExpr)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "minimum_should_match should be a const integer expression")
|
||||
|
||||
// Test case 2: ExprWithType with a ValueExpr but nil Value
|
||||
// This will make getValueExpr return a non-nil ValueExpr but GetValue() returns nil
|
||||
exprWithNilValue := &ExprWithType{
|
||||
expr: &planpb.Expr{
|
||||
Expr: &planpb.Expr_ValueExpr{
|
||||
ValueExpr: &planpb.ValueExpr{
|
||||
Value: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
dataType: schemapb.DataType_Int64,
|
||||
}
|
||||
|
||||
_, err = validateAndExtractMinShouldMatch(exprWithNilValue)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "minimum_should_match should be a const integer expression")
|
||||
|
||||
// Test case 3: Valid ExprWithType with proper value
|
||||
validExpr := &ExprWithType{
|
||||
expr: &planpb.Expr{
|
||||
Expr: &planpb.Expr_ValueExpr{
|
||||
ValueExpr: &planpb.ValueExpr{
|
||||
Value: NewInt(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
dataType: schemapb.DataType_Int64,
|
||||
}
|
||||
|
||||
extraVals, err := validateAndExtractMinShouldMatch(validExpr)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, extraVals)
|
||||
assert.Equal(t, 1, len(extraVals))
|
||||
assert.Equal(t, int64(10), extraVals[0].GetInt64Val())
|
||||
}
|
||||
|
||||
func TestExpr_TextMatch_MinShouldMatch_Omitted(t *testing.T) {
|
||||
schema := newTestSchema(true)
|
||||
enableMatch(schema)
|
||||
helper, err := typeutil.CreateSchemaHelper(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expr := `text_match(VarCharField, "query")`
|
||||
plan, err := CreateSearchPlan(helper, expr, "FloatVectorField", &planpb.QueryInfo{
|
||||
Topk: 10,
|
||||
MetricType: "L2",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, plan)
|
||||
|
||||
predicates := plan.GetVectorAnns().GetPredicates()
|
||||
assert.NotNil(t, predicates)
|
||||
ure := predicates.GetUnaryRangeExpr()
|
||||
assert.NotNil(t, ure)
|
||||
assert.Equal(t, planpb.OpType_TextMatch, ure.GetOp())
|
||||
assert.Equal(t, "query", ure.GetValue().GetStringVal())
|
||||
// When omitted, ExtraValues should be empty
|
||||
assert.Equal(t, 0, len(ure.GetExtraValues()))
|
||||
}
|
||||
|
||||
func TestExpr_TextMatch_MinShouldMatch_IntegerConstant(t *testing.T) {
|
||||
schema := newTestSchema(true)
|
||||
enableMatch(schema)
|
||||
helper, err := typeutil.CreateSchemaHelper(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expr := `text_match(VarCharField, "query", minimum_should_match=10)`
|
||||
plan, err := CreateSearchPlan(helper, expr, "FloatVectorField", &planpb.QueryInfo{
|
||||
Topk: 10,
|
||||
MetricType: "L2",
|
||||
SearchParams: "",
|
||||
RoundDecimal: 0,
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, plan)
|
||||
|
||||
predicates := plan.GetVectorAnns().GetPredicates()
|
||||
assert.NotNil(t, predicates)
|
||||
ure := predicates.GetUnaryRangeExpr()
|
||||
assert.NotNil(t, ure)
|
||||
assert.Equal(t, planpb.OpType_TextMatch, ure.GetOp())
|
||||
assert.Equal(t, "query", ure.GetValue().GetStringVal())
|
||||
extra := ure.GetExtraValues()
|
||||
assert.Equal(t, 1, len(extra))
|
||||
assert.Equal(t, int64(10), extra[0].GetInt64Val())
|
||||
}
|
||||
|
||||
func TestExpr_TextMatch_MinShouldMatch_NameTypos(t *testing.T) {
|
||||
schema := newTestSchema(true)
|
||||
enableMatch(schema)
|
||||
helper, err := typeutil.CreateSchemaHelper(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
invalid := []string{
|
||||
`text_match(VarCharField, "q", minimum_shouldmatch=1)`,
|
||||
`text_match(VarCharField, "q", min_should_match=1)`,
|
||||
`text_match(VarCharField, "q", minimumShouldMatch=1)`,
|
||||
`text_match(VarCharField, "q", minimum-should-match=1)`,
|
||||
`text_match(VarCharField, "q", minimum_should_matchx=1)`,
|
||||
}
|
||||
for _, expr := range invalid {
|
||||
_, err := CreateSearchPlan(helper, expr, "FloatVectorField", &planpb.QueryInfo{}, nil, nil)
|
||||
assert.Error(t, err, expr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpr_TextMatch_MinShouldMatch_InvalidValueTypes(t *testing.T) {
|
||||
schema := newTestSchema(true)
|
||||
enableMatch(schema)
|
||||
helper, err := typeutil.CreateSchemaHelper(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
invalid := []string{
|
||||
`text_match(VarCharField, "q", minimum_should_match=10*10)`,
|
||||
`text_match(VarCharField, "q", minimum_should_match=nil)`,
|
||||
`text_match(VarCharField, "q", minimum_should_match=)`,
|
||||
`text_match(VarCharField, "q", minimum_should_match="10")`,
|
||||
`text_match(VarCharField, "q", minimum_should_match=true)`,
|
||||
`text_match(VarCharField, "q", minimum_should_match=a)`,
|
||||
`text_match(VarCharField, "q", minimum_should_match={min})`,
|
||||
`text_match(VarCharField, "q", minimum_should_match=1.0)`,
|
||||
`text_match(VarCharField, "q", minimum_should_match=-1)`,
|
||||
}
|
||||
for _, expr := range invalid {
|
||||
_, err := CreateSearchPlan(helper, expr, "FloatVectorField", &planpb.QueryInfo{}, nil, nil)
|
||||
assert.Error(t, err, expr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpr_TextMatch_MinShouldMatch_LeadingZerosAndOctal(t *testing.T) {
|
||||
schema := newTestSchema(true)
|
||||
enableMatch(schema)
|
||||
helper, err := typeutil.CreateSchemaHelper(schema)
|
||||
assert.NoError(t, err)
|
||||
{
|
||||
expr := `text_match(VarCharField, "query", minimum_should_match=001)`
|
||||
plan, err := CreateSearchPlan(helper, expr, "FloatVectorField", &planpb.QueryInfo{Topk: 10}, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
ure := plan.GetVectorAnns().GetPredicates().GetUnaryRangeExpr()
|
||||
extra := ure.GetExtraValues()
|
||||
assert.Equal(t, 1, len(extra))
|
||||
assert.Equal(t, int64(1), extra[0].GetInt64Val())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpr_TextMatch_MinShouldMatch_DuplicateOption(t *testing.T) {
|
||||
schema := newTestSchema(true)
|
||||
enableMatch(schema)
|
||||
helper, err := typeutil.CreateSchemaHelper(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expr := `text_match(VarCharField, "query", minimum_should_match=2, minimum_should_match=3)`
|
||||
_, err = CreateSearchPlan(helper, expr, "FloatVectorField", &planpb.QueryInfo{}, nil, nil)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestExpr_PhraseMatch(t *testing.T) {
|
||||
schema := newTestSchema(true)
|
||||
helper, err := typeutil.CreateSchemaHelper(schema)
|
||||
|
||||
@ -340,3 +340,42 @@ func TestFullTextSearchDefaultValue(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestTextMatchMinimumShouldMatch verifies text_match(..., minimum_should_match=N)
|
||||
func TestTextMatchMinimumShouldMatch(t *testing.T) {
|
||||
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
|
||||
mc := hp.CreateDefaultMilvusClient(ctx, t)
|
||||
|
||||
function := hp.TNewBM25Function(common.DefaultTextFieldName, common.DefaultTextSparseVecFieldName)
|
||||
// Provide valid field options instead of nil to satisfy CreateCollection's type check
|
||||
analyzerParams := map[string]any{"tokenizer": "standard"}
|
||||
fieldsOption := hp.TNewFieldsOption().TWithAnalyzerParams(analyzerParams)
|
||||
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.FullTextSearch), fieldsOption, hp.TNewSchemaOption().TWithFunction(function))
|
||||
|
||||
docs := []string{"a b", "a c", "b c", "c", "a b c"}
|
||||
insertOption := hp.TNewDataOption().TWithTextLang(common.DefaultTextLang).TWithTextData(docs)
|
||||
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), insertOption)
|
||||
prepare.FlushData(ctx, t, mc, schema.CollectionName)
|
||||
|
||||
indexparams := hp.TNewIndexParams(schema).TWithFieldIndex(map[string]index.Index{common.DefaultTextSparseVecFieldName: index.NewSparseInvertedIndex(entity.BM25, 0.1)})
|
||||
prepare.CreateIndex(ctx, t, mc, indexparams)
|
||||
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
|
||||
|
||||
// min=1 should return docs containing any token from "a b c"
|
||||
expr1 := fmt.Sprintf("text_match(%s, \"%s\", minimum_should_match=%d)", common.DefaultTextFieldName, "a b c", 1)
|
||||
res1, err := mc.Query(ctx, milvusclient.NewQueryOption(schema.CollectionName).WithFilter(expr1))
|
||||
require.NoError(t, err)
|
||||
require.GreaterOrEqual(t, res1.ResultCount, 4)
|
||||
|
||||
// min=3 should return only the doc containing all three tokens
|
||||
expr3 := fmt.Sprintf("text_match(%s, \"%s\", minimum_should_match=%d)", common.DefaultTextFieldName, "a b c", 3)
|
||||
res3, err := mc.Query(ctx, milvusclient.NewQueryOption(schema.CollectionName).WithFilter(expr3))
|
||||
require.NoError(t, err)
|
||||
require.GreaterOrEqual(t, res3.ResultCount, 1)
|
||||
|
||||
// min large (e.g. 10) should return 0
|
||||
exprLarge := fmt.Sprintf("text_match(%s, \"%s\", minimum_should_match=%d)", common.DefaultTextFieldName, "a b c", 10)
|
||||
resLarge, err := mc.Query(ctx, milvusclient.NewQueryOption(schema.CollectionName).WithFilter(exprLarge))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, resLarge.ResultCount)
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user