mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
### Is there an existing issue for this? - [x] I have searched the existing issues --- Please see: https://github.com/milvus-io/milvus/issues/44593 for the background This PR makes https://github.com/milvus-io/milvus/pull/44638 redundant, which can be closed. The PR comments for the original implementation suggested an alternative and a better approach, this new PR has that implementation. --- This PR - Adds an optional `minimum_should_match` argument to `text_match(...)` and wires it through the parser, planner/visitor, index bindings, and client-level tests/examples so full-text queries can require a minimum number of tokens to match. Motivation - Provide a way to require an expression to match a minimum number of tokens in lexical search. What changed - Parser / grammar - Added grammar rule and token: `MINIMUM_SHOULD_MATCH` and `textMatchOption` in `internal/parser/planparserv2/Plan.g4`. - Regenerated parser outputs: `internal/parser/planparserv2/generated/*` (parser, lexer, visitor, etc.) to support the new rule. - Planner / visitor - `parser_visitor.go`: parse and validate the `minimum_should_match` integer; propagate as an extra value on the `TextMatch` expression so downstream components receive it. - Added `VisitTextMatchOption` visitor method handling. - Client (Golang) - Added a unit test to verify `text_match(..., minimum_should_match=...)` appears in the generated DSL and is accepted by client code: `client/milvusclient/read_test.go` (new test coverage). - Added an integration-style test for the feature to the go-client testcase suite: `tests/go_client/testcases/full_text_search_test.go` (exercise min=1, min=3, large min). - Added an example demonstrating `text_match` usage: `client/milvusclient/read_example_test.go` (example name conforms to godoc mapping). - Engine / index - Updated C++ index interface: `TextMatchIndex::MatchQuery` - Added/updated unit tests for the index behavior: `internal/core/src/index/TextMatchIndexTest.cpp`. - Tantivy binding - Added `match_query_with_minimum` implementation and unit tests to `internal/core/thirdparty/tantivy/tantivy-binding/src/index_reader_text.rs` that construct boolean queries with minimum required clauses. Behavioral / compatibility notes - This adds an optional argument to `text_match` only; default behavior (no `minimum_should_match`) is unchanged. - Internal API change: `TextMatchIndex::MatchQuery` signature changed (internal component). Callers in the repo were updated accordingly. - Parser changes required regenerating ANTLR outputs Tests and verification - New/updated tests: - Go client unit test: `client/milvusclient/read_test.go` (mocked Search request asserts DSL contains `minimum_should_match=2`). - Go e2e-style test: `tests/go_client/testcases/full_text_search_test.go` (exercises min=1, 3 and a large min). - C++ unit tests for index behavior: `internal/core/src/index/TextMatchIndexTest.cpp`. - Rust binding unit tests for `match_query_with_minimum`. - Local verification commands to run: - Go client tests: `cd client && go test ./milvusclient -run ^$` (client package) - Go testcases: `cd tests/go_client && go test ./testcases -run TestTextMatchMinimumShouldMatch` (requires a running Milvus instance) - C++ unit tests / build: run core build/test per repo instructions (the change touches core index code). - Rust binding tests: `cd internal/core/thirdparty/tantivy/tantivy-binding && cargo test` (if developing locally). --------- Signed-off-by: Amit Kumar <amit.kumar@reddit.com> Co-authored-by: Amit Kumar <amit.kumar@reddit.com>
332 lines
11 KiB
Go
332 lines
11 KiB
Go
package milvusclient
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"math/rand"
|
|
"testing"
|
|
|
|
"github.com/samber/lo"
|
|
"github.com/stretchr/testify/mock"
|
|
"github.com/stretchr/testify/suite"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
"github.com/milvus-io/milvus/client/v2/entity"
|
|
"github.com/milvus-io/milvus/client/v2/index"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
|
)
|
|
|
|
type ReadSuite struct {
|
|
MockSuiteBase
|
|
|
|
schema *entity.Schema
|
|
schemaDyn *entity.Schema
|
|
}
|
|
|
|
func (s *ReadSuite) SetupSuite() {
|
|
s.MockSuiteBase.SetupSuite()
|
|
s.schema = entity.NewSchema().
|
|
WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
|
|
WithField(entity.NewField().WithName("Vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128))
|
|
|
|
s.schemaDyn = entity.NewSchema().WithDynamicFieldEnabled(true).
|
|
WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
|
|
WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128))
|
|
}
|
|
|
|
func (s *ReadSuite) TestSearch() {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
s.Run("success", func() {
|
|
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
|
partitionName := fmt.Sprintf("part_%s", s.randString(6))
|
|
s.setupCache(collectionName, s.schema)
|
|
s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
|
|
s.Equal(collectionName, sr.GetCollectionName())
|
|
s.ElementsMatch([]string{partitionName}, sr.GetPartitionNames())
|
|
|
|
return &milvuspb.SearchResults{
|
|
Status: merr.Success(),
|
|
Results: &schemapb.SearchResultData{
|
|
NumQueries: 1,
|
|
TopK: 10,
|
|
FieldsData: []*schemapb.FieldData{
|
|
s.getInt64FieldData("ID", []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
|
|
},
|
|
Ids: &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{
|
|
Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
|
},
|
|
},
|
|
},
|
|
Scores: make([]float32, 10),
|
|
Topks: []int64{10},
|
|
Recalls: []float32{0.9},
|
|
},
|
|
}, nil
|
|
}).Once()
|
|
|
|
ap := index.NewCustomAnnParam()
|
|
ap.WithExtraParam("custom_level", 1)
|
|
_, err := s.client.Search(ctx, NewSearchOption(collectionName, 10, []entity.Vector{
|
|
entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
|
|
return rand.Float32()
|
|
})),
|
|
}).WithPartitions(partitionName).
|
|
WithFilter("id > {tmpl_id}").
|
|
WithTemplateParam("tmpl_id", 100).
|
|
WithGroupByField("group_by").
|
|
WithSearchParam("ignore_growing", "true").
|
|
WithAnnParam(ap),
|
|
)
|
|
s.NoError(err)
|
|
})
|
|
|
|
s.Run("dynamic_schema", func() {
|
|
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
|
partitionName := fmt.Sprintf("part_%s", s.randString(6))
|
|
s.setupCache(collectionName, s.schemaDyn)
|
|
s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
|
|
return &milvuspb.SearchResults{
|
|
Status: merr.Success(),
|
|
Results: &schemapb.SearchResultData{
|
|
NumQueries: 1,
|
|
TopK: 2,
|
|
FieldsData: []*schemapb.FieldData{
|
|
s.getInt64FieldData("ID", []int64{1, 2}),
|
|
s.getJSONBytesFieldData("$meta", [][]byte{
|
|
[]byte(`{"A": 123, "B": "456"}`),
|
|
[]byte(`{"B": "abc", "A": 456}`),
|
|
}, true),
|
|
},
|
|
Ids: &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{
|
|
Data: []int64{1, 2},
|
|
},
|
|
},
|
|
},
|
|
Scores: make([]float32, 2),
|
|
Topks: []int64{2},
|
|
},
|
|
}, nil
|
|
}).Once()
|
|
|
|
_, err := s.client.Search(ctx, NewSearchOption(collectionName, 10, []entity.Vector{
|
|
entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
|
|
return rand.Float32()
|
|
})),
|
|
}).WithPartitions(partitionName))
|
|
s.NoError(err)
|
|
})
|
|
|
|
s.Run("bad_result", func() {
|
|
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
|
partitionName := fmt.Sprintf("part_%s", s.randString(6))
|
|
s.setupCache(collectionName, s.schema)
|
|
s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
|
|
s.Equal(collectionName, sr.GetCollectionName())
|
|
s.ElementsMatch([]string{partitionName}, sr.GetPartitionNames())
|
|
|
|
return &milvuspb.SearchResults{
|
|
Status: merr.Success(),
|
|
Results: &schemapb.SearchResultData{
|
|
NumQueries: 1,
|
|
TopK: 10,
|
|
FieldsData: []*schemapb.FieldData{
|
|
s.getInt64FieldData("ID", []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
|
|
},
|
|
Ids: &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{
|
|
Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
|
},
|
|
},
|
|
},
|
|
Scores: make([]float32, 10),
|
|
Topks: []int64{},
|
|
},
|
|
}, nil
|
|
}).Once()
|
|
|
|
ap := index.NewCustomAnnParam()
|
|
ap.WithExtraParam("custom_level", 1)
|
|
rss, err := s.client.Search(ctx, NewSearchOption(collectionName, 10, []entity.Vector{
|
|
entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
|
|
return rand.Float32()
|
|
})),
|
|
}).WithPartitions(partitionName).
|
|
WithFilter("id > {tmpl_id}").
|
|
WithTemplateParam("tmpl_id", 100).
|
|
WithGroupByField("group_by").
|
|
WithSearchParam("ignore_growing", "true").
|
|
WithAnnParam(ap),
|
|
)
|
|
s.NoError(err)
|
|
s.Len(rss, 1)
|
|
rs := rss[0]
|
|
s.Error(rs.Err)
|
|
})
|
|
|
|
s.Run("failure", func() {
|
|
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
|
s.setupCache(collectionName, s.schemaDyn)
|
|
|
|
_, err := s.client.Search(ctx, NewSearchOption(collectionName, 10, []entity.Vector{nonSupportData{}}))
|
|
s.Error(err)
|
|
|
|
s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
|
|
return nil, merr.WrapErrServiceInternal("mocked")
|
|
}).Once()
|
|
|
|
_, err = s.client.Search(ctx, NewSearchOption(collectionName, 10, []entity.Vector{
|
|
entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
|
|
return rand.Float32()
|
|
})),
|
|
}))
|
|
s.Error(err)
|
|
})
|
|
}
|
|
|
|
// TestSearch_TextMatch tests the text match search functionality.
|
|
// It tests the minimum_should_match parameter in the expression.
|
|
func (s *ReadSuite) TestSearch_TextMatch() {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
s.Run("min_should_match_in_expr", func() {
|
|
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
|
s.setupCache(collectionName, s.schema)
|
|
|
|
s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
|
|
// ensure the expression contains minimum_should_match and both fields
|
|
s.Contains(sr.GetDsl(), "minimum_should_match=2")
|
|
s.Contains(sr.GetDsl(), "text_match(")
|
|
return &milvuspb.SearchResults{
|
|
Status: merr.Success(),
|
|
Results: &schemapb.SearchResultData{
|
|
NumQueries: 1,
|
|
TopK: 1,
|
|
FieldsData: []*schemapb.FieldData{
|
|
s.getInt64FieldData("ID", []int64{1}),
|
|
},
|
|
Ids: &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{1}}}},
|
|
Scores: []float32{0.1},
|
|
Topks: []int64{1},
|
|
},
|
|
}, nil
|
|
}).Once()
|
|
|
|
q := "artificial intelligence"
|
|
expr := "text_match(title, \"" + q + "\", minimum_should_match=2) OR text_match(document_text, \"" + q + "\", minimum_should_match=2)"
|
|
vectors := []entity.Vector{entity.Text(q)}
|
|
_, err := s.client.Search(ctx, NewSearchOption(collectionName, 5, vectors).
|
|
WithANNSField("text_sparse_vector").
|
|
WithFilter(expr).
|
|
WithOutputFields("ID"))
|
|
s.NoError(err)
|
|
})
|
|
}
|
|
|
|
func (s *ReadSuite) TestHybridSearch() {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
s.Run("success", func() {
|
|
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
|
partitionName := fmt.Sprintf("part_%s", s.randString(6))
|
|
s.setupCache(collectionName, s.schema)
|
|
|
|
s.mock.EXPECT().HybridSearch(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, hsr *milvuspb.HybridSearchRequest) (*milvuspb.SearchResults, error) {
|
|
s.Equal(collectionName, hsr.GetCollectionName())
|
|
s.ElementsMatch([]string{partitionName}, hsr.GetPartitionNames())
|
|
s.ElementsMatch([]string{"*"}, hsr.GetOutputFields())
|
|
return &milvuspb.SearchResults{
|
|
Status: merr.Success(),
|
|
Results: &schemapb.SearchResultData{
|
|
NumQueries: 1,
|
|
TopK: 2,
|
|
FieldsData: []*schemapb.FieldData{
|
|
s.getInt64FieldData("ID", []int64{1, 2}),
|
|
s.getJSONBytesFieldData("$meta", [][]byte{
|
|
[]byte(`{"A": 123, "B": "456"}`),
|
|
[]byte(`{"B": "abc", "A": 456}`),
|
|
}, true),
|
|
},
|
|
Ids: &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{
|
|
Data: []int64{1, 2},
|
|
},
|
|
},
|
|
},
|
|
Scores: make([]float32, 2),
|
|
Topks: []int64{2},
|
|
},
|
|
}, nil
|
|
}).Once()
|
|
|
|
_, err := s.client.HybridSearch(ctx, NewHybridSearchOption(collectionName, 5, NewAnnRequest("vector", 10, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
|
|
return rand.Float32()
|
|
}))).WithFilter("ID > 100"), NewAnnRequest("vector", 10, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
|
|
return rand.Float32()
|
|
})))).WithConsistencyLevel(entity.ClStrong).WithPartitons(partitionName).WithReranker(NewRRFReranker()).WithOutputFields("*"))
|
|
s.NoError(err)
|
|
})
|
|
|
|
s.Run("failure", func() {
|
|
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
|
s.setupCache(collectionName, s.schemaDyn)
|
|
|
|
_, err := s.client.HybridSearch(ctx, NewHybridSearchOption(collectionName, 5, NewAnnRequest("vector", 10, nonSupportData{})))
|
|
s.Error(err)
|
|
|
|
s.mock.EXPECT().HybridSearch(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, hsr *milvuspb.HybridSearchRequest) (*milvuspb.SearchResults, error) {
|
|
return nil, merr.WrapErrServiceInternal("mocked")
|
|
}).Once()
|
|
|
|
_, err = s.client.HybridSearch(ctx, NewHybridSearchOption(collectionName, 5, NewAnnRequest("vector", 10, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
|
|
return rand.Float32()
|
|
}))).WithFilter("ID > 100"), NewAnnRequest("vector", 10, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
|
|
return rand.Float32()
|
|
})))))
|
|
s.Error(err)
|
|
})
|
|
}
|
|
|
|
func (s *ReadSuite) TestQuery() {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
s.Run("success", func() {
|
|
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
|
partitionName := fmt.Sprintf("part_%s", s.randString(6))
|
|
s.setupCache(collectionName, s.schema)
|
|
|
|
s.mock.EXPECT().Query(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, qr *milvuspb.QueryRequest) (*milvuspb.QueryResults, error) {
|
|
s.Equal(collectionName, qr.GetCollectionName())
|
|
|
|
return &milvuspb.QueryResults{}, nil
|
|
}).Once()
|
|
|
|
rs, err := s.client.Query(ctx, NewQueryOption(collectionName).WithPartitions(partitionName))
|
|
s.NoError(err)
|
|
s.NotNil(rs.sch)
|
|
})
|
|
|
|
s.Run("bad_request", func() {
|
|
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
|
s.setupCache(collectionName, s.schema)
|
|
|
|
_, err := s.client.Query(ctx, NewQueryOption(collectionName).WithFilter("id > {tmpl_id}").WithTemplateParam("tmpl_id", struct{}{}))
|
|
s.Error(err)
|
|
})
|
|
}
|
|
|
|
func TestRead(t *testing.T) {
|
|
suite.Run(t, new(ReadSuite))
|
|
}
|