mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 17:48:29 +08:00
enhance: Update latest sdk update to client pkg (#33105)
Related to #31293 See also milvus-io/milvus-sdk-go#704 milvus-io/milvus-sdk-go#711 milvus-io/milvus-sdk-go#713 milvus-io/milvus-sdk-go#721 milvus-io/milvus-sdk-go#732 milvus-io/milvus-sdk-go#739 milvus-io/milvus-sdk-go#748 --------- Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
parent
f1c9986974
commit
1ef975d327
@ -1,5 +1,6 @@
|
|||||||
reviewers:
|
reviewers:
|
||||||
- congqixia
|
- congqixia
|
||||||
|
- ThreadDao
|
||||||
|
|
||||||
approvers:
|
approvers:
|
||||||
- maintainers
|
- maintainers
|
||||||
|
|||||||
@ -18,14 +18,20 @@ package client
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gogo/status"
|
"github.com/gogo/status"
|
||||||
|
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||||
@ -39,6 +45,11 @@ type Client struct {
|
|||||||
service milvuspb.MilvusServiceClient
|
service milvuspb.MilvusServiceClient
|
||||||
config *ClientConfig
|
config *ClientConfig
|
||||||
|
|
||||||
|
// mutable status
|
||||||
|
stateMut sync.RWMutex
|
||||||
|
currentDB string
|
||||||
|
identifier string
|
||||||
|
|
||||||
collCache *CollectionCache
|
collCache *CollectionCache
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -54,8 +65,10 @@ func New(ctx context.Context, config *ClientConfig) (*Client, error) {
|
|||||||
// Parse remote address.
|
// Parse remote address.
|
||||||
addr := c.config.getParsedAddress()
|
addr := c.config.getParsedAddress()
|
||||||
|
|
||||||
|
// parse authentication parameters
|
||||||
|
c.config.parseAuthentication()
|
||||||
// Parse grpc options
|
// Parse grpc options
|
||||||
options := c.config.getDialOption()
|
options := c.dialOptions()
|
||||||
|
|
||||||
// Connect the grpc server.
|
// Connect the grpc server.
|
||||||
if err := c.connect(ctx, addr, options...); err != nil {
|
if err := c.connect(ctx, addr, options...); err != nil {
|
||||||
@ -69,6 +82,40 @@ func New(ctx context.Context, config *ClientConfig) (*Client, error) {
|
|||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) dialOptions() []grpc.DialOption {
|
||||||
|
var options []grpc.DialOption
|
||||||
|
// Construct dial option.
|
||||||
|
if c.config.EnableTLSAuth {
|
||||||
|
options = append(options, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})))
|
||||||
|
} else {
|
||||||
|
options = append(options, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.config.DialOptions == nil {
|
||||||
|
// Add default connection options.
|
||||||
|
options = append(options, DefaultGrpcOpts...)
|
||||||
|
} else {
|
||||||
|
options = append(options, c.config.DialOptions...)
|
||||||
|
}
|
||||||
|
|
||||||
|
options = append(options,
|
||||||
|
grpc.WithChainUnaryInterceptor(grpc_retry.UnaryClientInterceptor(
|
||||||
|
grpc_retry.WithMax(6),
|
||||||
|
grpc_retry.WithBackoff(func(attempt uint) time.Duration {
|
||||||
|
return 60 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt)))
|
||||||
|
}),
|
||||||
|
grpc_retry.WithCodes(codes.Unavailable, codes.ResourceExhausted)),
|
||||||
|
|
||||||
|
// c.getRetryOnRateLimitInterceptor(),
|
||||||
|
))
|
||||||
|
|
||||||
|
options = append(options, grpc.WithChainUnaryInterceptor(
|
||||||
|
c.MetadataUnaryInterceptor(),
|
||||||
|
))
|
||||||
|
|
||||||
|
return options
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) Close(ctx context.Context) error {
|
func (c *Client) Close(ctx context.Context) error {
|
||||||
if c.conn == nil {
|
if c.conn == nil {
|
||||||
return nil
|
return nil
|
||||||
@ -82,6 +129,18 @@ func (c *Client) Close(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) usingDatabase(dbName string) {
|
||||||
|
c.stateMut.Lock()
|
||||||
|
defer c.stateMut.Unlock()
|
||||||
|
c.currentDB = dbName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) setIdentifier(identifier string) {
|
||||||
|
c.stateMut.Lock()
|
||||||
|
defer c.stateMut.Unlock()
|
||||||
|
c.identifier = identifier
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) connect(ctx context.Context, addr string, options ...grpc.DialOption) error {
|
func (c *Client) connect(ctx context.Context, addr string, options ...grpc.DialOption) error {
|
||||||
if addr == "" {
|
if addr == "" {
|
||||||
return fmt.Errorf("address is empty")
|
return fmt.Errorf("address is empty")
|
||||||
@ -112,7 +171,7 @@ func (c *Client) connectInternal(ctx context.Context) error {
|
|||||||
|
|
||||||
req := &milvuspb.ConnectRequest{
|
req := &milvuspb.ConnectRequest{
|
||||||
ClientInfo: &commonpb.ClientInfo{
|
ClientInfo: &commonpb.ClientInfo{
|
||||||
SdkType: "Golang",
|
SdkType: "GoMilvusClient",
|
||||||
SdkVersion: common.SDKVersion,
|
SdkVersion: common.SDKVersion,
|
||||||
LocalTime: time.Now().String(),
|
LocalTime: time.Now().String(),
|
||||||
User: c.config.Username,
|
User: c.config.Username,
|
||||||
@ -131,9 +190,9 @@ func (c *Client) connectInternal(ctx context.Context) error {
|
|||||||
disableJSON |
|
disableJSON |
|
||||||
disableParitionKey |
|
disableParitionKey |
|
||||||
disableDynamicSchema)
|
disableDynamicSchema)
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -142,7 +201,7 @@ func (c *Client) connectInternal(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.config.setServerInfo(resp.GetServerInfo().GetBuildTags())
|
c.config.setServerInfo(resp.GetServerInfo().GetBuildTags())
|
||||||
c.config.setIdentifier(strconv.FormatInt(resp.GetIdentifier(), 10))
|
c.setIdentifier(strconv.FormatInt(resp.GetIdentifier(), 10))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -10,12 +10,9 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cockroachdb/errors"
|
"github.com/cockroachdb/errors"
|
||||||
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
|
"github.com/milvus-io/milvus/pkg/util/crypto"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/backoff"
|
"google.golang.org/grpc/backoff"
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/credentials"
|
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -59,16 +56,23 @@ type ClientConfig struct {
|
|||||||
|
|
||||||
DialOptions []grpc.DialOption // Dial options for GRPC.
|
DialOptions []grpc.DialOption // Dial options for GRPC.
|
||||||
|
|
||||||
// RetryRateLimit *RetryRateLimitOption // option for retry on rate limit inteceptor
|
RetryRateLimit *RetryRateLimitOption // option for retry on rate limit inteceptor
|
||||||
|
|
||||||
DisableConn bool
|
DisableConn bool
|
||||||
|
|
||||||
|
metadataHeaders map[string]string
|
||||||
|
|
||||||
identifier string // Identifier for this connection
|
identifier string // Identifier for this connection
|
||||||
ServerVersion string // ServerVersion
|
ServerVersion string // ServerVersion
|
||||||
parsedAddress *url.URL
|
parsedAddress *url.URL
|
||||||
flags uint64 // internal flags
|
flags uint64 // internal flags
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RetryRateLimitOption struct {
|
||||||
|
MaxRetry uint
|
||||||
|
MaxBackoff time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
func (cfg *ClientConfig) parse() error {
|
func (cfg *ClientConfig) parse() error {
|
||||||
// Prepend default fake tcp:// scheme for remote address.
|
// Prepend default fake tcp:// scheme for remote address.
|
||||||
address := cfg.Address
|
address := cfg.Address
|
||||||
@ -118,54 +122,36 @@ func (c *ClientConfig) setServerInfo(serverInfo string) {
|
|||||||
c.ServerVersion = serverInfo
|
c.ServerVersion = serverInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get parsed grpc dial options, should be called after parse was called.
|
// parseAuthentication prepares authentication headers for grpc inteceptors based on the provided username, password or API key.
|
||||||
func (c *ClientConfig) getDialOption() []grpc.DialOption {
|
func (c *ClientConfig) parseAuthentication() {
|
||||||
options := c.DialOptions
|
c.metadataHeaders = make(map[string]string)
|
||||||
if c.DialOptions == nil {
|
if c.Username != "" || c.Password != "" {
|
||||||
// Add default connection options.
|
value := crypto.Base64Encode(fmt.Sprintf("%s:%s", c.Username, c.Password))
|
||||||
options = make([]grpc.DialOption, len(DefaultGrpcOpts))
|
c.metadataHeaders[authorizationHeader] = value
|
||||||
copy(options, DefaultGrpcOpts)
|
|
||||||
}
|
}
|
||||||
|
// API overwrites username & passwd
|
||||||
// Construct dial option.
|
if c.APIKey != "" {
|
||||||
if c.EnableTLSAuth {
|
value := crypto.Base64Encode(c.APIKey)
|
||||||
options = append(options, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})))
|
c.metadataHeaders[authorizationHeader] = value
|
||||||
} else {
|
|
||||||
options = append(options, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
options = append(options,
|
|
||||||
grpc.WithChainUnaryInterceptor(grpc_retry.UnaryClientInterceptor(
|
|
||||||
grpc_retry.WithMax(6),
|
|
||||||
grpc_retry.WithBackoff(func(attempt uint) time.Duration {
|
|
||||||
return 60 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt)))
|
|
||||||
}),
|
|
||||||
grpc_retry.WithCodes(codes.Unavailable, codes.ResourceExhausted)),
|
|
||||||
// c.getRetryOnRateLimitInterceptor(),
|
|
||||||
))
|
|
||||||
|
|
||||||
// options = append(options, grpc.WithChainUnaryInterceptor(
|
|
||||||
// createMetaDataUnaryInterceptor(c),
|
|
||||||
// ))
|
|
||||||
return options
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// func (c *ClientConfig) getRetryOnRateLimitInterceptor() grpc.UnaryClientInterceptor {
|
func (c *ClientConfig) getRetryOnRateLimitInterceptor() grpc.UnaryClientInterceptor {
|
||||||
// if c.RetryRateLimit == nil {
|
if c.RetryRateLimit == nil {
|
||||||
// c.RetryRateLimit = c.defaultRetryRateLimitOption()
|
c.RetryRateLimit = c.defaultRetryRateLimitOption()
|
||||||
// }
|
}
|
||||||
|
|
||||||
// return RetryOnRateLimitInterceptor(c.RetryRateLimit.MaxRetry, c.RetryRateLimit.MaxBackoff, func(ctx context.Context, attempt uint) time.Duration {
|
return RetryOnRateLimitInterceptor(c.RetryRateLimit.MaxRetry, c.RetryRateLimit.MaxBackoff, func(ctx context.Context, attempt uint) time.Duration {
|
||||||
// return 10 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt)))
|
return 10 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt)))
|
||||||
// })
|
})
|
||||||
// }
|
}
|
||||||
|
|
||||||
// func (c *ClientConfig) defaultRetryRateLimitOption() *RetryRateLimitOption {
|
func (c *ClientConfig) defaultRetryRateLimitOption() *RetryRateLimitOption {
|
||||||
// return &RetryRateLimitOption{
|
return &RetryRateLimitOption{
|
||||||
// MaxRetry: 75,
|
MaxRetry: 75,
|
||||||
// MaxBackoff: 3 * time.Second,
|
MaxBackoff: 3 * time.Second,
|
||||||
// }
|
}
|
||||||
// }
|
}
|
||||||
|
|
||||||
// addFlags set internal flags
|
// addFlags set internal flags
|
||||||
func (c *ClientConfig) addFlags(flags uint64) {
|
func (c *ClientConfig) addFlags(flags uint64) {
|
||||||
|
|||||||
@ -98,6 +98,7 @@ func (c *Client) DescribeCollection(ctx context.Context, option *describeCollect
|
|||||||
VirtualChannels: resp.GetVirtualChannelNames(),
|
VirtualChannels: resp.GetVirtualChannelNames(),
|
||||||
ConsistencyLevel: entity.ConsistencyLevel(resp.ConsistencyLevel),
|
ConsistencyLevel: entity.ConsistencyLevel(resp.ConsistencyLevel),
|
||||||
ShardNum: resp.GetShardsNum(),
|
ShardNum: resp.GetShardsNum(),
|
||||||
|
Properties: entity.KvPairsMap(resp.GetProperties()),
|
||||||
}
|
}
|
||||||
collection.Name = collection.Schema.CollectionName
|
collection.Name = collection.Schema.CollectionName
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@ -140,6 +140,7 @@ func SimpleCreateCollectionOptions(name string, dim int64) *createCollectionOpti
|
|||||||
autoID: true,
|
autoID: true,
|
||||||
dim: dim,
|
dim: dim,
|
||||||
enabledDynamicSchema: true,
|
enabledDynamicSchema: true,
|
||||||
|
consistencyLevel: entity.DefaultConsistencyLevel,
|
||||||
|
|
||||||
isFast: true,
|
isFast: true,
|
||||||
metricType: entity.COSINE,
|
metricType: entity.COSINE,
|
||||||
@ -152,6 +153,7 @@ func NewCreateCollectionOption(name string, collectionSchema *entity.Schema) *cr
|
|||||||
name: name,
|
name: name,
|
||||||
shardNum: 1,
|
shardNum: 1,
|
||||||
schema: collectionSchema,
|
schema: collectionSchema,
|
||||||
|
consistencyLevel: entity.DefaultConsistencyLevel,
|
||||||
|
|
||||||
metricType: entity.COSINE,
|
metricType: entity.COSINE,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -64,26 +64,38 @@ var errFieldDataTypeNotMatch = errors.New("FieldData type not matched")
|
|||||||
|
|
||||||
// IDColumns converts schemapb.IDs to corresponding column
|
// IDColumns converts schemapb.IDs to corresponding column
|
||||||
// currently Int64 / string may be in IDs
|
// currently Int64 / string may be in IDs
|
||||||
func IDColumns(idField *schemapb.IDs, begin, end int) (Column, error) {
|
func IDColumns(schema *entity.Schema, ids *schemapb.IDs, begin, end int) (Column, error) {
|
||||||
var idColumn Column
|
var idColumn Column
|
||||||
if idField == nil {
|
pkField := schema.PKField()
|
||||||
|
if pkField == nil {
|
||||||
|
return nil, errors.New("PK Field not found")
|
||||||
|
}
|
||||||
|
if ids == nil {
|
||||||
return nil, errors.New("nil Ids from response")
|
return nil, errors.New("nil Ids from response")
|
||||||
}
|
}
|
||||||
switch field := idField.GetIdField().(type) {
|
switch pkField.DataType {
|
||||||
case *schemapb.IDs_IntId:
|
case entity.FieldTypeInt64:
|
||||||
if end >= 0 {
|
data := ids.GetIntId().GetData()
|
||||||
idColumn = NewColumnInt64("", field.IntId.GetData()[begin:end])
|
if data == nil {
|
||||||
} else {
|
return NewColumnInt64(pkField.Name, nil), nil
|
||||||
idColumn = NewColumnInt64("", field.IntId.GetData()[begin:])
|
|
||||||
}
|
}
|
||||||
case *schemapb.IDs_StrId:
|
|
||||||
if end >= 0 {
|
if end >= 0 {
|
||||||
idColumn = NewColumnVarChar("", field.StrId.GetData()[begin:end])
|
idColumn = NewColumnInt64(pkField.Name, data[begin:end])
|
||||||
} else {
|
} else {
|
||||||
idColumn = NewColumnVarChar("", field.StrId.GetData()[begin:])
|
idColumn = NewColumnInt64(pkField.Name, data[begin:])
|
||||||
|
}
|
||||||
|
case entity.FieldTypeVarChar, entity.FieldTypeString:
|
||||||
|
data := ids.GetStrId().GetData()
|
||||||
|
if data == nil {
|
||||||
|
return NewColumnVarChar(pkField.Name, nil), nil
|
||||||
|
}
|
||||||
|
if end >= 0 {
|
||||||
|
idColumn = NewColumnVarChar(pkField.Name, data[begin:end])
|
||||||
|
} else {
|
||||||
|
idColumn = NewColumnVarChar(pkField.Name, data[begin:])
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported id type %v", field)
|
return nil, fmt.Errorf("unsupported id type %v", pkField.DataType)
|
||||||
}
|
}
|
||||||
return idColumn, nil
|
return idColumn, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -24,18 +24,34 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/client/v2/entity"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestIDColumns(t *testing.T) {
|
func TestIDColumns(t *testing.T) {
|
||||||
dataLen := rand.Intn(100) + 1
|
dataLen := rand.Intn(100) + 1
|
||||||
base := rand.Intn(5000) // id start point
|
base := rand.Intn(5000) // id start point
|
||||||
|
|
||||||
|
intPKCol := entity.NewSchema().WithField(
|
||||||
|
entity.NewField().WithName("pk").WithIsPrimaryKey(true).WithDataType(entity.FieldTypeInt64),
|
||||||
|
)
|
||||||
|
strPKCol := entity.NewSchema().WithField(
|
||||||
|
entity.NewField().WithName("pk").WithIsPrimaryKey(true).WithDataType(entity.FieldTypeVarChar),
|
||||||
|
)
|
||||||
|
|
||||||
t.Run("nil id", func(t *testing.T) {
|
t.Run("nil id", func(t *testing.T) {
|
||||||
_, err := IDColumns(nil, 0, -1)
|
col, err := IDColumns(intPKCol, nil, 0, -1)
|
||||||
assert.NotNil(t, err)
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 0, col.Len())
|
||||||
|
col, err = IDColumns(strPKCol, nil, 0, -1)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 0, col.Len())
|
||||||
idField := &schemapb.IDs{}
|
idField := &schemapb.IDs{}
|
||||||
_, err = IDColumns(idField, 0, -1)
|
col, err = IDColumns(intPKCol, idField, 0, -1)
|
||||||
assert.NotNil(t, err)
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 0, col.Len())
|
||||||
|
col, err = IDColumns(strPKCol, idField, 0, -1)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 0, col.Len())
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("int ids", func(t *testing.T) {
|
t.Run("int ids", func(t *testing.T) {
|
||||||
@ -50,12 +66,12 @@ func TestIDColumns(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
column, err := IDColumns(idField, 0, dataLen)
|
column, err := IDColumns(intPKCol, idField, 0, dataLen)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.NotNil(t, column)
|
assert.NotNil(t, column)
|
||||||
assert.Equal(t, dataLen, column.Len())
|
assert.Equal(t, dataLen, column.Len())
|
||||||
|
|
||||||
column, err = IDColumns(idField, 0, -1) // test -1 method
|
column, err = IDColumns(intPKCol, idField, 0, -1) // test -1 method
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.NotNil(t, column)
|
assert.NotNil(t, column)
|
||||||
assert.Equal(t, dataLen, column.Len())
|
assert.Equal(t, dataLen, column.Len())
|
||||||
@ -72,12 +88,12 @@ func TestIDColumns(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
column, err := IDColumns(idField, 0, dataLen)
|
column, err := IDColumns(strPKCol, idField, 0, dataLen)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.NotNil(t, column)
|
assert.NotNil(t, column)
|
||||||
assert.Equal(t, dataLen, column.Len())
|
assert.Equal(t, dataLen, column.Len())
|
||||||
|
|
||||||
column, err = IDColumns(idField, 0, -1) // test -1 method
|
column, err = IDColumns(strPKCol, idField, 0, -1) // test -1 method
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.NotNil(t, column)
|
assert.NotNil(t, column)
|
||||||
assert.Equal(t, dataLen, column.Len())
|
assert.Equal(t, dataLen, column.Len())
|
||||||
|
|||||||
@ -25,6 +25,12 @@ import (
|
|||||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (c *Client) UsingDatabase(ctx context.Context, option UsingDatabaseOption) error {
|
||||||
|
dbName := option.DbName()
|
||||||
|
c.usingDatabase(dbName)
|
||||||
|
return c.connectInternal(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) ListDatabase(ctx context.Context, option ListDatabaseOption, callOptions ...grpc.CallOption) (databaseNames []string, err error) {
|
func (c *Client) ListDatabase(ctx context.Context, option ListDatabaseOption, callOptions ...grpc.CallOption) (databaseNames []string, err error) {
|
||||||
req := option.Request()
|
req := option.Request()
|
||||||
|
|
||||||
|
|||||||
@ -18,6 +18,24 @@ package client
|
|||||||
|
|
||||||
import "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
import "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||||
|
|
||||||
|
type UsingDatabaseOption interface {
|
||||||
|
DbName() string
|
||||||
|
}
|
||||||
|
|
||||||
|
type usingDatabaseNameOpt struct {
|
||||||
|
dbName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (opt *usingDatabaseNameOpt) DbName() string {
|
||||||
|
return opt.dbName
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUsingDatabaseOption(dbName string) *usingDatabaseNameOpt {
|
||||||
|
return &usingDatabaseNameOpt{
|
||||||
|
dbName: dbName,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ListDatabaseOption is a builder interface for ListDatabase request.
|
// ListDatabaseOption is a builder interface for ListDatabase request.
|
||||||
type ListDatabaseOption interface {
|
type ListDatabaseOption interface {
|
||||||
Request() *milvuspb.ListDatabasesRequest
|
Request() *milvuspb.ListDatabasesRequest
|
||||||
|
|||||||
@ -32,6 +32,7 @@ type Collection struct {
|
|||||||
Loaded bool
|
Loaded bool
|
||||||
ConsistencyLevel ConsistencyLevel
|
ConsistencyLevel ConsistencyLevel
|
||||||
ShardNum int32
|
ShardNum int32
|
||||||
|
Properties map[string]string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Partition represent partition meta in Milvus
|
// Partition represent partition meta in Milvus
|
||||||
|
|||||||
@ -60,6 +60,8 @@ type Schema struct {
|
|||||||
AutoID bool
|
AutoID bool
|
||||||
Fields []*Field
|
Fields []*Field
|
||||||
EnableDynamicField bool
|
EnableDynamicField bool
|
||||||
|
|
||||||
|
pkField *Field
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSchema creates an empty schema object.
|
// NewSchema creates an empty schema object.
|
||||||
@ -91,6 +93,9 @@ func (s *Schema) WithDynamicFieldEnabled(dynamicEnabled bool) *Schema {
|
|||||||
|
|
||||||
// WithField adds a field into schema and returns schema itself.
|
// WithField adds a field into schema and returns schema itself.
|
||||||
func (s *Schema) WithField(f *Field) *Schema {
|
func (s *Schema) WithField(f *Field) *Schema {
|
||||||
|
if f.PrimaryKey {
|
||||||
|
s.pkField = f
|
||||||
|
}
|
||||||
s.Fields = append(s.Fields, f)
|
s.Fields = append(s.Fields, f)
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
@ -116,10 +121,14 @@ func (s *Schema) ReadProto(p *schemapb.CollectionSchema) *Schema {
|
|||||||
s.CollectionName = p.GetName()
|
s.CollectionName = p.GetName()
|
||||||
s.Fields = make([]*Field, 0, len(p.GetFields()))
|
s.Fields = make([]*Field, 0, len(p.GetFields()))
|
||||||
for _, fp := range p.GetFields() {
|
for _, fp := range p.GetFields() {
|
||||||
|
field := NewField().ReadProto(fp)
|
||||||
if fp.GetAutoID() {
|
if fp.GetAutoID() {
|
||||||
s.AutoID = true
|
s.AutoID = true
|
||||||
}
|
}
|
||||||
s.Fields = append(s.Fields, NewField().ReadProto(fp))
|
if field.PrimaryKey {
|
||||||
|
s.pkField = field
|
||||||
|
}
|
||||||
|
s.Fields = append(s.Fields, field)
|
||||||
}
|
}
|
||||||
s.EnableDynamicField = p.GetEnableDynamicField()
|
s.EnableDynamicField = p.GetEnableDynamicField()
|
||||||
return s
|
return s
|
||||||
@ -127,12 +136,15 @@ func (s *Schema) ReadProto(p *schemapb.CollectionSchema) *Schema {
|
|||||||
|
|
||||||
// PKFieldName returns pk field name for this schemapb.
|
// PKFieldName returns pk field name for this schemapb.
|
||||||
func (s *Schema) PKFieldName() string {
|
func (s *Schema) PKFieldName() string {
|
||||||
for _, field := range s.Fields {
|
if s.pkField == nil {
|
||||||
if field.PrimaryKey {
|
|
||||||
return field.Name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
return ""
|
||||||
|
}
|
||||||
|
return s.pkField.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
// PKField returns PK Field schema for this schema.
|
||||||
|
func (s *Schema) PKField() *Field {
|
||||||
|
return s.pkField
|
||||||
}
|
}
|
||||||
|
|
||||||
// Field represent field schema in milvus
|
// Field represent field schema in milvus
|
||||||
|
|||||||
@ -17,6 +17,8 @@
|
|||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||||
"github.com/milvus-io/milvus/client/v2/entity"
|
"github.com/milvus-io/milvus/client/v2/entity"
|
||||||
"github.com/milvus-io/milvus/client/v2/index"
|
"github.com/milvus-io/milvus/client/v2/index"
|
||||||
@ -31,15 +33,27 @@ type createIndexOption struct {
|
|||||||
fieldName string
|
fieldName string
|
||||||
indexName string
|
indexName string
|
||||||
indexDef index.Index
|
indexDef index.Index
|
||||||
|
|
||||||
|
extraParams map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
func (opt *createIndexOption) WithExtraParam(key string, value any) {
|
||||||
|
opt.extraParams[key] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
func (opt *createIndexOption) Request() *milvuspb.CreateIndexRequest {
|
func (opt *createIndexOption) Request() *milvuspb.CreateIndexRequest {
|
||||||
return &milvuspb.CreateIndexRequest{
|
params := opt.indexDef.Params()
|
||||||
|
for key, value := range opt.extraParams {
|
||||||
|
params[key] = fmt.Sprintf("%v", value)
|
||||||
|
}
|
||||||
|
req := &milvuspb.CreateIndexRequest{
|
||||||
CollectionName: opt.collectionName,
|
CollectionName: opt.collectionName,
|
||||||
FieldName: opt.fieldName,
|
FieldName: opt.fieldName,
|
||||||
IndexName: opt.indexName,
|
IndexName: opt.indexName,
|
||||||
ExtraParams: entity.MapKvPairs(opt.indexDef.Params()),
|
ExtraParams: entity.MapKvPairs(params),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return req
|
||||||
}
|
}
|
||||||
|
|
||||||
func (opt *createIndexOption) WithIndexName(indexName string) *createIndexOption {
|
func (opt *createIndexOption) WithIndexName(indexName string) *createIndexOption {
|
||||||
@ -52,6 +66,7 @@ func NewCreateIndexOption(collectionName string, fieldName string, index index.I
|
|||||||
collectionName: collectionName,
|
collectionName: collectionName,
|
||||||
fieldName: fieldName,
|
fieldName: fieldName,
|
||||||
indexDef: index,
|
indexDef: index,
|
||||||
|
extraParams: make(map[string]any),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
159
client/interceptors.go
Normal file
159
client/interceptors.go
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
// Licensed to the LF AI & Data foundation under one
|
||||||
|
// or more contributor license agreements. See the NOTICE file
|
||||||
|
// distributed with this work for additional information
|
||||||
|
// regarding copyright ownership. The ASF licenses this file
|
||||||
|
// to you under the Apache License, Version 2.0 (the
|
||||||
|
// "License"); you may not use this file except in compliance
|
||||||
|
// with the License. You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/metadata"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
authorizationHeader = `authorization`
|
||||||
|
|
||||||
|
identifierHeader = `identifier`
|
||||||
|
|
||||||
|
databaseHeader = `dbname`
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *Client) MetadataUnaryInterceptor() grpc.UnaryClientInterceptor {
|
||||||
|
return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||||
|
ctx = c.metadata(ctx)
|
||||||
|
ctx = c.state(ctx)
|
||||||
|
|
||||||
|
return invoker(ctx, method, req, reply, cc, opts...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) metadata(ctx context.Context) context.Context {
|
||||||
|
for k, v := range c.config.metadataHeaders {
|
||||||
|
ctx = metadata.AppendToOutgoingContext(ctx, k, v)
|
||||||
|
}
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) state(ctx context.Context) context.Context {
|
||||||
|
c.stateMut.RLock()
|
||||||
|
defer c.stateMut.RUnlock()
|
||||||
|
|
||||||
|
if c.currentDB != "" {
|
||||||
|
ctx = metadata.AppendToOutgoingContext(ctx, databaseHeader, c.currentDB)
|
||||||
|
}
|
||||||
|
if c.identifier != "" {
|
||||||
|
ctx = metadata.AppendToOutgoingContext(ctx, identifierHeader, c.identifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
// ref: https://github.com/grpc-ecosystem/go-grpc-middleware
|
||||||
|
|
||||||
|
type ctxKey int
|
||||||
|
|
||||||
|
const (
|
||||||
|
RetryOnRateLimit ctxKey = iota
|
||||||
|
)
|
||||||
|
|
||||||
|
// RetryOnRateLimitInterceptor returns a new retrying unary client interceptor.
|
||||||
|
func RetryOnRateLimitInterceptor(maxRetry uint, maxBackoff time.Duration, backoffFunc grpc_retry.BackoffFuncContext) grpc.UnaryClientInterceptor {
|
||||||
|
return func(parentCtx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||||
|
if maxRetry == 0 {
|
||||||
|
return invoker(parentCtx, method, req, reply, cc, opts...)
|
||||||
|
}
|
||||||
|
var lastErr error
|
||||||
|
for attempt := uint(0); attempt < maxRetry; attempt++ {
|
||||||
|
_, err := waitRetryBackoff(parentCtx, attempt, maxBackoff, backoffFunc)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
lastErr = invoker(parentCtx, method, req, reply, cc, opts...)
|
||||||
|
rspStatus := getResultStatus(reply)
|
||||||
|
if retryOnRateLimit(parentCtx) && rspStatus.GetErrorCode() == commonpb.ErrorCode_RateLimit {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func retryOnRateLimit(ctx context.Context) bool {
|
||||||
|
retry, ok := ctx.Value(RetryOnRateLimit).(bool)
|
||||||
|
if !ok {
|
||||||
|
return true // default true
|
||||||
|
}
|
||||||
|
return retry
|
||||||
|
}
|
||||||
|
|
||||||
|
// getResultStatus returns status of response.
|
||||||
|
func getResultStatus(reply interface{}) *commonpb.Status {
|
||||||
|
switch r := reply.(type) {
|
||||||
|
case *commonpb.Status:
|
||||||
|
return r
|
||||||
|
case *milvuspb.MutationResult:
|
||||||
|
return r.GetStatus()
|
||||||
|
case *milvuspb.BoolResponse:
|
||||||
|
return r.GetStatus()
|
||||||
|
case *milvuspb.SearchResults:
|
||||||
|
return r.GetStatus()
|
||||||
|
case *milvuspb.QueryResults:
|
||||||
|
return r.GetStatus()
|
||||||
|
case *milvuspb.FlushResponse:
|
||||||
|
return r.GetStatus()
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func contextErrToGrpcErr(err error) error {
|
||||||
|
switch err {
|
||||||
|
case context.DeadlineExceeded:
|
||||||
|
return status.Error(codes.DeadlineExceeded, err.Error())
|
||||||
|
case context.Canceled:
|
||||||
|
return status.Error(codes.Canceled, err.Error())
|
||||||
|
default:
|
||||||
|
return status.Error(codes.Unknown, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitRetryBackoff(parentCtx context.Context, attempt uint, maxBackoff time.Duration, backoffFunc grpc_retry.BackoffFuncContext) (time.Duration, error) {
|
||||||
|
var waitTime time.Duration
|
||||||
|
if attempt > 0 {
|
||||||
|
waitTime = backoffFunc(parentCtx, attempt)
|
||||||
|
}
|
||||||
|
if waitTime > 0 {
|
||||||
|
if waitTime > maxBackoff {
|
||||||
|
waitTime = maxBackoff
|
||||||
|
}
|
||||||
|
timer := time.NewTimer(waitTime)
|
||||||
|
select {
|
||||||
|
case <-parentCtx.Done():
|
||||||
|
timer.Stop()
|
||||||
|
return waitTime, contextErrToGrpcErr(parentCtx.Err())
|
||||||
|
case <-timer.C:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return waitTime, nil
|
||||||
|
}
|
||||||
66
client/interceptors_test.go
Normal file
66
client/interceptors_test.go
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
// Licensed to the LF AI & Data foundation under one
|
||||||
|
// or more contributor license agreements. See the NOTICE file
|
||||||
|
// distributed with this work for additional information
|
||||||
|
// regarding copyright ownership. The ASF licenses this file
|
||||||
|
// to you under the Apache License, Version 2.0 (the
|
||||||
|
// "License"); you may not use this file except in compliance
|
||||||
|
// with the License. You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"math"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
|
)
|
||||||
|
|
||||||
|
var mockInvokerError error
|
||||||
|
var mockInvokerReply interface{}
|
||||||
|
var mockInvokeTimes = 0
|
||||||
|
|
||||||
|
var mockInvoker grpc.UnaryInvoker = func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
|
||||||
|
mockInvokeTimes++
|
||||||
|
return mockInvokerError
|
||||||
|
}
|
||||||
|
|
||||||
|
func resetMockInvokeTimes() {
|
||||||
|
mockInvokeTimes = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimitInterceptor(t *testing.T) {
|
||||||
|
maxRetry := uint(3)
|
||||||
|
maxBackoff := 3 * time.Second
|
||||||
|
inter := RetryOnRateLimitInterceptor(maxRetry, maxBackoff, func(ctx context.Context, attempt uint) time.Duration {
|
||||||
|
return 60 * time.Millisecond * time.Duration(math.Pow(2, float64(attempt)))
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// with retry
|
||||||
|
mockInvokerReply = &commonpb.Status{ErrorCode: commonpb.ErrorCode_RateLimit}
|
||||||
|
resetMockInvokeTimes()
|
||||||
|
err := inter(ctx, "", nil, mockInvokerReply, nil, mockInvoker)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, maxRetry, uint(mockInvokeTimes))
|
||||||
|
|
||||||
|
// without retry
|
||||||
|
ctx1 := context.WithValue(ctx, RetryOnRateLimit, false)
|
||||||
|
resetMockInvokeTimes()
|
||||||
|
err = inter(ctx1, "", nil, mockInvokerReply, nil, mockInvoker)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, uint(1), uint(mockInvokeTimes))
|
||||||
|
}
|
||||||
@ -33,7 +33,7 @@ type ResultSets struct{}
|
|||||||
|
|
||||||
type ResultSet struct {
|
type ResultSet struct {
|
||||||
ResultCount int // the returning entry count
|
ResultCount int // the returning entry count
|
||||||
GroupByValue any
|
GroupByValue column.Column
|
||||||
IDs column.Column // auto generated id, can be mapped to the columns from `Insert` API
|
IDs column.Column // auto generated id, can be mapped to the columns from `Insert` API
|
||||||
Fields DataSet // output field data
|
Fields DataSet // output field data
|
||||||
Scores []float32 // distance to the target vector
|
Scores []float32 // distance to the target vector
|
||||||
@ -67,35 +67,32 @@ func (c *Client) Search(ctx context.Context, option SearchOption, callOptions ..
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) handleSearchResult(schema *entity.Schema, outputFields []string, nq int, resp *milvuspb.SearchResults) ([]ResultSet, error) {
|
func (c *Client) handleSearchResult(schema *entity.Schema, outputFields []string, nq int, resp *milvuspb.SearchResults) ([]ResultSet, error) {
|
||||||
var err error
|
|
||||||
sr := make([]ResultSet, 0, nq)
|
sr := make([]ResultSet, 0, nq)
|
||||||
results := resp.GetResults()
|
results := resp.GetResults()
|
||||||
offset := 0
|
offset := 0
|
||||||
fieldDataList := results.GetFieldsData()
|
fieldDataList := results.GetFieldsData()
|
||||||
gb := results.GetGroupByFieldValue()
|
gb := results.GetGroupByFieldValue()
|
||||||
var gbc column.Column
|
|
||||||
if gb != nil {
|
|
||||||
gbc, err = column.FieldDataColumn(gb, 0, -1)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for i := 0; i < int(results.GetNumQueries()); i++ {
|
for i := 0; i < int(results.GetNumQueries()); i++ {
|
||||||
rc := int(results.GetTopks()[i]) // result entry count for current query
|
rc := int(results.GetTopks()[i]) // result entry count for current query
|
||||||
entry := ResultSet{
|
entry := ResultSet{
|
||||||
ResultCount: rc,
|
ResultCount: rc,
|
||||||
Scores: results.GetScores()[offset : offset+rc],
|
Scores: results.GetScores()[offset : offset+rc],
|
||||||
}
|
}
|
||||||
if gbc != nil {
|
|
||||||
entry.GroupByValue, _ = gbc.Get(i)
|
|
||||||
}
|
|
||||||
// parse result set if current nq is not empty
|
// parse result set if current nq is not empty
|
||||||
if rc > 0 {
|
if rc > 0 {
|
||||||
entry.IDs, entry.Err = column.IDColumns(results.GetIds(), offset, offset+rc)
|
entry.IDs, entry.Err = column.IDColumns(schema, results.GetIds(), offset, offset+rc)
|
||||||
if entry.Err != nil {
|
if entry.Err != nil {
|
||||||
offset += rc
|
offset += rc
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// parse group-by values
|
||||||
|
if gb != nil {
|
||||||
|
entry.GroupByValue, entry.Err = column.FieldDataColumn(gb, offset, offset+rc)
|
||||||
|
if entry.Err != nil {
|
||||||
|
offset += rc
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
entry.Fields, entry.Err = c.parseSearchResult(schema, outputFields, fieldDataList, i, offset, offset+rc)
|
entry.Fields, entry.Err = c.parseSearchResult(schema, outputFields, fieldDataList, i, offset, offset+rc)
|
||||||
sr = append(sr, entry)
|
sr = append(sr, entry)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -87,7 +87,7 @@ func (opt *searchOption) prepareSearchRequest(annRequest *annRequest) *milvuspb.
|
|||||||
|
|
||||||
// search param
|
// search param
|
||||||
bs, _ := json.Marshal(annRequest.searchParam)
|
bs, _ := json.Marshal(annRequest.searchParam)
|
||||||
request.SearchParams = entity.MapKvPairs(map[string]string{
|
params := map[string]string{
|
||||||
spAnnsField: annRequest.annField,
|
spAnnsField: annRequest.annField,
|
||||||
spTopK: strconv.Itoa(opt.topK),
|
spTopK: strconv.Itoa(opt.topK),
|
||||||
spOffset: strconv.Itoa(opt.offset),
|
spOffset: strconv.Itoa(opt.offset),
|
||||||
@ -95,8 +95,11 @@ func (opt *searchOption) prepareSearchRequest(annRequest *annRequest) *milvuspb.
|
|||||||
spMetricsType: string(annRequest.metricsType),
|
spMetricsType: string(annRequest.metricsType),
|
||||||
spRoundDecimal: "-1",
|
spRoundDecimal: "-1",
|
||||||
spIgnoreGrowing: strconv.FormatBool(opt.ignoreGrowing),
|
spIgnoreGrowing: strconv.FormatBool(opt.ignoreGrowing),
|
||||||
spGroupBy: annRequest.groupByField,
|
}
|
||||||
})
|
if annRequest.groupByField != "" {
|
||||||
|
params[spGroupBy] = annRequest.groupByField
|
||||||
|
}
|
||||||
|
request.SearchParams = entity.MapKvPairs(params)
|
||||||
|
|
||||||
// placeholder group
|
// placeholder group
|
||||||
request.PlaceholderGroup = vector2PlaceholderGroupBytes(annRequest.vectors)
|
request.PlaceholderGroup = vector2PlaceholderGroupBytes(annRequest.vectors)
|
||||||
|
|||||||
@ -22,53 +22,90 @@ import (
|
|||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||||
|
"github.com/milvus-io/milvus/client/v2/column"
|
||||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *Client) Insert(ctx context.Context, option InsertOption, callOptions ...grpc.CallOption) error {
|
type InsertResult struct {
|
||||||
|
InsertCount int64
|
||||||
|
IDs column.Column
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) Insert(ctx context.Context, option InsertOption, callOptions ...grpc.CallOption) (InsertResult, error) {
|
||||||
|
result := InsertResult{}
|
||||||
collection, err := c.getCollection(ctx, option.CollectionName())
|
collection, err := c.getCollection(ctx, option.CollectionName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return result, err
|
||||||
}
|
}
|
||||||
req, err := option.InsertRequest(collection)
|
req, err := option.InsertRequest(collection)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||||
resp, err := milvusService.Insert(ctx, req, callOptions...)
|
resp, err := milvusService.Insert(ctx, req, callOptions...)
|
||||||
|
|
||||||
return merr.CheckRPCCall(resp, err)
|
err = merr.CheckRPCCall(resp, err)
|
||||||
})
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
result.InsertCount = resp.GetInsertCnt()
|
||||||
|
result.IDs, err = column.IDColumns(collection.Schema, resp.GetIDs(), 0, -1)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Delete(ctx context.Context, option DeleteOption, callOptions ...grpc.CallOption) error {
|
type DeleteResult struct {
|
||||||
|
DeleteCount int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) Delete(ctx context.Context, option DeleteOption, callOptions ...grpc.CallOption) (DeleteResult, error) {
|
||||||
req := option.Request()
|
req := option.Request()
|
||||||
|
|
||||||
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
result := DeleteResult{}
|
||||||
|
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||||
resp, err := milvusService.Delete(ctx, req, callOptions...)
|
resp, err := milvusService.Delete(ctx, req, callOptions...)
|
||||||
if err = merr.CheckRPCCall(resp, err); err != nil {
|
if err = merr.CheckRPCCall(resp, err); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
result.DeleteCount = resp.GetDeleteCnt()
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Upsert(ctx context.Context, option UpsertOption, callOptions ...grpc.CallOption) error {
|
type UpsertResult struct {
|
||||||
|
UpsertCount int64
|
||||||
|
IDs column.Column
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) Upsert(ctx context.Context, option UpsertOption, callOptions ...grpc.CallOption) (UpsertResult, error) {
|
||||||
|
result := UpsertResult{}
|
||||||
collection, err := c.getCollection(ctx, option.CollectionName())
|
collection, err := c.getCollection(ctx, option.CollectionName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return result, err
|
||||||
}
|
}
|
||||||
req, err := option.UpsertRequest(collection)
|
req, err := option.UpsertRequest(collection)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return result, err
|
||||||
}
|
}
|
||||||
|
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||||
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
|
||||||
resp, err := milvusService.Upsert(ctx, req, callOptions...)
|
resp, err := milvusService.Upsert(ctx, req, callOptions...)
|
||||||
if err = merr.CheckRPCCall(resp, err); err != nil {
|
if err = merr.CheckRPCCall(resp, err); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
result.UpsertCount = resp.GetUpsertCnt()
|
||||||
|
result.IDs, err = column.IDColumns(collection.Schema, resp.GetIDs(), 0, -1)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
return result, err
|
||||||
}
|
}
|
||||||
|
|||||||
@ -23,6 +23,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
"github.com/milvus-io/milvus/client/v2/entity"
|
"github.com/milvus-io/milvus/client/v2/entity"
|
||||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
@ -64,15 +65,24 @@ func (s *WriteSuite) TestInsert() {
|
|||||||
s.EqualValues(3, ir.GetNumRows())
|
s.EqualValues(3, ir.GetNumRows())
|
||||||
return &milvuspb.MutationResult{
|
return &milvuspb.MutationResult{
|
||||||
Status: merr.Success(),
|
Status: merr.Success(),
|
||||||
|
InsertCnt: 3,
|
||||||
|
IDs: &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_IntId{
|
||||||
|
IntId: &schemapb.LongArray{
|
||||||
|
Data: []int64{1, 2, 3},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}).Once()
|
}).Once()
|
||||||
|
|
||||||
err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
|
result, err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
|
||||||
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
||||||
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
||||||
})).
|
})).
|
||||||
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
|
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
|
s.EqualValues(3, result.InsertCount)
|
||||||
})
|
})
|
||||||
|
|
||||||
s.Run("dynamic_schema", func() {
|
s.Run("dynamic_schema", func() {
|
||||||
@ -87,16 +97,25 @@ func (s *WriteSuite) TestInsert() {
|
|||||||
s.EqualValues(3, ir.GetNumRows())
|
s.EqualValues(3, ir.GetNumRows())
|
||||||
return &milvuspb.MutationResult{
|
return &milvuspb.MutationResult{
|
||||||
Status: merr.Success(),
|
Status: merr.Success(),
|
||||||
|
InsertCnt: 3,
|
||||||
|
IDs: &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_IntId{
|
||||||
|
IntId: &schemapb.LongArray{
|
||||||
|
Data: []int64{1, 2, 3},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}).Once()
|
}).Once()
|
||||||
|
|
||||||
err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
|
result, err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
|
||||||
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
||||||
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
||||||
})).
|
})).
|
||||||
WithVarcharColumn("extra", []string{"a", "b", "c"}).
|
WithVarcharColumn("extra", []string{"a", "b", "c"}).
|
||||||
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
|
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
|
s.EqualValues(3, result.InsertCount)
|
||||||
})
|
})
|
||||||
|
|
||||||
s.Run("bad_input", func() {
|
s.Run("bad_input", func() {
|
||||||
@ -141,7 +160,7 @@ func (s *WriteSuite) TestInsert() {
|
|||||||
|
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
s.Run(tc.tag, func() {
|
s.Run(tc.tag, func() {
|
||||||
err := s.client.Insert(ctx, tc.input)
|
_, err := s.client.Insert(ctx, tc.input)
|
||||||
s.Error(err)
|
s.Error(err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -153,7 +172,7 @@ func (s *WriteSuite) TestInsert() {
|
|||||||
|
|
||||||
s.mock.EXPECT().Insert(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
s.mock.EXPECT().Insert(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||||
|
|
||||||
err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
|
_, err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
|
||||||
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
||||||
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
||||||
})).
|
})).
|
||||||
@ -178,15 +197,24 @@ func (s *WriteSuite) TestUpsert() {
|
|||||||
s.EqualValues(3, ur.GetNumRows())
|
s.EqualValues(3, ur.GetNumRows())
|
||||||
return &milvuspb.MutationResult{
|
return &milvuspb.MutationResult{
|
||||||
Status: merr.Success(),
|
Status: merr.Success(),
|
||||||
|
UpsertCnt: 3,
|
||||||
|
IDs: &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_IntId{
|
||||||
|
IntId: &schemapb.LongArray{
|
||||||
|
Data: []int64{1, 2, 3},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}).Once()
|
}).Once()
|
||||||
|
|
||||||
err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
|
result, err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
|
||||||
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
||||||
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
||||||
})).
|
})).
|
||||||
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
|
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
|
s.EqualValues(3, result.UpsertCount)
|
||||||
})
|
})
|
||||||
|
|
||||||
s.Run("dynamic_schema", func() {
|
s.Run("dynamic_schema", func() {
|
||||||
@ -201,16 +229,25 @@ func (s *WriteSuite) TestUpsert() {
|
|||||||
s.EqualValues(3, ur.GetNumRows())
|
s.EqualValues(3, ur.GetNumRows())
|
||||||
return &milvuspb.MutationResult{
|
return &milvuspb.MutationResult{
|
||||||
Status: merr.Success(),
|
Status: merr.Success(),
|
||||||
|
UpsertCnt: 3,
|
||||||
|
IDs: &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_IntId{
|
||||||
|
IntId: &schemapb.LongArray{
|
||||||
|
Data: []int64{1, 2, 3},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}).Once()
|
}).Once()
|
||||||
|
|
||||||
err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
|
result, err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
|
||||||
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
||||||
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
||||||
})).
|
})).
|
||||||
WithVarcharColumn("extra", []string{"a", "b", "c"}).
|
WithVarcharColumn("extra", []string{"a", "b", "c"}).
|
||||||
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
|
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
|
s.EqualValues(3, result.UpsertCount)
|
||||||
})
|
})
|
||||||
|
|
||||||
s.Run("bad_input", func() {
|
s.Run("bad_input", func() {
|
||||||
@ -255,7 +292,7 @@ func (s *WriteSuite) TestUpsert() {
|
|||||||
|
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
s.Run(tc.tag, func() {
|
s.Run(tc.tag, func() {
|
||||||
err := s.client.Upsert(ctx, tc.input)
|
_, err := s.client.Upsert(ctx, tc.input)
|
||||||
s.Error(err)
|
s.Error(err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -267,7 +304,7 @@ func (s *WriteSuite) TestUpsert() {
|
|||||||
|
|
||||||
s.mock.EXPECT().Upsert(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
s.mock.EXPECT().Upsert(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||||
|
|
||||||
err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
|
_, err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
|
||||||
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
||||||
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
||||||
})).
|
})).
|
||||||
@ -316,10 +353,12 @@ func (s *WriteSuite) TestDelete() {
|
|||||||
s.Equal(tc.expectExpr, dr.GetExpr())
|
s.Equal(tc.expectExpr, dr.GetExpr())
|
||||||
return &milvuspb.MutationResult{
|
return &milvuspb.MutationResult{
|
||||||
Status: merr.Success(),
|
Status: merr.Success(),
|
||||||
|
DeleteCnt: 100,
|
||||||
}, nil
|
}, nil
|
||||||
}).Once()
|
}).Once()
|
||||||
err := s.client.Delete(ctx, tc.input)
|
result, err := s.client.Delete(ctx, tc.input)
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
|
s.EqualValues(100, result.DeleteCount)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user