diff --git a/client/common/version.go b/client/common/version.go index c85bd417c6..882fd48d91 100644 --- a/client/common/version.go +++ b/client/common/version.go @@ -18,5 +18,5 @@ package common const ( // SDKVersion const value for current version - SDKVersion = `2.5.5` + SDKVersion = `2.5.6` ) diff --git a/client/milvusclient/iterator.go b/client/milvusclient/iterator.go new file mode 100644 index 0000000000..7924e74118 --- /dev/null +++ b/client/milvusclient/iterator.go @@ -0,0 +1,213 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package milvusclient + +import ( + "context" + "fmt" + "io" + + "github.com/cockroachdb/errors" + "google.golang.org/grpc" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/client/v2/entity" + "github.com/milvus-io/milvus/pkg/v2/util/merr" +) + +const ( + // IteratorKey is the const search param key in indicating enabling iterator. + IteratorKey = "iterator" + IteratorSessionTsKey = "iterator_session_ts" + IteratorSearchV2Key = "search_iter_v2" + IteratorSearchBatchSizeKey = "search_iter_batch_size" + IteratorSearchLastBoundKey = "search_iter_last_bound" + IteratorSearchIDKey = "search_iter_id" + CollectionIDKey = `collection_id` + + // Unlimited + Unlimited int64 = -1 +) + +var ErrServerVersionIncompatible = errors.New("server version incompatible") + +// SearchIterator is the interface for search iterator. +type SearchIterator interface { + // Next returns next batch of iterator + // when iterator reaches the end, return `io.EOF`. + Next(ctx context.Context) (ResultSet, error) +} + +type searchIteratorV2 struct { + client *Client + option SearchIteratorOption + schema *entity.Schema + limit int64 +} + +func (it *searchIteratorV2) Next(ctx context.Context) (ResultSet, error) { + // limit reached, return EOF + if it.limit == 0 { + return ResultSet{}, io.EOF + } + + rs, err := it.next(ctx) + if err != nil { + return rs, err + } + + if it.limit == Unlimited { + return rs, err + } + + if int64(rs.Len()) > it.limit { + rs = rs.Slice(0, int(it.limit)) + } + it.limit -= int64(rs.Len()) + return rs, nil +} + +func (it *searchIteratorV2) next(ctx context.Context) (ResultSet, error) { + opt := it.option.SearchOption() + req, err := opt.Request() + if err != nil { + return ResultSet{}, err + } + + var rs ResultSet + + err = it.client.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.Search(ctx, req) + err = merr.CheckRPCCall(resp, err) + if err != nil { + return err + } + + iteratorInfo := resp.GetResults().GetSearchIteratorV2Results() + opt.annRequest.WithSearchParam(IteratorSearchIDKey, iteratorInfo.GetToken()) + opt.annRequest.WithSearchParam(IteratorSearchLastBoundKey, fmt.Sprintf("%v", iteratorInfo.GetLastBound())) + + resultSets, err := it.client.handleSearchResult(it.schema, req.GetOutputFields(), int(resp.GetResults().GetNumQueries()), resp) + if err != nil { + return err + } + rs = resultSets[0] + + if rs.IDs.Len() == 0 { + return io.EOF + } + + return nil + }) + return rs, err +} + +func (it *searchIteratorV2) setupCollectionID(ctx context.Context) error { + opt := it.option.SearchOption() + + return it.client.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{ + CollectionName: opt.collectionName, + }) + if merr.CheckRPCCall(resp, err) != nil { + return err + } + + opt.WithSearchParam(CollectionIDKey, fmt.Sprintf("%d", resp.GetCollectionID())) + schema := &entity.Schema{} + it.schema = schema.ReadProto(resp.GetSchema()) + return nil + }) +} + +// probeCompatiblity checks if the server support SearchIteratorV2. +// It checks if the search result contains search iterator v2 results info and token. +func (it *searchIteratorV2) probeCompatiblity(ctx context.Context) error { + opt := it.option.SearchOption() + opt.annRequest.topK = 1 // ok to leave it here, will be overwritten in next iteration + opt.annRequest.WithSearchParam(IteratorSearchBatchSizeKey, "1") + req, err := opt.Request() + if err != nil { + return err + } + return it.client.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.Search(ctx, req) + err = merr.CheckRPCCall(resp, err) + if err != nil { + return err + } + + if resp.GetResults().GetSearchIteratorV2Results() == nil || resp.GetResults().GetSearchIteratorV2Results().GetToken() == "" { + return ErrServerVersionIncompatible + } + return nil + }) +} + +// newSearchIteratorV2 creates a new search iterator V2. +// +// It sets up the collection ID and checks if the server supports search iterator V2. +// If the server does not support search iterator V2, it returns an error. +func newSearchIteratorV2(ctx context.Context, client *Client, option SearchIteratorOption) (*searchIteratorV2, error) { + iter := &searchIteratorV2{ + client: client, + option: option, + limit: option.Limit(), + } + if err := iter.setupCollectionID(ctx); err != nil { + return nil, err + } + + if err := iter.probeCompatiblity(ctx); err != nil { + return nil, err + } + + return iter, nil +} + +type searchIteratorV1 struct { + client *Client +} + +func (s *searchIteratorV1) Next(_ context.Context) (ResultSet, error) { + return ResultSet{}, errors.New("not implemented") +} + +func newSearchIteratorV1(_ *Client) (*searchIteratorV1, error) { + // search iterator v1 is not supported + return nil, ErrServerVersionIncompatible +} + +// SearchIterator creates a search iterator from a collection. +// +// If the server supports search iterator V2, it creates a search iterator V2. +func (c *Client) SearchIterator(ctx context.Context, option SearchIteratorOption, callOptions ...grpc.CallOption) (SearchIterator, error) { + if err := option.ValidateParams(); err != nil { + return nil, err + } + + iter, err := newSearchIteratorV2(ctx, c, option) + if err == nil { + return iter, nil + } + + if !errors.Is(err, ErrServerVersionIncompatible) { + return nil, err + } + + return newSearchIteratorV1(c) +} diff --git a/client/milvusclient/iterator_option.go b/client/milvusclient/iterator_option.go new file mode 100644 index 0000000000..3a75eb2888 --- /dev/null +++ b/client/milvusclient/iterator_option.go @@ -0,0 +1,148 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package milvusclient + +import ( + "fmt" + + "github.com/milvus-io/milvus/client/v2/entity" + "github.com/milvus-io/milvus/client/v2/index" +) + +type SearchIteratorOption interface { + // SearchOption returns the search option when iterate search + SearchOption() *searchOption + // Limit returns the overall limit of entries to iterate + Limit() int64 + // ValidateParams performs the static params validation + ValidateParams() error +} + +type searchIteratorOption struct { + *searchOption + batchSize int + iteratorLimit int64 +} + +func (opt *searchIteratorOption) SearchOption() *searchOption { + opt.annRequest.topK = opt.batchSize + opt.WithSearchParam(IteratorSearchBatchSizeKey, fmt.Sprintf("%d", opt.batchSize)) + return opt.searchOption +} + +func (opt *searchIteratorOption) Limit() int64 { + return opt.iteratorLimit +} + +// ValidateParams performs the static params validation +func (opt *searchIteratorOption) ValidateParams() error { + if opt.batchSize <= 0 { + return fmt.Errorf("batch size must be greater than 0") + } + return nil +} + +func (opt *searchIteratorOption) WithBatchSize(batchSize int) *searchIteratorOption { + opt.batchSize = batchSize + return opt +} + +func (opt *searchIteratorOption) WithPartitions(partitionNames ...string) *searchIteratorOption { + opt.partitionNames = partitionNames + return opt +} + +func (opt *searchIteratorOption) WithFilter(expr string) *searchIteratorOption { + opt.annRequest.WithFilter(expr) + return opt +} + +func (opt *searchIteratorOption) WithTemplateParam(key string, val any) *searchIteratorOption { + opt.annRequest.WithTemplateParam(key, val) + return opt +} + +func (opt *searchIteratorOption) WithOffset(offset int) *searchIteratorOption { + opt.annRequest.WithOffset(offset) + return opt +} + +func (opt *searchIteratorOption) WithOutputFields(fieldNames ...string) *searchIteratorOption { + opt.outputFields = fieldNames + return opt +} + +func (opt *searchIteratorOption) WithConsistencyLevel(consistencyLevel entity.ConsistencyLevel) *searchIteratorOption { + opt.consistencyLevel = consistencyLevel + opt.useDefaultConsistencyLevel = false + return opt +} + +func (opt *searchIteratorOption) WithANNSField(annsField string) *searchIteratorOption { + opt.annRequest.WithANNSField(annsField) + return opt +} + +func (opt *searchIteratorOption) WithGroupByField(groupByField string) *searchIteratorOption { + opt.annRequest.WithGroupByField(groupByField) + return opt +} + +func (opt *searchIteratorOption) WithGroupSize(groupSize int) *searchIteratorOption { + opt.annRequest.WithGroupSize(groupSize) + return opt +} + +func (opt *searchIteratorOption) WithStrictGroupSize(strictGroupSize bool) *searchIteratorOption { + opt.annRequest.WithStrictGroupSize(strictGroupSize) + return opt +} + +func (opt *searchIteratorOption) WithIgnoreGrowing(ignoreGrowing bool) *searchIteratorOption { + opt.annRequest.WithIgnoreGrowing(ignoreGrowing) + return opt +} + +func (opt *searchIteratorOption) WithAnnParam(ap index.AnnParam) *searchIteratorOption { + opt.annRequest.WithAnnParam(ap) + return opt +} + +func (opt *searchIteratorOption) WithSearchParam(key, value string) *searchIteratorOption { + opt.annRequest.WithSearchParam(key, value) + return opt +} + +// WithIteratorLimit sets the limit of entries to iterate +// if limit < 0, then it will be set to Unlimited +func (opt *searchIteratorOption) WithIteratorLimit(limit int64) *searchIteratorOption { + if limit < 0 { + limit = Unlimited + } + opt.iteratorLimit = limit + return opt +} + +func NewSearchIteratorOption(collectionName string, vector entity.Vector) *searchIteratorOption { + return &searchIteratorOption{ + searchOption: NewSearchOption(collectionName, 1000, []entity.Vector{vector}). + WithSearchParam(IteratorKey, "true"). + WithSearchParam(IteratorSearchV2Key, "true"), + batchSize: 1000, + iteratorLimit: Unlimited, + } +} diff --git a/client/milvusclient/iterator_test.go b/client/milvusclient/iterator_test.go new file mode 100644 index 0000000000..b76763b34d --- /dev/null +++ b/client/milvusclient/iterator_test.go @@ -0,0 +1,418 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package milvusclient + +import ( + "context" + "fmt" + "io" + "math/rand" + "testing" + + "github.com/samber/lo" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/client/v2/entity" + "github.com/milvus-io/milvus/pkg/v2/util/merr" +) + +type SearchIteratorSuite struct { + MockSuiteBase + + schema *entity.Schema +} + +func (s *SearchIteratorSuite) SetupSuite() { + s.MockSuiteBase.SetupSuite() + s.schema = entity.NewSchema(). + WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)). + WithField(entity.NewField().WithName("Vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128)) +} + +func (s *SearchIteratorSuite) TestSearchIteratorInit() { + ctx := context.Background() + s.Run("success", func() { + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + + s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionID: 1, + Schema: s.schema.ProtoMessage(), + }, nil).Once() + s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) { + s.Equal(collectionName, sr.GetCollectionName()) + checkSearchParam := func(kvs []*commonpb.KeyValuePair, key string, value string) bool { + for _, kv := range kvs { + if kv.GetKey() == key && kv.GetValue() == value { + return true + } + } + return false + } + + s.True(checkSearchParam(sr.GetSearchParams(), IteratorKey, "true")) + s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchV2Key, "true")) + return &milvuspb.SearchResults{ + Status: merr.Success(), + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 1, + FieldsData: []*schemapb.FieldData{ + s.getInt64FieldData("ID", []int64{1}), + }, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{1}, + }, + }, + }, + Scores: make([]float32, 1), + Topks: []int64{1}, + Recalls: []float32{1}, + SearchIteratorV2Results: &schemapb.SearchIteratorV2Results{ + Token: s.randString(16), + }, + }, + }, nil + }).Once() + + iter, err := s.client.SearchIterator(ctx, NewSearchIteratorOption(collectionName, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 { + return rand.Float32() + })))) + + s.NoError(err) + _, ok := iter.(*searchIteratorV2) + s.True(ok) + }) + + s.Run("failure", func() { + s.Run("option_error", func() { + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + + _, err := s.client.SearchIterator(ctx, NewSearchIteratorOption(collectionName, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 { + return rand.Float32() + }))).WithBatchSize(-1).WithIteratorLimit(-2)) + s.Error(err) + }) + + s.Run("describe_fail", func() { + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + + s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(nil, fmt.Errorf("mock error")).Once() + _, err := s.client.SearchIterator(ctx, NewSearchIteratorOption(collectionName, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 { + return rand.Float32() + })))) + s.Error(err) + }) + + s.Run("not_v2_result", func() { + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionID: 1, + Schema: s.schema.ProtoMessage(), + }, nil).Once() + s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) { + s.Equal(collectionName, sr.GetCollectionName()) + return &milvuspb.SearchResults{ + Status: merr.Success(), + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 1, + FieldsData: []*schemapb.FieldData{ + s.getInt64FieldData("ID", []int64{1}), + }, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{1}, + }, + }, + }, + Scores: make([]float32, 1), + Topks: []int64{1}, + Recalls: []float32{1}, + SearchIteratorV2Results: nil, // nil v2 results + }, + }, nil + }).Once() + + _, err := s.client.SearchIterator(ctx, NewSearchIteratorOption(collectionName, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 { + return rand.Float32() + })))) + s.Error(err) + s.ErrorIs(err, ErrServerVersionIncompatible) + }) + }) +} + +func (s *SearchIteratorSuite) TestNext() { + ctx := context.Background() + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + + token := fmt.Sprintf("iter_token_%s", s.randString(8)) + + s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionID: 1, + Schema: s.schema.ProtoMessage(), + }, nil).Once() + s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) { + s.Equal(collectionName, sr.GetCollectionName()) + checkSearchParam := func(kvs []*commonpb.KeyValuePair, key string, value string) bool { + for _, kv := range kvs { + if kv.GetKey() == key && kv.GetValue() == value { + return true + } + } + return false + } + + s.True(checkSearchParam(sr.GetSearchParams(), IteratorKey, "true")) + s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchV2Key, "true")) + return &milvuspb.SearchResults{ + Status: merr.Success(), + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 1, + FieldsData: []*schemapb.FieldData{ + s.getInt64FieldData("ID", []int64{1}), + }, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{1}, + }, + }, + }, + Scores: make([]float32, 1), + Topks: []int64{1}, + Recalls: []float32{1}, + SearchIteratorV2Results: &schemapb.SearchIteratorV2Results{ + Token: token, + }, + }, + }, nil + }).Once() + + iter, err := s.client.SearchIterator(ctx, NewSearchIteratorOption(collectionName, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 { + return rand.Float32() + })))) + s.Require().NoError(err) + s.Require().NotNil(iter) + + s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) { + s.Equal(collectionName, sr.GetCollectionName()) + checkSearchParam := func(kvs []*commonpb.KeyValuePair, key string, value string) bool { + for _, kv := range kvs { + if kv.GetKey() == key && kv.GetValue() == value { + return true + } + } + return false + } + + s.True(checkSearchParam(sr.GetSearchParams(), IteratorKey, "true")) + s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchV2Key, "true")) + return &milvuspb.SearchResults{ + Status: merr.Success(), + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 1, + FieldsData: []*schemapb.FieldData{ + s.getInt64FieldData("ID", []int64{1}), + }, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{1}, + }, + }, + }, + Scores: []float32{0.5}, + Topks: []int64{1}, + Recalls: []float32{1}, + SearchIteratorV2Results: &schemapb.SearchIteratorV2Results{ + Token: token, + LastBound: 0.5, + }, + }, + }, nil + }).Once() + + rs, err := iter.Next(ctx) + s.NoError(err) + s.EqualValues(1, rs.IDs.Len()) + + s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) { + s.Equal(collectionName, sr.GetCollectionName()) + checkSearchParam := func(kvs []*commonpb.KeyValuePair, key string, value string) bool { + for _, kv := range kvs { + if kv.GetKey() == key && kv.GetValue() == value { + return true + } + } + return false + } + + s.True(checkSearchParam(sr.GetSearchParams(), IteratorKey, "true")) + s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchV2Key, "true")) + s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchIDKey, token)) + s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchLastBoundKey, "0.5")) + + return &milvuspb.SearchResults{ + Status: merr.Success(), + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 1, + FieldsData: []*schemapb.FieldData{ + s.getInt64FieldData("ID", []int64{}), + }, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{}, + }, + }, + }, + Scores: []float32{}, + Topks: []int64{0}, + Recalls: []float32{1.0}, + SearchIteratorV2Results: &schemapb.SearchIteratorV2Results{ + Token: token, + LastBound: 0.5, + }, + }, + }, nil + }).Once() + + _, err = iter.Next(ctx) + s.Error(err) + s.ErrorIs(err, io.EOF) +} + +func (s *SearchIteratorSuite) TestNextWithLimit() { + ctx := context.Background() + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + + token := fmt.Sprintf("iter_token_%s", s.randString(8)) + + s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionID: 1, + Schema: s.schema.ProtoMessage(), + }, nil).Once() + s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) { + s.Equal(collectionName, sr.GetCollectionName()) + checkSearchParam := func(kvs []*commonpb.KeyValuePair, key string, value string) bool { + for _, kv := range kvs { + if kv.GetKey() == key && kv.GetValue() == value { + return true + } + } + return false + } + + s.True(checkSearchParam(sr.GetSearchParams(), IteratorKey, "true")) + s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchV2Key, "true")) + return &milvuspb.SearchResults{ + Status: merr.Success(), + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 1, + FieldsData: []*schemapb.FieldData{ + s.getInt64FieldData("ID", []int64{1}), + }, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{1}, + }, + }, + }, + Scores: make([]float32, 1), + Topks: []int64{5}, + Recalls: []float32{1}, + SearchIteratorV2Results: &schemapb.SearchIteratorV2Results{ + Token: token, + }, + }, + }, nil + }).Once() + + iter, err := s.client.SearchIterator(ctx, NewSearchIteratorOption(collectionName, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 { + return rand.Float32() + }))).WithIteratorLimit(6).WithBatchSize(5)) + s.Require().NoError(err) + s.Require().NotNil(iter) + + s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) { + s.Equal(collectionName, sr.GetCollectionName()) + checkSearchParam := func(kvs []*commonpb.KeyValuePair, key string, value string) bool { + for _, kv := range kvs { + if kv.GetKey() == key && kv.GetValue() == value { + return true + } + } + return false + } + + s.True(checkSearchParam(sr.GetSearchParams(), IteratorKey, "true")) + s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchV2Key, "true")) + return &milvuspb.SearchResults{ + Status: merr.Success(), + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 1, + FieldsData: []*schemapb.FieldData{ + s.getInt64FieldData("ID", []int64{1, 2, 3, 4, 5}), + }, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4, 5}, + }, + }, + }, + Scores: []float32{0.5, 0.4, 0.3, 0.2, 0.1}, + Topks: []int64{5}, + Recalls: []float32{1}, + SearchIteratorV2Results: &schemapb.SearchIteratorV2Results{ + Token: token, + LastBound: 0.5, + }, + }, + }, nil + }).Times(2) + + rs, err := iter.Next(ctx) + s.NoError(err) + s.EqualValues(5, rs.IDs.Len(), "first batch, return all results") + + rs, err = iter.Next(ctx) + s.NoError(err) + s.EqualValues(1, rs.IDs.Len(), "second batch, return sliced results") + + _, err = iter.Next(ctx) + s.Error(err) + s.ErrorIs(err, io.EOF, "limit reached, return EOF") +} + +func TestSearchIterator(t *testing.T) { + suite.Run(t, new(SearchIteratorSuite)) +} diff --git a/client/milvusclient/results.go b/client/milvusclient/results.go index 0f16fecd59..e8e76798ad 100644 --- a/client/milvusclient/results.go +++ b/client/milvusclient/results.go @@ -21,6 +21,7 @@ import ( "runtime/debug" "github.com/cockroachdb/errors" + "github.com/samber/lo" "github.com/milvus-io/milvus/client/v2/column" "github.com/milvus-io/milvus/client/v2/entity" @@ -51,6 +52,31 @@ func (rs *ResultSet) GetColumn(fieldName string) column.Column { return nil } +func (rs ResultSet) Len() int { + return rs.ResultCount +} + +func (rs ResultSet) Slice(start, end int) ResultSet { + result := ResultSet{ + sch: rs.sch, + IDs: rs.IDs.Slice(start, end), + Fields: lo.Map(rs.Fields, func(column column.Column, _ int) column.Column { + return column.Slice(start, end) + }), + // Recall will not be sliced + Err: rs.Err, + } + + if rs.GroupByValue != nil { + result.GroupByValue = rs.GroupByValue.Slice(start, end) + } + + result.ResultCount = result.IDs.Len() + result.Scores = rs.Scores[start : start+result.ResultCount] + + return result +} + // Unmarshal puts dataset into receiver in row based way. // `receiver` shall be a slice of pointer of model struct // eg, []*Records, in which type `Record` defines the row data. diff --git a/tests/go_client/base/milvus_client.go b/tests/go_client/base/milvus_client.go index 48fee24cb2..36061b10b1 100644 --- a/tests/go_client/base/milvus_client.go +++ b/tests/go_client/base/milvus_client.go @@ -323,6 +323,11 @@ func (mc *MilvusClient) HybridSearch(ctx context.Context, option client.HybridSe return resultSets, err } +func (mc *MilvusClient) SearchIterator(ctx context.Context, option client.SearchIteratorOption, callOptions ...grpc.CallOption) (client.SearchIterator, error) { + searchIterator, err := mc.mClient.SearchIterator(ctx, option, callOptions...) + return searchIterator, err +} + // ListResourceGroups list all resource groups func (mc *MilvusClient) ListResourceGroups(ctx context.Context, option client.ListResourceGroupsOption, callOptions ...grpc.CallOption) ([]string, error) { resourceGroups, err := mc.mClient.ListResourceGroups(ctx, option, callOptions...) diff --git a/tests/go_client/common/consts.go b/tests/go_client/common/consts.go index 0da518ddf2..bd64b79a3b 100644 --- a/tests/go_client/common/consts.go +++ b/tests/go_client/common/consts.go @@ -71,6 +71,7 @@ const ( MaxTopK = 16384 MaxVectorFieldNum = 4 MaxShardNum = 16 + DefaultBatchSize = 1000 ) const ( diff --git a/tests/go_client/common/response_checker.go b/tests/go_client/common/response_checker.go index 501ea4cf03..d463ec63c9 100644 --- a/tests/go_client/common/response_checker.go +++ b/tests/go_client/common/response_checker.go @@ -1,7 +1,9 @@ package common import ( + "context" "fmt" + "io" "reflect" "strings" "testing" @@ -182,6 +184,63 @@ func CheckQueryResult(t *testing.T, expColumns []column.Column, actualColumns [] } } +type CheckIteratorOption func(opt *checkIteratorOpt) + +type checkIteratorOpt struct { + expBatchSize []int + expOutputFields []string +} + +func WithExpBatchSize(expBatchSize []int) CheckIteratorOption { + return func(opt *checkIteratorOpt) { + opt.expBatchSize = expBatchSize + } +} + +func WithExpOutputFields(expOutputFields []string) CheckIteratorOption { + return func(opt *checkIteratorOpt) { + opt.expOutputFields = expOutputFields + } +} + +// check queryIterator: result limit, each batch size, output fields +func CheckSearchIteratorResult(ctx context.Context, t *testing.T, itr client.SearchIterator, expLimit int, opts ...CheckIteratorOption) { + opt := &checkIteratorOpt{} + for _, o := range opts { + o(opt) + } + actualLimit := 0 + var actualBatchSize []int + for { + rs, err := itr.Next(ctx) + if err != nil { + if err == io.EOF { + break + } else { + log.Error("SearchIterator next gets error", zap.Error(err)) + break + } + } + + if opt.expBatchSize != nil { + actualBatchSize = append(actualBatchSize, rs.ResultCount) + } + var actualOutputFields []string + if opt.expOutputFields != nil { + for _, column := range rs.Fields { + actualOutputFields = append(actualOutputFields, column.Name()) + } + require.ElementsMatch(t, opt.expOutputFields, actualOutputFields) + } + actualLimit = actualLimit + rs.ResultCount + } + require.Equal(t, expLimit, actualLimit) + if opt.expBatchSize != nil { + log.Debug("SearchIterator result len", zap.Any("result len", actualBatchSize)) + require.True(t, EqualIntSlice(opt.expBatchSize, actualBatchSize)) + } +} + // GenColumnDataOption -- create column data -- type checkIndexOpt struct { state index.IndexState diff --git a/tests/go_client/common/utils.go b/tests/go_client/common/utils.go index df6957e5cf..dbbf114a4d 100644 --- a/tests/go_client/common/utils.go +++ b/tests/go_client/common/utils.go @@ -137,7 +137,7 @@ type InvalidExprStruct struct { } var InvalidExpressions = []InvalidExprStruct{ - {Expr: "id in [0]", ErrNil: true, ErrMsg: "fieldName(id) not found"}, // not exist field but no error + {Expr: "id in [0]", ErrNil: true, ErrMsg: "fieldName(id) not found"}, // not exist field but no error, because enable dynamic {Expr: "int64 in not [0]", ErrNil: false, ErrMsg: "cannot parse expression"}, // wrong term expr keyword {Expr: "int64 < floatVec", ErrNil: false, ErrMsg: "not supported"}, // unsupported compare field {Expr: "floatVec in [0]", ErrNil: false, ErrMsg: "cannot be casted to FloatVector"}, // value and field type mismatch @@ -218,3 +218,15 @@ func GenText(lang string) string { func IsZeroValue(value interface{}) bool { return reflect.DeepEqual(value, reflect.Zero(reflect.TypeOf(value)).Interface()) } + +func EqualIntSlice(a []int, b []int) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/tests/go_client/testcases/helper/read_helper.go b/tests/go_client/testcases/helper/read_helper.go index fcd66ca684..b88caffe84 100644 --- a/tests/go_client/testcases/helper/read_helper.go +++ b/tests/go_client/testcases/helper/read_helper.go @@ -2,6 +2,7 @@ package helper import ( "github.com/milvus-io/milvus/client/v2/entity" + "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/tests/go_client/common" ) @@ -102,3 +103,22 @@ func GenFp16OrBf16VectorsFromFloatVector(nq int, dim int, dataType entity.FieldT } return vectors } + +func GenBatchSizes(limit int, batch int) []int { + if batch == 0 { + log.Fatal("Batch should be larger than 0") + } + if limit == 0 { + return []int{} + } + _loop := limit / batch + _last := limit % batch + batchSizes := make([]int, 0, _loop+1) + for i := 0; i < _loop; i++ { + batchSizes = append(batchSizes, batch) + } + if _last > 0 { + batchSizes = append(batchSizes, _last) + } + return batchSizes +} diff --git a/tests/go_client/testcases/search_iterator_test.go b/tests/go_client/testcases/search_iterator_test.go new file mode 100644 index 0000000000..fb207fb581 --- /dev/null +++ b/tests/go_client/testcases/search_iterator_test.go @@ -0,0 +1,437 @@ +package testcases + +import ( + "fmt" + "io" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/client/v2/column" + "github.com/milvus-io/milvus/client/v2/entity" + client "github.com/milvus-io/milvus/client/v2/milvusclient" + "github.com/milvus-io/milvus/pkg/v2/log" + "github.com/milvus-io/milvus/tests/go_client/common" + hp "github.com/milvus-io/milvus/tests/go_client/testcases/helper" +) + +func TestSearchIteratorDefault(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + // create -> insert -> flush -> index -> load + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64Vec), hp.TNewFieldsOption(), hp.TNewSchemaOption(), + hp.TWithConsistencyLevel(entity.ClStrong)) + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(common.DefaultNb*2)) + prepare.FlushData(ctx, t, mc, schema.CollectionName) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + vector := entity.FloatVector(common.GenFloatVector(common.DefaultDim)) + + // search iterator default + itr, err := mc.SearchIterator(ctx, client.NewSearchIteratorOption(schema.CollectionName, vector)) + common.CheckErr(t, err, true) + actualLimit := 0 + for { + rs, err := itr.Next(ctx) + if err != nil { + if err == io.EOF { + break + } else { + log.Error("SearchIterator next gets error", zap.Error(err)) + break + } + } + actualLimit = actualLimit + rs.ResultCount + } + require.LessOrEqual(t, actualLimit, common.DefaultNb*2) + + // search iterator with limit + limit := 2000 + itr, err = mc.SearchIterator(ctx, client.NewSearchIteratorOption(schema.CollectionName, vector).WithIteratorLimit(int64(limit))) + common.CheckErr(t, err, true) + common.CheckSearchIteratorResult(ctx, t, itr, limit, common.WithExpBatchSize(hp.GenBatchSizes(limit, common.DefaultBatchSize))) +} + +func TestSearchIteratorGrowing(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + // create -> insert -> flush -> index -> load + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64Vec), hp.TNewFieldsOption(), hp.TNewSchemaOption(), + hp.TWithConsistencyLevel(entity.ClStrong)) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(common.DefaultNb*2)) + + // search iterator growing + vector := entity.FloatVector(common.GenFloatVector(common.DefaultDim)) + // wait limit support + limit := 1000 + itr, err := mc.SearchIterator(ctx, client.NewSearchIteratorOption(schema.CollectionName, vector).WithIteratorLimit(int64(limit)).WithBatchSize(100)) + common.CheckErr(t, err, true) + common.CheckSearchIteratorResult(ctx, t, itr, limit, common.WithExpBatchSize(hp.GenBatchSizes(limit, 100))) +} + +func TestSearchIteratorHitEmpty(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + // create -> insert -> flush -> index -> load + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64Vec), hp.TNewFieldsOption(), hp.TNewSchemaOption(), + hp.TWithConsistencyLevel(entity.ClStrong)) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + // search + vector := entity.FloatVector(common.GenFloatVector(common.DefaultDim)) + itr, err := mc.SearchIterator(ctx, client.NewSearchIteratorOption(schema.CollectionName, vector)) + common.CheckErr(t, err, true) + common.CheckSearchIteratorResult(ctx, t, itr, 0, common.WithExpBatchSize(hp.GenBatchSizes(0, common.DefaultBatchSize))) +} + +func TestSearchIteratorBatchSize(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + // create -> insert -> flush -> index -> load + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64Vec), hp.TNewFieldsOption(), hp.TNewSchemaOption(), + hp.TWithConsistencyLevel(entity.ClStrong)) + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(common.DefaultNb)) + prepare.FlushData(ctx, t, mc, schema.CollectionName) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + // search iterator with special limit: 0, -1, -2 + vector := entity.FloatVector(common.GenFloatVector(common.DefaultDim)) + itr, err := mc.SearchIterator(ctx, client.NewSearchIteratorOption(schema.CollectionName, vector).WithIteratorLimit(0)) + common.CheckErr(t, err, true) + common.CheckSearchIteratorResult(ctx, t, itr, 0, common.WithExpBatchSize(hp.GenBatchSizes(0, common.DefaultBatchSize))) + + for _, _limit := range []int64{-1, -2} { + itr, err = mc.SearchIterator(ctx, client.NewSearchIteratorOption(schema.CollectionName, vector).WithIteratorLimit(_limit)) + common.CheckErr(t, err, true) + actualLimit := 0 + for { + rs, err := itr.Next(ctx) + if err != nil { + if err == io.EOF { + break + } + log.Error("SearchIterator next gets error", zap.Error(err)) + break + } + actualLimit = actualLimit + rs.ResultCount + require.LessOrEqual(t, rs.ResultCount, common.DefaultBatchSize) + } + require.LessOrEqual(t, actualLimit, common.DefaultNb) + } + + // search iterator + type batchStruct struct { + batch int + expBatchSize []int + } + limit := 201 + batchStructs := []batchStruct{ + {batch: limit / 2, expBatchSize: hp.GenBatchSizes(limit, limit/2)}, + {batch: limit, expBatchSize: hp.GenBatchSizes(limit, limit)}, + {batch: limit + 1, expBatchSize: hp.GenBatchSizes(limit, limit+1)}, + } + + for _, _batchStruct := range batchStructs { + vector := entity.FloatVector(common.GenFloatVector(common.DefaultDim)) + itr, err := mc.SearchIterator(ctx, client.NewSearchIteratorOption(schema.CollectionName, vector).WithIteratorLimit(int64(limit)).WithBatchSize(_batchStruct.batch)) + common.CheckErr(t, err, true) + common.CheckSearchIteratorResult(ctx, t, itr, limit, common.WithExpBatchSize(_batchStruct.expBatchSize)) + } +} + +func TestSearchIteratorOutputAllFields(t *testing.T) { + t.Parallel() + for _, dynamic := range [2]bool{false, true} { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.AllFields), hp.TNewFieldsOption(), + hp.TNewSchemaOption().TWithEnableDynamicField(dynamic), hp.TWithConsistencyLevel(entity.ClStrong)) + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption()) + prepare.FlushData(ctx, t, mc, schema.CollectionName) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + var allFieldsName []string + for _, field := range schema.Fields { + allFieldsName = append(allFieldsName, field.Name) + } + if dynamic { + allFieldsName = append(allFieldsName, common.DefaultDynamicFieldName) + } + vector := entity.FloatVector(common.GenFloatVector(common.DefaultDim)) + itr, err := mc.SearchIterator(ctx, client.NewSearchIteratorOption(schema.CollectionName, vector).WithANNSField(common.DefaultFloatVecFieldName). + WithOutputFields("*").WithIteratorLimit(100).WithBatchSize(12)) + common.CheckErr(t, err, true) + common.CheckSearchIteratorResult(ctx, t, itr, 100, common.WithExpBatchSize(hp.GenBatchSizes(100, 12)), common.WithExpOutputFields(allFieldsName)) + } +} + +func TestQueryIteratorOutputSparseFieldsRows(t *testing.T) { + t.Parallel() + // connect + for _, withRows := range [2]bool{true, false} { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64VarcharSparseVec), hp.TNewFieldsOption(), + hp.TNewSchemaOption().TWithEnableDynamicField(true), hp.TWithConsistencyLevel(entity.ClStrong)) + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema).TWithIsRows(withRows), hp.TNewDataOption()) + prepare.FlushData(ctx, t, mc, schema.CollectionName) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + fieldsName := []string{common.DefaultDynamicFieldName} + for _, field := range schema.Fields { + fieldsName = append(fieldsName, field.Name) + } + + // output * fields + vector := common.GenSparseVector(common.DefaultDim) + itr, err := mc.SearchIterator(ctx, client.NewSearchIteratorOption(schema.CollectionName, vector).WithOutputFields("*").WithIteratorLimit(200).WithBatchSize(120)) + common.CheckErr(t, err, true) + common.CheckSearchIteratorResult(ctx, t, itr, 200, common.WithExpBatchSize(hp.GenBatchSizes(200, 120)), common.WithExpOutputFields(fieldsName)) + } +} + +func TestSearchIteratorInvalid(t *testing.T) { + nb := 201 + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + // create -> insert -> flush -> index -> load + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64Vec), hp.TNewFieldsOption(), hp.TNewSchemaOption(), + hp.TWithConsistencyLevel(entity.ClStrong)) + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(nb)) + prepare.FlushData(ctx, t, mc, schema.CollectionName) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + // search iterator with not exist collection name + vector := entity.FloatVector(common.GenFloatVector(common.DefaultDim)) + _, err := mc.SearchIterator(ctx, client.NewSearchIteratorOption(common.GenRandomString("c", 5), vector)) + common.CheckErr(t, err, false, "collection not found") + + // search iterator with not exist partition name + _, err = mc.SearchIterator(ctx, client.NewSearchIteratorOption(schema.CollectionName, vector).WithPartitions(common.GenRandomString("p", 5))) + common.CheckErr(t, err, false, "not found") + _, err = mc.SearchIterator(ctx, client.NewSearchIteratorOption(schema.CollectionName, vector).WithPartitions(common.DefaultPartition, common.GenRandomString("p", 5))) + common.CheckErr(t, err, false, "not found") + + // search iterator with not exist vector field name + _, err = mc.SearchIterator(ctx, client.NewSearchIteratorOption(schema.CollectionName, vector).WithANNSField(common.GenRandomString("f", 5))) + common.CheckErr(t, err, false, "failed to get field schema by name") + + // search iterator with count(*) + _, err = mc.SearchIterator(ctx, client.NewSearchIteratorOption(schema.CollectionName, vector).WithOutputFields(common.QueryCountFieldName)) + common.CheckErr(t, err, false, "field count(*) not exist") + + // search iterator with invalid batch size + for _, batch := range []int{-1, 0, -2} { + _, err := mc.SearchIterator(ctx, client.NewSearchIteratorOption(schema.CollectionName, vector).WithBatchSize(batch)) + common.CheckErr(t, err, false, "batch size must be greater than 0") + } + + itr, err2 := mc.SearchIterator(ctx, client.NewSearchIteratorOption(schema.CollectionName, vector).WithBatchSize(common.MaxTopK+1)) + common.CheckErr(t, err2, true) + _, err2 = itr.Next(ctx) + common.CheckErr(t, err2, false, "batch size is invalid, it should be in range [1, 16384]") + + // search iterator with invalid offset + for _, offset := range []int{-2, -1, common.MaxTopK + 1} { + _, err := mc.SearchIterator(ctx, client.NewSearchIteratorOption(schema.CollectionName, vector).WithOffset(offset)) + common.CheckErr(t, err, false, "it should be in range [1, 16384]") + } +} + +func TestSearchIteratorWithInvalidExpr(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + // create -> insert -> flush -> index -> load + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64VecJSON), hp.TNewFieldsOption(), hp.TNewSchemaOption().TWithEnableDynamicField(true), + hp.TWithConsistencyLevel(entity.ClStrong)) + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(common.DefaultNb)) + prepare.FlushData(ctx, t, mc, schema.CollectionName) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + vector := entity.FloatVector(common.GenFloatVector(common.DefaultDim)) + for _, _invalidExprs := range common.InvalidExpressions { + t.Log(_invalidExprs) + _, err := mc.SearchIterator(ctx, client.NewSearchIteratorOption(schema.CollectionName, vector).WithFilter(_invalidExprs.Expr)) + common.CheckErr(t, err, _invalidExprs.ErrNil, _invalidExprs.ErrMsg, "") + } +} + +func TestSearchIteratorTemplateKey(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + // create -> insert -> flush -> index -> load + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64Vec), hp.TNewFieldsOption(), hp.TNewSchemaOption(), + hp.TWithConsistencyLevel(entity.ClStrong)) + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(common.DefaultNb*2)) + prepare.FlushData(ctx, t, mc, schema.CollectionName) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + vector := entity.FloatVector(common.GenFloatVector(common.DefaultDim)) + + // search iterator default + value := 2000 + itr, err := mc.SearchIterator(ctx, client.NewSearchIteratorOption(schema.CollectionName, vector).WithIteratorLimit(100).WithBatchSize(10). + WithFilter(fmt.Sprintf("%s < {key}", common.DefaultInt64FieldName)).WithTemplateParam("key", value)) + common.CheckErr(t, err, true) + actualLimit := 0 + for { + rs, err := itr.Next(ctx) + if err != nil { + if err == io.EOF { + break + } + log.Error("SearchIterator next gets error", zap.Error(err)) + break + } + actualLimit = actualLimit + rs.ResultCount + require.Equal(t, 10, rs.ResultCount) + + // check result ids < value + for _, id := range rs.IDs.(*column.ColumnInt64).Data() { + require.Less(t, id, int64(value)) + } + } + require.LessOrEqual(t, actualLimit, 100) +} + +func TestSearchIteratorGroupBy(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + // create -> insert -> flush -> index -> load + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64Vec), hp.TNewFieldsOption(), hp.TNewSchemaOption(), + hp.TWithConsistencyLevel(entity.ClStrong)) + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(common.DefaultNb)) + prepare.FlushData(ctx, t, mc, schema.CollectionName) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + vector := entity.FloatVector(common.GenFloatVector(common.DefaultDim)) + _, err := mc.SearchIterator(ctx, client.NewSearchIteratorOption(schema.CollectionName, vector).WithGroupByField(common.DefaultInt64FieldName). + WithIteratorLimit(500).WithBatchSize(100)) + common.CheckErr(t, err, false, "Not allowed to do groupBy when doing iteration") +} + +func TestSearchIteratorIgnoreGrowing(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + // create -> insert -> flush -> index -> load + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64Vec), hp.TNewFieldsOption(), hp.TNewSchemaOption(), + hp.TWithConsistencyLevel(entity.ClStrong)) + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(common.DefaultNb)) + prepare.FlushData(ctx, t, mc, schema.CollectionName) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + // growing pk [DefaultNb, DefaultNb*2] + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption().TWithNb(common.DefaultNb).TWithStart(common.DefaultNb)) + + // search iterator growing + vector := entity.FloatVector(common.GenFloatVector(common.DefaultDim)) + itr, err := mc.SearchIterator(ctx, client.NewSearchIteratorOption(schema.CollectionName, vector).WithIgnoreGrowing(true).WithIteratorLimit(100).WithBatchSize(10)) + common.CheckErr(t, err, true) + actualLimit := 0 + for { + rs, err := itr.Next(ctx) + if err != nil { + if err == io.EOF { + break + } + log.Error("SearchIterator next gets error", zap.Error(err)) + break + } + actualLimit = actualLimit + rs.ResultCount + for _, id := range rs.IDs.(*column.ColumnInt64).Data() { + require.Less(t, id, int64(common.DefaultNb)) + } + } + require.LessOrEqual(t, actualLimit, 100) +} + +func TestSearchIteratorNull(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true) + vecField := entity.NewField().WithName(common.DefaultFloatVecFieldName).WithDataType(entity.FieldTypeFloatVector).WithDim(common.DefaultDim) + int32NullField := entity.NewField().WithName(common.DefaultInt32FieldName).WithDataType(entity.FieldTypeInt32).WithNullable(true) + schema := entity.NewSchema().WithName(common.GenRandomString("null_int32", 10)).WithField(pkField).WithField(vecField).WithField(int32NullField) + errCreate := mc.CreateCollection(ctx, client.NewCreateCollectionOption(schema.CollectionName, schema).WithConsistencyLevel(entity.ClStrong)) + common.CheckErr(t, errCreate, true) + + prepare := hp.CollPrepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + // Generate test data with boundary values + nb := common.DefaultNb * 3 + pkColumn := hp.GenColumnData(nb, entity.FieldTypeInt64, *hp.TNewDataOption()) + vecColumn := hp.GenColumnData(nb, entity.FieldTypeFloatVector, *hp.TNewDataOption()) + int32Values := make([]int32, 0, nb) + validData := make([]bool, 0, nb) + + // Generate JSON documents + for i := 0; i < nb; i++ { + _mod := i % 2 + if _mod == 0 { + validData = append(validData, false) + } else { + int32Values = append(int32Values, int32(i)) + validData = append(validData, true) + } + } + nullColumn, err := column.NewNullableColumnInt32(common.DefaultInt32FieldName, int32Values, validData) + common.CheckErr(t, err, true) + _, err = mc.Insert(ctx, client.NewColumnBasedInsertOption(schema.CollectionName, pkColumn, vecColumn, nullColumn)) + common.CheckErr(t, err, true) + + // search iterator with null expr + expr := fmt.Sprintf("%s is null", common.DefaultInt32FieldName) + vector := entity.FloatVector(common.GenFloatVector(common.DefaultDim)) + itr, err2 := mc.SearchIterator(ctx, client.NewSearchIteratorOption(schema.CollectionName, vector).WithFilter(expr).WithIteratorLimit(100).WithBatchSize(10).WithOutputFields(common.DefaultInt32FieldName)) + common.CheckErr(t, err2, true) + actualLimit := 0 + for { + rs, err := itr.Next(ctx) + if err != nil { + if err == io.EOF { + break + } + log.Error("SearchIterator next gets error", zap.Error(err)) + break + } + actualLimit = actualLimit + rs.ResultCount + require.Equal(t, 10, rs.ResultCount) + for _, field := range rs.Fields { + if field.Name() == common.DefaultInt32FieldName { + for i := 0; i < field.Len(); i++ { + isNull, err := field.IsNull(i) + common.CheckErr(t, err, true) + require.True(t, isNull) + } + } + } + } + require.LessOrEqual(t, actualLimit, 100) +}