enhance: Update latest sdk update to client pkg (#33105)

Related to #31293
See also milvus-io/milvus-sdk-go#704 milvus-io/milvus-sdk-go#711 
milvus-io/milvus-sdk-go#713 milvus-io/milvus-sdk-go#721
milvus-io/milvus-sdk-go#732 milvus-io/milvus-sdk-go#739 
milvus-io/milvus-sdk-go#748

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
congqixia 2024-05-17 10:39:37 +08:00 committed by GitHub
parent f1c9986974
commit 1ef975d327
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 555 additions and 125 deletions

View File

@ -1,5 +1,6 @@
reviewers: reviewers:
- congqixia - congqixia
- ThreadDao
approvers: approvers:
- maintainers - maintainers

View File

@ -18,14 +18,20 @@ package client
import ( import (
"context" "context"
"crypto/tls"
"fmt" "fmt"
"math"
"os" "os"
"strconv" "strconv"
"sync"
"time" "time"
"github.com/gogo/status" "github.com/gogo/status"
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
@ -39,6 +45,11 @@ type Client struct {
service milvuspb.MilvusServiceClient service milvuspb.MilvusServiceClient
config *ClientConfig config *ClientConfig
// mutable status
stateMut sync.RWMutex
currentDB string
identifier string
collCache *CollectionCache collCache *CollectionCache
} }
@ -54,8 +65,10 @@ func New(ctx context.Context, config *ClientConfig) (*Client, error) {
// Parse remote address. // Parse remote address.
addr := c.config.getParsedAddress() addr := c.config.getParsedAddress()
// parse authentication parameters
c.config.parseAuthentication()
// Parse grpc options // Parse grpc options
options := c.config.getDialOption() options := c.dialOptions()
// Connect the grpc server. // Connect the grpc server.
if err := c.connect(ctx, addr, options...); err != nil { if err := c.connect(ctx, addr, options...); err != nil {
@ -69,6 +82,40 @@ func New(ctx context.Context, config *ClientConfig) (*Client, error) {
return c, nil return c, nil
} }
func (c *Client) dialOptions() []grpc.DialOption {
var options []grpc.DialOption
// Construct dial option.
if c.config.EnableTLSAuth {
options = append(options, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})))
} else {
options = append(options, grpc.WithTransportCredentials(insecure.NewCredentials()))
}
if c.config.DialOptions == nil {
// Add default connection options.
options = append(options, DefaultGrpcOpts...)
} else {
options = append(options, c.config.DialOptions...)
}
options = append(options,
grpc.WithChainUnaryInterceptor(grpc_retry.UnaryClientInterceptor(
grpc_retry.WithMax(6),
grpc_retry.WithBackoff(func(attempt uint) time.Duration {
return 60 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt)))
}),
grpc_retry.WithCodes(codes.Unavailable, codes.ResourceExhausted)),
// c.getRetryOnRateLimitInterceptor(),
))
options = append(options, grpc.WithChainUnaryInterceptor(
c.MetadataUnaryInterceptor(),
))
return options
}
func (c *Client) Close(ctx context.Context) error { func (c *Client) Close(ctx context.Context) error {
if c.conn == nil { if c.conn == nil {
return nil return nil
@ -82,6 +129,18 @@ func (c *Client) Close(ctx context.Context) error {
return nil return nil
} }
func (c *Client) usingDatabase(dbName string) {
c.stateMut.Lock()
defer c.stateMut.Unlock()
c.currentDB = dbName
}
func (c *Client) setIdentifier(identifier string) {
c.stateMut.Lock()
defer c.stateMut.Unlock()
c.identifier = identifier
}
func (c *Client) connect(ctx context.Context, addr string, options ...grpc.DialOption) error { func (c *Client) connect(ctx context.Context, addr string, options ...grpc.DialOption) error {
if addr == "" { if addr == "" {
return fmt.Errorf("address is empty") return fmt.Errorf("address is empty")
@ -112,7 +171,7 @@ func (c *Client) connectInternal(ctx context.Context) error {
req := &milvuspb.ConnectRequest{ req := &milvuspb.ConnectRequest{
ClientInfo: &commonpb.ClientInfo{ ClientInfo: &commonpb.ClientInfo{
SdkType: "Golang", SdkType: "GoMilvusClient",
SdkVersion: common.SDKVersion, SdkVersion: common.SDKVersion,
LocalTime: time.Now().String(), LocalTime: time.Now().String(),
User: c.config.Username, User: c.config.Username,
@ -131,8 +190,8 @@ func (c *Client) connectInternal(ctx context.Context) error {
disableJSON | disableJSON |
disableParitionKey | disableParitionKey |
disableDynamicSchema) disableDynamicSchema)
return nil
} }
return nil
} }
return err return err
} }
@ -142,7 +201,7 @@ func (c *Client) connectInternal(ctx context.Context) error {
} }
c.config.setServerInfo(resp.GetServerInfo().GetBuildTags()) c.config.setServerInfo(resp.GetServerInfo().GetBuildTags())
c.config.setIdentifier(strconv.FormatInt(resp.GetIdentifier(), 10)) c.setIdentifier(strconv.FormatInt(resp.GetIdentifier(), 10))
return nil return nil
} }

View File

@ -1,7 +1,7 @@
package client package client
import ( import (
"crypto/tls" "context"
"fmt" "fmt"
"math" "math"
"net/url" "net/url"
@ -10,12 +10,9 @@ import (
"time" "time"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" "github.com/milvus-io/milvus/pkg/util/crypto"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/backoff" "google.golang.org/grpc/backoff"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
) )
@ -59,16 +56,23 @@ type ClientConfig struct {
DialOptions []grpc.DialOption // Dial options for GRPC. DialOptions []grpc.DialOption // Dial options for GRPC.
// RetryRateLimit *RetryRateLimitOption // option for retry on rate limit inteceptor RetryRateLimit *RetryRateLimitOption // option for retry on rate limit inteceptor
DisableConn bool DisableConn bool
metadataHeaders map[string]string
identifier string // Identifier for this connection identifier string // Identifier for this connection
ServerVersion string // ServerVersion ServerVersion string // ServerVersion
parsedAddress *url.URL parsedAddress *url.URL
flags uint64 // internal flags flags uint64 // internal flags
} }
type RetryRateLimitOption struct {
MaxRetry uint
MaxBackoff time.Duration
}
func (cfg *ClientConfig) parse() error { func (cfg *ClientConfig) parse() error {
// Prepend default fake tcp:// scheme for remote address. // Prepend default fake tcp:// scheme for remote address.
address := cfg.Address address := cfg.Address
@ -118,54 +122,36 @@ func (c *ClientConfig) setServerInfo(serverInfo string) {
c.ServerVersion = serverInfo c.ServerVersion = serverInfo
} }
// Get parsed grpc dial options, should be called after parse was called. // parseAuthentication prepares authentication headers for grpc inteceptors based on the provided username, password or API key.
func (c *ClientConfig) getDialOption() []grpc.DialOption { func (c *ClientConfig) parseAuthentication() {
options := c.DialOptions c.metadataHeaders = make(map[string]string)
if c.DialOptions == nil { if c.Username != "" || c.Password != "" {
// Add default connection options. value := crypto.Base64Encode(fmt.Sprintf("%s:%s", c.Username, c.Password))
options = make([]grpc.DialOption, len(DefaultGrpcOpts)) c.metadataHeaders[authorizationHeader] = value
copy(options, DefaultGrpcOpts)
} }
// API overwrites username & passwd
// Construct dial option. if c.APIKey != "" {
if c.EnableTLSAuth { value := crypto.Base64Encode(c.APIKey)
options = append(options, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))) c.metadataHeaders[authorizationHeader] = value
} else {
options = append(options, grpc.WithTransportCredentials(insecure.NewCredentials()))
} }
options = append(options,
grpc.WithChainUnaryInterceptor(grpc_retry.UnaryClientInterceptor(
grpc_retry.WithMax(6),
grpc_retry.WithBackoff(func(attempt uint) time.Duration {
return 60 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt)))
}),
grpc_retry.WithCodes(codes.Unavailable, codes.ResourceExhausted)),
// c.getRetryOnRateLimitInterceptor(),
))
// options = append(options, grpc.WithChainUnaryInterceptor(
// createMetaDataUnaryInterceptor(c),
// ))
return options
} }
// func (c *ClientConfig) getRetryOnRateLimitInterceptor() grpc.UnaryClientInterceptor { func (c *ClientConfig) getRetryOnRateLimitInterceptor() grpc.UnaryClientInterceptor {
// if c.RetryRateLimit == nil { if c.RetryRateLimit == nil {
// c.RetryRateLimit = c.defaultRetryRateLimitOption() c.RetryRateLimit = c.defaultRetryRateLimitOption()
// } }
// return RetryOnRateLimitInterceptor(c.RetryRateLimit.MaxRetry, c.RetryRateLimit.MaxBackoff, func(ctx context.Context, attempt uint) time.Duration { return RetryOnRateLimitInterceptor(c.RetryRateLimit.MaxRetry, c.RetryRateLimit.MaxBackoff, func(ctx context.Context, attempt uint) time.Duration {
// return 10 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt))) return 10 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt)))
// }) })
// } }
// func (c *ClientConfig) defaultRetryRateLimitOption() *RetryRateLimitOption { func (c *ClientConfig) defaultRetryRateLimitOption() *RetryRateLimitOption {
// return &RetryRateLimitOption{ return &RetryRateLimitOption{
// MaxRetry: 75, MaxRetry: 75,
// MaxBackoff: 3 * time.Second, MaxBackoff: 3 * time.Second,
// } }
// } }
// addFlags set internal flags // addFlags set internal flags
func (c *ClientConfig) addFlags(flags uint64) { func (c *ClientConfig) addFlags(flags uint64) {

View File

@ -98,6 +98,7 @@ func (c *Client) DescribeCollection(ctx context.Context, option *describeCollect
VirtualChannels: resp.GetVirtualChannelNames(), VirtualChannels: resp.GetVirtualChannelNames(),
ConsistencyLevel: entity.ConsistencyLevel(resp.ConsistencyLevel), ConsistencyLevel: entity.ConsistencyLevel(resp.ConsistencyLevel),
ShardNum: resp.GetShardsNum(), ShardNum: resp.GetShardsNum(),
Properties: entity.KvPairsMap(resp.GetProperties()),
} }
collection.Name = collection.Schema.CollectionName collection.Name = collection.Schema.CollectionName
return nil return nil

View File

@ -140,6 +140,7 @@ func SimpleCreateCollectionOptions(name string, dim int64) *createCollectionOpti
autoID: true, autoID: true,
dim: dim, dim: dim,
enabledDynamicSchema: true, enabledDynamicSchema: true,
consistencyLevel: entity.DefaultConsistencyLevel,
isFast: true, isFast: true,
metricType: entity.COSINE, metricType: entity.COSINE,
@ -149,9 +150,10 @@ func SimpleCreateCollectionOptions(name string, dim int64) *createCollectionOpti
// NewCreateCollectionOption returns a CreateCollectionOption with customized collection schema // NewCreateCollectionOption returns a CreateCollectionOption with customized collection schema
func NewCreateCollectionOption(name string, collectionSchema *entity.Schema) *createCollectionOption { func NewCreateCollectionOption(name string, collectionSchema *entity.Schema) *createCollectionOption {
return &createCollectionOption{ return &createCollectionOption{
name: name, name: name,
shardNum: 1, shardNum: 1,
schema: collectionSchema, schema: collectionSchema,
consistencyLevel: entity.DefaultConsistencyLevel,
metricType: entity.COSINE, metricType: entity.COSINE,
} }

View File

@ -64,26 +64,38 @@ var errFieldDataTypeNotMatch = errors.New("FieldData type not matched")
// IDColumns converts schemapb.IDs to corresponding column // IDColumns converts schemapb.IDs to corresponding column
// currently Int64 / string may be in IDs // currently Int64 / string may be in IDs
func IDColumns(idField *schemapb.IDs, begin, end int) (Column, error) { func IDColumns(schema *entity.Schema, ids *schemapb.IDs, begin, end int) (Column, error) {
var idColumn Column var idColumn Column
if idField == nil { pkField := schema.PKField()
if pkField == nil {
return nil, errors.New("PK Field not found")
}
if ids == nil {
return nil, errors.New("nil Ids from response") return nil, errors.New("nil Ids from response")
} }
switch field := idField.GetIdField().(type) { switch pkField.DataType {
case *schemapb.IDs_IntId: case entity.FieldTypeInt64:
if end >= 0 { data := ids.GetIntId().GetData()
idColumn = NewColumnInt64("", field.IntId.GetData()[begin:end]) if data == nil {
} else { return NewColumnInt64(pkField.Name, nil), nil
idColumn = NewColumnInt64("", field.IntId.GetData()[begin:])
} }
case *schemapb.IDs_StrId:
if end >= 0 { if end >= 0 {
idColumn = NewColumnVarChar("", field.StrId.GetData()[begin:end]) idColumn = NewColumnInt64(pkField.Name, data[begin:end])
} else { } else {
idColumn = NewColumnVarChar("", field.StrId.GetData()[begin:]) idColumn = NewColumnInt64(pkField.Name, data[begin:])
}
case entity.FieldTypeVarChar, entity.FieldTypeString:
data := ids.GetStrId().GetData()
if data == nil {
return NewColumnVarChar(pkField.Name, nil), nil
}
if end >= 0 {
idColumn = NewColumnVarChar(pkField.Name, data[begin:end])
} else {
idColumn = NewColumnVarChar(pkField.Name, data[begin:])
} }
default: default:
return nil, fmt.Errorf("unsupported id type %v", field) return nil, fmt.Errorf("unsupported id type %v", pkField.DataType)
} }
return idColumn, nil return idColumn, nil
} }

View File

@ -24,18 +24,34 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/entity"
) )
func TestIDColumns(t *testing.T) { func TestIDColumns(t *testing.T) {
dataLen := rand.Intn(100) + 1 dataLen := rand.Intn(100) + 1
base := rand.Intn(5000) // id start point base := rand.Intn(5000) // id start point
intPKCol := entity.NewSchema().WithField(
entity.NewField().WithName("pk").WithIsPrimaryKey(true).WithDataType(entity.FieldTypeInt64),
)
strPKCol := entity.NewSchema().WithField(
entity.NewField().WithName("pk").WithIsPrimaryKey(true).WithDataType(entity.FieldTypeVarChar),
)
t.Run("nil id", func(t *testing.T) { t.Run("nil id", func(t *testing.T) {
_, err := IDColumns(nil, 0, -1) col, err := IDColumns(intPKCol, nil, 0, -1)
assert.NotNil(t, err) assert.NoError(t, err)
assert.EqualValues(t, 0, col.Len())
col, err = IDColumns(strPKCol, nil, 0, -1)
assert.NoError(t, err)
assert.EqualValues(t, 0, col.Len())
idField := &schemapb.IDs{} idField := &schemapb.IDs{}
_, err = IDColumns(idField, 0, -1) col, err = IDColumns(intPKCol, idField, 0, -1)
assert.NotNil(t, err) assert.NoError(t, err)
assert.EqualValues(t, 0, col.Len())
col, err = IDColumns(strPKCol, idField, 0, -1)
assert.NoError(t, err)
assert.EqualValues(t, 0, col.Len())
}) })
t.Run("int ids", func(t *testing.T) { t.Run("int ids", func(t *testing.T) {
@ -50,12 +66,12 @@ func TestIDColumns(t *testing.T) {
}, },
}, },
} }
column, err := IDColumns(idField, 0, dataLen) column, err := IDColumns(intPKCol, idField, 0, dataLen)
assert.Nil(t, err) assert.Nil(t, err)
assert.NotNil(t, column) assert.NotNil(t, column)
assert.Equal(t, dataLen, column.Len()) assert.Equal(t, dataLen, column.Len())
column, err = IDColumns(idField, 0, -1) // test -1 method column, err = IDColumns(intPKCol, idField, 0, -1) // test -1 method
assert.Nil(t, err) assert.Nil(t, err)
assert.NotNil(t, column) assert.NotNil(t, column)
assert.Equal(t, dataLen, column.Len()) assert.Equal(t, dataLen, column.Len())
@ -72,12 +88,12 @@ func TestIDColumns(t *testing.T) {
}, },
}, },
} }
column, err := IDColumns(idField, 0, dataLen) column, err := IDColumns(strPKCol, idField, 0, dataLen)
assert.Nil(t, err) assert.Nil(t, err)
assert.NotNil(t, column) assert.NotNil(t, column)
assert.Equal(t, dataLen, column.Len()) assert.Equal(t, dataLen, column.Len())
column, err = IDColumns(idField, 0, -1) // test -1 method column, err = IDColumns(strPKCol, idField, 0, -1) // test -1 method
assert.Nil(t, err) assert.Nil(t, err)
assert.NotNil(t, column) assert.NotNil(t, column)
assert.Equal(t, dataLen, column.Len()) assert.Equal(t, dataLen, column.Len())

