feat: [GoSDK] Support search iterator v2 (#43612)

Related to #37548

Also link #43122

This patch implements basic functions of search iterator v2.

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
congqixia 2025-07-29 11:21:35 +08:00 committed by GitHub
parent c9412434c8
commit 18d8dc82b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 608 additions and 0 deletions

View File

@ -0,0 +1,186 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package milvusclient
import (
"context"
"fmt"
"io"
"github.com/cockroachdb/errors"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
const (
// IteratorKey is the const search param key in indicating enabling iterator.
IteratorKey = "iterator"
IteratorSessionTsKey = "iterator_session_ts"
IteratorSearchV2Key = "search_iter_v2"
IteratorSearchBatchSizeKey = "search_iter_batch_size"
IteratorSearchLastBoundKey = "search_iter_last_bound"
IteratorSearchIDKey = "search_iter_id"
CollectionIDKey = `collection_id`
)
var ErrServerVersionIncompatible = errors.New("server version incompatible")
// SearchIterator is the interface for search iterator.
type SearchIterator interface {
// Next returns next batch of iterator
// when iterator reaches the end, return `io.EOF`.
Next(ctx context.Context) (ResultSet, error)
}
type searchIteratorV2 struct {
client *Client
option SearchIteratorOption
schema *entity.Schema
}
func (it *searchIteratorV2) Next(ctx context.Context) (ResultSet, error) {
return it.next(ctx)
}
func (it *searchIteratorV2) next(ctx context.Context) (ResultSet, error) {
opt := it.option.SearchOption()
req, err := opt.Request()
if err != nil {
return ResultSet{}, err
}
var rs ResultSet
err = it.client.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.Search(ctx, req)
err = merr.CheckRPCCall(resp, err)
if err != nil {
return err
}
iteratorInfo := resp.GetResults().GetSearchIteratorV2Results()
opt.annRequest.WithSearchParam(IteratorSearchIDKey, iteratorInfo.GetToken())
opt.annRequest.WithSearchParam(IteratorSearchLastBoundKey, fmt.Sprintf("%v", iteratorInfo.GetLastBound()))
resultSets, err := it.client.handleSearchResult(it.schema, req.GetOutputFields(), int(resp.GetResults().GetNumQueries()), resp)
if err != nil {
return err
}
rs = resultSets[0]
if rs.IDs.Len() == 0 {
return io.EOF
}
return nil
})
return rs, err
}
func (it *searchIteratorV2) setupCollectionID(ctx context.Context) error {
opt := it.option.SearchOption()
return it.client.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{
CollectionName: opt.collectionName,
})
if merr.CheckRPCCall(resp, err) != nil {
return err
}
opt.WithSearchParam(CollectionIDKey, fmt.Sprintf("%d", resp.GetCollectionID()))
schema := &entity.Schema{}
it.schema = schema.ReadProto(resp.GetSchema())
return nil
})
}
// probeCompatiblity checks if the server support SearchIteratorV2.
// It checks if the search result contains search iterator v2 results info and token.
func (it *searchIteratorV2) probeCompatiblity(ctx context.Context) error {
opt := it.option.SearchOption()
opt.annRequest.topK = 1 // ok to leave it here, will be overwritten in next iteration
opt.annRequest.WithSearchParam(IteratorSearchBatchSizeKey, "1")
req, err := opt.Request()
if err != nil {
return err
}
return it.client.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.Search(ctx, req)
err = merr.CheckRPCCall(resp, err)
if err != nil {
return err
}
if resp.GetResults().GetSearchIteratorV2Results() == nil || resp.GetResults().GetSearchIteratorV2Results().GetToken() == "" {
return ErrServerVersionIncompatible
}
return nil
})
}
// newSearchIteratorV2 creates a new search iterator V2.
//
// It sets up the collection ID and checks if the server supports search iterator V2.
// If the server does not support search iterator V2, it returns an error.
func newSearchIteratorV2(ctx context.Context, client *Client, option SearchIteratorOption) (*searchIteratorV2, error) {
iter := &searchIteratorV2{
client: client,
option: option,
}
if err := iter.setupCollectionID(ctx); err != nil {
return nil, err
}
if err := iter.probeCompatiblity(ctx); err != nil {
return nil, err
}
return iter, nil
}
type searchIteratorV1 struct {
client *Client
}
func (s *searchIteratorV1) Next(_ context.Context) (ResultSet, error) {
return ResultSet{}, errors.New("not implemented")
}
func newSearchIteratorV1(_ *Client) (*searchIteratorV1, error) {
// search iterator v1 is not supported
return nil, ErrServerVersionIncompatible
}
// SearchIterator creates a search iterator from a collection.
//
// If the server supports search iterator V2, it creates a search iterator V2.
func (c *Client) SearchIterator(ctx context.Context, option SearchIteratorOption, callOptions ...grpc.CallOption) (SearchIterator, error) {
iter, err := newSearchIteratorV2(ctx, c, option)
if err == nil {
return iter, nil
}
if !errors.Is(err, ErrServerVersionIncompatible) {
return nil, err
}
return newSearchIteratorV1(c)
}

