From 18d8dc82b8d6057b0e84bce31057d39f48a425ea Mon Sep 17 00:00:00 2001 From: congqixia Date: Tue, 29 Jul 2025 11:21:35 +0800 Subject: [PATCH] feat: [GoSDK] Support search iterator v2 (#43612) Related to #37548 Also link #43122 This patch implements basic functions of search iterator v2. --------- Signed-off-by: Congqi Xia --- client/milvusclient/iterator.go | 186 +++++++++++++++ client/milvusclient/iterator_option.go | 119 ++++++++++ client/milvusclient/iterator_test.go | 303 +++++++++++++++++++++++++ 3 files changed, 608 insertions(+) create mode 100644 client/milvusclient/iterator.go create mode 100644 client/milvusclient/iterator_option.go create mode 100644 client/milvusclient/iterator_test.go diff --git a/client/milvusclient/iterator.go b/client/milvusclient/iterator.go new file mode 100644 index 0000000000..adee5cf28d --- /dev/null +++ b/client/milvusclient/iterator.go @@ -0,0 +1,186 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package milvusclient + +import ( + "context" + "fmt" + "io" + + "github.com/cockroachdb/errors" + "google.golang.org/grpc" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/client/v2/entity" + "github.com/milvus-io/milvus/pkg/v2/util/merr" +) + +const ( + // IteratorKey is the const search param key in indicating enabling iterator. + IteratorKey = "iterator" + IteratorSessionTsKey = "iterator_session_ts" + IteratorSearchV2Key = "search_iter_v2" + IteratorSearchBatchSizeKey = "search_iter_batch_size" + IteratorSearchLastBoundKey = "search_iter_last_bound" + IteratorSearchIDKey = "search_iter_id" + CollectionIDKey = `collection_id` +) + +var ErrServerVersionIncompatible = errors.New("server version incompatible") + +// SearchIterator is the interface for search iterator. +type SearchIterator interface { + // Next returns next batch of iterator + // when iterator reaches the end, return `io.EOF`. + Next(ctx context.Context) (ResultSet, error) +} + +type searchIteratorV2 struct { + client *Client + option SearchIteratorOption + schema *entity.Schema +} + +func (it *searchIteratorV2) Next(ctx context.Context) (ResultSet, error) { + return it.next(ctx) +} + +func (it *searchIteratorV2) next(ctx context.Context) (ResultSet, error) { + opt := it.option.SearchOption() + req, err := opt.Request() + if err != nil { + return ResultSet{}, err + } + + var rs ResultSet + + err = it.client.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.Search(ctx, req) + err = merr.CheckRPCCall(resp, err) + if err != nil { + return err + } + + iteratorInfo := resp.GetResults().GetSearchIteratorV2Results() + opt.annRequest.WithSearchParam(IteratorSearchIDKey, iteratorInfo.GetToken()) + opt.annRequest.WithSearchParam(IteratorSearchLastBoundKey, fmt.Sprintf("%v", iteratorInfo.GetLastBound())) + + resultSets, err := it.client.handleSearchResult(it.schema, req.GetOutputFields(), int(resp.GetResults().GetNumQueries()), resp) + if err != nil { + return err + } + rs = resultSets[0] + + if rs.IDs.Len() == 0 { + return io.EOF + } + + return nil + }) + return rs, err +} + +func (it *searchIteratorV2) setupCollectionID(ctx context.Context) error { + opt := it.option.SearchOption() + + return it.client.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{ + CollectionName: opt.collectionName, + }) + if merr.CheckRPCCall(resp, err) != nil { + return err + } + + opt.WithSearchParam(CollectionIDKey, fmt.Sprintf("%d", resp.GetCollectionID())) + schema := &entity.Schema{} + it.schema = schema.ReadProto(resp.GetSchema()) + return nil + }) +} + +// probeCompatiblity checks if the server support SearchIteratorV2. +// It checks if the search result contains search iterator v2 results info and token. +func (it *searchIteratorV2) probeCompatiblity(ctx context.Context) error { + opt := it.option.SearchOption() + opt.annRequest.topK = 1 // ok to leave it here, will be overwritten in next iteration + opt.annRequest.WithSearchParam(IteratorSearchBatchSizeKey, "1") + req, err := opt.Request() + if err != nil { + return err + } + return it.client.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.Search(ctx, req) + err = merr.CheckRPCCall(resp, err) + if err != nil { + return err + } + + if resp.GetResults().GetSearchIteratorV2Results() == nil || resp.GetResults().GetSearchIteratorV2Results().GetToken() == "" { + return ErrServerVersionIncompatible + } + return nil + }) +} + +// newSearchIteratorV2 creates a new search iterator V2. +// +// It sets up the collection ID and checks if the server supports search iterator V2. +// If the server does not support search iterator V2, it returns an error. +func newSearchIteratorV2(ctx context.Context, client *Client, option SearchIteratorOption) (*searchIteratorV2, error) { + iter := &searchIteratorV2{ + client: client, + option: option, + } + if err := iter.setupCollectionID(ctx); err != nil { + return nil, err + } + + if err := iter.probeCompatiblity(ctx); err != nil { + return nil, err + } + + return iter, nil +} + +type searchIteratorV1 struct { + client *Client +} + +func (s *searchIteratorV1) Next(_ context.Context) (ResultSet, error) { + return ResultSet{}, errors.New("not implemented") +} + +func newSearchIteratorV1(_ *Client) (*searchIteratorV1, error) { + // search iterator v1 is not supported + return nil, ErrServerVersionIncompatible +} + +// SearchIterator creates a search iterator from a collection. +// +// If the server supports search iterator V2, it creates a search iterator V2. +func (c *Client) SearchIterator(ctx context.Context, option SearchIteratorOption, callOptions ...grpc.CallOption) (SearchIterator, error) { + iter, err := newSearchIteratorV2(ctx, c, option) + if err == nil { + return iter, nil + } + + if !errors.Is(err, ErrServerVersionIncompatible) { + return nil, err + } + + return newSearchIteratorV1(c) +} diff --git a/client/milvusclient/iterator_option.go b/client/milvusclient/iterator_option.go new file mode 100644 index 0000000000..e6ed4e7b02 --- /dev/null +++ b/client/milvusclient/iterator_option.go @@ -0,0 +1,119 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package milvusclient + +import ( + "fmt" + + "github.com/milvus-io/milvus/client/v2/entity" + "github.com/milvus-io/milvus/client/v2/index" +) + +type SearchIteratorOption interface { + SearchOption() *searchOption +} + +type searchIteratorOption struct { + *searchOption + batchSize int +} + +func (opt *searchIteratorOption) SearchOption() *searchOption { + opt.annRequest.topK = opt.batchSize + opt.WithSearchParam(IteratorSearchBatchSizeKey, fmt.Sprintf("%d", opt.batchSize)) + return opt.searchOption +} + +func (opt *searchIteratorOption) WithBatchSize(batchSize int) *searchIteratorOption { + opt.batchSize = batchSize + return opt +} + +func (opt *searchIteratorOption) WithPartitions(partitionNames ...string) *searchIteratorOption { + opt.partitionNames = partitionNames + return opt +} + +func (opt *searchIteratorOption) WithFilter(expr string) *searchIteratorOption { + opt.annRequest.WithFilter(expr) + return opt +} + +func (opt *searchIteratorOption) WithTemplateParam(key string, val any) *searchIteratorOption { + opt.annRequest.WithTemplateParam(key, val) + return opt +} + +func (opt *searchIteratorOption) WithOffset(offset int) *searchIteratorOption { + opt.annRequest.WithOffset(offset) + return opt +} + +func (opt *searchIteratorOption) WithOutputFields(fieldNames ...string) *searchIteratorOption { + opt.outputFields = fieldNames + return opt +} + +func (opt *searchIteratorOption) WithConsistencyLevel(consistencyLevel entity.ConsistencyLevel) *searchIteratorOption { + opt.consistencyLevel = consistencyLevel + opt.useDefaultConsistencyLevel = false + return opt +} + +func (opt *searchIteratorOption) WithANNSField(annsField string) *searchIteratorOption { + opt.annRequest.WithANNSField(annsField) + return opt +} + +func (opt *searchIteratorOption) WithGroupByField(groupByField string) *searchIteratorOption { + opt.annRequest.WithGroupByField(groupByField) + return opt +} + +func (opt *searchIteratorOption) WithGroupSize(groupSize int) *searchIteratorOption { + opt.annRequest.WithGroupSize(groupSize) + return opt +} + +func (opt *searchIteratorOption) WithStrictGroupSize(strictGroupSize bool) *searchIteratorOption { + opt.annRequest.WithStrictGroupSize(strictGroupSize) + return opt +} + +func (opt *searchIteratorOption) WithIgnoreGrowing(ignoreGrowing bool) *searchIteratorOption { + opt.annRequest.WithIgnoreGrowing(ignoreGrowing) + return opt +} + +func (opt *searchIteratorOption) WithAnnParam(ap index.AnnParam) *searchIteratorOption { + opt.annRequest.WithAnnParam(ap) + return opt +} + +func (opt *searchIteratorOption) WithSearchParam(key, value string) *searchIteratorOption { + opt.annRequest.WithSearchParam(key, value) + return opt +} + +func NewSearchIteratorOption(collectionName string, vector entity.Vector) *searchIteratorOption { + return &searchIteratorOption{ + searchOption: NewSearchOption(collectionName, 100, []entity.Vector{vector}). + WithSearchParam(IteratorKey, "true"). + WithSearchParam(IteratorSearchV2Key, "true"), + batchSize: 1000, + } +} diff --git a/client/milvusclient/iterator_test.go b/client/milvusclient/iterator_test.go new file mode 100644 index 0000000000..2a11783e87 --- /dev/null +++ b/client/milvusclient/iterator_test.go @@ -0,0 +1,303 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package milvusclient + +import ( + "context" + "fmt" + "io" + "math/rand" + "testing" + + "github.com/samber/lo" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/client/v2/entity" + "github.com/milvus-io/milvus/pkg/v2/util/merr" +) + +type SearchIteratorSuite struct { + MockSuiteBase + + schema *entity.Schema +} + +func (s *SearchIteratorSuite) SetupSuite() { + s.MockSuiteBase.SetupSuite() + s.schema = entity.NewSchema(). + WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)). + WithField(entity.NewField().WithName("Vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128)) +} + +func (s *SearchIteratorSuite) TestSearchIteratorInit() { + ctx := context.Background() + s.Run("success", func() { + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + + s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionID: 1, + Schema: s.schema.ProtoMessage(), + }, nil).Once() + s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) { + s.Equal(collectionName, sr.GetCollectionName()) + checkSearchParam := func(kvs []*commonpb.KeyValuePair, key string, value string) bool { + for _, kv := range kvs { + if kv.GetKey() == key && kv.GetValue() == value { + return true + } + } + return false + } + + s.True(checkSearchParam(sr.GetSearchParams(), IteratorKey, "true")) + s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchV2Key, "true")) + return &milvuspb.SearchResults{ + Status: merr.Success(), + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 1, + FieldsData: []*schemapb.FieldData{ + s.getInt64FieldData("ID", []int64{1}), + }, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{1}, + }, + }, + }, + Scores: make([]float32, 1), + Topks: []int64{1}, + Recalls: []float32{1}, + SearchIteratorV2Results: &schemapb.SearchIteratorV2Results{ + Token: s.randString(16), + }, + }, + }, nil + }).Once() + + iter, err := s.client.SearchIterator(ctx, NewSearchIteratorOption(collectionName, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 { + return rand.Float32() + })))) + + s.NoError(err) + _, ok := iter.(*searchIteratorV2) + s.True(ok) + }) + + s.Run("failure", func() { + s.Run("describe_fail", func() { + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + + s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(nil, fmt.Errorf("mock error")).Once() + _, err := s.client.SearchIterator(ctx, NewSearchIteratorOption(collectionName, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 { + return rand.Float32() + })))) + s.Error(err) + }) + + s.Run("not_v2_result", func() { + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionID: 1, + Schema: s.schema.ProtoMessage(), + }, nil).Once() + s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) { + s.Equal(collectionName, sr.GetCollectionName()) + return &milvuspb.SearchResults{ + Status: merr.Success(), + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 1, + FieldsData: []*schemapb.FieldData{ + s.getInt64FieldData("ID", []int64{1}), + }, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{1}, + }, + }, + }, + Scores: make([]float32, 1), + Topks: []int64{1}, + Recalls: []float32{1}, + SearchIteratorV2Results: nil, // nil v2 results + }, + }, nil + }).Once() + + _, err := s.client.SearchIterator(ctx, NewSearchIteratorOption(collectionName, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 { + return rand.Float32() + })))) + s.Error(err) + s.ErrorIs(err, ErrServerVersionIncompatible) + }) + }) +} + +func (s *SearchIteratorSuite) TestNext() { + ctx := context.Background() + collectionName := fmt.Sprintf("coll_%s", s.randString(6)) + + token := fmt.Sprintf("iter_token_%s", s.randString(8)) + + s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionID: 1, + Schema: s.schema.ProtoMessage(), + }, nil).Once() + s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) { + s.Equal(collectionName, sr.GetCollectionName()) + checkSearchParam := func(kvs []*commonpb.KeyValuePair, key string, value string) bool { + for _, kv := range kvs { + if kv.GetKey() == key && kv.GetValue() == value { + return true + } + } + return false + } + + s.True(checkSearchParam(sr.GetSearchParams(), IteratorKey, "true")) + s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchV2Key, "true")) + return &milvuspb.SearchResults{ + Status: merr.Success(), + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 1, + FieldsData: []*schemapb.FieldData{ + s.getInt64FieldData("ID", []int64{1}), + }, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{1}, + }, + }, + }, + Scores: make([]float32, 1), + Topks: []int64{1}, + Recalls: []float32{1}, + SearchIteratorV2Results: &schemapb.SearchIteratorV2Results{ + Token: token, + }, + }, + }, nil + }).Once() + + iter, err := s.client.SearchIterator(ctx, NewSearchIteratorOption(collectionName, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 { + return rand.Float32() + })))) + s.Require().NoError(err) + s.Require().NotNil(iter) + + s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) { + s.Equal(collectionName, sr.GetCollectionName()) + checkSearchParam := func(kvs []*commonpb.KeyValuePair, key string, value string) bool { + for _, kv := range kvs { + if kv.GetKey() == key && kv.GetValue() == value { + return true + } + } + return false + } + + s.True(checkSearchParam(sr.GetSearchParams(), IteratorKey, "true")) + s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchV2Key, "true")) + return &milvuspb.SearchResults{ + Status: merr.Success(), + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 1, + FieldsData: []*schemapb.FieldData{ + s.getInt64FieldData("ID", []int64{1}), + }, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{1}, + }, + }, + }, + Scores: []float32{0.5}, + Topks: []int64{1}, + Recalls: []float32{1}, + SearchIteratorV2Results: &schemapb.SearchIteratorV2Results{ + Token: token, + LastBound: 0.5, + }, + }, + }, nil + }).Once() + + rs, err := iter.Next(ctx) + s.NoError(err) + s.EqualValues(1, rs.IDs.Len()) + + s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) { + s.Equal(collectionName, sr.GetCollectionName()) + checkSearchParam := func(kvs []*commonpb.KeyValuePair, key string, value string) bool { + for _, kv := range kvs { + if kv.GetKey() == key && kv.GetValue() == value { + return true + } + } + return false + } + + s.True(checkSearchParam(sr.GetSearchParams(), IteratorKey, "true")) + s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchV2Key, "true")) + s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchIDKey, token)) + s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchLastBoundKey, "0.5")) + + return &milvuspb.SearchResults{ + Status: merr.Success(), + Results: &schemapb.SearchResultData{ + NumQueries: 1, + TopK: 1, + FieldsData: []*schemapb.FieldData{ + s.getInt64FieldData("ID", []int64{}), + }, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{}, + }, + }, + }, + Scores: []float32{}, + Topks: []int64{0}, + Recalls: []float32{1.0}, + SearchIteratorV2Results: &schemapb.SearchIteratorV2Results{ + Token: token, + LastBound: 0.5, + }, + }, + }, nil + }).Once() + + _, err = iter.Next(ctx) + s.Error(err) + s.ErrorIs(err, io.EOF) +} + +func TestSearchIterator(t *testing.T) { + suite.Run(t, new(SearchIteratorSuite)) +}