View File

@ -25,6 +25,12 @@ import (
"github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/merr"
) )
func (c *Client) UsingDatabase(ctx context.Context, option UsingDatabaseOption) error {
dbName := option.DbName()
c.usingDatabase(dbName)
return c.connectInternal(ctx)
}
func (c *Client) ListDatabase(ctx context.Context, option ListDatabaseOption, callOptions ...grpc.CallOption) (databaseNames []string, err error) { func (c *Client) ListDatabase(ctx context.Context, option ListDatabaseOption, callOptions ...grpc.CallOption) (databaseNames []string, err error) {
req := option.Request() req := option.Request()

View File

@ -18,6 +18,24 @@ package client
import "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" import "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
type UsingDatabaseOption interface {
DbName() string
}
type usingDatabaseNameOpt struct {
dbName string
}
func (opt *usingDatabaseNameOpt) DbName() string {
return opt.dbName
}
func NewUsingDatabaseOption(dbName string) *usingDatabaseNameOpt {
return &usingDatabaseNameOpt{
dbName: dbName,
}
}
// ListDatabaseOption is a builder interface for ListDatabase request. // ListDatabaseOption is a builder interface for ListDatabase request.
type ListDatabaseOption interface { type ListDatabaseOption interface {
Request() *milvuspb.ListDatabasesRequest Request() *milvuspb.ListDatabasesRequest

View File

@ -32,6 +32,7 @@ type Collection struct {
Loaded bool Loaded bool
ConsistencyLevel ConsistencyLevel ConsistencyLevel ConsistencyLevel
ShardNum int32 ShardNum int32
Properties map[string]string
} }
// Partition represent partition meta in Milvus // Partition represent partition meta in Milvus

View File

@ -60,6 +60,8 @@ type Schema struct {
AutoID bool AutoID bool
Fields []*Field Fields []*Field
EnableDynamicField bool EnableDynamicField bool
pkField *Field
} }
// NewSchema creates an empty schema object. // NewSchema creates an empty schema object.
@ -91,6 +93,9 @@ func (s *Schema) WithDynamicFieldEnabled(dynamicEnabled bool) *Schema {
// WithField adds a field into schema and returns schema itself. // WithField adds a field into schema and returns schema itself.
func (s *Schema) WithField(f *Field) *Schema { func (s *Schema) WithField(f *Field) *Schema {
if f.PrimaryKey {
s.pkField = f
}
s.Fields = append(s.Fields, f) s.Fields = append(s.Fields, f)
return s return s
} }
@ -116,10 +121,14 @@ func (s *Schema) ReadProto(p *schemapb.CollectionSchema) *Schema {
s.CollectionName = p.GetName() s.CollectionName = p.GetName()
s.Fields = make([]*Field, 0, len(p.GetFields())) s.Fields = make([]*Field, 0, len(p.GetFields()))
for _, fp := range p.GetFields() { for _, fp := range p.GetFields() {
field := NewField().ReadProto(fp)
if fp.GetAutoID() { if fp.GetAutoID() {
s.AutoID = true s.AutoID = true
} }
s.Fields = append(s.Fields, NewField().ReadProto(fp)) if field.PrimaryKey {
s.pkField = field
}
s.Fields = append(s.Fields, field)
} }
s.EnableDynamicField = p.GetEnableDynamicField() s.EnableDynamicField = p.GetEnableDynamicField()
return s return s
@ -127,12 +136,15 @@ func (s *Schema) ReadProto(p *schemapb.CollectionSchema) *Schema {
// PKFieldName returns pk field name for this schemapb. // PKFieldName returns pk field name for this schemapb.
func (s *Schema) PKFieldName() string { func (s *Schema) PKFieldName() string {
for _, field := range s.Fields { if s.pkField == nil {
if field.PrimaryKey { return ""
return field.Name
}
} }
return "" return s.pkField.Name
}
// PKField returns PK Field schema for this schema.
func (s *Schema) PKField() *Field {
return s.pkField
} }
// Field represent field schema in milvus // Field represent field schema in milvus

View File

@ -17,6 +17,8 @@
package client package client
import ( import (
"fmt"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/client/v2/entity" "github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/client/v2/index" "github.com/milvus-io/milvus/client/v2/index"
@ -31,15 +33,27 @@ type createIndexOption struct {
fieldName string fieldName string
indexName string indexName string
indexDef index.Index indexDef index.Index
extraParams map[string]any
}
func (opt *createIndexOption) WithExtraParam(key string, value any) {
opt.extraParams[key] = value
} }
func (opt *createIndexOption) Request() *milvuspb.CreateIndexRequest { func (opt *createIndexOption) Request() *milvuspb.CreateIndexRequest {
return &milvuspb.CreateIndexRequest{ params := opt.indexDef.Params()
for key, value := range opt.extraParams {
params[key] = fmt.Sprintf("%v", value)
}
req := &milvuspb.CreateIndexRequest{
CollectionName: opt.collectionName, CollectionName: opt.collectionName,
FieldName: opt.fieldName, FieldName: opt.fieldName,
IndexName: opt.indexName, IndexName: opt.indexName,
ExtraParams: entity.MapKvPairs(opt.indexDef.Params()), ExtraParams: entity.MapKvPairs(params),
} }
return req
} }
func (opt *createIndexOption) WithIndexName(indexName string) *createIndexOption { func (opt *createIndexOption) WithIndexName(indexName string) *createIndexOption {
@ -52,6 +66,7 @@ func NewCreateIndexOption(collectionName string, fieldName string, index index.I
collectionName: collectionName, collectionName: collectionName,
fieldName: fieldName, fieldName: fieldName,
indexDef: index, indexDef: index,
extraParams: make(map[string]any),
} }
} }

159
client/interceptors.go Normal file
View File

@ -0,0 +1,159 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"context"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
)
const (
authorizationHeader = `authorization`
identifierHeader = `identifier`
databaseHeader = `dbname`
)
func (c *Client) MetadataUnaryInterceptor() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
ctx = c.metadata(ctx)
ctx = c.state(ctx)
return invoker(ctx, method, req, reply, cc, opts...)
}
}
func (c *Client) metadata(ctx context.Context) context.Context {
for k, v := range c.config.metadataHeaders {
ctx = metadata.AppendToOutgoingContext(ctx, k, v)
}
return ctx
}
func (c *Client) state(ctx context.Context) context.Context {
c.stateMut.RLock()
defer c.stateMut.RUnlock()
if c.currentDB != "" {
ctx = metadata.AppendToOutgoingContext(ctx, databaseHeader, c.currentDB)
}
if c.identifier != "" {
ctx = metadata.AppendToOutgoingContext(ctx, identifierHeader, c.identifier)
}
return ctx
}
// ref: https://github.com/grpc-ecosystem/go-grpc-middleware
type ctxKey int
const (
RetryOnRateLimit ctxKey = iota
)
// RetryOnRateLimitInterceptor returns a new retrying unary client interceptor.
func RetryOnRateLimitInterceptor(maxRetry uint, maxBackoff time.Duration, backoffFunc grpc_retry.BackoffFuncContext) grpc.UnaryClientInterceptor {
return func(parentCtx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
if maxRetry == 0 {
return invoker(parentCtx, method, req, reply, cc, opts...)
}
var lastErr error
for attempt := uint(0); attempt < maxRetry; attempt++ {
_, err := waitRetryBackoff(parentCtx, attempt, maxBackoff, backoffFunc)
if err != nil {
return err
}
lastErr = invoker(parentCtx, method, req, reply, cc, opts...)
rspStatus := getResultStatus(reply)
if retryOnRateLimit(parentCtx) && rspStatus.GetErrorCode() == commonpb.ErrorCode_RateLimit {
continue
}
return lastErr
}
return lastErr
}
}
func retryOnRateLimit(ctx context.Context) bool {
retry, ok := ctx.Value(RetryOnRateLimit).(bool)
if !ok {
return true // default true
}
return retry
}
// getResultStatus returns status of response.
func getResultStatus(reply interface{}) *commonpb.Status {
switch r := reply.(type) {
case *commonpb.Status:
return r
case *milvuspb.MutationResult:
return r.GetStatus()
case *milvuspb.BoolResponse:
return r.GetStatus()
case *milvuspb.SearchResults:
return r.GetStatus()
case *milvuspb.QueryResults:
return r.GetStatus()
case *milvuspb.FlushResponse:
return r.GetStatus()
default:
return nil
}
}
func contextErrToGrpcErr(err error) error {
switch err {
case context.DeadlineExceeded:
return status.Error(codes.DeadlineExceeded, err.Error())
case context.Canceled:
return status.Error(codes.Canceled, err.Error())
default:
return status.Error(codes.Unknown, err.Error())
}
}
func waitRetryBackoff(parentCtx context.Context, attempt uint, maxBackoff time.Duration, backoffFunc grpc_retry.BackoffFuncContext) (time.Duration, error) {
var waitTime time.Duration
if attempt > 0 {
waitTime = backoffFunc(parentCtx, attempt)
}
if waitTime > 0 {
if waitTime > maxBackoff {
waitTime = maxBackoff
}
timer := time.NewTimer(waitTime)
select {
case <-parentCtx.Done():
timer.Stop()
return waitTime, contextErrToGrpcErr(parentCtx.Err())
case <-timer.C:
}
}
return waitTime, nil
}

View File

@ -0,0 +1,66 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"context"
"math"
"testing"
"time"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
)
var mockInvokerError error
var mockInvokerReply interface{}
var mockInvokeTimes = 0
var mockInvoker grpc.UnaryInvoker = func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
mockInvokeTimes++
return mockInvokerError
}
func resetMockInvokeTimes() {
mockInvokeTimes = 0
}
func TestRateLimitInterceptor(t *testing.T) {
maxRetry := uint(3)
maxBackoff := 3 * time.Second
inter := RetryOnRateLimitInterceptor(maxRetry, maxBackoff, func(ctx context.Context, attempt uint) time.Duration {
return 60 * time.Millisecond * time.Duration(math.Pow(2, float64(attempt)))
})
ctx := context.Background()
// with retry
mockInvokerReply = &commonpb.Status{ErrorCode: commonpb.ErrorCode_RateLimit}
resetMockInvokeTimes()
err := inter(ctx, "", nil, mockInvokerReply, nil, mockInvoker)
assert.NoError(t, err)
assert.Equal(t, maxRetry, uint(mockInvokeTimes))
// without retry
ctx1 := context.WithValue(ctx, RetryOnRateLimit, false)
resetMockInvokeTimes()
err = inter(ctx1, "", nil, mockInvokerReply, nil, mockInvoker)
assert.NoError(t, err)
assert.Equal(t, uint(1), uint(mockInvokeTimes))
}