View File

@ -0,0 +1,119 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package milvusclient
import (
"fmt"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/client/v2/index"
)
type SearchIteratorOption interface {
SearchOption() *searchOption
}
type searchIteratorOption struct {
*searchOption
batchSize int
}
func (opt *searchIteratorOption) SearchOption() *searchOption {
opt.annRequest.topK = opt.batchSize
opt.WithSearchParam(IteratorSearchBatchSizeKey, fmt.Sprintf("%d", opt.batchSize))
return opt.searchOption
}
func (opt *searchIteratorOption) WithBatchSize(batchSize int) *searchIteratorOption {
opt.batchSize = batchSize
return opt
}
func (opt *searchIteratorOption) WithPartitions(partitionNames ...string) *searchIteratorOption {
opt.partitionNames = partitionNames
return opt
}
func (opt *searchIteratorOption) WithFilter(expr string) *searchIteratorOption {
opt.annRequest.WithFilter(expr)
return opt
}
func (opt *searchIteratorOption) WithTemplateParam(key string, val any) *searchIteratorOption {
opt.annRequest.WithTemplateParam(key, val)
return opt
}
func (opt *searchIteratorOption) WithOffset(offset int) *searchIteratorOption {
opt.annRequest.WithOffset(offset)
return opt
}
func (opt *searchIteratorOption) WithOutputFields(fieldNames ...string) *searchIteratorOption {
opt.outputFields = fieldNames
return opt
}
func (opt *searchIteratorOption) WithConsistencyLevel(consistencyLevel entity.ConsistencyLevel) *searchIteratorOption {
opt.consistencyLevel = consistencyLevel
opt.useDefaultConsistencyLevel = false
return opt
}
func (opt *searchIteratorOption) WithANNSField(annsField string) *searchIteratorOption {
opt.annRequest.WithANNSField(annsField)
return opt
}
func (opt *searchIteratorOption) WithGroupByField(groupByField string) *searchIteratorOption {
opt.annRequest.WithGroupByField(groupByField)
return opt
}
func (opt *searchIteratorOption) WithGroupSize(groupSize int) *searchIteratorOption {
opt.annRequest.WithGroupSize(groupSize)
return opt
}
func (opt *searchIteratorOption) WithStrictGroupSize(strictGroupSize bool) *searchIteratorOption {
opt.annRequest.WithStrictGroupSize(strictGroupSize)
return opt
}
func (opt *searchIteratorOption) WithIgnoreGrowing(ignoreGrowing bool) *searchIteratorOption {
opt.annRequest.WithIgnoreGrowing(ignoreGrowing)
return opt
}
func (opt *searchIteratorOption) WithAnnParam(ap index.AnnParam) *searchIteratorOption {
opt.annRequest.WithAnnParam(ap)
return opt
}
func (opt *searchIteratorOption) WithSearchParam(key, value string) *searchIteratorOption {
opt.annRequest.WithSearchParam(key, value)
return opt
}
func NewSearchIteratorOption(collectionName string, vector entity.Vector) *searchIteratorOption {
return &searchIteratorOption{
searchOption: NewSearchOption(collectionName, 100, []entity.Vector{vector}).
WithSearchParam(IteratorKey, "true").
WithSearchParam(IteratorSearchV2Key, "true"),
batchSize: 1000,
}
}

View File

