milvus/client/milvusclient/iterator.go
congqixia 6c8e11da4f
feat: [GoSDK] add QueryIterator support for Go client (#46633)
Related to #31293

Implement QueryIterator for the Go SDK to enable efficient iteration
over large query result sets using PK-based pagination.

Key changes:
- Add QueryIterator interface and implementation with PK-based
pagination
- Support Int64 and VarChar primary key types for pagination filtering
- Add QueryIteratorOption with batchSize, limit, filter, outputFields
config
- Fix ResultSet.Slice to handle Query results without IDs/Scores
- Add comprehensive unit tests and integration tests

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
- Core invariant: the iterator requires the collection primary key (PK)
to be present in outputFields so PK-based pagination and accurate row
counting work. The constructor enforces this by appending the PK to
outputFields when absent, and all pagination (lastPK tracking, PK-range
filters) and ResultCount calculations depend on that guaranteed PK
column.

- New capability: adds a public QueryIterator API (Client.QueryIterator,
QueryIterator interface, QueryIteratorOption) that issues server-side
Query RPCs in configurable batches and implements PK-based pagination
supporting Int64 and VarChar PKs, with options for batchSize, limit,
filter, outputFields and an upfront first-batch validation to fail fast
on invalid params.

- Removed/simplified logic: ResultSet.Slice no longer assumes IDs and
Scores are always present — it branches on presence of IDs (use IDs
length when non-nil; otherwise derive row count from Fields[0]) and
guards Scores slicing. This eliminates redundant/unsafe assumptions and
centralizes correct row-count logic based on actual returned fields.

- No data loss or behavior regression: pagination composes the user
filter with a PK-range filter and always requests the PK field, so
lastPK is extracted from a real column and fetchNextBatch only advances
when rows are returned; EOF is returned only when the server returns no
rows or iterator limit is reached. ResultSet.Slice guards prevent panics
for queries that lack IDs/Scores; Query RPC → ResultSet.Fields remains
the authoritative path for row data, so rows are not dropped and
existing query behavior is preserved.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
2025-12-27 01:43:20 +08:00

445 lines
12 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package milvusclient
import (
"context"
"fmt"
"io"
"strconv"
"strings"
"github.com/cockroachdb/errors"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
const (
// IteratorKey is the const search param key in indicating enabling iterator.
IteratorKey = "iterator"
IteratorSessionTsKey = "iterator_session_ts"
IteratorSearchV2Key = "search_iter_v2"
IteratorSearchBatchSizeKey = "search_iter_batch_size"
IteratorSearchLastBoundKey = "search_iter_last_bound"
IteratorSearchIDKey = "search_iter_id"
CollectionIDKey = `collection_id`
// Unlimited
Unlimited int64 = -1
)
var ErrServerVersionIncompatible = errors.New("server version incompatible")
// SearchIterator is the interface for search iterator.
type SearchIterator interface {
// Next returns next batch of iterator
// when iterator reaches the end, return `io.EOF`.
Next(ctx context.Context) (ResultSet, error)
}
type searchIteratorV2 struct {
client *Client
option SearchIteratorOption
schema *entity.Schema
limit int64
}
func (it *searchIteratorV2) Next(ctx context.Context) (ResultSet, error) {
// limit reached, return EOF
if it.limit == 0 {
return ResultSet{}, io.EOF
}
rs, err := it.next(ctx)
if err != nil {
return rs, err
}
if it.limit == Unlimited {
return rs, err
}
if int64(rs.Len()) > it.limit {
rs = rs.Slice(0, int(it.limit))
}
it.limit -= int64(rs.Len())
return rs, nil
}
func (it *searchIteratorV2) next(ctx context.Context) (ResultSet, error) {
opt := it.option.SearchOption()
req, err := opt.Request()
if err != nil {
return ResultSet{}, err
}
var rs ResultSet
err = it.client.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.Search(ctx, req)
err = merr.CheckRPCCall(resp, err)
if err != nil {
return err
}
iteratorInfo := resp.GetResults().GetSearchIteratorV2Results()
opt.annRequest.WithSearchParam(IteratorSearchIDKey, iteratorInfo.GetToken())
opt.annRequest.WithSearchParam(IteratorSearchLastBoundKey, fmt.Sprintf("%v", iteratorInfo.GetLastBound()))
resultSets, err := it.client.handleSearchResult(it.schema, req.GetOutputFields(), int(resp.GetResults().GetNumQueries()), resp)
if err != nil {
return err
}
rs = resultSets[0]
if rs.IDs.Len() == 0 {
return io.EOF
}
return nil
})
return rs, err
}
func (it *searchIteratorV2) setupCollectionID(ctx context.Context) error {
opt := it.option.SearchOption()
return it.client.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{
CollectionName: opt.collectionName,
})
if merr.CheckRPCCall(resp, err) != nil {
return err
}
opt.WithSearchParam(CollectionIDKey, fmt.Sprintf("%d", resp.GetCollectionID()))
schema := &entity.Schema{}
it.schema = schema.ReadProto(resp.GetSchema())
return nil
})
}
// probeCompatiblity checks if the server support SearchIteratorV2.
// It checks if the search result contains search iterator v2 results info and token.
func (it *searchIteratorV2) probeCompatiblity(ctx context.Context) error {
opt := it.option.SearchOption()
opt.annRequest.topK = 1 // ok to leave it here, will be overwritten in next iteration
opt.annRequest.WithSearchParam(IteratorSearchBatchSizeKey, "1")
req, err := opt.Request()
if err != nil {
return err
}
return it.client.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.Search(ctx, req)
err = merr.CheckRPCCall(resp, err)
if err != nil {
return err
}
if resp.GetResults().GetSearchIteratorV2Results() == nil || resp.GetResults().GetSearchIteratorV2Results().GetToken() == "" {
return ErrServerVersionIncompatible
}
return nil
})
}
// newSearchIteratorV2 creates a new search iterator V2.
//
// It sets up the collection ID and checks if the server supports search iterator V2.
// If the server does not support search iterator V2, it returns an error.
func newSearchIteratorV2(ctx context.Context, client *Client, option SearchIteratorOption) (*searchIteratorV2, error) {
iter := &searchIteratorV2{
client: client,
option: option,
limit: option.Limit(),
}
if err := iter.setupCollectionID(ctx); err != nil {
return nil, err
}
if err := iter.probeCompatiblity(ctx); err != nil {
return nil, err
}
return iter, nil
}
type searchIteratorV1 struct {
client *Client
}
func (s *searchIteratorV1) Next(_ context.Context) (ResultSet, error) {
return ResultSet{}, errors.New("not implemented")
}
func newSearchIteratorV1(_ *Client) (*searchIteratorV1, error) {
// search iterator v1 is not supported
return nil, ErrServerVersionIncompatible
}
// SearchIterator creates a search iterator from a collection.
//
// If the server supports search iterator V2, it creates a search iterator V2.
func (c *Client) SearchIterator(ctx context.Context, option SearchIteratorOption, callOptions ...grpc.CallOption) (SearchIterator, error) {
if err := option.ValidateParams(); err != nil {
return nil, err
}
iter, err := newSearchIteratorV2(ctx, c, option)
if err == nil {
return iter, nil
}
if !errors.Is(err, ErrServerVersionIncompatible) {
return nil, err
}
return newSearchIteratorV1(c)
}
// QueryIterator is the interface for query iterator.
type QueryIterator interface {
// Next returns next batch of iterator
// when iterator reaches the end, return `io.EOF`.
Next(ctx context.Context) (ResultSet, error)
}
type queryIterator struct {
client *Client
option QueryIteratorOption
schema *entity.Schema
// pagination state
expr string // base expression from option
outputFields []string // override output fields(force include pk field)
pkField *entity.Field
lastPK any
batchSize int
limit int64
// cached results
cached ResultSet
}
// composeIteratorExpr builds the filter expression for pagination.
// It combines the user's original expression with a PK range filter.
func (it *queryIterator) composeIteratorExpr() string {
if it.lastPK == nil {
return it.expr
}
expr := strings.TrimSpace(it.expr)
pkName := it.pkField.Name
switch it.pkField.DataType {
case entity.FieldTypeInt64:
pkFilter := fmt.Sprintf("%s > %d", pkName, it.lastPK)
if len(expr) == 0 {
return pkFilter
}
return fmt.Sprintf("(%s) and %s", expr, pkFilter)
case entity.FieldTypeVarChar:
pkFilter := fmt.Sprintf(`%s > "%s"`, pkName, it.lastPK)
if len(expr) == 0 {
return pkFilter
}
return fmt.Sprintf(`(%s) and %s`, expr, pkFilter)
default:
return it.expr
}
}
// fetchNextBatch fetches the next batch of data from the server.
func (it *queryIterator) fetchNextBatch(ctx context.Context) (ResultSet, error) {
req, err := it.option.Request()
if err != nil {
return ResultSet{}, err
}
// override expression and limit for pagination
req.Expr = it.composeIteratorExpr()
req.OutputFields = it.outputFields
req.QueryParams = append(req.QueryParams,
&commonpb.KeyValuePair{Key: spLimit, Value: strconv.Itoa(it.batchSize)},
)
var resultSet ResultSet
err = it.client.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.Query(ctx, req)
err = merr.CheckRPCCall(resp, err)
if err != nil {
return err
}
columns, err := it.client.parseSearchResult(it.schema, resp.GetOutputFields(), resp.GetFieldsData(), 0, 0, -1)
if err != nil {
return err
}
resultSet = ResultSet{
sch: it.schema,
Fields: columns,
}
if len(columns) > 0 {
resultSet.ResultCount = columns[0].Len()
}
return nil
})
return resultSet, err
}
// cacheNextBatch returns the next batch and updates the cache.
func (it *queryIterator) cacheNextBatch(rs ResultSet) (ResultSet, error) {
var result ResultSet
if rs.ResultCount > it.batchSize {
result = rs.Slice(0, it.batchSize)
it.cached = rs.Slice(it.batchSize, rs.ResultCount)
} else {
result = rs
it.cached = ResultSet{}
}
if result.ResultCount == 0 {
return ResultSet{}, io.EOF
}
// extract and update the last PK for pagination
pkColumn := result.GetColumn(it.pkField.Name)
if pkColumn == nil {
// try to find PK in Fields
for _, col := range result.Fields {
if col.Name() == it.pkField.Name {
pkColumn = col
break
}
}
}
if pkColumn != nil && pkColumn.Len() > 0 {
pk, err := pkColumn.Get(pkColumn.Len() - 1)
if err != nil {
return ResultSet{}, errors.Wrapf(err, "failed to get last pk value")
}
it.lastPK = pk
}
return result, nil
}
// Next returns the next batch of results.
func (it *queryIterator) Next(ctx context.Context) (ResultSet, error) {
// limit reached, return EOF
if it.limit == 0 {
return ResultSet{}, io.EOF
}
// if cache is empty, fetch new data
if it.cached.ResultCount == 0 {
rs, err := it.fetchNextBatch(ctx)
if err != nil {
return ResultSet{}, err
}
it.cached = rs
}
// if still no data, return EOF
if it.cached.ResultCount == 0 {
return ResultSet{}, io.EOF
}
result, err := it.cacheNextBatch(it.cached)
if err != nil {
return ResultSet{}, err
}
// handle overall limit
if it.limit != Unlimited {
if int64(result.ResultCount) > it.limit {
result = result.Slice(0, int(it.limit))
}
it.limit -= int64(result.ResultCount)
}
return result, nil
}
// newQueryIterator creates a new query iterator.
func newQueryIterator(ctx context.Context, client *Client, option QueryIteratorOption) (*queryIterator, error) {
req, err := option.Request()
if err != nil {
return nil, err
}
collection, err := client.getCollection(ctx, req.GetCollectionName())
if err != nil {
return nil, err
}
pkField := collection.Schema.PKField()
if pkField == nil {
return nil, errors.New("primary key field not found in schema")
}
// ensure PK field is included in output fields for pagination
outputFields := req.GetOutputFields()
hasPK := false
for _, f := range outputFields {
if f == pkField.Name {
hasPK = true
break
}
}
if !hasPK && len(outputFields) > 0 {
// modify the underlying option to include PK field
outputFields = append(outputFields, pkField.Name)
}
iter := &queryIterator{
client: client,
option: option,
schema: collection.Schema,
expr: req.GetExpr(),
outputFields: outputFields,
pkField: pkField,
batchSize: option.BatchSize(),
limit: option.Limit(),
}
// init: fetch the first batch to validate parameters
rs, err := iter.fetchNextBatch(ctx)
if err != nil {
return nil, err
}
iter.cached = rs
return iter, nil
}
// QueryIterator creates a query iterator from a collection.
func (c *Client) QueryIterator(ctx context.Context, option QueryIteratorOption, callOptions ...grpc.CallOption) (QueryIterator, error) {
if err := option.ValidateParams(); err != nil {
return nil, err
}
return newQueryIterator(ctx, c, option)
}