View File

@ -33,7 +33,7 @@ type ResultSets struct{}
type ResultSet struct { type ResultSet struct {
ResultCount int // the returning entry count ResultCount int // the returning entry count
GroupByValue any GroupByValue column.Column
IDs column.Column // auto generated id, can be mapped to the columns from `Insert` API IDs column.Column // auto generated id, can be mapped to the columns from `Insert` API
Fields DataSet // output field data Fields DataSet // output field data
Scores []float32 // distance to the target vector Scores []float32 // distance to the target vector
@ -67,35 +67,32 @@ func (c *Client) Search(ctx context.Context, option SearchOption, callOptions ..
} }
func (c *Client) handleSearchResult(schema *entity.Schema, outputFields []string, nq int, resp *milvuspb.SearchResults) ([]ResultSet, error) { func (c *Client) handleSearchResult(schema *entity.Schema, outputFields []string, nq int, resp *milvuspb.SearchResults) ([]ResultSet, error) {
var err error
sr := make([]ResultSet, 0, nq) sr := make([]ResultSet, 0, nq)
results := resp.GetResults() results := resp.GetResults()
offset := 0 offset := 0
fieldDataList := results.GetFieldsData() fieldDataList := results.GetFieldsData()
gb := results.GetGroupByFieldValue() gb := results.GetGroupByFieldValue()
var gbc column.Column
if gb != nil {
gbc, err = column.FieldDataColumn(gb, 0, -1)
if err != nil {
return nil, err
}
}
for i := 0; i < int(results.GetNumQueries()); i++ { for i := 0; i < int(results.GetNumQueries()); i++ {
rc := int(results.GetTopks()[i]) // result entry count for current query rc := int(results.GetTopks()[i]) // result entry count for current query
entry := ResultSet{ entry := ResultSet{
ResultCount: rc, ResultCount: rc,
Scores: results.GetScores()[offset : offset+rc], Scores: results.GetScores()[offset : offset+rc],
} }
if gbc != nil {
entry.GroupByValue, _ = gbc.Get(i)
}
// parse result set if current nq is not empty // parse result set if current nq is not empty
if rc > 0 { if rc > 0 {
entry.IDs, entry.Err = column.IDColumns(results.GetIds(), offset, offset+rc) entry.IDs, entry.Err = column.IDColumns(schema, results.GetIds(), offset, offset+rc)
if entry.Err != nil { if entry.Err != nil {
offset += rc offset += rc
continue continue
} }
// parse group-by values
if gb != nil {
entry.GroupByValue, entry.Err = column.FieldDataColumn(gb, offset, offset+rc)
if entry.Err != nil {
offset += rc
continue
}
}
entry.Fields, entry.Err = c.parseSearchResult(schema, outputFields, fieldDataList, i, offset, offset+rc) entry.Fields, entry.Err = c.parseSearchResult(schema, outputFields, fieldDataList, i, offset, offset+rc)
sr = append(sr, entry) sr = append(sr, entry)
} }

View File

@ -87,7 +87,7 @@ func (opt *searchOption) prepareSearchRequest(annRequest *annRequest) *milvuspb.
// search param // search param
bs, _ := json.Marshal(annRequest.searchParam) bs, _ := json.Marshal(annRequest.searchParam)
request.SearchParams = entity.MapKvPairs(map[string]string{ params := map[string]string{
spAnnsField: annRequest.annField, spAnnsField: annRequest.annField,
spTopK: strconv.Itoa(opt.topK), spTopK: strconv.Itoa(opt.topK),
spOffset: strconv.Itoa(opt.offset), spOffset: strconv.Itoa(opt.offset),
@ -95,8 +95,11 @@ func (opt *searchOption) prepareSearchRequest(annRequest *annRequest) *milvuspb.
spMetricsType: string(annRequest.metricsType), spMetricsType: string(annRequest.metricsType),
spRoundDecimal: "-1", spRoundDecimal: "-1",
spIgnoreGrowing: strconv.FormatBool(opt.ignoreGrowing), spIgnoreGrowing: strconv.FormatBool(opt.ignoreGrowing),
spGroupBy: annRequest.groupByField, }
}) if annRequest.groupByField != "" {
params[spGroupBy] = annRequest.groupByField
}
request.SearchParams = entity.MapKvPairs(params)
// placeholder group // placeholder group
request.PlaceholderGroup = vector2PlaceholderGroupBytes(annRequest.vectors) request.PlaceholderGroup = vector2PlaceholderGroupBytes(annRequest.vectors)

View File

@ -22,53 +22,90 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/client/v2/column"
"github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/merr"
) )
func (c *Client) Insert(ctx context.Context, option InsertOption, callOptions ...grpc.CallOption) error { type InsertResult struct {
InsertCount int64
IDs column.Column
}
func (c *Client) Insert(ctx context.Context, option InsertOption, callOptions ...grpc.CallOption) (InsertResult, error) {
result := InsertResult{}
collection, err := c.getCollection(ctx, option.CollectionName()) collection, err := c.getCollection(ctx, option.CollectionName())
if err != nil { if err != nil {
return err return result, err
} }
req, err := option.InsertRequest(collection) req, err := option.InsertRequest(collection)
if err != nil { if err != nil {
return err return result, err
} }
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error { err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.Insert(ctx, req, callOptions...) resp, err := milvusService.Insert(ctx, req, callOptions...)
return merr.CheckRPCCall(resp, err) err = merr.CheckRPCCall(resp, err)
if err != nil {
return err
}
result.InsertCount = resp.GetInsertCnt()
result.IDs, err = column.IDColumns(collection.Schema, resp.GetIDs(), 0, -1)
if err != nil {
return err
}
return nil
}) })
return err return result, err
} }
func (c *Client) Delete(ctx context.Context, option DeleteOption, callOptions ...grpc.CallOption) error { type DeleteResult struct {
DeleteCount int64
}
func (c *Client) Delete(ctx context.Context, option DeleteOption, callOptions ...grpc.CallOption) (DeleteResult, error) {
req := option.Request() req := option.Request()
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error { result := DeleteResult{}
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.Delete(ctx, req, callOptions...) resp, err := milvusService.Delete(ctx, req, callOptions...)
if err = merr.CheckRPCCall(resp, err); err != nil { if err = merr.CheckRPCCall(resp, err); err != nil {
return err return err
} }
result.DeleteCount = resp.GetDeleteCnt()
return nil return nil
}) })
return result, err
} }
func (c *Client) Upsert(ctx context.Context, option UpsertOption, callOptions ...grpc.CallOption) error { type UpsertResult struct {
UpsertCount int64
IDs column.Column
}
func (c *Client) Upsert(ctx context.Context, option UpsertOption, callOptions ...grpc.CallOption) (UpsertResult, error) {
result := UpsertResult{}
collection, err := c.getCollection(ctx, option.CollectionName()) collection, err := c.getCollection(ctx, option.CollectionName())
if err != nil { if err != nil {
return err return result, err
} }
req, err := option.UpsertRequest(collection) req, err := option.UpsertRequest(collection)
if err != nil { if err != nil {
return err return result, err
} }
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.Upsert(ctx, req, callOptions...) resp, err := milvusService.Upsert(ctx, req, callOptions...)
if err = merr.CheckRPCCall(resp, err); err != nil { if err = merr.CheckRPCCall(resp, err); err != nil {
return err return err
} }
result.UpsertCount = resp.GetUpsertCnt()
result.IDs, err = column.IDColumns(collection.Schema, resp.GetIDs(), 0, -1)
if err != nil {
return err
}
return nil return nil
}) })
return result, err
} }