@ -0,0 +1,303 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package milvusclient
import (
"context"
"fmt"
"io"
"math/rand"
"testing"
"github.com/samber/lo"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
type SearchIteratorSuite struct {
MockSuiteBase
schema *entity.Schema
}
func (s *SearchIteratorSuite) SetupSuite() {
s.MockSuiteBase.SetupSuite()
s.schema = entity.NewSchema().
WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
WithField(entity.NewField().WithName("Vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128))
}
func (s *SearchIteratorSuite) TestSearchIteratorInit() {
ctx := context.Background()
s.Run("success", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
CollectionID: 1,
Schema: s.schema.ProtoMessage(),
}, nil).Once()
s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
s.Equal(collectionName, sr.GetCollectionName())
checkSearchParam := func(kvs []*commonpb.KeyValuePair, key string, value string) bool {
for _, kv := range kvs {
if kv.GetKey() == key && kv.GetValue() == value {
return true
}
}
return false
}
s.True(checkSearchParam(sr.GetSearchParams(), IteratorKey, "true"))
s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchV2Key, "true"))
return &milvuspb.SearchResults{
Status: merr.Success(),
Results: &schemapb.SearchResultData{
NumQueries: 1,
TopK: 1,
FieldsData: []*schemapb.FieldData{
s.getInt64FieldData("ID", []int64{1}),
},
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1},
},
},
},
Scores: make([]float32, 1),
Topks: []int64{1},
Recalls: []float32{1},
SearchIteratorV2Results: &schemapb.SearchIteratorV2Results{
Token: s.randString(16),
},
},
}, nil
}).Once()
iter, err := s.client.SearchIterator(ctx, NewSearchIteratorOption(collectionName, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
return rand.Float32()
}))))
s.NoError(err)
_, ok := iter.(*searchIteratorV2)
s.True(ok)
})
s.Run("failure", func() {
s.Run("describe_fail", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(nil, fmt.Errorf("mock error")).Once()
_, err := s.client.SearchIterator(ctx, NewSearchIteratorOption(collectionName, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
return rand.Float32()
}))))
s.Error(err)
})
s.Run("not_v2_result", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
CollectionID: 1,
Schema: s.schema.ProtoMessage(),
}, nil).Once()
s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
s.Equal(collectionName, sr.GetCollectionName())
return &milvuspb.SearchResults{
Status: merr.Success(),
Results: &schemapb.SearchResultData{
NumQueries: 1,
TopK: 1,
FieldsData: []*schemapb.FieldData{
s.getInt64FieldData("ID", []int64{1}),
},
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1},
},
},
},
Scores: make([]float32, 1),
Topks: []int64{1},
Recalls: []float32{1},
SearchIteratorV2Results: nil, // nil v2 results
},
}, nil
}).Once()
_, err := s.client.SearchIterator(ctx, NewSearchIteratorOption(collectionName, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
return rand.Float32()
}))))
s.Error(err)
s.ErrorIs(err, ErrServerVersionIncompatible)
})
})
}
func (s *SearchIteratorSuite) TestNext() {
ctx := context.Background()
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
token := fmt.Sprintf("iter_token_%s", s.randString(8))
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
CollectionID: 1,
Schema: s.schema.ProtoMessage(),
}, nil).Once()
s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
s.Equal(collectionName, sr.GetCollectionName())
checkSearchParam := func(kvs []*commonpb.KeyValuePair, key string, value string) bool {
for _, kv := range kvs {
if kv.GetKey() == key && kv.GetValue() == value {
return true
}
}
return false
}
s.True(checkSearchParam(sr.GetSearchParams(), IteratorKey, "true"))
s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchV2Key, "true"))
return &milvuspb.SearchResults{
Status: merr.Success(),
Results: &schemapb.SearchResultData{
NumQueries: 1,
TopK: 1,
FieldsData: []*schemapb.FieldData{
s.getInt64FieldData("ID", []int64{1}),
},
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1},
},
},
},
Scores: make([]float32, 1),
Topks: []int64{1},
Recalls: []float32{1},
SearchIteratorV2Results: &schemapb.SearchIteratorV2Results{
Token: token,
},
},
}, nil
}).Once()
iter, err := s.client.SearchIterator(ctx, NewSearchIteratorOption(collectionName, entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
return rand.Float32()
}))))
s.Require().NoError(err)
s.Require().NotNil(iter)
s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
s.Equal(collectionName, sr.GetCollectionName())
checkSearchParam := func(kvs []*commonpb.KeyValuePair, key string, value string) bool {
for _, kv := range kvs {
if kv.GetKey() == key && kv.GetValue() == value {
return true
}
}
return false
}
s.True(checkSearchParam(sr.GetSearchParams(), IteratorKey, "true"))
s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchV2Key, "true"))
return &milvuspb.SearchResults{
Status: merr.Success(),
Results: &schemapb.SearchResultData{
NumQueries: 1,
TopK: 1,
FieldsData: []*schemapb.FieldData{
s.getInt64FieldData("ID", []int64{1}),
},
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1},
},
},
},
Scores: []float32{0.5},
Topks: []int64{1},
Recalls: []float32{1},
SearchIteratorV2Results: &schemapb.SearchIteratorV2Results{
Token: token,
LastBound: 0.5,
},
},
}, nil
}).Once()
rs, err := iter.Next(ctx)
s.NoError(err)
s.EqualValues(1, rs.IDs.Len())
s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
s.Equal(collectionName, sr.GetCollectionName())
checkSearchParam := func(kvs []*commonpb.KeyValuePair, key string, value string) bool {
for _, kv := range kvs {
if kv.GetKey() == key && kv.GetValue() == value {
return true
}
}
return false
}
s.True(checkSearchParam(sr.GetSearchParams(), IteratorKey, "true"))
s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchV2Key, "true"))
s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchIDKey, token))
s.True(checkSearchParam(sr.GetSearchParams(), IteratorSearchLastBoundKey, "0.5"))
return &milvuspb.SearchResults{
Status: merr.Success(),
Results: &schemapb.SearchResultData{
NumQueries: 1,
TopK: 1,
FieldsData: []*schemapb.FieldData{
s.getInt64FieldData("ID", []int64{}),
},
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{},
},
},
},
Scores: []float32{},
Topks: []int64{0},
Recalls: []float32{1.0},
SearchIteratorV2Results: &schemapb.SearchIteratorV2Results{
Token: token,
LastBound: 0.5,
},
},
}, nil
}).Once()
_, err = iter.Next(ctx)
s.Error(err)
s.ErrorIs(err, io.EOF)
}
func TestSearchIterator(t *testing.T) {
suite.Run(t, new(SearchIteratorSuite))
}