diff --git a/client/milvusclient/iterator.go b/client/milvusclient/iterator.go index 7924e74118..4bf60ec6a8 100644 --- a/client/milvusclient/iterator.go +++ b/client/milvusclient/iterator.go @@ -20,10 +20,13 @@ import ( "context" "fmt" "io" + "strconv" + "strings" "github.com/cockroachdb/errors" "google.golang.org/grpc" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/client/v2/entity" "github.com/milvus-io/milvus/pkg/v2/util/merr" @@ -211,3 +214,231 @@ func (c *Client) SearchIterator(ctx context.Context, option SearchIteratorOption return newSearchIteratorV1(c) } + +// QueryIterator is the interface for query iterator. +type QueryIterator interface { + // Next returns next batch of iterator + // when iterator reaches the end, return `io.EOF`. + Next(ctx context.Context) (ResultSet, error) +} + +type queryIterator struct { + client *Client + option QueryIteratorOption + schema *entity.Schema + + // pagination state + expr string // base expression from option + outputFields []string // override output fields(force include pk field) + pkField *entity.Field + lastPK any + batchSize int + limit int64 + + // cached results + cached ResultSet +} + +// composeIteratorExpr builds the filter expression for pagination. +// It combines the user's original expression with a PK range filter. +func (it *queryIterator) composeIteratorExpr() string { + if it.lastPK == nil { + return it.expr + } + + expr := strings.TrimSpace(it.expr) + pkName := it.pkField.Name + + switch it.pkField.DataType { + case entity.FieldTypeInt64: + pkFilter := fmt.Sprintf("%s > %d", pkName, it.lastPK) + if len(expr) == 0 { + return pkFilter + } + return fmt.Sprintf("(%s) and %s", expr, pkFilter) + case entity.FieldTypeVarChar: + pkFilter := fmt.Sprintf(`%s > "%s"`, pkName, it.lastPK) + if len(expr) == 0 { + return pkFilter + } + return fmt.Sprintf(`(%s) and %s`, expr, pkFilter) + default: + return it.expr + } +} + +// fetchNextBatch fetches the next batch of data from the server. +func (it *queryIterator) fetchNextBatch(ctx context.Context) (ResultSet, error) { + req, err := it.option.Request() + if err != nil { + return ResultSet{}, err + } + + // override expression and limit for pagination + req.Expr = it.composeIteratorExpr() + req.OutputFields = it.outputFields + req.QueryParams = append(req.QueryParams, + &commonpb.KeyValuePair{Key: spLimit, Value: strconv.Itoa(it.batchSize)}, + ) + + var resultSet ResultSet + err = it.client.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.Query(ctx, req) + err = merr.CheckRPCCall(resp, err) + if err != nil { + return err + } + + columns, err := it.client.parseSearchResult(it.schema, resp.GetOutputFields(), resp.GetFieldsData(), 0, 0, -1) + if err != nil { + return err + } + resultSet = ResultSet{ + sch: it.schema, + Fields: columns, + } + if len(columns) > 0 { + resultSet.ResultCount = columns[0].Len() + } + + return nil + }) + + return resultSet, err +} + +// cacheNextBatch returns the next batch and updates the cache. +func (it *queryIterator) cacheNextBatch(rs ResultSet) (ResultSet, error) { + var result ResultSet + if rs.ResultCount > it.batchSize { + result = rs.Slice(0, it.batchSize) + it.cached = rs.Slice(it.batchSize, rs.ResultCount) + } else { + result = rs + it.cached = ResultSet{} + } + + if result.ResultCount == 0 { + return ResultSet{}, io.EOF + } + + // extract and update the last PK for pagination + pkColumn := result.GetColumn(it.pkField.Name) + if pkColumn == nil { + // try to find PK in Fields + for _, col := range result.Fields { + if col.Name() == it.pkField.Name { + pkColumn = col + break + } + } + } + + if pkColumn != nil && pkColumn.Len() > 0 { + pk, err := pkColumn.Get(pkColumn.Len() - 1) + if err != nil { + return ResultSet{}, errors.Wrapf(err, "failed to get last pk value") + } + it.lastPK = pk + } + + return result, nil +} + +// Next returns the next batch of results. +func (it *queryIterator) Next(ctx context.Context) (ResultSet, error) { + // limit reached, return EOF + if it.limit == 0 { + return ResultSet{}, io.EOF + } + + // if cache is empty, fetch new data + if it.cached.ResultCount == 0 { + rs, err := it.fetchNextBatch(ctx) + if err != nil { + return ResultSet{}, err + } + it.cached = rs + } + + // if still no data, return EOF + if it.cached.ResultCount == 0 { + return ResultSet{}, io.EOF + } + + result, err := it.cacheNextBatch(it.cached) + if err != nil { + return ResultSet{}, err + } + + // handle overall limit + if it.limit != Unlimited { + if int64(result.ResultCount) > it.limit { + result = result.Slice(0, int(it.limit)) + } + it.limit -= int64(result.ResultCount) + } + + return result, nil +} + +// newQueryIterator creates a new query iterator. +func newQueryIterator(ctx context.Context, client *Client, option QueryIteratorOption) (*queryIterator, error) { + req, err := option.Request() + if err != nil { + return nil, err + } + + collection, err := client.getCollection(ctx, req.GetCollectionName()) + if err != nil { + return nil, err + } + + pkField := collection.Schema.PKField() + if pkField == nil { + return nil, errors.New("primary key field not found in schema") + } + + // ensure PK field is included in output fields for pagination + outputFields := req.GetOutputFields() + hasPK := false + for _, f := range outputFields { + if f == pkField.Name { + hasPK = true + break + } + } + if !hasPK && len(outputFields) > 0 { + // modify the underlying option to include PK field + outputFields = append(outputFields, pkField.Name) + } + + iter := &queryIterator{ + client: client, + option: option, + schema: collection.Schema, + expr: req.GetExpr(), + outputFields: outputFields, + pkField: pkField, + batchSize: option.BatchSize(), + limit: option.Limit(), + } + + // init: fetch the first batch to validate parameters + rs, err := iter.fetchNextBatch(ctx) + if err != nil { + return nil, err + } + iter.cached = rs + + return iter, nil +} + +// QueryIterator creates a query iterator from a collection. +func (c *Client) QueryIterator(ctx context.Context, option QueryIteratorOption, callOptions ...grpc.CallOption) (QueryIterator, error) { + if err := option.ValidateParams(); err != nil { + return nil, err + } + + return newQueryIterator(ctx, c, option) +} diff --git a/client/milvusclient/iterator_option.go b/client/milvusclient/iterator_option.go index 3a75eb2888..1d2bd6f5ba 100644 --- a/client/milvusclient/iterator_option.go +++ b/client/milvusclient/iterator_option.go @@ -19,6 +19,7 @@ package milvusclient import ( "fmt" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/client/v2/entity" "github.com/milvus-io/milvus/client/v2/index" ) @@ -146,3 +147,99 @@ func NewSearchIteratorOption(collectionName string, vector entity.Vector) *searc iteratorLimit: Unlimited, } } + +// QueryIteratorOption is the interface for query iterator options. +type QueryIteratorOption interface { + // Request returns the query request when iterate query + Request() (*milvuspb.QueryRequest, error) + // BatchSize returns the batch size for each query iteration + BatchSize() int + // Limit returns the overall limit of entries to iterate + Limit() int64 + // ValidateParams performs the static params validation + ValidateParams() error +} + +type queryIteratorOption struct { + collectionName string + partitionNames []string + outputFields []string + expr string + batchSize int + iteratorLimit int64 + consistencyLevel entity.ConsistencyLevel + useDefaultConsistencyLevel bool +} + +func (opt *queryIteratorOption) Request() (*milvuspb.QueryRequest, error) { + return &milvuspb.QueryRequest{ + CollectionName: opt.collectionName, + PartitionNames: opt.partitionNames, + OutputFields: opt.outputFields, + Expr: opt.expr, + ConsistencyLevel: opt.consistencyLevel.CommonConsistencyLevel(), + UseDefaultConsistency: opt.useDefaultConsistencyLevel, + QueryParams: entity.MapKvPairs(map[string]string{IteratorKey: "true", "reduce_stop_for_best": "true"}), + }, nil +} + +func (opt *queryIteratorOption) BatchSize() int { + return opt.batchSize +} + +func (opt *queryIteratorOption) Limit() int64 { + return opt.iteratorLimit +} + +func (opt *queryIteratorOption) ValidateParams() error { + if opt.batchSize <= 0 { + return fmt.Errorf("batch size must be greater than 0") + } + return nil +} + +func (opt *queryIteratorOption) WithBatchSize(batchSize int) *queryIteratorOption { + opt.batchSize = batchSize + return opt +} + +func (opt *queryIteratorOption) WithPartitions(partitionNames ...string) *queryIteratorOption { + opt.partitionNames = partitionNames + return opt +} + +func (opt *queryIteratorOption) WithFilter(expr string) *queryIteratorOption { + opt.expr = expr + return opt +} + +func (opt *queryIteratorOption) WithOutputFields(fieldNames ...string) *queryIteratorOption { + opt.outputFields = fieldNames + return opt +} + +func (opt *queryIteratorOption) WithConsistencyLevel(consistencyLevel entity.ConsistencyLevel) *queryIteratorOption { + opt.consistencyLevel = consistencyLevel + opt.useDefaultConsistencyLevel = false + return opt +} + +// WithIteratorLimit sets the limit of entries to iterate +// if limit < 0, then it will be set to Unlimited +func (opt *queryIteratorOption) WithIteratorLimit(limit int64) *queryIteratorOption { + if limit < 0 { + limit = Unlimited + } + opt.iteratorLimit = limit + return opt +} + +func NewQueryIteratorOption(collectionName string) *queryIteratorOption { + return &queryIteratorOption{ + collectionName: collectionName, + batchSize: 1000, + iteratorLimit: Unlimited, + useDefaultConsistencyLevel: true, + consistencyLevel: entity.ClBounded, + } +} diff --git a/client/milvusclient/iterator_test.go b/client/milvusclient/iterator_test.go index b76763b34d..c1944f3555 100644 --- a/client/milvusclient/iterator_test.go +++ b/client/milvusclient/iterator_test.go @@ -416,3 +416,291 @@ func (s *SearchIteratorSuite) TestNextWithLimit() { func TestSearchIterator(t *testing.T) { suite.Run(t, new(SearchIteratorSuite)) } + +type QueryIteratorSuite struct { + MockSuiteBase + + schema *entity.Schema +} + +func (s *QueryIteratorSuite) SetupSuite() { + s.MockSuiteBase.SetupSuite() + s.schema = entity.NewSchema(). + WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)). + WithField(entity.NewField().WithName("Vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128)). + WithField(entity.NewField().WithName("Name").WithDataType(entity.FieldTypeVarChar).WithMaxLength(256)) +} + +func (s *QueryIteratorSuite) TestQueryIteratorInit() { + ctx := context.Background() + s.Run("success", func() { + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + + s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionID: 1, + Schema: s.schema.ProtoMessage(), + }, nil).Once() + s.mock.EXPECT().Query(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, qr *milvuspb.QueryRequest) (*milvuspb.QueryResults, error) { + s.Equal(collectionName, qr.GetCollectionName()) + return &milvuspb.QueryResults{ + Status: merr.Success(), + FieldsData: []*schemapb.FieldData{ + s.getInt64FieldData("ID", []int64{1, 2, 3}), + s.getVarcharFieldData("Name", []string{"a", "b", "c"}), + }, + }, nil + }).Once() + + iter, err := s.client.QueryIterator(ctx, NewQueryIteratorOption(collectionName). + WithOutputFields("ID", "Name"). + WithBatchSize(10)) + + s.NoError(err) + s.NotNil(iter) + }) + + s.Run("failure", func() { + s.Run("option_error", func() { + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + + _, err := s.client.QueryIterator(ctx, NewQueryIteratorOption(collectionName).WithBatchSize(-1)) + s.Error(err) + }) + + s.Run("describe_fail", func() { + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + + s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(nil, fmt.Errorf("mock error")).Once() + _, err := s.client.QueryIterator(ctx, NewQueryIteratorOption(collectionName)) + s.Error(err) + }) + + s.Run("query_fail", func() { + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + + s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionID: 1, + Schema: s.schema.ProtoMessage(), + }, nil).Once() + s.mock.EXPECT().Query(mock.Anything, mock.Anything).Return(nil, fmt.Errorf("mock query error")).Once() + + _, err := s.client.QueryIterator(ctx, NewQueryIteratorOption(collectionName)) + s.Error(err) + }) + }) +} + +func (s *QueryIteratorSuite) TestQueryIteratorNext() { + ctx := context.Background() + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + + s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionID: 1, + Schema: s.schema.ProtoMessage(), + }, nil).Once() + + // first query for init + s.mock.EXPECT().Query(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, qr *milvuspb.QueryRequest) (*milvuspb.QueryResults, error) { + s.Equal(collectionName, qr.GetCollectionName()) + return &milvuspb.QueryResults{ + Status: merr.Success(), + FieldsData: []*schemapb.FieldData{ + s.getInt64FieldData("ID", []int64{1, 2, 3}), + s.getVarcharFieldData("Name", []string{"a", "b", "c"}), + }, + }, nil + }).Once() + + iter, err := s.client.QueryIterator(ctx, NewQueryIteratorOption(collectionName). + WithOutputFields("ID", "Name"). + WithBatchSize(3)) + s.Require().NoError(err) + s.Require().NotNil(iter) + + // first Next should return cached data + rs, err := iter.Next(ctx) + s.NoError(err) + s.EqualValues(3, rs.ResultCount) + + // second query + s.mock.EXPECT().Query(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, qr *milvuspb.QueryRequest) (*milvuspb.QueryResults, error) { + s.Equal(collectionName, qr.GetCollectionName()) + // verify pagination expression contains PK filter + s.Contains(qr.GetExpr(), "ID > 3") + return &milvuspb.QueryResults{ + Status: merr.Success(), + FieldsData: []*schemapb.FieldData{ + s.getInt64FieldData("ID", []int64{4, 5}), + s.getVarcharFieldData("Name", []string{"d", "e"}), + }, + }, nil + }).Once() + + rs, err = iter.Next(ctx) + s.NoError(err) + s.EqualValues(2, rs.ResultCount) + + // third query - empty result + s.mock.EXPECT().Query(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, qr *milvuspb.QueryRequest) (*milvuspb.QueryResults, error) { + s.Equal(collectionName, qr.GetCollectionName()) + s.Contains(qr.GetExpr(), "ID > 5") + return &milvuspb.QueryResults{ + Status: merr.Success(), + FieldsData: []*schemapb.FieldData{}, + }, nil + }).Once() + + _, err = iter.Next(ctx) + s.Error(err) + s.ErrorIs(err, io.EOF) +} + +func (s *QueryIteratorSuite) TestQueryIteratorWithLimit() { + ctx := context.Background() + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + + s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionID: 1, + Schema: s.schema.ProtoMessage(), + }, nil).Once() + + // first query for init + s.mock.EXPECT().Query(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, qr *milvuspb.QueryRequest) (*milvuspb.QueryResults, error) { + return &milvuspb.QueryResults{ + Status: merr.Success(), + FieldsData: []*schemapb.FieldData{ + s.getInt64FieldData("ID", []int64{1, 2, 3, 4, 5}), + s.getVarcharFieldData("Name", []string{"a", "b", "c", "d", "e"}), + }, + }, nil + }).Once() + + iter, err := s.client.QueryIterator(ctx, NewQueryIteratorOption(collectionName). + WithOutputFields("ID", "Name"). + WithBatchSize(5). + WithIteratorLimit(7)) + s.Require().NoError(err) + s.Require().NotNil(iter) + + // first Next - returns 5 items + rs, err := iter.Next(ctx) + s.NoError(err) + s.EqualValues(5, rs.ResultCount) + + // second query + s.mock.EXPECT().Query(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, qr *milvuspb.QueryRequest) (*milvuspb.QueryResults, error) { + return &milvuspb.QueryResults{ + Status: merr.Success(), + FieldsData: []*schemapb.FieldData{ + s.getInt64FieldData("ID", []int64{6, 7, 8, 9, 10}), + s.getVarcharFieldData("Name", []string{"f", "g", "h", "i", "j"}), + }, + }, nil + }).Once() + + // second Next - returns only 2 items due to limit (7 - 5 = 2) + rs, err = iter.Next(ctx) + s.NoError(err) + s.EqualValues(2, rs.ResultCount, "should return sliced result due to limit") + + // third Next - limit reached, should return EOF + _, err = iter.Next(ctx) + s.Error(err) + s.ErrorIs(err, io.EOF, "limit reached, return EOF") +} + +func (s *QueryIteratorSuite) TestQueryIteratorWithVarCharPK() { + ctx := context.Background() + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + + schemaVarCharPK := entity.NewSchema(). + WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeVarChar).WithIsPrimaryKey(true).WithMaxLength(64)). + WithField(entity.NewField().WithName("Vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128)) + + s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionID: 1, + Schema: schemaVarCharPK.ProtoMessage(), + }, nil).Once() + + s.mock.EXPECT().Query(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, qr *milvuspb.QueryRequest) (*milvuspb.QueryResults, error) { + return &milvuspb.QueryResults{ + Status: merr.Success(), + FieldsData: []*schemapb.FieldData{ + s.getVarcharFieldData("ID", []string{"a", "b", "c"}), + }, + }, nil + }).Once() + + iter, err := s.client.QueryIterator(ctx, NewQueryIteratorOption(collectionName). + WithOutputFields("ID"). + WithBatchSize(3)) + s.Require().NoError(err) + s.Require().NotNil(iter) + + rs, err := iter.Next(ctx) + s.NoError(err) + s.EqualValues(3, rs.ResultCount) + + // second query - verify varchar PK filter + s.mock.EXPECT().Query(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, qr *milvuspb.QueryRequest) (*milvuspb.QueryResults, error) { + s.Contains(qr.GetExpr(), `ID > "c"`) + return &milvuspb.QueryResults{ + Status: merr.Success(), + FieldsData: []*schemapb.FieldData{}, + }, nil + }).Once() + + _, err = iter.Next(ctx) + s.Error(err) + s.ErrorIs(err, io.EOF) +} + +func (s *QueryIteratorSuite) TestQueryIteratorWithFilter() { + ctx := context.Background() + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + + s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionID: 1, + Schema: s.schema.ProtoMessage(), + }, nil).Once() + + s.mock.EXPECT().Query(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, qr *milvuspb.QueryRequest) (*milvuspb.QueryResults, error) { + s.Equal(`Name == "test"`, qr.GetExpr()) + return &milvuspb.QueryResults{ + Status: merr.Success(), + FieldsData: []*schemapb.FieldData{ + s.getInt64FieldData("ID", []int64{1, 2}), + s.getVarcharFieldData("Name", []string{"test", "test"}), + }, + }, nil + }).Once() + + iter, err := s.client.QueryIterator(ctx, NewQueryIteratorOption(collectionName). + WithFilter(`Name == "test"`). + WithOutputFields("ID", "Name"). + WithBatchSize(10)) + s.Require().NoError(err) + s.Require().NotNil(iter) + + rs, err := iter.Next(ctx) + s.NoError(err) + s.EqualValues(2, rs.ResultCount) + + // second query - filter combined with PK filter + s.mock.EXPECT().Query(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, qr *milvuspb.QueryRequest) (*milvuspb.QueryResults, error) { + s.Contains(qr.GetExpr(), `Name == "test"`) + s.Contains(qr.GetExpr(), "ID > 2") + return &milvuspb.QueryResults{ + Status: merr.Success(), + FieldsData: []*schemapb.FieldData{}, + }, nil + }).Once() + + _, err = iter.Next(ctx) + s.Error(err) + s.ErrorIs(err, io.EOF) +} + +func TestQueryIterator(t *testing.T) { + suite.Run(t, new(QueryIteratorSuite)) +} diff --git a/client/milvusclient/results.go b/client/milvusclient/results.go index 2fe0dcf9e4..08fb173c3a 100644 --- a/client/milvusclient/results.go +++ b/client/milvusclient/results.go @@ -59,7 +59,6 @@ func (rs ResultSet) Len() int { func (rs ResultSet) Slice(start, end int) ResultSet { result := ResultSet{ sch: rs.sch, - IDs: rs.IDs.Slice(start, end), Fields: lo.Map(rs.Fields, func(column column.Column, _ int) column.Column { return column.Slice(start, end) }), @@ -67,12 +66,26 @@ func (rs ResultSet) Slice(start, end int) ResultSet { Err: rs.Err, } + // Handle IDs - may be nil for Query results + if rs.IDs != nil { + result.IDs = rs.IDs.Slice(start, end) + result.ResultCount = result.IDs.Len() + } else if len(result.Fields) > 0 { + result.ResultCount = result.Fields[0].Len() + } + if rs.GroupByValue != nil { result.GroupByValue = rs.GroupByValue.Slice(start, end) } - result.ResultCount = result.IDs.Len() - result.Scores = rs.Scores[start : start+result.ResultCount] + // Handle Scores - may be nil or empty for Query results + if len(rs.Scores) > 0 && result.ResultCount > 0 { + scoreEnd := start + result.ResultCount + if scoreEnd > len(rs.Scores) { + scoreEnd = len(rs.Scores) + } + result.Scores = rs.Scores[start:scoreEnd] + } return result } diff --git a/tests/go_client/common/response_checker.go b/tests/go_client/common/response_checker.go index e7cc2fa244..4c77cc9ad8 100644 --- a/tests/go_client/common/response_checker.go +++ b/tests/go_client/common/response_checker.go @@ -281,7 +281,7 @@ func WithExpOutputFields(expOutputFields []string) CheckIteratorOption { } } -// check queryIterator: result limit, each batch size, output fields +// check searchIterator: result limit, each batch size, output fields func CheckSearchIteratorResult(ctx context.Context, t *testing.T, itr client.SearchIterator, expLimit int, opts ...CheckIteratorOption) { opt := &checkIteratorOpt{} for _, o := range opts { @@ -319,6 +319,44 @@ func CheckSearchIteratorResult(ctx context.Context, t *testing.T, itr client.Sea } } +// check queryIterator: result limit, each batch size, output fields +func CheckQueryIteratorResult(ctx context.Context, t *testing.T, itr client.QueryIterator, expLimit int, opts ...CheckIteratorOption) { + opt := &checkIteratorOpt{} + for _, o := range opts { + o(opt) + } + actualLimit := 0 + var actualBatchSize []int + for { + rs, err := itr.Next(ctx) + if err != nil { + if err == io.EOF { + break + } else { + log.Error("QueryIterator next gets error", zap.Error(err)) + break + } + } + + if opt.expBatchSize != nil { + actualBatchSize = append(actualBatchSize, rs.ResultCount) + } + var actualOutputFields []string + if opt.expOutputFields != nil { + for _, column := range rs.Fields { + actualOutputFields = append(actualOutputFields, column.Name()) + } + require.ElementsMatch(t, opt.expOutputFields, actualOutputFields) + } + actualLimit = actualLimit + rs.ResultCount + } + require.Equal(t, expLimit, actualLimit) + if opt.expBatchSize != nil { + log.Debug("QueryIterator result len", zap.Any("result len", actualBatchSize)) + require.True(t, EqualIntSlice(opt.expBatchSize, actualBatchSize)) + } +} + // check expected columns should be contains in actual columns func CheckPartialResult(t *testing.T, expColumns []column.Column, actualColumns []column.Column) { for _, expColumn := range expColumns { diff --git a/tests/go_client/testcases/query_iterator_test.go b/tests/go_client/testcases/query_iterator_test.go new file mode 100644 index 0000000000..bbf90880ba --- /dev/null +++ b/tests/go_client/testcases/query_iterator_test.go @@ -0,0 +1,425 @@ +package testcases + +import ( + "fmt" + "io" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/client/v2/entity" + client "github.com/milvus-io/milvus/client/v2/milvusclient" + "github.com/milvus-io/milvus/pkg/v2/log" + "github.com/milvus-io/milvus/tests/go_client/common" + hp "github.com/milvus-io/milvus/tests/go_client/testcases/helper" +) + +// TestQueryIteratorDefault tests query iterator with default parameters +func TestQueryIteratorDefault(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + // create -> insert -> flush -> index -> load + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64Vec), hp.TNewFieldsOption(), hp.TNewSchemaOption().TWithEnableDynamicField(true), + hp.TWithConsistencyLevel(entity.ClStrong)) + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(common.DefaultNb)) + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(common.DefaultNb*2).TWithStart(common.DefaultNb)) + prepare.FlushData(ctx, t, mc, schema.CollectionName) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + // query iterator with default batch + itr, err := mc.QueryIterator(ctx, client.NewQueryIteratorOption(schema.CollectionName)) + common.CheckErr(t, err, true) + common.CheckQueryIteratorResult(ctx, t, itr, common.DefaultNb*3, common.WithExpBatchSize(hp.GenBatchSizes(common.DefaultNb*3, common.DefaultBatchSize))) +} + +// TestQueryIteratorHitEmpty tests query iterator on empty collection +func TestQueryIteratorHitEmpty(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + // create -> index -> load (no data inserted) + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64Vec), hp.TNewFieldsOption(), hp.TNewSchemaOption().TWithEnableDynamicField(true), + hp.TWithConsistencyLevel(entity.ClStrong)) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + // query iterator with default batch + itr, err := mc.QueryIterator(ctx, client.NewQueryIteratorOption(schema.CollectionName)) + common.CheckErr(t, err, true) + rs, err := itr.Next(ctx) + require.Empty(t, rs.Fields) + require.ErrorIs(t, err, io.EOF) + common.CheckQueryIteratorResult(ctx, t, itr, 0, common.WithExpBatchSize(hp.GenBatchSizes(0, common.DefaultBatchSize))) +} + +// TestQueryIteratorBatchSize tests query iterator with different batch sizes +func TestQueryIteratorBatchSize(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + // create -> insert -> flush -> index -> load + nb := 201 + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64Vec), hp.TNewFieldsOption(), hp.TNewSchemaOption().TWithEnableDynamicField(true), + hp.TWithConsistencyLevel(entity.ClStrong)) + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(nb)) + prepare.FlushData(ctx, t, mc, schema.CollectionName) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + type batchStruct struct { + batch int + expBatchSize []int + } + batchStructs := []batchStruct{ + {batch: nb / 2, expBatchSize: hp.GenBatchSizes(nb, nb/2)}, + {batch: nb, expBatchSize: hp.GenBatchSizes(nb, nb)}, + {batch: nb + 1, expBatchSize: hp.GenBatchSizes(nb, nb+1)}, + } + + for _, _batchStruct := range batchStructs { + // query iterator with different batch sizes + itr, err := mc.QueryIterator(ctx, client.NewQueryIteratorOption(schema.CollectionName).WithBatchSize(_batchStruct.batch)) + common.CheckErr(t, err, true) + common.CheckQueryIteratorResult(ctx, t, itr, nb, common.WithExpBatchSize(_batchStruct.expBatchSize)) + } +} + +// TestQueryIteratorOutputAllFields tests query iterator with all fields output +func TestQueryIteratorOutputAllFields(t *testing.T) { + t.Parallel() + for _, dynamic := range [2]bool{false, true} { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.AllFields), hp.TNewFieldsOption(), + hp.TNewSchemaOption().TWithEnableDynamicField(dynamic), hp.TWithConsistencyLevel(entity.ClStrong)) + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption()) + prepare.FlushData(ctx, t, mc, schema.CollectionName) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + var allFieldsName []string + for _, field := range schema.Fields { + allFieldsName = append(allFieldsName, field.Name) + } + if dynamic { + allFieldsName = append(allFieldsName, common.DefaultDynamicFieldName) + } + + // output * fields + nbFilter := 1001 + batch := 500 + expr := fmt.Sprintf("%s < %d", common.DefaultInt64FieldName, nbFilter) + + itr, err := mc.QueryIterator(ctx, client.NewQueryIteratorOption(schema.CollectionName).WithBatchSize(batch).WithOutputFields("*").WithFilter(expr)) + common.CheckErr(t, err, true) + common.CheckQueryIteratorResult(ctx, t, itr, nbFilter, common.WithExpBatchSize(hp.GenBatchSizes(nbFilter, batch)), common.WithExpOutputFields(allFieldsName)) + } +} + +// TestQueryIteratorSparseVecFields tests query iterator with sparse vector fields +func TestQueryIteratorSparseVecFields(t *testing.T) { + t.Parallel() + for _, withRows := range [2]bool{true, false} { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64VarcharSparseVec), hp.TNewFieldsOption(), + hp.TNewSchemaOption().TWithEnableDynamicField(true), hp.TWithConsistencyLevel(entity.ClStrong)) + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema).TWithIsRows(withRows), hp.TNewDataOption()) + prepare.FlushData(ctx, t, mc, schema.CollectionName) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + fieldsName := []string{common.DefaultDynamicFieldName} + for _, field := range schema.Fields { + fieldsName = append(fieldsName, field.Name) + } + + // output * fields + itr, err := mc.QueryIterator(ctx, client.NewQueryIteratorOption(schema.CollectionName).WithBatchSize(400).WithOutputFields("*")) + common.CheckErr(t, err, true) + common.CheckQueryIteratorResult(ctx, t, itr, common.DefaultNb, common.WithExpBatchSize(hp.GenBatchSizes(common.DefaultNb, 400)), common.WithExpOutputFields(fieldsName)) + } +} + +// TestQueryIteratorInvalid tests query iterator with invalid parameters +func TestQueryIteratorInvalid(t *testing.T) { + nb := 201 + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + // create -> insert -> flush -> index -> load + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64Vec), hp.TNewFieldsOption(), hp.TNewSchemaOption(), + hp.TWithConsistencyLevel(entity.ClStrong)) + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(nb)) + prepare.FlushData(ctx, t, mc, schema.CollectionName) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + // query iterator with not exist collection name + _, err := mc.QueryIterator(ctx, client.NewQueryIteratorOption(common.GenRandomString("c", 5))) + common.CheckErr(t, err, false, "collection not found", "can't find collection") + + // query iterator with not exist partition name + _, errPar := mc.QueryIterator(ctx, client.NewQueryIteratorOption(schema.CollectionName).WithPartitions(common.GenRandomString("p", 5))) + common.CheckErr(t, errPar, false, "partition name", "not found") + + // query iterator with not exist partition name + _, errPar = mc.QueryIterator(ctx, client.NewQueryIteratorOption(schema.CollectionName).WithPartitions(common.GenRandomString("p", 5), common.DefaultPartition)) + common.CheckErr(t, errPar, false, "partition name", "not found") + + // query iterator with count(*) + _, errOutput := mc.QueryIterator(ctx, client.NewQueryIteratorOption(schema.CollectionName).WithOutputFields(common.QueryCountFieldName)) + common.CheckErr(t, errOutput, false, "count entities with pagination is not allowed", "count(*)") + + // query iterator with invalid batch size + for _, batch := range []int{-1, 0} { + _, err := mc.QueryIterator(ctx, client.NewQueryIteratorOption(schema.CollectionName).WithBatchSize(batch)) + common.CheckErr(t, err, false, "batch size", "must be greater than 0", "cannot less than 1") + } +} + +// TestQueryIteratorInvalidExpr tests query iterator with invalid expressions +func TestQueryIteratorInvalidExpr(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + // create -> insert -> flush -> index -> load + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64VecJSON), hp.TNewFieldsOption(), hp.TNewSchemaOption().TWithEnableDynamicField(true), + hp.TWithConsistencyLevel(entity.ClStrong)) + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(common.DefaultNb)) + prepare.FlushData(ctx, t, mc, schema.CollectionName) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + for _, _invalidExprs := range common.InvalidExpressions { + t.Log(_invalidExprs) + _, err := mc.QueryIterator(ctx, client.NewQueryIteratorOption(schema.CollectionName).WithFilter(_invalidExprs.Expr)) + common.CheckErr(t, err, _invalidExprs.ErrNil, _invalidExprs.ErrMsg, "") + } +} + +// TestQueryIteratorOutputFieldDynamic tests query iterator with non-existed field when dynamic enabled or not +func TestQueryIteratorOutputFieldDynamic(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + nb := 201 + + for _, dynamic := range [2]bool{true, false} { + // create -> insert -> flush -> index -> load + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64Vec), hp.TNewFieldsOption(), + hp.TNewSchemaOption().TWithEnableDynamicField(dynamic), hp.TWithConsistencyLevel(entity.ClStrong)) + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(nb)) + prepare.FlushData(ctx, t, mc, schema.CollectionName) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + // query iterator with not existed output fields: if dynamic, non-existent field are equivalent to dynamic field + itr, errOutput := mc.QueryIterator(ctx, client.NewQueryIteratorOption(schema.CollectionName).WithOutputFields("aaa")) + if dynamic { + common.CheckErr(t, errOutput, true) + expFields := []string{common.DefaultInt64FieldName, "aaa"} + common.CheckQueryIteratorResult(ctx, t, itr, nb, common.WithExpBatchSize(hp.GenBatchSizes(nb, common.DefaultBatchSize)), common.WithExpOutputFields(expFields)) + } else { + common.CheckErr(t, errOutput, false, "field aaa not exist", "field not exist") + } + } +} + +// TestQueryIteratorExpr tests query iterator with various expressions +func TestQueryIteratorExpr(t *testing.T) { + type exprCount struct { + expr string + count int + } + capacity := common.TestCapacity + exprLimits := []exprCount{ + {expr: fmt.Sprintf("%s in [0, 1, 2]", common.DefaultInt64FieldName), count: 3}, + {expr: fmt.Sprintf("%s >= 1000 || %s > 2000", common.DefaultInt64FieldName, common.DefaultInt64FieldName), count: 2000}, + {expr: fmt.Sprintf("%s >= 1000 and %s < 2000", common.DefaultInt64FieldName, common.DefaultInt64FieldName), count: 1000}, + + // json and dynamic field filter expr: == < in bool/ list/ int + // {expr: fmt.Sprintf("%s['number'] == 0", common.DefaultJSONFieldName), count: 1500}, + // {expr: fmt.Sprintf("%s['number'] < 100 and %s['number'] != 0", common.DefaultJSONFieldName, common.DefaultJSONFieldName), count: 99}, + {expr: fmt.Sprintf("%s < 100", common.DefaultDynamicNumberField), count: 100}, + {expr: "dynamicNumber % 2 == 0", count: 1500}, + {expr: fmt.Sprintf("%s == false", common.DefaultDynamicBoolField), count: 1500}, + {expr: fmt.Sprintf("%s in ['1', '2'] ", common.DefaultDynamicStringField), count: 2}, + {expr: fmt.Sprintf("%s['string'] in ['1', '2', '5'] ", common.DefaultJSONFieldName), count: 3}, + {expr: fmt.Sprintf("%s['list'] == [1, 2] ", common.DefaultJSONFieldName), count: 1}, + {expr: fmt.Sprintf("%s['list'][0] < 10 ", common.DefaultJSONFieldName), count: 10}, + {expr: fmt.Sprintf("%s[\"dynamicList\"] != [2, 3]", common.DefaultDynamicFieldName), count: 0}, + + // json contains + {expr: fmt.Sprintf("json_contains (%s['list'], 2)", common.DefaultJSONFieldName), count: 1}, + {expr: fmt.Sprintf("json_contains (%s['number'], 0)", common.DefaultJSONFieldName), count: 0}, + {expr: fmt.Sprintf("JSON_CONTAINS_ANY (%s['list'], [1, 3])", common.DefaultJSONFieldName), count: 2}, + // string like + {expr: "dynamicString like '1%' ", count: 1111}, + + // key exist + {expr: fmt.Sprintf("exists %s['list']", common.DefaultJSONFieldName), count: common.DefaultNb}, + {expr: "exists a ", count: 0}, + {expr: fmt.Sprintf("exists %s ", common.DefaultDynamicStringField), count: common.DefaultNb}, + + // data type not match and no error + {expr: fmt.Sprintf("%s['number'] == '0' ", common.DefaultJSONFieldName), count: 0}, + + // json field + {expr: fmt.Sprintf("%s >= 1500", common.DefaultJSONFieldName), count: 1500}, // json >= 1500 + {expr: fmt.Sprintf("%s > 1499.5", common.DefaultJSONFieldName), count: 1500}, // json >= 1500.0 + {expr: fmt.Sprintf("%s like '21%%'", common.DefaultJSONFieldName), count: 100}, // json like '21%' + {expr: fmt.Sprintf("%s == [1503, 1504]", common.DefaultJSONFieldName), count: 1}, // json == [1,2] + {expr: fmt.Sprintf("%s[0] > 1", common.DefaultJSONFieldName), count: 1500}, // json[0] > 1 + {expr: fmt.Sprintf("%s[0][0] > 1", common.DefaultJSONFieldName), count: 0}, // json == [1,2] + {expr: fmt.Sprintf("%s[0] == false", common.DefaultBoolArrayField), count: common.DefaultNb / 2}, // array[0] == + {expr: fmt.Sprintf("%s[0] > 0", common.DefaultInt64ArrayField), count: common.DefaultNb - 1}, // array[0] > + {expr: fmt.Sprintf("array_contains (%s, %d)", common.DefaultInt16ArrayField, capacity), count: capacity}, // array_contains(array, 1) + {expr: fmt.Sprintf("json_contains (%s, 1)", common.DefaultInt32ArrayField), count: 2}, // json_contains(array, 1) + {expr: fmt.Sprintf("array_contains (%s, 1000000)", common.DefaultInt32ArrayField), count: 0}, // array_contains(array, 1) + {expr: fmt.Sprintf("json_contains_all (%s, [90, 91])", common.DefaultInt64ArrayField), count: 91}, // json_contains_all(array, [x]) + {expr: fmt.Sprintf("json_contains_any (%s, [0, 100, 10])", common.DefaultFloatArrayField), count: 101}, // json_contains_any (array, [x]) + {expr: fmt.Sprintf("%s == [0, 1]", common.DefaultDoubleArrayField), count: 0}, // array == + {expr: fmt.Sprintf("array_length(%s) == %d", common.DefaultDoubleArrayField, capacity), count: common.DefaultNb}, // array_length + } + + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + // create -> insert -> flush -> index -> load + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.AllFields), hp.TNewFieldsOption(), + hp.TNewSchemaOption().TWithEnableDynamicField(true), hp.TWithConsistencyLevel(entity.ClStrong)) + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption()) + prepare.FlushData(ctx, t, mc, schema.CollectionName) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + batch := 500 + for _, exprLimit := range exprLimits { + rs, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(exprLimit.expr).WithOutputFields("count(*)")) + common.CheckErr(t, err, true) + expectCount, err := rs.GetColumn("count(*)").GetAsInt64(0) + common.CheckErr(t, err, true) + + log.Info("case expr is", zap.String("expr", exprLimit.expr), zap.Int64("expectedCount", expectCount)) + itr, err := mc.QueryIterator(ctx, client.NewQueryIteratorOption(schema.CollectionName).WithBatchSize(batch).WithFilter(exprLimit.expr)) + common.CheckErr(t, err, true) + common.CheckQueryIteratorResult(ctx, t, itr, int(expectCount), common.WithExpBatchSize(hp.GenBatchSizes(int(expectCount), batch))) + } +} + +// TestQueryIteratorPartitions tests query iterator with partition filtering +func TestQueryIteratorPartitions(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + // create collection + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64Vec), hp.TNewFieldsOption(), hp.TNewSchemaOption(), + hp.TWithConsistencyLevel(entity.ClStrong)) + + // create partition + pName := "p1" + err := mc.CreatePartition(ctx, client.NewCreatePartitionOption(schema.CollectionName, pName)) + common.CheckErr(t, err, true) + + // insert [0, nb) into partition: _default + nb := 1500 + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(nb)) + + // insert [nb, nb*2) into partition: p1 + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema).TWithPartitionName(pName), hp.TNewDataOption().TWithNb(nb).TWithStart(nb)) + + prepare.FlushData(ctx, t, mc, schema.CollectionName) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + // query iterator with partition + expr := fmt.Sprintf("%s < %d", common.DefaultInt64FieldName, nb) + mParLimit := map[string]int{ + common.DefaultPartition: nb, + pName: 0, + } + for par, limit := range mParLimit { + itr, err := mc.QueryIterator(ctx, client.NewQueryIteratorOption(schema.CollectionName).WithFilter(expr).WithPartitions(par)) + common.CheckErr(t, err, true) + common.CheckQueryIteratorResult(ctx, t, itr, limit, common.WithExpBatchSize(hp.GenBatchSizes(limit, common.DefaultBatchSize))) + } +} + +// TestQueryIteratorWithLimit tests query iterator with limit +func TestQueryIteratorWithLimit(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + // create -> insert -> flush -> index -> load + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64Vec), hp.TNewFieldsOption(), hp.TNewSchemaOption().TWithEnableDynamicField(true), + hp.TWithConsistencyLevel(entity.ClStrong)) + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(common.DefaultNb*2)) + prepare.FlushData(ctx, t, mc, schema.CollectionName) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + // query iterator with limit + limit := int64(2000) + batch := 500 + itr, err := mc.QueryIterator(ctx, client.NewQueryIteratorOption(schema.CollectionName).WithIteratorLimit(limit).WithBatchSize(batch)) + common.CheckErr(t, err, true) + common.CheckQueryIteratorResult(ctx, t, itr, int(limit), common.WithExpBatchSize(hp.GenBatchSizes(int(limit), batch))) +} + +// TestQueryIteratorGrowing tests query iterator on growing segments +func TestQueryIteratorGrowing(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + // create -> index -> load -> insert (growing) + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64Vec), hp.TNewFieldsOption(), hp.TNewSchemaOption(), + hp.TWithConsistencyLevel(entity.ClStrong)) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(common.DefaultNb*2)) + + // query iterator growing + limit := int64(1000) + itr, err := mc.QueryIterator(ctx, client.NewQueryIteratorOption(schema.CollectionName).WithIteratorLimit(limit).WithBatchSize(100)) + common.CheckErr(t, err, true) + common.CheckQueryIteratorResult(ctx, t, itr, int(limit), common.WithExpBatchSize(hp.GenBatchSizes(int(limit), 100))) +} + +// TestQueryIteratorConsistencyLevel tests query iterator with different consistency levels +func TestQueryIteratorConsistencyLevel(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + // create -> insert -> flush -> index -> load + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64Vec), hp.TNewFieldsOption(), hp.TNewSchemaOption(), + hp.TWithConsistencyLevel(entity.ClStrong)) + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(common.DefaultNb)) + prepare.FlushData(ctx, t, mc, schema.CollectionName) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + // query iterator with different consistency levels + for _, cl := range []entity.ConsistencyLevel{entity.ClStrong, entity.ClBounded, entity.ClEventually} { + itr, err := mc.QueryIterator(ctx, client.NewQueryIteratorOption(schema.CollectionName).WithConsistencyLevel(cl).WithBatchSize(500)) + common.CheckErr(t, err, true) + actualLimit := 0 + for { + rs, err := itr.Next(ctx) + if err != nil { + if err == io.EOF { + break + } + log.Error("QueryIterator next gets error", zap.Error(err)) + break + } + actualLimit = actualLimit + rs.ResultCount + } + require.LessOrEqual(t, actualLimit, common.DefaultNb) + } +}