View File

@ -23,6 +23,7 @@ import (
"testing" "testing"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/entity" "github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/merr"
"github.com/samber/lo" "github.com/samber/lo"
@ -63,16 +64,25 @@ func (s *WriteSuite) TestInsert() {
s.Require().Len(ir.GetFieldsData(), 2) s.Require().Len(ir.GetFieldsData(), 2)
s.EqualValues(3, ir.GetNumRows()) s.EqualValues(3, ir.GetNumRows())
return &milvuspb.MutationResult{ return &milvuspb.MutationResult{
Status: merr.Success(), Status: merr.Success(),
InsertCnt: 3,
IDs: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1, 2, 3},
},
},
},
}, nil }, nil
}).Once() }).Once()
err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName). result, err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 { WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() }) return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
})). })).
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName)) WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
s.NoError(err) s.NoError(err)
s.EqualValues(3, result.InsertCount)
}) })
s.Run("dynamic_schema", func() { s.Run("dynamic_schema", func() {
@ -86,17 +96,26 @@ func (s *WriteSuite) TestInsert() {
s.Require().Len(ir.GetFieldsData(), 3) s.Require().Len(ir.GetFieldsData(), 3)
s.EqualValues(3, ir.GetNumRows()) s.EqualValues(3, ir.GetNumRows())
return &milvuspb.MutationResult{ return &milvuspb.MutationResult{
Status: merr.Success(), Status: merr.Success(),
InsertCnt: 3,
IDs: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1, 2, 3},
},
},
},
}, nil }, nil
}).Once() }).Once()
err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName). result, err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 { WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() }) return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
})). })).
WithVarcharColumn("extra", []string{"a", "b", "c"}). WithVarcharColumn("extra", []string{"a", "b", "c"}).
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName)) WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
s.NoError(err) s.NoError(err)
s.EqualValues(3, result.InsertCount)
}) })
s.Run("bad_input", func() { s.Run("bad_input", func() {
@ -141,7 +160,7 @@ func (s *WriteSuite) TestInsert() {
for _, tc := range cases { for _, tc := range cases {
s.Run(tc.tag, func() { s.Run(tc.tag, func() {
err := s.client.Insert(ctx, tc.input) _, err := s.client.Insert(ctx, tc.input)
s.Error(err) s.Error(err)
}) })
} }
@ -153,7 +172,7 @@ func (s *WriteSuite) TestInsert() {
s.mock.EXPECT().Insert(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() s.mock.EXPECT().Insert(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName). _, err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 { WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() }) return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
})). })).
@ -177,16 +196,25 @@ func (s *WriteSuite) TestUpsert() {
s.Require().Len(ur.GetFieldsData(), 2) s.Require().Len(ur.GetFieldsData(), 2)
s.EqualValues(3, ur.GetNumRows()) s.EqualValues(3, ur.GetNumRows())
return &milvuspb.MutationResult{ return &milvuspb.MutationResult{
Status: merr.Success(), Status: merr.Success(),
UpsertCnt: 3,
IDs: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1, 2, 3},
},
},
},
}, nil }, nil
}).Once() }).Once()
err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName). result, err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 { WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() }) return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
})). })).
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName)) WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
s.NoError(err) s.NoError(err)
s.EqualValues(3, result.UpsertCount)
}) })
s.Run("dynamic_schema", func() { s.Run("dynamic_schema", func() {
@ -200,17 +228,26 @@ func (s *WriteSuite) TestUpsert() {
s.Require().Len(ur.GetFieldsData(), 3) s.Require().Len(ur.GetFieldsData(), 3)
s.EqualValues(3, ur.GetNumRows()) s.EqualValues(3, ur.GetNumRows())
return &milvuspb.MutationResult{ return &milvuspb.MutationResult{
Status: merr.Success(), Status: merr.Success(),
UpsertCnt: 3,
IDs: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1, 2, 3},
},
},
},
}, nil }, nil
}).Once() }).Once()
err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName). result, err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 { WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() }) return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
})). })).
WithVarcharColumn("extra", []string{"a", "b", "c"}). WithVarcharColumn("extra", []string{"a", "b", "c"}).
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName)) WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
s.NoError(err) s.NoError(err)
s.EqualValues(3, result.UpsertCount)
}) })
s.Run("bad_input", func() { s.Run("bad_input", func() {
@ -255,7 +292,7 @@ func (s *WriteSuite) TestUpsert() {
for _, tc := range cases { for _, tc := range cases {
s.Run(tc.tag, func() { s.Run(tc.tag, func() {
err := s.client.Upsert(ctx, tc.input) _, err := s.client.Upsert(ctx, tc.input)
s.Error(err) s.Error(err)
}) })
} }
@ -267,7 +304,7 @@ func (s *WriteSuite) TestUpsert() {
s.mock.EXPECT().Upsert(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once() s.mock.EXPECT().Upsert(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName). _, err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 { WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() }) return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
})). })).
@ -315,11 +352,13 @@ func (s *WriteSuite) TestDelete() {
s.Equal(partName, dr.GetPartitionName()) s.Equal(partName, dr.GetPartitionName())
s.Equal(tc.expectExpr, dr.GetExpr()) s.Equal(tc.expectExpr, dr.GetExpr())
return &milvuspb.MutationResult{ return &milvuspb.MutationResult{
Status: merr.Success(), Status: merr.Success(),
DeleteCnt: 100,
}, nil }, nil
}).Once() }).Once()
err := s.client.Delete(ctx, tc.input) result, err := s.client.Delete(ctx, tc.input)
s.NoError(err) s.NoError(err)
s.EqualValues(100, result.DeleteCount)
}) })
} }
}) })