mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-28 14:35:27 +08:00
Related to #31293 Implement QueryIterator for the Go SDK to enable efficient iteration over large query result sets using PK-based pagination. Key changes: - Add QueryIterator interface and implementation with PK-based pagination - Support Int64 and VarChar primary key types for pagination filtering - Add QueryIteratorOption with batchSize, limit, filter, outputFields config - Fix ResultSet.Slice to handle Query results without IDs/Scores - Add comprehensive unit tests and integration tests <!-- This is an auto-generated comment: release notes by coderabbit.ai --> - Core invariant: the iterator requires the collection primary key (PK) to be present in outputFields so PK-based pagination and accurate row counting work. The constructor enforces this by appending the PK to outputFields when absent, and all pagination (lastPK tracking, PK-range filters) and ResultCount calculations depend on that guaranteed PK column. - New capability: adds a public QueryIterator API (Client.QueryIterator, QueryIterator interface, QueryIteratorOption) that issues server-side Query RPCs in configurable batches and implements PK-based pagination supporting Int64 and VarChar PKs, with options for batchSize, limit, filter, outputFields and an upfront first-batch validation to fail fast on invalid params. - Removed/simplified logic: ResultSet.Slice no longer assumes IDs and Scores are always present — it branches on presence of IDs (use IDs length when non-nil; otherwise derive row count from Fields[0]) and guards Scores slicing. This eliminates redundant/unsafe assumptions and centralizes correct row-count logic based on actual returned fields. - No data loss or behavior regression: pagination composes the user filter with a PK-range filter and always requests the PK field, so lastPK is extracted from a real column and fetchNextBatch only advances when rows are returned; EOF is returned only when the server returns no rows or iterator limit is reached. ResultSet.Slice guards prevent panics for queries that lack IDs/Scores; Query RPC → ResultSet.Fields remains the authoritative path for row data, so rows are not dropped and existing query behavior is preserved. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
445 lines
12 KiB
Go
445 lines
12 KiB
Go
// Licensed to the LF AI & Data foundation under one
|
|
// or more contributor license agreements. See the NOTICE file
|
|
// distributed with this work for additional information
|
|
// regarding copyright ownership. The ASF licenses this file
|
|
// to you under the Apache License, Version 2.0 (the
|
|
// "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package milvusclient
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/cockroachdb/errors"
|
|
"google.golang.org/grpc"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
|
"github.com/milvus-io/milvus/client/v2/entity"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
|
)
|
|
|
|
const (
|
|
// IteratorKey is the const search param key in indicating enabling iterator.
|
|
IteratorKey = "iterator"
|
|
IteratorSessionTsKey = "iterator_session_ts"
|
|
IteratorSearchV2Key = "search_iter_v2"
|
|
IteratorSearchBatchSizeKey = "search_iter_batch_size"
|
|
IteratorSearchLastBoundKey = "search_iter_last_bound"
|
|
IteratorSearchIDKey = "search_iter_id"
|
|
CollectionIDKey = `collection_id`
|
|
|
|
// Unlimited
|
|
Unlimited int64 = -1
|
|
)
|
|
|
|
var ErrServerVersionIncompatible = errors.New("server version incompatible")
|
|
|
|
// SearchIterator is the interface for search iterator.
|
|
type SearchIterator interface {
|
|
// Next returns next batch of iterator
|
|
// when iterator reaches the end, return `io.EOF`.
|
|
Next(ctx context.Context) (ResultSet, error)
|
|
}
|
|
|
|
type searchIteratorV2 struct {
|
|
client *Client
|
|
option SearchIteratorOption
|
|
schema *entity.Schema
|
|
limit int64
|
|
}
|
|
|
|
func (it *searchIteratorV2) Next(ctx context.Context) (ResultSet, error) {
|
|
// limit reached, return EOF
|
|
if it.limit == 0 {
|
|
return ResultSet{}, io.EOF
|
|
}
|
|
|
|
rs, err := it.next(ctx)
|
|
if err != nil {
|
|
return rs, err
|
|
}
|
|
|
|
if it.limit == Unlimited {
|
|
return rs, err
|
|
}
|
|
|
|
if int64(rs.Len()) > it.limit {
|
|
rs = rs.Slice(0, int(it.limit))
|
|
}
|
|
it.limit -= int64(rs.Len())
|
|
return rs, nil
|
|
}
|
|
|
|
func (it *searchIteratorV2) next(ctx context.Context) (ResultSet, error) {
|
|
opt := it.option.SearchOption()
|
|
req, err := opt.Request()
|
|
if err != nil {
|
|
return ResultSet{}, err
|
|
}
|
|
|
|
var rs ResultSet
|
|
|
|
err = it.client.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
|
resp, err := milvusService.Search(ctx, req)
|
|
err = merr.CheckRPCCall(resp, err)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
iteratorInfo := resp.GetResults().GetSearchIteratorV2Results()
|
|
opt.annRequest.WithSearchParam(IteratorSearchIDKey, iteratorInfo.GetToken())
|
|
opt.annRequest.WithSearchParam(IteratorSearchLastBoundKey, fmt.Sprintf("%v", iteratorInfo.GetLastBound()))
|
|
|
|
resultSets, err := it.client.handleSearchResult(it.schema, req.GetOutputFields(), int(resp.GetResults().GetNumQueries()), resp)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rs = resultSets[0]
|
|
|
|
if rs.IDs.Len() == 0 {
|
|
return io.EOF
|
|
}
|
|
|
|
return nil
|
|
})
|
|
return rs, err
|
|
}
|
|
|
|
func (it *searchIteratorV2) setupCollectionID(ctx context.Context) error {
|
|
opt := it.option.SearchOption()
|
|
|
|
return it.client.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
|
resp, err := milvusService.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{
|
|
CollectionName: opt.collectionName,
|
|
})
|
|
if merr.CheckRPCCall(resp, err) != nil {
|
|
return err
|
|
}
|
|
|
|
opt.WithSearchParam(CollectionIDKey, fmt.Sprintf("%d", resp.GetCollectionID()))
|
|
schema := &entity.Schema{}
|
|
it.schema = schema.ReadProto(resp.GetSchema())
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// probeCompatiblity checks if the server support SearchIteratorV2.
|
|
// It checks if the search result contains search iterator v2 results info and token.
|
|
func (it *searchIteratorV2) probeCompatiblity(ctx context.Context) error {
|
|
opt := it.option.SearchOption()
|
|
opt.annRequest.topK = 1 // ok to leave it here, will be overwritten in next iteration
|
|
opt.annRequest.WithSearchParam(IteratorSearchBatchSizeKey, "1")
|
|
req, err := opt.Request()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return it.client.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
|
resp, err := milvusService.Search(ctx, req)
|
|
err = merr.CheckRPCCall(resp, err)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if resp.GetResults().GetSearchIteratorV2Results() == nil || resp.GetResults().GetSearchIteratorV2Results().GetToken() == "" {
|
|
return ErrServerVersionIncompatible
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// newSearchIteratorV2 creates a new search iterator V2.
|
|
//
|
|
// It sets up the collection ID and checks if the server supports search iterator V2.
|
|
// If the server does not support search iterator V2, it returns an error.
|
|
func newSearchIteratorV2(ctx context.Context, client *Client, option SearchIteratorOption) (*searchIteratorV2, error) {
|
|
iter := &searchIteratorV2{
|
|
client: client,
|
|
option: option,
|
|
limit: option.Limit(),
|
|
}
|
|
if err := iter.setupCollectionID(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := iter.probeCompatiblity(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return iter, nil
|
|
}
|
|
|
|
type searchIteratorV1 struct {
|
|
client *Client
|
|
}
|
|
|
|
func (s *searchIteratorV1) Next(_ context.Context) (ResultSet, error) {
|
|
return ResultSet{}, errors.New("not implemented")
|
|
}
|
|
|
|
func newSearchIteratorV1(_ *Client) (*searchIteratorV1, error) {
|
|
// search iterator v1 is not supported
|
|
return nil, ErrServerVersionIncompatible
|
|
}
|
|
|
|
// SearchIterator creates a search iterator from a collection.
|
|
//
|
|
// If the server supports search iterator V2, it creates a search iterator V2.
|
|
func (c *Client) SearchIterator(ctx context.Context, option SearchIteratorOption, callOptions ...grpc.CallOption) (SearchIterator, error) {
|
|
if err := option.ValidateParams(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
iter, err := newSearchIteratorV2(ctx, c, option)
|
|
if err == nil {
|
|
return iter, nil
|
|
}
|
|
|
|
if !errors.Is(err, ErrServerVersionIncompatible) {
|
|
return nil, err
|
|
}
|
|
|
|
return newSearchIteratorV1(c)
|
|
}
|
|
|
|
// QueryIterator is the interface for query iterator.
|
|
type QueryIterator interface {
|
|
// Next returns next batch of iterator
|
|
// when iterator reaches the end, return `io.EOF`.
|
|
Next(ctx context.Context) (ResultSet, error)
|
|
}
|
|
|
|
type queryIterator struct {
|
|
client *Client
|
|
option QueryIteratorOption
|
|
schema *entity.Schema
|
|
|
|
// pagination state
|
|
expr string // base expression from option
|
|
outputFields []string // override output fields(force include pk field)
|
|
pkField *entity.Field
|
|
lastPK any
|
|
batchSize int
|
|
limit int64
|
|
|
|
// cached results
|
|
cached ResultSet
|
|
}
|
|
|
|
// composeIteratorExpr builds the filter expression for pagination.
|
|
// It combines the user's original expression with a PK range filter.
|
|
func (it *queryIterator) composeIteratorExpr() string {
|
|
if it.lastPK == nil {
|
|
return it.expr
|
|
}
|
|
|
|
expr := strings.TrimSpace(it.expr)
|
|
pkName := it.pkField.Name
|
|
|
|
switch it.pkField.DataType {
|
|
case entity.FieldTypeInt64:
|
|
pkFilter := fmt.Sprintf("%s > %d", pkName, it.lastPK)
|
|
if len(expr) == 0 {
|
|
return pkFilter
|
|
}
|
|
return fmt.Sprintf("(%s) and %s", expr, pkFilter)
|
|
case entity.FieldTypeVarChar:
|
|
pkFilter := fmt.Sprintf(`%s > "%s"`, pkName, it.lastPK)
|
|
if len(expr) == 0 {
|
|
return pkFilter
|
|
}
|
|
return fmt.Sprintf(`(%s) and %s`, expr, pkFilter)
|
|
default:
|
|
return it.expr
|
|
}
|
|
}
|
|
|
|
// fetchNextBatch fetches the next batch of data from the server.
|
|
func (it *queryIterator) fetchNextBatch(ctx context.Context) (ResultSet, error) {
|
|
req, err := it.option.Request()
|
|
if err != nil {
|
|
return ResultSet{}, err
|
|
}
|
|
|
|
// override expression and limit for pagination
|
|
req.Expr = it.composeIteratorExpr()
|
|
req.OutputFields = it.outputFields
|
|
req.QueryParams = append(req.QueryParams,
|
|
&commonpb.KeyValuePair{Key: spLimit, Value: strconv.Itoa(it.batchSize)},
|
|
)
|
|
|
|
var resultSet ResultSet
|
|
err = it.client.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
|
resp, err := milvusService.Query(ctx, req)
|
|
err = merr.CheckRPCCall(resp, err)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
columns, err := it.client.parseSearchResult(it.schema, resp.GetOutputFields(), resp.GetFieldsData(), 0, 0, -1)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
resultSet = ResultSet{
|
|
sch: it.schema,
|
|
Fields: columns,
|
|
}
|
|
if len(columns) > 0 {
|
|
resultSet.ResultCount = columns[0].Len()
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
return resultSet, err
|
|
}
|
|
|
|
// cacheNextBatch returns the next batch and updates the cache.
|
|
func (it *queryIterator) cacheNextBatch(rs ResultSet) (ResultSet, error) {
|
|
var result ResultSet
|
|
if rs.ResultCount > it.batchSize {
|
|
result = rs.Slice(0, it.batchSize)
|
|
it.cached = rs.Slice(it.batchSize, rs.ResultCount)
|
|
} else {
|
|
result = rs
|
|
it.cached = ResultSet{}
|
|
}
|
|
|
|
if result.ResultCount == 0 {
|
|
return ResultSet{}, io.EOF
|
|
}
|
|
|
|
// extract and update the last PK for pagination
|
|
pkColumn := result.GetColumn(it.pkField.Name)
|
|
if pkColumn == nil {
|
|
// try to find PK in Fields
|
|
for _, col := range result.Fields {
|
|
if col.Name() == it.pkField.Name {
|
|
pkColumn = col
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if pkColumn != nil && pkColumn.Len() > 0 {
|
|
pk, err := pkColumn.Get(pkColumn.Len() - 1)
|
|
if err != nil {
|
|
return ResultSet{}, errors.Wrapf(err, "failed to get last pk value")
|
|
}
|
|
it.lastPK = pk
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// Next returns the next batch of results.
|
|
func (it *queryIterator) Next(ctx context.Context) (ResultSet, error) {
|
|
// limit reached, return EOF
|
|
if it.limit == 0 {
|
|
return ResultSet{}, io.EOF
|
|
}
|
|
|
|
// if cache is empty, fetch new data
|
|
if it.cached.ResultCount == 0 {
|
|
rs, err := it.fetchNextBatch(ctx)
|
|
if err != nil {
|
|
return ResultSet{}, err
|
|
}
|
|
it.cached = rs
|
|
}
|
|
|
|
// if still no data, return EOF
|
|
if it.cached.ResultCount == 0 {
|
|
return ResultSet{}, io.EOF
|
|
}
|
|
|
|
result, err := it.cacheNextBatch(it.cached)
|
|
if err != nil {
|
|
return ResultSet{}, err
|
|
}
|
|
|
|
// handle overall limit
|
|
if it.limit != Unlimited {
|
|
if int64(result.ResultCount) > it.limit {
|
|
result = result.Slice(0, int(it.limit))
|
|
}
|
|
it.limit -= int64(result.ResultCount)
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// newQueryIterator creates a new query iterator.
|
|
func newQueryIterator(ctx context.Context, client *Client, option QueryIteratorOption) (*queryIterator, error) {
|
|
req, err := option.Request()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
collection, err := client.getCollection(ctx, req.GetCollectionName())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
pkField := collection.Schema.PKField()
|
|
if pkField == nil {
|
|
return nil, errors.New("primary key field not found in schema")
|
|
}
|
|
|
|
// ensure PK field is included in output fields for pagination
|
|
outputFields := req.GetOutputFields()
|
|
hasPK := false
|
|
for _, f := range outputFields {
|
|
if f == pkField.Name {
|
|
hasPK = true
|
|
break
|
|
}
|
|
}
|
|
if !hasPK && len(outputFields) > 0 {
|
|
// modify the underlying option to include PK field
|
|
outputFields = append(outputFields, pkField.Name)
|
|
}
|
|
|
|
iter := &queryIterator{
|
|
client: client,
|
|
option: option,
|
|
schema: collection.Schema,
|
|
expr: req.GetExpr(),
|
|
outputFields: outputFields,
|
|
pkField: pkField,
|
|
batchSize: option.BatchSize(),
|
|
limit: option.Limit(),
|
|
}
|
|
|
|
// init: fetch the first batch to validate parameters
|
|
rs, err := iter.fetchNextBatch(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
iter.cached = rs
|
|
|
|
return iter, nil
|
|
}
|
|
|
|
// QueryIterator creates a query iterator from a collection.
|
|
func (c *Client) QueryIterator(ctx context.Context, option QueryIteratorOption, callOptions ...grpc.CallOption) (QueryIterator, error) {
|
|
if err := option.ValidateParams(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return newQueryIterator(ctx, c, option)
|
|
}
|