feat: Add milvusclient package and migrate GoSDK (#32907)

Related to #31293

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
congqixia 2024-05-10 18:01:30 +08:00 committed by GitHub
parent 855192eb3d
commit 244d2c04f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
73 changed files with 16840 additions and 0 deletions

View File

@ -7,6 +7,7 @@ on:
paths:
- 'scripts/**'
- 'internal/**'
- 'client/**'
- 'pkg/**'
- 'cmd/**'
- 'build/**'
@ -24,6 +25,7 @@ on:
- 'scripts/**'
- 'internal/**'
- 'pkg/**'
- 'client/**'
- 'cmd/**'
- 'build/**'
- 'tests/integration/**' # run integration test

33
client/Makefile Normal file
View File

@ -0,0 +1,33 @@
# Licensed to the LF AI & Data foundation under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
GO ?= go
PWD := $(shell pwd)
GOPATH := $(shell $(GO) env GOPATH)
SHELL := /bin/bash
OBJPREFIX := "github.com/milvus-io/milvus/cmd/milvus/v2"
# TODO pass golangci-lint path
lint:
@echo "Running lint checks..."
unittest:
@echo "Running unittests..."
@(env bash $(PWD)/scripts/run_unittest.sh)
generate-mockery:
@echo "Generating mockery Milvus service server"
@../bin/mockery --srcpkg=github.com/milvus-io/milvus-proto/go-api/v2/milvuspb --name=MilvusServiceServer --filename=mock_milvus_server_test.go --output=. --outpkg=client --with-expecter

6
client/OWNERS Normal file
View File

@ -0,0 +1,6 @@
reviewers:
- congqixia
approvers:
- maintainers

157
client/client.go Normal file
View File

@ -0,0 +1,157 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"context"
"fmt"
"os"
"strconv"
"time"
"github.com/gogo/status"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/client/v2/common"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/pkg/util/merr"
)
type Client struct {
conn *grpc.ClientConn
service milvuspb.MilvusServiceClient
config *ClientConfig
collCache *CollectionCache
}
func New(ctx context.Context, config *ClientConfig) (*Client, error) {
if err := config.parse(); err != nil {
return nil, err
}
c := &Client{
config: config,
}
// Parse remote address.
addr := c.config.getParsedAddress()
// Parse grpc options
options := c.config.getDialOption()
// Connect the grpc server.
if err := c.connect(ctx, addr, options...); err != nil {
return nil, err
}
c.collCache = NewCollectionCache(func(ctx context.Context, collName string) (*entity.Collection, error) {
return c.DescribeCollection(ctx, NewDescribeCollectionOption(collName))
})
return c, nil
}
func (c *Client) Close(ctx context.Context) error {
if c.conn == nil {
return nil
}
err := c.conn.Close()
if err != nil {
return err
}
c.conn = nil
c.service = nil
return nil
}
func (c *Client) connect(ctx context.Context, addr string, options ...grpc.DialOption) error {
if addr == "" {
return fmt.Errorf("address is empty")
}
conn, err := grpc.DialContext(ctx, addr, options...)
if err != nil {
return err
}
c.conn = conn
c.service = milvuspb.NewMilvusServiceClient(c.conn)
if !c.config.DisableConn {
err = c.connectInternal(ctx)
if err != nil {
return err
}
}
return nil
}
func (c *Client) connectInternal(ctx context.Context) error {
hostName, err := os.Hostname()
if err != nil {
return err
}
req := &milvuspb.ConnectRequest{
ClientInfo: &commonpb.ClientInfo{
SdkType: "Golang",
SdkVersion: common.SDKVersion,
LocalTime: time.Now().String(),
User: c.config.Username,
Host: hostName,
},
}
resp, err := c.service.Connect(ctx, req)
if err != nil {
status, ok := status.FromError(err)
if ok {
if status.Code() == codes.Unimplemented {
// disable unsupported feature
c.config.addFlags(
disableDatabase |
disableJSON |
disableParitionKey |
disableDynamicSchema)
}
return nil
}
return err
}
if !merr.Ok(resp.GetStatus()) {
return merr.Error(resp.GetStatus())
}
c.config.setServerInfo(resp.GetServerInfo().GetBuildTags())
c.config.setIdentifier(strconv.FormatInt(resp.GetIdentifier(), 10))
return nil
}
func (c *Client) callService(fn func(milvusService milvuspb.MilvusServiceClient) error) error {
service := c.service
if service == nil {
return merr.WrapErrServiceNotReady("SDK", 0, "not connected")
}
return fn(c.service)
}

182
client/client_config.go Normal file
View File

@ -0,0 +1,182 @@
package client
import (
"crypto/tls"
"fmt"
"math"
"net/url"
"regexp"
"strings"
"time"
"github.com/cockroachdb/errors"
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
"google.golang.org/grpc"
"google.golang.org/grpc/backoff"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
)
const (
disableDatabase uint64 = 1 << iota
disableJSON
disableDynamicSchema
disableParitionKey
)
var regexValidScheme = regexp.MustCompile(`^https?:\/\/`)
// DefaultGrpcOpts is GRPC options for milvus client.
var DefaultGrpcOpts = []grpc.DialOption{
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 5 * time.Second,
Timeout: 10 * time.Second,
PermitWithoutStream: true,
}),
grpc.WithConnectParams(grpc.ConnectParams{
Backoff: backoff.Config{
BaseDelay: 100 * time.Millisecond,
Multiplier: 1.6,
Jitter: 0.2,
MaxDelay: 3 * time.Second,
},
MinConnectTimeout: 3 * time.Second,
}),
}
// ClientConfig for milvus client.
type ClientConfig struct {
Address string // Remote address, "localhost:19530".
Username string // Username for auth.
Password string // Password for auth.
DBName string // DBName for this client.
EnableTLSAuth bool // Enable TLS Auth for transport security.
APIKey string // API key
DialOptions []grpc.DialOption // Dial options for GRPC.
// RetryRateLimit *RetryRateLimitOption // option for retry on rate limit inteceptor
DisableConn bool
identifier string // Identifier for this connection
ServerVersion string // ServerVersion
parsedAddress *url.URL
flags uint64 // internal flags
}
func (cfg *ClientConfig) parse() error {
// Prepend default fake tcp:// scheme for remote address.
address := cfg.Address
if !regexValidScheme.MatchString(address) {
address = fmt.Sprintf("tcp://%s", address)
}
remoteURL, err := url.Parse(address)
if err != nil {
return errors.Wrap(err, "milvus address parse fail")
}
// Remote Host should never be empty.
if remoteURL.Host == "" {
return errors.New("empty remote host of milvus address")
}
// Use DBName in remote url path.
if cfg.DBName == "" {
cfg.DBName = strings.TrimLeft(remoteURL.Path, "/")
}
// Always enable tls auth for https remote url.
if remoteURL.Scheme == "https" {
cfg.EnableTLSAuth = true
}
if remoteURL.Port() == "" && cfg.EnableTLSAuth {
remoteURL.Host += ":443"
}
cfg.parsedAddress = remoteURL
return nil
}
// Get parsed remote milvus address, should be called after parse was called.
func (c *ClientConfig) getParsedAddress() string {
return c.parsedAddress.Host
}
// useDatabase change the inner db name.
func (c *ClientConfig) useDatabase(dbName string) {
c.DBName = dbName
}
// useDatabase change the inner db name.
func (c *ClientConfig) setIdentifier(identifier string) {
c.identifier = identifier
}
func (c *ClientConfig) setServerInfo(serverInfo string) {
c.ServerVersion = serverInfo
}
// Get parsed grpc dial options, should be called after parse was called.
func (c *ClientConfig) getDialOption() []grpc.DialOption {
options := c.DialOptions
if c.DialOptions == nil {
// Add default connection options.
options = make([]grpc.DialOption, len(DefaultGrpcOpts))
copy(options, DefaultGrpcOpts)
}
// Construct dial option.
if c.EnableTLSAuth {
options = append(options, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})))
} else {
options = append(options, grpc.WithTransportCredentials(insecure.NewCredentials()))
}
options = append(options,
grpc.WithChainUnaryInterceptor(grpc_retry.UnaryClientInterceptor(
grpc_retry.WithMax(6),
grpc_retry.WithBackoff(func(attempt uint) time.Duration {
return 60 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt)))
}),
grpc_retry.WithCodes(codes.Unavailable, codes.ResourceExhausted)),
// c.getRetryOnRateLimitInterceptor(),
))
// options = append(options, grpc.WithChainUnaryInterceptor(
// createMetaDataUnaryInterceptor(c),
// ))
return options
}
// func (c *ClientConfig) getRetryOnRateLimitInterceptor() grpc.UnaryClientInterceptor {
// if c.RetryRateLimit == nil {
// c.RetryRateLimit = c.defaultRetryRateLimitOption()
// }
// return RetryOnRateLimitInterceptor(c.RetryRateLimit.MaxRetry, c.RetryRateLimit.MaxBackoff, func(ctx context.Context, attempt uint) time.Duration {
// return 10 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt)))
// })
// }
// func (c *ClientConfig) defaultRetryRateLimitOption() *RetryRateLimitOption {
// return &RetryRateLimitOption{
// MaxRetry: 75,
// MaxBackoff: 3 * time.Second,
// }
// }
// addFlags set internal flags
func (c *ClientConfig) addFlags(flags uint64) {
c.flags |= flags
}
// hasFlags check flags is set
func (c *ClientConfig) hasFlags(flags uint64) bool {
return (c.flags & flags) > 0
}
func (c *ClientConfig) resetFlags(flags uint64) {
c.flags &= ^flags
}

251
client/client_suite_test.go Normal file
View File

@ -0,0 +1,251 @@
package client
import (
"context"
"math/rand"
"net"
"strings"
mock "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/test/bufconn"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/entity"
)
const (
bufSize = 1024 * 1024
)
type MockSuiteBase struct {
suite.Suite
lis *bufconn.Listener
svr *grpc.Server
mock *MilvusServiceServer
client *Client
}
func (s *MockSuiteBase) SetupSuite() {
s.lis = bufconn.Listen(bufSize)
s.svr = grpc.NewServer()
s.mock = &MilvusServiceServer{}
milvuspb.RegisterMilvusServiceServer(s.svr, s.mock)
go func() {
s.T().Log("start mock server")
if err := s.svr.Serve(s.lis); err != nil {
s.Fail("failed to start mock server", err.Error())
}
}()
s.setupConnect()
}
func (s *MockSuiteBase) TearDownSuite() {
s.svr.Stop()
s.lis.Close()
}
func (s *MockSuiteBase) mockDialer(context.Context, string) (net.Conn, error) {
return s.lis.Dial()
}
func (s *MockSuiteBase) SetupTest() {
c, err := New(context.Background(), &ClientConfig{
Address: "bufnet",
DialOptions: []grpc.DialOption{
grpc.WithBlock(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithContextDialer(s.mockDialer),
},
})
s.Require().NoError(err)
s.setupConnect()
s.client = c
}
func (s *MockSuiteBase) TearDownTest() {
s.client.Close(context.Background())
s.client = nil
}
func (s *MockSuiteBase) resetMock() {
// MetaCache.reset()
if s.mock != nil {
s.mock.Calls = nil
s.mock.ExpectedCalls = nil
s.setupConnect()
}
}
func (s *MockSuiteBase) setupConnect() {
s.mock.EXPECT().Connect(mock.Anything, mock.AnythingOfType("*milvuspb.ConnectRequest")).
Return(&milvuspb.ConnectResponse{
Status: &commonpb.Status{},
Identifier: 1,
}, nil).Maybe()
}
func (s *MockSuiteBase) setupCache(collName string, schema *entity.Schema) {
s.client.collCache.collections.Insert(collName, &entity.Collection{
Name: collName,
Schema: schema,
})
}
func (s *MockSuiteBase) setupHasCollection(collNames ...string) {
s.mock.EXPECT().HasCollection(mock.Anything, mock.AnythingOfType("*milvuspb.HasCollectionRequest")).
Call.Return(func(ctx context.Context, req *milvuspb.HasCollectionRequest) *milvuspb.BoolResponse {
resp := &milvuspb.BoolResponse{Status: &commonpb.Status{}}
for _, collName := range collNames {
if req.GetCollectionName() == collName {
resp.Value = true
break
}
}
return resp
}, nil)
}
func (s *MockSuiteBase) setupHasCollectionError(errorCode commonpb.ErrorCode, err error) {
s.mock.EXPECT().HasCollection(mock.Anything, mock.AnythingOfType("*milvuspb.HasCollectionRequest")).
Return(&milvuspb.BoolResponse{
Status: &commonpb.Status{ErrorCode: errorCode},
}, err)
}
func (s *MockSuiteBase) setupHasPartition(collName string, partNames ...string) {
s.mock.EXPECT().HasPartition(mock.Anything, mock.AnythingOfType("*milvuspb.HasPartitionRequest")).
Call.Return(func(ctx context.Context, req *milvuspb.HasPartitionRequest) *milvuspb.BoolResponse {
resp := &milvuspb.BoolResponse{Status: &commonpb.Status{}}
if req.GetCollectionName() == collName {
for _, partName := range partNames {
if req.GetPartitionName() == partName {
resp.Value = true
break
}
}
}
return resp
}, nil)
}
func (s *MockSuiteBase) setupHasPartitionError(errorCode commonpb.ErrorCode, err error) {
s.mock.EXPECT().HasPartition(mock.Anything, mock.AnythingOfType("*milvuspb.HasPartitionRequest")).
Return(&milvuspb.BoolResponse{
Status: &commonpb.Status{ErrorCode: errorCode},
}, err)
}
func (s *MockSuiteBase) setupDescribeCollection(_ string, schema *entity.Schema) {
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.AnythingOfType("*milvuspb.DescribeCollectionRequest")).
Call.Return(func(ctx context.Context, req *milvuspb.DescribeCollectionRequest) *milvuspb.DescribeCollectionResponse {
return &milvuspb.DescribeCollectionResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
Schema: schema.ProtoMessage(),
}
}, nil)
}
func (s *MockSuiteBase) setupDescribeCollectionError(errorCode commonpb.ErrorCode, err error) {
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.AnythingOfType("*milvuspb.DescribeCollectionRequest")).
Return(&milvuspb.DescribeCollectionResponse{
Status: &commonpb.Status{ErrorCode: errorCode},
}, err)
}
func (s *MockSuiteBase) getInt64FieldData(name string, data []int64) *schemapb.FieldData {
return &schemapb.FieldData{
Type: schemapb.DataType_Int64,
FieldName: name,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: data,
},
},
},
},
}
}
func (s *MockSuiteBase) getVarcharFieldData(name string, data []string) *schemapb.FieldData {
return &schemapb.FieldData{
Type: schemapb.DataType_VarChar,
FieldName: name,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: data,
},
},
},
},
}
}
func (s *MockSuiteBase) getJSONBytesFieldData(name string, data [][]byte, isDynamic bool) *schemapb.FieldData {
return &schemapb.FieldData{
Type: schemapb.DataType_JSON,
FieldName: name,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_JsonData{
JsonData: &schemapb.JSONArray{
Data: data,
},
},
},
},
IsDynamic: isDynamic,
}
}
func (s *MockSuiteBase) getFloatVectorFieldData(name string, dim int64, data []float32) *schemapb.FieldData {
return &schemapb.FieldData{
Type: schemapb.DataType_FloatVector,
FieldName: name,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: dim,
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: data,
},
},
},
},
}
}
func (s *MockSuiteBase) getSuccessStatus() *commonpb.Status {
return s.getStatus(commonpb.ErrorCode_Success, "")
}
func (s *MockSuiteBase) getStatus(code commonpb.ErrorCode, reason string) *commonpb.Status {
return &commonpb.Status{
ErrorCode: code,
Reason: reason,
}
}
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
func (s *MockSuiteBase) randString(l int) string {
builder := strings.Builder{}
for i := 0; i < l; i++ {
builder.WriteRune(letters[rand.Intn(len(letters))])
}
return builder.String()
}

43
client/client_test.go Normal file
View File

@ -0,0 +1,43 @@
package client
import (
"context"
"testing"
"github.com/stretchr/testify/suite"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
type ClientSuite struct {
MockSuiteBase
}
func (s *ClientSuite) TestNewClient() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("Use bufconn dailer, testing case", func() {
c, err := New(ctx,
&ClientConfig{
Address: "bufnet",
DialOptions: []grpc.DialOption{
grpc.WithBlock(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithContextDialer(s.mockDialer),
},
})
s.NoError(err)
s.NotNil(c)
})
s.Run("emtpy_addr", func() {
_, err := New(ctx, &ClientConfig{})
s.Error(err)
s.T().Log(err)
})
}
func TestClient(t *testing.T) {
suite.Run(t, new(ClientSuite))
}

134
client/collection.go Normal file
View File

@ -0,0 +1,134 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"context"
"github.com/cockroachdb/errors"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/pkg/util/merr"
)
// CreateCollection is the API for create a collection in Milvus.
func (c *Client) CreateCollection(ctx context.Context, option CreateCollectionOption, callOptions ...grpc.CallOption) error {
req := option.Request()
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.CreateCollection(ctx, req, callOptions...)
return merr.CheckRPCCall(resp, err)
})
if err != nil {
return err
}
indexes := option.Indexes()
for _, indexOption := range indexes {
task, err := c.CreateIndex(ctx, indexOption, callOptions...)
if err != nil {
return err
}
err = task.Await(ctx)
if err != nil {
return nil
}
}
if option.IsFast() {
task, err := c.LoadCollection(ctx, NewLoadCollectionOption(req.GetCollectionName()))
if err != nil {
return err
}
return task.Await(ctx)
}
return nil
}
type ListCollectionOption interface {
Request() *milvuspb.ShowCollectionsRequest
}
func (c *Client) ListCollections(ctx context.Context, option ListCollectionOption, callOptions ...grpc.CallOption) (collectionNames []string, err error) {
req := option.Request()
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.ShowCollections(ctx, req, callOptions...)
err = merr.CheckRPCCall(resp, err)
if err != nil {
return err
}
collectionNames = resp.GetCollectionNames()
return nil
})
return collectionNames, err
}
func (c *Client) DescribeCollection(ctx context.Context, option *describeCollectionOption, callOptions ...grpc.CallOption) (collection *entity.Collection, err error) {
req := option.Request()
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.DescribeCollection(ctx, req, callOptions...)
err = merr.CheckRPCCall(resp, err)
if err != nil {
return err
}
collection = &entity.Collection{
ID: resp.GetCollectionID(),
Schema: entity.NewSchema().ReadProto(resp.GetSchema()),
PhysicalChannels: resp.GetPhysicalChannelNames(),
VirtualChannels: resp.GetVirtualChannelNames(),
ConsistencyLevel: entity.ConsistencyLevel(resp.ConsistencyLevel),
ShardNum: resp.GetShardsNum(),
}
collection.Name = collection.Schema.CollectionName
return nil
})
return collection, err
}
func (c *Client) HasCollection(ctx context.Context, option HasCollectionOption, callOptions ...grpc.CallOption) (has bool, err error) {
req := option.Request()
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.DescribeCollection(ctx, req, callOptions...)
err = merr.CheckRPCCall(resp, err)
if err != nil {
// ErrCollectionNotFound for collection not exist
if errors.Is(err, merr.ErrCollectionNotFound) {
return nil
}
return err
}
has = true
return nil
})
return has, err
}
func (c *Client) DropCollection(ctx context.Context, option DropCollectionOption, callOptions ...grpc.CallOption) error {
req := option.Request()
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.DropCollection(ctx, req, callOptions...)
return merr.CheckRPCCall(resp, err)
})
return err
}

View File

@ -0,0 +1,232 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/client/v2/index"
)
// CreateCollectionOption is the interface builds CreateCollectionRequest.
type CreateCollectionOption interface {
// Request is the method returns the composed request.
Request() *milvuspb.CreateCollectionRequest
// Indexes is the method returns IndexOption to create
Indexes() []CreateIndexOption
IsFast() bool
}
// createCollectionOption contains all the parameters to create collection.
type createCollectionOption struct {
name string
shardNum int32
// fast create collection params
varcharPK bool
varcharPKMaxLength int
pkFieldName string
vectorFieldName string
dim int64
autoID bool
enabledDynamicSchema bool
// advanced create collection params
schema *entity.Schema
consistencyLevel entity.ConsistencyLevel
properties map[string]string
// partition key
numPartitions int64
// is fast create collection
isFast bool
// fast creation with index
metricType entity.MetricType
}
func (opt *createCollectionOption) WithAutoID(autoID bool) *createCollectionOption {
opt.autoID = autoID
return opt
}
func (opt *createCollectionOption) WithShardNum(shardNum int32) *createCollectionOption {
opt.shardNum = shardNum
return opt
}
func (opt *createCollectionOption) WithDynamicSchema(dynamicSchema bool) *createCollectionOption {
opt.enabledDynamicSchema = dynamicSchema
return opt
}
func (opt *createCollectionOption) WithVarcharPK(varcharPK bool, maxLen int) *createCollectionOption {
opt.varcharPK = varcharPK
opt.varcharPKMaxLength = maxLen
return opt
}
func (opt *createCollectionOption) Request() *milvuspb.CreateCollectionRequest {
// fast create collection
if opt.isFast || opt.schema == nil {
var pkField *entity.Field
if opt.varcharPK {
pkField = entity.NewField().WithDataType(entity.FieldTypeVarChar).WithMaxLength(int64(opt.varcharPKMaxLength))
} else {
pkField = entity.NewField().WithDataType(entity.FieldTypeInt64)
}
pkField = pkField.WithName(opt.pkFieldName).WithIsPrimaryKey(true).WithIsAutoID(opt.autoID)
opt.schema = entity.NewSchema().
WithName(opt.name).
WithAutoID(opt.autoID).
WithDynamicFieldEnabled(opt.enabledDynamicSchema).
WithField(pkField).
WithField(entity.NewField().WithName(opt.vectorFieldName).WithDataType(entity.FieldTypeFloatVector).WithDim(opt.dim))
}
schemaProto := opt.schema.ProtoMessage()
schemaBytes, _ := proto.Marshal(schemaProto)
return &milvuspb.CreateCollectionRequest{
DbName: "", // reserved fields, not used for now
CollectionName: opt.name,
Schema: schemaBytes,
ShardsNum: opt.shardNum,
ConsistencyLevel: commonpb.ConsistencyLevel(opt.consistencyLevel),
NumPartitions: opt.numPartitions,
Properties: entity.MapKvPairs(opt.properties),
}
}
func (opt *createCollectionOption) Indexes() []CreateIndexOption {
// fast create
if opt.isFast {
return []CreateIndexOption{
NewCreateIndexOption(opt.name, opt.vectorFieldName, index.NewGenericIndex("", map[string]string{})),
}
}
return nil
}
func (opt *createCollectionOption) IsFast() bool {
return opt.isFast
}
// SimpleCreateCollectionOptions returns a CreateCollectionOption with default fast collection options.
func SimpleCreateCollectionOptions(name string, dim int64) *createCollectionOption {
return &createCollectionOption{
name: name,
shardNum: 1,
pkFieldName: "id",
vectorFieldName: "vector",
autoID: true,
dim: dim,
enabledDynamicSchema: true,
isFast: true,
metricType: entity.COSINE,
}
}
// NewCreateCollectionOption returns a CreateCollectionOption with customized collection schema
func NewCreateCollectionOption(name string, collectionSchema *entity.Schema) *createCollectionOption {
return &createCollectionOption{
name: name,
shardNum: 1,
schema: collectionSchema,
metricType: entity.COSINE,
}
}
type listCollectionOption struct{}
func (opt *listCollectionOption) Request() *milvuspb.ShowCollectionsRequest {
return &milvuspb.ShowCollectionsRequest{}
}
func NewListCollectionOption() *listCollectionOption {
return &listCollectionOption{}
}
// DescribeCollectionOption is the interface builds DescribeCollection request.
type DescribeCollectionOption interface {
// Request is the method returns the composed request.
Request() *milvuspb.DescribeCollectionRequest
}
type describeCollectionOption struct {
name string
}
func (opt *describeCollectionOption) Request() *milvuspb.DescribeCollectionRequest {
return &milvuspb.DescribeCollectionRequest{
CollectionName: opt.name,
}
}
// NewDescribeCollectionOption composes a describeCollectionOption with provided collection name.
func NewDescribeCollectionOption(name string) *describeCollectionOption {
return &describeCollectionOption{
name: name,
}
}
// HasCollectionOption is the interface to build DescribeCollectionRequest.
type HasCollectionOption interface {
Request() *milvuspb.DescribeCollectionRequest
}
type hasCollectionOpt struct {
name string
}
func (opt *hasCollectionOpt) Request() *milvuspb.DescribeCollectionRequest {
return &milvuspb.DescribeCollectionRequest{
CollectionName: opt.name,
}
}
func NewHasCollectionOption(name string) HasCollectionOption {
return &hasCollectionOpt{
name: name,
}
}
// The DropCollectionOption interface builds DropCollectionRequest.
type DropCollectionOption interface {
Request() *milvuspb.DropCollectionRequest
}
type dropCollectionOption struct {
name string
}
func (opt *dropCollectionOption) Request() *milvuspb.DropCollectionRequest {
return &milvuspb.DropCollectionRequest{
CollectionName: opt.name,
}
}
func NewDropCollectionOption(name string) *dropCollectionOption {
return &dropCollectionOption{
name: name,
}
}

253
client/collection_test.go Normal file
View File

@ -0,0 +1,253 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"context"
"fmt"
"testing"
"github.com/golang/protobuf/proto"
"github.com/samber/lo"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/pkg/util/merr"
)
type CollectionSuite struct {
MockSuiteBase
}
func (s *CollectionSuite) TestListCollection() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
s.mock.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&milvuspb.ShowCollectionsResponse{
CollectionNames: []string{"test1", "test2", "test3"},
}, nil).Once()
names, err := s.client.ListCollections(ctx, NewListCollectionOption())
s.NoError(err)
s.ElementsMatch([]string{"test1", "test2", "test3"}, names)
})
s.Run("failure", func() {
s.mock.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
_, err := s.client.ListCollections(ctx, NewListCollectionOption())
s.Error(err)
})
}
func (s *CollectionSuite) TestCreateCollection() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
s.mock.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(merr.Success(), nil).Once()
s.mock.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(merr.Success(), nil).Once()
s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(merr.Success(), nil).Once()
s.mock.EXPECT().DescribeIndex(mock.Anything, mock.Anything).Return(&milvuspb.DescribeIndexResponse{
Status: merr.Success(),
IndexDescriptions: []*milvuspb.IndexDescription{
{FieldName: "vector", State: commonpb.IndexState_Finished},
},
}, nil).Once()
s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).Return(&milvuspb.GetLoadingProgressResponse{
Status: merr.Success(),
Progress: 100,
}, nil).Once()
err := s.client.CreateCollection(ctx, SimpleCreateCollectionOptions("test_collection", 128))
s.NoError(err)
})
s.Run("failure", func() {
s.mock.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
err := s.client.CreateCollection(ctx, SimpleCreateCollectionOptions("test_collection", 128))
s.Error(err)
})
}
func (s *CollectionSuite) TestCreateCollectionOptions() {
collectionName := fmt.Sprintf("test_collection_%s", s.randString(6))
opt := SimpleCreateCollectionOptions(collectionName, 128)
req := opt.Request()
s.Equal(collectionName, req.GetCollectionName())
s.EqualValues(1, req.GetShardsNum())
collSchema := &schemapb.CollectionSchema{}
err := proto.Unmarshal(req.GetSchema(), collSchema)
s.Require().NoError(err)
s.True(collSchema.GetEnableDynamicField())
collectionName = fmt.Sprintf("test_collection_%s", s.randString(6))
opt = SimpleCreateCollectionOptions(collectionName, 128).WithVarcharPK(true, 64).WithAutoID(false).WithDynamicSchema(false)
req = opt.Request()
s.Equal(collectionName, req.GetCollectionName())
s.EqualValues(1, req.GetShardsNum())
collSchema = &schemapb.CollectionSchema{}
err = proto.Unmarshal(req.GetSchema(), collSchema)
s.Require().NoError(err)
s.False(collSchema.GetEnableDynamicField())
collectionName = fmt.Sprintf("test_collection_%s", s.randString(6))
schema := entity.NewSchema().
WithField(entity.NewField().WithName("int64").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
WithField(entity.NewField().WithName("vector").WithDim(128).WithDataType(entity.FieldTypeFloatVector))
opt = NewCreateCollectionOption(collectionName, schema).WithShardNum(2)
req = opt.Request()
s.Equal(collectionName, req.GetCollectionName())
s.EqualValues(2, req.GetShardsNum())
collSchema = &schemapb.CollectionSchema{}
err = proto.Unmarshal(req.GetSchema(), collSchema)
s.Require().NoError(err)
}
func (s *CollectionSuite) TestDescribeCollection() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
Status: merr.Success(),
Schema: &schemapb.CollectionSchema{
Name: "test_collection",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, DataType: schemapb.DataType_Int64, AutoID: true, Name: "ID"},
{
FieldID: 101, DataType: schemapb.DataType_FloatVector, Name: "vector",
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "128"},
},
},
},
},
CollectionID: 1000,
CollectionName: "test_collection",
}, nil).Once()
coll, err := s.client.DescribeCollection(ctx, NewDescribeCollectionOption("test_collection"))
s.NoError(err)
s.EqualValues(1000, coll.ID)
s.Equal("test_collection", coll.Name)
s.Len(coll.Schema.Fields, 2)
idField, ok := lo.Find(coll.Schema.Fields, func(field *entity.Field) bool {
return field.ID == 100
})
s.Require().True(ok)
s.Equal("ID", idField.Name)
s.Equal(entity.FieldTypeInt64, idField.DataType)
s.True(idField.AutoID)
vectorField, ok := lo.Find(coll.Schema.Fields, func(field *entity.Field) bool {
return field.ID == 101
})
s.Require().True(ok)
s.Equal("vector", vectorField.Name)
s.Equal(entity.FieldTypeFloatVector, vectorField.DataType)
})
s.Run("failure", func() {
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
_, err := s.client.DescribeCollection(ctx, NewDescribeCollectionOption("test_collection"))
s.Error(err)
})
}
func (s *CollectionSuite) TestHasCollection() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
Status: merr.Success(),
Schema: &schemapb.CollectionSchema{
Name: "test_collection",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, DataType: schemapb.DataType_Int64, AutoID: true, Name: "ID"},
{
FieldID: 101, DataType: schemapb.DataType_FloatVector, Name: "vector",
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "128"},
},
},
},
},
CollectionID: 1000,
CollectionName: "test_collection",
}, nil).Once()
has, err := s.client.HasCollection(ctx, NewHasCollectionOption("test_collection"))
s.NoError(err)
s.True(has)
})
s.Run("collection_not_exist", func() {
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
Status: merr.Status(merr.WrapErrCollectionNotFound("test_collection")),
}, nil).Once()
has, err := s.client.HasCollection(ctx, NewHasCollectionOption("test_collection"))
s.NoError(err)
s.False(has)
})
s.Run("failure", func() {
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
_, err := s.client.HasCollection(ctx, NewHasCollectionOption("test_collection"))
s.Error(err)
})
}
func (s *CollectionSuite) TestDropCollection() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
s.mock.EXPECT().DropCollection(mock.Anything, mock.Anything).Return(merr.Success(), nil).Once()
err := s.client.DropCollection(ctx, NewDropCollectionOption("test_collection"))
s.NoError(err)
})
s.Run("failure", func() {
s.mock.EXPECT().DropCollection(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
err := s.client.DropCollection(ctx, NewDropCollectionOption("test_collection"))
s.Error(err)
})
}
func TestCollection(t *testing.T) {
suite.Run(t, new(CollectionSuite))
}

125
client/column/array.go Normal file
View File

@ -0,0 +1,125 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package column
import (
"fmt"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/entity"
)
// ColumnVarCharArray generated columns type for VarChar
type ColumnVarCharArray struct {
ColumnBase
name string
values [][][]byte
}
// Name returns column name
func (c *ColumnVarCharArray) Name() string {
return c.name
}
// Type returns column entity.FieldType
func (c *ColumnVarCharArray) Type() entity.FieldType {
return entity.FieldTypeArray
}
// Len returns column values length
func (c *ColumnVarCharArray) Len() int {
return len(c.values)
}
// Get returns value at index as interface{}.
func (c *ColumnVarCharArray) Get(idx int) (interface{}, error) {
var r []string // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// FieldData return column data mapped to schemapb.FieldData
func (c *ColumnVarCharArray) FieldData() *schemapb.FieldData {
fd := &schemapb.FieldData{
Type: schemapb.DataType_Array,
FieldName: c.name,
}
data := make([]*schemapb.ScalarField, 0, c.Len())
for _, arr := range c.values {
converted := make([]string, 0, c.Len())
for i := 0; i < len(arr); i++ {
converted = append(converted, string(arr[i]))
}
data = append(data, &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: converted,
},
},
})
}
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_ArrayData{
ArrayData: &schemapb.ArrayArray{
Data: data,
ElementType: schemapb.DataType_VarChar,
},
},
},
}
return fd
}
// ValueByIdx returns value of the provided index
// error occurs when index out of range
func (c *ColumnVarCharArray) ValueByIdx(idx int) ([][]byte, error) {
var r [][]byte // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// AppendValue append value into column
func (c *ColumnVarCharArray) AppendValue(i interface{}) error {
v, ok := i.([][]byte)
if !ok {
return fmt.Errorf("invalid type, expected []string, got %T", i)
}
c.values = append(c.values, v)
return nil
}
// Data returns column data
func (c *ColumnVarCharArray) Data() [][][]byte {
return c.values
}
// NewColumnVarChar auto generated constructor
func NewColumnVarCharArray(name string, values [][][]byte) *ColumnVarCharArray {
return &ColumnVarCharArray{
name: name,
values: values,
}
}

705
client/column/array_gen.go Normal file
View File

@ -0,0 +1,705 @@
// Code generated by go generate; DO NOT EDIT
// This file is generated by go generate
package column
import (
"fmt"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/entity"
)
// ColumnBoolArray generated columns type for Bool
type ColumnBoolArray struct {
ColumnBase
name string
values [][]bool
}
// Name returns column name
func (c *ColumnBoolArray) Name() string {
return c.name
}
// Type returns column entity.FieldType
func (c *ColumnBoolArray) Type() entity.FieldType {
return entity.FieldTypeArray
}
// Len returns column values length
func (c *ColumnBoolArray) Len() int {
return len(c.values)
}
// Get returns value at index as interface{}.
func (c *ColumnBoolArray) Get(idx int) (interface{}, error) {
var r []bool // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// FieldData return column data mapped to schemapb.FieldData
func (c *ColumnBoolArray) FieldData() *schemapb.FieldData {
fd := &schemapb.FieldData{
Type: schemapb.DataType_Array,
FieldName: c.name,
}
data := make([]*schemapb.ScalarField, 0, c.Len())
for _, arr := range c.values {
converted := make([]bool, 0, c.Len())
for i := 0; i < len(arr); i++ {
converted = append(converted, bool(arr[i]))
}
data = append(data, &schemapb.ScalarField{
Data: &schemapb.ScalarField_BoolData{
BoolData: &schemapb.BoolArray{
Data: converted,
},
},
})
}
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_ArrayData{
ArrayData: &schemapb.ArrayArray{
Data: data,
ElementType: schemapb.DataType_Bool,
},
},
},
}
return fd
}
// ValueByIdx returns value of the provided index
// error occurs when index out of range
func (c *ColumnBoolArray) ValueByIdx(idx int) ([]bool, error) {
var r []bool // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// AppendValue append value into column
func (c *ColumnBoolArray) AppendValue(i interface{}) error {
v, ok := i.([]bool)
if !ok {
return fmt.Errorf("invalid type, expected []bool, got %T", i)
}
c.values = append(c.values, v)
return nil
}
// Data returns column data
func (c *ColumnBoolArray) Data() [][]bool {
return c.values
}
// NewColumnBool auto generated constructor
func NewColumnBoolArray(name string, values [][]bool) *ColumnBoolArray {
return &ColumnBoolArray{
name: name,
values: values,
}
}
// ColumnInt8Array generated columns type for Int8
type ColumnInt8Array struct {
ColumnBase
name string
values [][]int8
}
// Name returns column name
func (c *ColumnInt8Array) Name() string {
return c.name
}
// Type returns column entity.FieldType
func (c *ColumnInt8Array) Type() entity.FieldType {
return entity.FieldTypeArray
}
// Len returns column values length
func (c *ColumnInt8Array) Len() int {
return len(c.values)
}
// Get returns value at index as interface{}.
func (c *ColumnInt8Array) Get(idx int) (interface{}, error) {
var r []int8 // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// FieldData return column data mapped to schemapb.FieldData
func (c *ColumnInt8Array) FieldData() *schemapb.FieldData {
fd := &schemapb.FieldData{
Type: schemapb.DataType_Array,
FieldName: c.name,
}
data := make([]*schemapb.ScalarField, 0, c.Len())
for _, arr := range c.values {
converted := make([]int32, 0, c.Len())
for i := 0; i < len(arr); i++ {
converted = append(converted, int32(arr[i]))
}
data = append(data, &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: converted,
},
},
})
}
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_ArrayData{
ArrayData: &schemapb.ArrayArray{
Data: data,
ElementType: schemapb.DataType_Int8,
},
},
},
}
return fd
}
// ValueByIdx returns value of the provided index
// error occurs when index out of range
func (c *ColumnInt8Array) ValueByIdx(idx int) ([]int8, error) {
var r []int8 // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// AppendValue append value into column
func (c *ColumnInt8Array) AppendValue(i interface{}) error {
v, ok := i.([]int8)
if !ok {
return fmt.Errorf("invalid type, expected []int8, got %T", i)
}
c.values = append(c.values, v)
return nil
}
// Data returns column data
func (c *ColumnInt8Array) Data() [][]int8 {
return c.values
}
// NewColumnInt8 auto generated constructor
func NewColumnInt8Array(name string, values [][]int8) *ColumnInt8Array {
return &ColumnInt8Array{
name: name,
values: values,
}
}
// ColumnInt16Array generated columns type for Int16
type ColumnInt16Array struct {
ColumnBase
name string
values [][]int16
}
// Name returns column name
func (c *ColumnInt16Array) Name() string {
return c.name
}
// Type returns column entity.FieldType
func (c *ColumnInt16Array) Type() entity.FieldType {
return entity.FieldTypeArray
}
// Len returns column values length
func (c *ColumnInt16Array) Len() int {
return len(c.values)
}
// Get returns value at index as interface{}.
func (c *ColumnInt16Array) Get(idx int) (interface{}, error) {
var r []int16 // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// FieldData return column data mapped to schemapb.FieldData
func (c *ColumnInt16Array) FieldData() *schemapb.FieldData {
fd := &schemapb.FieldData{
Type: schemapb.DataType_Array,
FieldName: c.name,
}
data := make([]*schemapb.ScalarField, 0, c.Len())
for _, arr := range c.values {
converted := make([]int32, 0, c.Len())
for i := 0; i < len(arr); i++ {
converted = append(converted, int32(arr[i]))
}
data = append(data, &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: converted,
},
},
})
}
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_ArrayData{
ArrayData: &schemapb.ArrayArray{
Data: data,
ElementType: schemapb.DataType_Int16,
},
},
},
}
return fd
}
// ValueByIdx returns value of the provided index
// error occurs when index out of range
func (c *ColumnInt16Array) ValueByIdx(idx int) ([]int16, error) {
var r []int16 // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// AppendValue append value into column
func (c *ColumnInt16Array) AppendValue(i interface{}) error {
v, ok := i.([]int16)
if !ok {
return fmt.Errorf("invalid type, expected []int16, got %T", i)
}
c.values = append(c.values, v)
return nil
}
// Data returns column data
func (c *ColumnInt16Array) Data() [][]int16 {
return c.values
}
// NewColumnInt16 auto generated constructor
func NewColumnInt16Array(name string, values [][]int16) *ColumnInt16Array {
return &ColumnInt16Array{
name: name,
values: values,
}
}
// ColumnInt32Array generated columns type for Int32
type ColumnInt32Array struct {
ColumnBase
name string
values [][]int32
}
// Name returns column name
func (c *ColumnInt32Array) Name() string {
return c.name
}
// Type returns column entity.FieldType
func (c *ColumnInt32Array) Type() entity.FieldType {
return entity.FieldTypeArray
}
// Len returns column values length
func (c *ColumnInt32Array) Len() int {
return len(c.values)
}
// Get returns value at index as interface{}.
func (c *ColumnInt32Array) Get(idx int) (interface{}, error) {
var r []int32 // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// FieldData return column data mapped to schemapb.FieldData
func (c *ColumnInt32Array) FieldData() *schemapb.FieldData {
fd := &schemapb.FieldData{
Type: schemapb.DataType_Array,
FieldName: c.name,
}
data := make([]*schemapb.ScalarField, 0, c.Len())
for _, arr := range c.values {
converted := make([]int32, 0, c.Len())
for i := 0; i < len(arr); i++ {
converted = append(converted, int32(arr[i]))
}
data = append(data, &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: converted,
},
},
})
}
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_ArrayData{
ArrayData: &schemapb.ArrayArray{
Data: data,
ElementType: schemapb.DataType_Int32,
},
},
},
}
return fd
}
// ValueByIdx returns value of the provided index
// error occurs when index out of range
func (c *ColumnInt32Array) ValueByIdx(idx int) ([]int32, error) {
var r []int32 // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// AppendValue append value into column
func (c *ColumnInt32Array) AppendValue(i interface{}) error {
v, ok := i.([]int32)
if !ok {
return fmt.Errorf("invalid type, expected []int32, got %T", i)
}
c.values = append(c.values, v)
return nil
}
// Data returns column data
func (c *ColumnInt32Array) Data() [][]int32 {
return c.values
}
// NewColumnInt32 auto generated constructor
func NewColumnInt32Array(name string, values [][]int32) *ColumnInt32Array {
return &ColumnInt32Array{
name: name,
values: values,
}
}
// ColumnInt64Array generated columns type for Int64
type ColumnInt64Array struct {
ColumnBase
name string
values [][]int64
}
// Name returns column name
func (c *ColumnInt64Array) Name() string {
return c.name
}
// Type returns column entity.FieldType
func (c *ColumnInt64Array) Type() entity.FieldType {
return entity.FieldTypeArray
}
// Len returns column values length
func (c *ColumnInt64Array) Len() int {
return len(c.values)
}
// Get returns value at index as interface{}.
func (c *ColumnInt64Array) Get(idx int) (interface{}, error) {
var r []int64 // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// FieldData return column data mapped to schemapb.FieldData
func (c *ColumnInt64Array) FieldData() *schemapb.FieldData {
fd := &schemapb.FieldData{
Type: schemapb.DataType_Array,
FieldName: c.name,
}
data := make([]*schemapb.ScalarField, 0, c.Len())
for _, arr := range c.values {
converted := make([]int64, 0, c.Len())
for i := 0; i < len(arr); i++ {
converted = append(converted, int64(arr[i]))
}
data = append(data, &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: converted,
},
},
})
}
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_ArrayData{
ArrayData: &schemapb.ArrayArray{
Data: data,
ElementType: schemapb.DataType_Int64,
},
},
},
}
return fd
}
// ValueByIdx returns value of the provided index
// error occurs when index out of range
func (c *ColumnInt64Array) ValueByIdx(idx int) ([]int64, error) {
var r []int64 // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// AppendValue append value into column
func (c *ColumnInt64Array) AppendValue(i interface{}) error {
v, ok := i.([]int64)
if !ok {
return fmt.Errorf("invalid type, expected []int64, got %T", i)
}
c.values = append(c.values, v)
return nil
}
// Data returns column data
func (c *ColumnInt64Array) Data() [][]int64 {
return c.values
}
// NewColumnInt64 auto generated constructor
func NewColumnInt64Array(name string, values [][]int64) *ColumnInt64Array {
return &ColumnInt64Array{
name: name,
values: values,
}
}
// ColumnFloatArray generated columns type for Float
type ColumnFloatArray struct {
ColumnBase
name string
values [][]float32
}
// Name returns column name
func (c *ColumnFloatArray) Name() string {
return c.name
}
// Type returns column entity.FieldType
func (c *ColumnFloatArray) Type() entity.FieldType {
return entity.FieldTypeArray
}
// Len returns column values length
func (c *ColumnFloatArray) Len() int {
return len(c.values)
}
// Get returns value at index as interface{}.
func (c *ColumnFloatArray) Get(idx int) (interface{}, error) {
var r []float32 // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// FieldData return column data mapped to schemapb.FieldData
func (c *ColumnFloatArray) FieldData() *schemapb.FieldData {
fd := &schemapb.FieldData{
Type: schemapb.DataType_Array,
FieldName: c.name,
}
data := make([]*schemapb.ScalarField, 0, c.Len())
for _, arr := range c.values {
converted := make([]float32, 0, c.Len())
for i := 0; i < len(arr); i++ {
converted = append(converted, float32(arr[i]))
}
data = append(data, &schemapb.ScalarField{
Data: &schemapb.ScalarField_FloatData{
FloatData: &schemapb.FloatArray{
Data: converted,
},
},
})
}
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_ArrayData{
ArrayData: &schemapb.ArrayArray{
Data: data,
ElementType: schemapb.DataType_Float,
},
},
},
}
return fd
}
// ValueByIdx returns value of the provided index
// error occurs when index out of range
func (c *ColumnFloatArray) ValueByIdx(idx int) ([]float32, error) {
var r []float32 // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// AppendValue append value into column
func (c *ColumnFloatArray) AppendValue(i interface{}) error {
v, ok := i.([]float32)
if !ok {
return fmt.Errorf("invalid type, expected []float32, got %T", i)
}
c.values = append(c.values, v)
return nil
}
// Data returns column data
func (c *ColumnFloatArray) Data() [][]float32 {
return c.values
}
// NewColumnFloat auto generated constructor
func NewColumnFloatArray(name string, values [][]float32) *ColumnFloatArray {
return &ColumnFloatArray{
name: name,
values: values,
}
}
// ColumnDoubleArray generated columns type for Double
type ColumnDoubleArray struct {
ColumnBase
name string
values [][]float64
}
// Name returns column name
func (c *ColumnDoubleArray) Name() string {
return c.name
}
// Type returns column entity.FieldType
func (c *ColumnDoubleArray) Type() entity.FieldType {
return entity.FieldTypeArray
}
// Len returns column values length
func (c *ColumnDoubleArray) Len() int {
return len(c.values)
}
// Get returns value at index as interface{}.
func (c *ColumnDoubleArray) Get(idx int) (interface{}, error) {
var r []float64 // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// FieldData return column data mapped to schemapb.FieldData
func (c *ColumnDoubleArray) FieldData() *schemapb.FieldData {
fd := &schemapb.FieldData{
Type: schemapb.DataType_Array,
FieldName: c.name,
}
data := make([]*schemapb.ScalarField, 0, c.Len())
for _, arr := range c.values {
converted := make([]float64, 0, c.Len())
for i := 0; i < len(arr); i++ {
converted = append(converted, float64(arr[i]))
}
data = append(data, &schemapb.ScalarField{
Data: &schemapb.ScalarField_DoubleData{
DoubleData: &schemapb.DoubleArray{
Data: converted,
},
},
})
}
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_ArrayData{
ArrayData: &schemapb.ArrayArray{
Data: data,
ElementType: schemapb.DataType_Double,
},
},
},
}
return fd
}
// ValueByIdx returns value of the provided index
// error occurs when index out of range
func (c *ColumnDoubleArray) ValueByIdx(idx int) ([]float64, error) {
var r []float64 // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// AppendValue append value into column
func (c *ColumnDoubleArray) AppendValue(i interface{}) error {
v, ok := i.([]float64)
if !ok {
return fmt.Errorf("invalid type, expected []float64, got %T", i)
}
c.values = append(c.values, v)
return nil
}
// Data returns column data
func (c *ColumnDoubleArray) Data() [][]float64 {
return c.values
}
// NewColumnDouble auto generated constructor
func NewColumnDoubleArray(name string, values [][]float64) *ColumnDoubleArray {
return &ColumnDoubleArray{
name: name,
values: values,
}
}

502
client/column/columns.go Normal file
View File

@ -0,0 +1,502 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package column
import (
"fmt"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/entity"
)
//go:generate go run gen/gen.go
// Column interface field type for column-based data frame
type Column interface {
Name() string
Type() entity.FieldType
Len() int
FieldData() *schemapb.FieldData
AppendValue(interface{}) error
Get(int) (interface{}, error)
GetAsInt64(int) (int64, error)
GetAsString(int) (string, error)
GetAsDouble(int) (float64, error)
GetAsBool(int) (bool, error)
}
// ColumnBase adds conversion methods support for fixed-type columns.
type ColumnBase struct{}
func (b ColumnBase) GetAsInt64(_ int) (int64, error) {
return 0, errors.New("conversion between fixed-type column not support")
}
func (b ColumnBase) GetAsString(_ int) (string, error) {
return "", errors.New("conversion between fixed-type column not support")
}
func (b ColumnBase) GetAsDouble(_ int) (float64, error) {
return 0, errors.New("conversion between fixed-type column not support")
}
func (b ColumnBase) GetAsBool(_ int) (bool, error) {
return false, errors.New("conversion between fixed-type column not support")
}
var errFieldDataTypeNotMatch = errors.New("FieldData type not matched")
// IDColumns converts schemapb.IDs to corresponding column
// currently Int64 / string may be in IDs
func IDColumns(idField *schemapb.IDs, begin, end int) (Column, error) {
var idColumn Column
if idField == nil {
return nil, errors.New("nil Ids from response")
}
switch field := idField.GetIdField().(type) {
case *schemapb.IDs_IntId:
if end >= 0 {
idColumn = NewColumnInt64("", field.IntId.GetData()[begin:end])
} else {
idColumn = NewColumnInt64("", field.IntId.GetData()[begin:])
}
case *schemapb.IDs_StrId:
if end >= 0 {
idColumn = NewColumnVarChar("", field.StrId.GetData()[begin:end])
} else {
idColumn = NewColumnVarChar("", field.StrId.GetData()[begin:])
}
default:
return nil, fmt.Errorf("unsupported id type %v", field)
}
return idColumn, nil
}
// FieldDataColumn converts schemapb.FieldData to Column, used int search result conversion logic
// begin, end specifies the start and end positions
func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) {
switch fd.GetType() {
case schemapb.DataType_Bool:
data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_BoolData)
if !ok {
return nil, errFieldDataTypeNotMatch
}
if end < 0 {
return NewColumnBool(fd.GetFieldName(), data.BoolData.GetData()[begin:]), nil
}
return NewColumnBool(fd.GetFieldName(), data.BoolData.GetData()[begin:end]), nil
case schemapb.DataType_Int8:
data, ok := getIntData(fd)
if !ok {
return nil, errFieldDataTypeNotMatch
}
values := make([]int8, 0, len(data.IntData.GetData()))
for _, v := range data.IntData.GetData() {
values = append(values, int8(v))
}
if end < 0 {
return NewColumnInt8(fd.GetFieldName(), values[begin:]), nil
}
return NewColumnInt8(fd.GetFieldName(), values[begin:end]), nil
case schemapb.DataType_Int16:
data, ok := getIntData(fd)
if !ok {
return nil, errFieldDataTypeNotMatch
}
values := make([]int16, 0, len(data.IntData.GetData()))
for _, v := range data.IntData.GetData() {
values = append(values, int16(v))
}
if end < 0 {
return NewColumnInt16(fd.GetFieldName(), values[begin:]), nil
}
return NewColumnInt16(fd.GetFieldName(), values[begin:end]), nil
case schemapb.DataType_Int32:
data, ok := getIntData(fd)
if !ok {
return nil, errFieldDataTypeNotMatch
}
if end < 0 {
return NewColumnInt32(fd.GetFieldName(), data.IntData.GetData()[begin:]), nil
}
return NewColumnInt32(fd.GetFieldName(), data.IntData.GetData()[begin:end]), nil
case schemapb.DataType_Int64:
data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_LongData)
if !ok {
return nil, errFieldDataTypeNotMatch
}
if end < 0 {
return NewColumnInt64(fd.GetFieldName(), data.LongData.GetData()[begin:]), nil
}
return NewColumnInt64(fd.GetFieldName(), data.LongData.GetData()[begin:end]), nil
case schemapb.DataType_Float:
data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_FloatData)
if !ok {
return nil, errFieldDataTypeNotMatch
}
if end < 0 {
return NewColumnFloat(fd.GetFieldName(), data.FloatData.GetData()[begin:]), nil
}
return NewColumnFloat(fd.GetFieldName(), data.FloatData.GetData()[begin:end]), nil
case schemapb.DataType_Double:
data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_DoubleData)
if !ok {
return nil, errFieldDataTypeNotMatch
}
if end < 0 {
return NewColumnDouble(fd.GetFieldName(), data.DoubleData.GetData()[begin:]), nil
}
return NewColumnDouble(fd.GetFieldName(), data.DoubleData.GetData()[begin:end]), nil
case schemapb.DataType_String:
data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_StringData)
if !ok {
return nil, errFieldDataTypeNotMatch
}
if end < 0 {
return NewColumnString(fd.GetFieldName(), data.StringData.GetData()[begin:]), nil
}
return NewColumnString(fd.GetFieldName(), data.StringData.GetData()[begin:end]), nil
case schemapb.DataType_VarChar:
data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_StringData)
if !ok {
return nil, errFieldDataTypeNotMatch
}
if end < 0 {
return NewColumnVarChar(fd.GetFieldName(), data.StringData.GetData()[begin:]), nil
}
return NewColumnVarChar(fd.GetFieldName(), data.StringData.GetData()[begin:end]), nil
case schemapb.DataType_Array:
data := fd.GetScalars().GetArrayData()
if data == nil {
return nil, errFieldDataTypeNotMatch
}
var arrayData []*schemapb.ScalarField
if end < 0 {
arrayData = data.GetData()[begin:]
} else {
arrayData = data.GetData()[begin:end]
}
return parseArrayData(fd.GetFieldName(), data.GetElementType(), arrayData)
case schemapb.DataType_JSON:
data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_JsonData)
isDynamic := fd.GetIsDynamic()
if !ok {
return nil, errFieldDataTypeNotMatch
}
if end < 0 {
return NewColumnJSONBytes(fd.GetFieldName(), data.JsonData.GetData()[begin:]).WithIsDynamic(isDynamic), nil
}
return NewColumnJSONBytes(fd.GetFieldName(), data.JsonData.GetData()[begin:end]).WithIsDynamic(isDynamic), nil
case schemapb.DataType_FloatVector:
vectors := fd.GetVectors()
x, ok := vectors.GetData().(*schemapb.VectorField_FloatVector)
if !ok {
return nil, errFieldDataTypeNotMatch
}
data := x.FloatVector.GetData()
dim := int(vectors.GetDim())
if end < 0 {
end = int(len(data) / dim)
}
vector := make([][]float32, 0, end-begin) // shall not have remanunt
for i := begin; i < end; i++ {
v := make([]float32, dim)
copy(v, data[i*dim:(i+1)*dim])
vector = append(vector, v)
}
return NewColumnFloatVector(fd.GetFieldName(), dim, vector), nil
case schemapb.DataType_BinaryVector:
vectors := fd.GetVectors()
x, ok := vectors.GetData().(*schemapb.VectorField_BinaryVector)
if !ok {
return nil, errFieldDataTypeNotMatch
}
data := x.BinaryVector
if data == nil {
return nil, errFieldDataTypeNotMatch
}
dim := int(vectors.GetDim())
blen := dim / 8
if end < 0 {
end = int(len(data) / blen)
}
vector := make([][]byte, 0, end-begin)
for i := begin; i < end; i++ {
v := make([]byte, blen)
copy(v, data[i*blen:(i+1)*blen])
vector = append(vector, v)
}
return NewColumnBinaryVector(fd.GetFieldName(), dim, vector), nil
case schemapb.DataType_Float16Vector:
vectors := fd.GetVectors()
x, ok := vectors.GetData().(*schemapb.VectorField_Float16Vector)
if !ok {
return nil, errFieldDataTypeNotMatch
}
data := x.Float16Vector
dim := int(vectors.GetDim())
if end < 0 {
end = int(len(data) / dim)
}
vector := make([][]byte, 0, end-begin)
for i := begin; i < end; i++ {
v := make([]byte, dim)
copy(v, data[i*dim:(i+1)*dim])
vector = append(vector, v)
}
return NewColumnFloat16Vector(fd.GetFieldName(), dim, vector), nil
case schemapb.DataType_BFloat16Vector:
vectors := fd.GetVectors()
x, ok := vectors.GetData().(*schemapb.VectorField_Bfloat16Vector)
if !ok {
return nil, errFieldDataTypeNotMatch
}
data := x.Bfloat16Vector
dim := int(vectors.GetDim())
if end < 0 {
end = int(len(data) / dim)
}
vector := make([][]byte, 0, end-begin) // shall not have remanunt
for i := begin; i < end; i++ {
v := make([]byte, dim)
copy(v, data[i*dim:(i+1)*dim])
vector = append(vector, v)
}
return NewColumnBFloat16Vector(fd.GetFieldName(), dim, vector), nil
default:
return nil, fmt.Errorf("unsupported data type %s", fd.GetType())
}
}
func parseArrayData(fieldName string, elementType schemapb.DataType, fieldDataList []*schemapb.ScalarField) (Column, error) {
switch elementType {
case schemapb.DataType_Bool:
var data [][]bool
for _, fd := range fieldDataList {
data = append(data, fd.GetBoolData().GetData())
}
return NewColumnBoolArray(fieldName, data), nil
case schemapb.DataType_Int8:
var data [][]int8
for _, fd := range fieldDataList {
raw := fd.GetIntData().GetData()
row := make([]int8, 0, len(raw))
for _, item := range raw {
row = append(row, int8(item))
}
data = append(data, row)
}
return NewColumnInt8Array(fieldName, data), nil
case schemapb.DataType_Int16:
var data [][]int16
for _, fd := range fieldDataList {
raw := fd.GetIntData().GetData()
row := make([]int16, 0, len(raw))
for _, item := range raw {
row = append(row, int16(item))
}
data = append(data, row)
}
return NewColumnInt16Array(fieldName, data), nil
case schemapb.DataType_Int32:
var data [][]int32
for _, fd := range fieldDataList {
data = append(data, fd.GetIntData().GetData())
}
return NewColumnInt32Array(fieldName, data), nil
case schemapb.DataType_Int64:
var data [][]int64
for _, fd := range fieldDataList {
data = append(data, fd.GetLongData().GetData())
}
return NewColumnInt64Array(fieldName, data), nil
case schemapb.DataType_Float:
var data [][]float32
for _, fd := range fieldDataList {
data = append(data, fd.GetFloatData().GetData())
}
return NewColumnFloatArray(fieldName, data), nil
case schemapb.DataType_Double:
var data [][]float64
for _, fd := range fieldDataList {
data = append(data, fd.GetDoubleData().GetData())
}
return NewColumnDoubleArray(fieldName, data), nil
case schemapb.DataType_VarChar, schemapb.DataType_String:
var data [][][]byte
for _, fd := range fieldDataList {
strs := fd.GetStringData().GetData()
bytesData := make([][]byte, 0, len(strs))
for _, str := range strs {
bytesData = append(bytesData, []byte(str))
}
data = append(data, bytesData)
}
return NewColumnVarCharArray(fieldName, data), nil
default:
return nil, fmt.Errorf("unsupported element type %s", elementType)
}
}
// getIntData get int32 slice from result field data
// also handles LongData bug (see also https://github.com/milvus-io/milvus/issues/23850)
func getIntData(fd *schemapb.FieldData) (*schemapb.ScalarField_IntData, bool) {
switch data := fd.GetScalars().GetData().(type) {
case *schemapb.ScalarField_IntData:
return data, true
case *schemapb.ScalarField_LongData:
// only alway empty LongData for backward compatibility
if len(data.LongData.GetData()) == 0 {
return &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{},
}, true
}
return nil, false
default:
return nil, false
}
}
// FieldDataColumn converts schemapb.FieldData to vector Column
func FieldDataVector(fd *schemapb.FieldData) (Column, error) {
switch fd.GetType() {
case schemapb.DataType_FloatVector:
vectors := fd.GetVectors()
x, ok := vectors.GetData().(*schemapb.VectorField_FloatVector)
if !ok {
return nil, errFieldDataTypeNotMatch
}
data := x.FloatVector.GetData()
dim := int(vectors.GetDim())
vector := make([][]float32, 0, len(data)/dim) // shall not have remanunt
for i := 0; i < len(data)/dim; i++ {
v := make([]float32, dim)
copy(v, data[i*dim:(i+1)*dim])
vector = append(vector, v)
}
return NewColumnFloatVector(fd.GetFieldName(), dim, vector), nil
case schemapb.DataType_BinaryVector:
vectors := fd.GetVectors()
x, ok := vectors.GetData().(*schemapb.VectorField_BinaryVector)
if !ok {
return nil, errFieldDataTypeNotMatch
}
data := x.BinaryVector
if data == nil {
return nil, errFieldDataTypeNotMatch
}
dim := int(vectors.GetDim())
blen := dim / 8
vector := make([][]byte, 0, len(data)/blen)
for i := 0; i < len(data)/blen; i++ {
v := make([]byte, blen)
copy(v, data[i*blen:(i+1)*blen])
vector = append(vector, v)
}
return NewColumnBinaryVector(fd.GetFieldName(), dim, vector), nil
case schemapb.DataType_Float16Vector:
vectors := fd.GetVectors()
x, ok := vectors.GetData().(*schemapb.VectorField_Float16Vector)
if !ok {
return nil, errFieldDataTypeNotMatch
}
data := x.Float16Vector
dim := int(vectors.GetDim())
vector := make([][]byte, 0, len(data)/dim) // shall not have remanunt
for i := 0; i < len(data)/dim; i++ {
v := make([]byte, dim)
copy(v, data[i*dim:(i+1)*dim])
vector = append(vector, v)
}
return NewColumnFloat16Vector(fd.GetFieldName(), dim, vector), nil
case schemapb.DataType_BFloat16Vector:
vectors := fd.GetVectors()
x, ok := vectors.GetData().(*schemapb.VectorField_Bfloat16Vector)
if !ok {
return nil, errFieldDataTypeNotMatch
}
data := x.Bfloat16Vector
dim := int(vectors.GetDim())
vector := make([][]byte, 0, len(data)/dim) // shall not have remanunt
for i := 0; i < len(data)/dim; i++ {
v := make([]byte, dim)
copy(v, data[i*dim:(i+1)*dim])
vector = append(vector, v)
}
return NewColumnBFloat16Vector(fd.GetFieldName(), dim, vector), nil
default:
return nil, errors.New("unsupported data type")
}
}
// defaultValueColumn will return the empty scalars column which will be fill with default value
func DefaultValueColumn(name string, dataType entity.FieldType) (Column, error) {
switch dataType {
case entity.FieldTypeBool:
return NewColumnBool(name, nil), nil
case entity.FieldTypeInt8:
return NewColumnInt8(name, nil), nil
case entity.FieldTypeInt16:
return NewColumnInt16(name, nil), nil
case entity.FieldTypeInt32:
return NewColumnInt32(name, nil), nil
case entity.FieldTypeInt64:
return NewColumnInt64(name, nil), nil
case entity.FieldTypeFloat:
return NewColumnFloat(name, nil), nil
case entity.FieldTypeDouble:
return NewColumnDouble(name, nil), nil
case entity.FieldTypeString:
return NewColumnString(name, nil), nil
case entity.FieldTypeVarChar:
return NewColumnVarChar(name, nil), nil
case entity.FieldTypeJSON:
return NewColumnJSONBytes(name, nil), nil
default:
return nil, fmt.Errorf("default value unsupported data type %s", dataType)
}
}

View File

@ -0,0 +1,160 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package column
import (
"math/rand"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
func TestIDColumns(t *testing.T) {
dataLen := rand.Intn(100) + 1
base := rand.Intn(5000) // id start point
t.Run("nil id", func(t *testing.T) {
_, err := IDColumns(nil, 0, -1)
assert.NotNil(t, err)
idField := &schemapb.IDs{}
_, err = IDColumns(idField, 0, -1)
assert.NotNil(t, err)
})
t.Run("int ids", func(t *testing.T) {
ids := make([]int64, 0, dataLen)
for i := 0; i < dataLen; i++ {
ids = append(ids, int64(i+base))
}
idField := &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: ids,
},
},
}
column, err := IDColumns(idField, 0, dataLen)
assert.Nil(t, err)
assert.NotNil(t, column)
assert.Equal(t, dataLen, column.Len())
column, err = IDColumns(idField, 0, -1) // test -1 method
assert.Nil(t, err)
assert.NotNil(t, column)
assert.Equal(t, dataLen, column.Len())
})
t.Run("string ids", func(t *testing.T) {
ids := make([]string, 0, dataLen)
for i := 0; i < dataLen; i++ {
ids = append(ids, strconv.FormatInt(int64(i+base), 10))
}
idField := &schemapb.IDs{
IdField: &schemapb.IDs_StrId{
StrId: &schemapb.StringArray{
Data: ids,
},
},
}
column, err := IDColumns(idField, 0, dataLen)
assert.Nil(t, err)
assert.NotNil(t, column)
assert.Equal(t, dataLen, column.Len())
column, err = IDColumns(idField, 0, -1) // test -1 method
assert.Nil(t, err)
assert.NotNil(t, column)
assert.Equal(t, dataLen, column.Len())
})
}
func TestGetIntData(t *testing.T) {
type testCase struct {
tag string
fd *schemapb.FieldData
expectOK bool
}
cases := []testCase{
{
tag: "normal_IntData",
fd: &schemapb.FieldData{
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{Data: []int32{1, 2, 3}},
},
},
},
},
expectOK: true,
},
{
tag: "empty_LongData",
fd: &schemapb.FieldData{
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{Data: nil},
},
},
},
},
expectOK: true,
},
{
tag: "nonempty_LongData",
fd: &schemapb.FieldData{
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}},
},
},
},
},
expectOK: false,
},
{
tag: "other_data",
fd: &schemapb.FieldData{
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_BoolData{},
},
},
},
expectOK: false,
},
{
tag: "vector_data",
fd: &schemapb.FieldData{
Field: &schemapb.FieldData_Vectors{},
},
expectOK: false,
},
}
for _, tc := range cases {
t.Run(tc.tag, func(t *testing.T) {
_, ok := getIntData(tc.fd)
assert.Equal(t, tc.expectOK, ok)
})
}
}

View File

@ -0,0 +1,53 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package column
func (c *ColumnInt8) GetAsInt64(idx int) (int64, error) {
v, err := c.ValueByIdx(idx)
return int64(v), err
}
func (c *ColumnInt16) GetAsInt64(idx int) (int64, error) {
v, err := c.ValueByIdx(idx)
return int64(v), err
}
func (c *ColumnInt32) GetAsInt64(idx int) (int64, error) {
v, err := c.ValueByIdx(idx)
return int64(v), err
}
func (c *ColumnInt64) GetAsInt64(idx int) (int64, error) {
return c.ValueByIdx(idx)
}
func (c *ColumnString) GetAsString(idx int) (string, error) {
return c.ValueByIdx(idx)
}
func (c *ColumnFloat) GetAsDouble(idx int) (float64, error) {
v, err := c.ValueByIdx(idx)
return float64(v), err
}
func (c *ColumnDouble) GetAsDouble(idx int) (float64, error) {
return c.ValueByIdx(idx)
}
func (c *ColumnBool) GetAsBool(idx int) (bool, error) {
return c.ValueByIdx(idx)
}

113
client/column/dynamic.go Normal file
View File

@ -0,0 +1,113 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package column
import (
"github.com/cockroachdb/errors"
"github.com/tidwall/gjson"
)
// ColumnDynamic is a logically wrapper for dynamic json field with provided output field.
type ColumnDynamic struct {
*ColumnJSONBytes
outputField string
}
func NewColumnDynamic(column *ColumnJSONBytes, outputField string) *ColumnDynamic {
return &ColumnDynamic{
ColumnJSONBytes: column,
outputField: outputField,
}
}
func (c *ColumnDynamic) Name() string {
return c.outputField
}
// Get returns element at idx as interface{}.
// Overrides internal json column behavior, returns raw json data.
func (c *ColumnDynamic) Get(idx int) (interface{}, error) {
bs, err := c.ColumnJSONBytes.ValueByIdx(idx)
if err != nil {
return 0, err
}
r := gjson.GetBytes(bs, c.outputField)
if !r.Exists() {
return 0, errors.New("column not has value")
}
return r.Raw, nil
}
func (c *ColumnDynamic) GetAsInt64(idx int) (int64, error) {
bs, err := c.ColumnJSONBytes.ValueByIdx(idx)
if err != nil {
return 0, err
}
r := gjson.GetBytes(bs, c.outputField)
if !r.Exists() {
return 0, errors.New("column not has value")
}
if r.Type != gjson.Number {
return 0, errors.New("column not int")
}
return r.Int(), nil
}
func (c *ColumnDynamic) GetAsString(idx int) (string, error) {
bs, err := c.ColumnJSONBytes.ValueByIdx(idx)
if err != nil {
return "", err
}
r := gjson.GetBytes(bs, c.outputField)
if !r.Exists() {
return "", errors.New("column not has value")
}
if r.Type != gjson.String {
return "", errors.New("column not string")
}
return r.String(), nil
}
func (c *ColumnDynamic) GetAsBool(idx int) (bool, error) {
bs, err := c.ColumnJSONBytes.ValueByIdx(idx)
if err != nil {
return false, err
}
r := gjson.GetBytes(bs, c.outputField)
if !r.Exists() {
return false, errors.New("column not has value")
}
if !r.IsBool() {
return false, errors.New("column not string")
}
return r.Bool(), nil
}
func (c *ColumnDynamic) GetAsDouble(idx int) (float64, error) {
bs, err := c.ColumnJSONBytes.ValueByIdx(idx)
if err != nil {
return 0, err
}
r := gjson.GetBytes(bs, c.outputField)
if !r.Exists() {
return 0, errors.New("column not has value")
}
if r.Type != gjson.Number {
return 0, errors.New("column not string")
}
return r.Float(), nil
}

View File

@ -0,0 +1,162 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package column
import (
"testing"
"github.com/stretchr/testify/suite"
)
type ColumnDynamicSuite struct {
suite.Suite
}
func (s *ColumnDynamicSuite) TestGetInt() {
cases := []struct {
input string
expectErr bool
expectValue int64
}{
{`{"field": 1000000000000000001}`, false, 1000000000000000001},
{`{"field": 4418489049307132905}`, false, 4418489049307132905},
{`{"other_field": 4418489049307132905}`, true, 0},
{`{"field": "string"}`, true, 0},
}
for _, c := range cases {
s.Run(c.input, func() {
column := NewColumnDynamic(&ColumnJSONBytes{
values: [][]byte{[]byte(c.input)},
}, "field")
v, err := column.GetAsInt64(0)
if c.expectErr {
s.Error(err)
return
}
s.NoError(err)
s.Equal(c.expectValue, v)
})
}
}
func (s *ColumnDynamicSuite) TestGetString() {
cases := []struct {
input string
expectErr bool
expectValue string
}{
{`{"field": "abc"}`, false, "abc"},
{`{"field": "test"}`, false, "test"},
{`{"other_field": "string"}`, true, ""},
{`{"field": 123}`, true, ""},
}
for _, c := range cases {
s.Run(c.input, func() {
column := NewColumnDynamic(&ColumnJSONBytes{
values: [][]byte{[]byte(c.input)},
}, "field")
v, err := column.GetAsString(0)
if c.expectErr {
s.Error(err)
return
}
s.NoError(err)
s.Equal(c.expectValue, v)
})
}
}
func (s *ColumnDynamicSuite) TestGetBool() {
cases := []struct {
input string
expectErr bool
expectValue bool
}{
{`{"field": true}`, false, true},
{`{"field": false}`, false, false},
{`{"other_field": true}`, true, false},
{`{"field": "test"}`, true, false},
}
for _, c := range cases {
s.Run(c.input, func() {
column := NewColumnDynamic(&ColumnJSONBytes{
values: [][]byte{[]byte(c.input)},
}, "field")
v, err := column.GetAsBool(0)
if c.expectErr {
s.Error(err)
return
}
s.NoError(err)
s.Equal(c.expectValue, v)
})
}
}
func (s *ColumnDynamicSuite) TestGetDouble() {
cases := []struct {
input string
expectErr bool
expectValue float64
}{
{`{"field": 1}`, false, 1.0},
{`{"field": 6231.123}`, false, 6231.123},
{`{"other_field": 1.0}`, true, 0},
{`{"field": "string"}`, true, 0},
}
for _, c := range cases {
s.Run(c.input, func() {
column := NewColumnDynamic(&ColumnJSONBytes{
values: [][]byte{[]byte(c.input)},
}, "field")
v, err := column.GetAsDouble(0)
if c.expectErr {
s.Error(err)
return
}
s.NoError(err)
s.Less(v-c.expectValue, 1e-10)
})
}
}
func (s *ColumnDynamicSuite) TestIndexOutOfRange() {
var err error
column := NewColumnDynamic(&ColumnJSONBytes{}, "field")
s.Equal("field", column.Name())
_, err = column.GetAsInt64(0)
s.Error(err)
_, err = column.GetAsString(0)
s.Error(err)
_, err = column.GetAsBool(0)
s.Error(err)
_, err = column.GetAsDouble(0)
s.Error(err)
}
func TestColumnDynamic(t *testing.T) {
suite.Run(t, new(ColumnDynamicSuite))
}

146
client/column/json.go Normal file
View File

@ -0,0 +1,146 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package column
import (
"encoding/json"
"fmt"
"reflect"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/entity"
)
var _ (Column) = (*ColumnJSONBytes)(nil)
// ColumnJSONBytes column type for JSON.
// all items are marshaled json bytes.
type ColumnJSONBytes struct {
ColumnBase
name string
values [][]byte
isDynamic bool
}
// Name returns column name.
func (c *ColumnJSONBytes) Name() string {
return c.name
}
// Type returns column entity.FieldType.
func (c *ColumnJSONBytes) Type() entity.FieldType {
return entity.FieldTypeJSON
}
// Len returns column values length.
func (c *ColumnJSONBytes) Len() int {
return len(c.values)
}
// Get returns value at index as interface{}.
func (c *ColumnJSONBytes) Get(idx int) (interface{}, error) {
if idx < 0 || idx > c.Len() {
return nil, errors.New("index out of range")
}
return c.values[idx], nil
}
func (c *ColumnJSONBytes) GetAsString(idx int) (string, error) {
bs, err := c.ValueByIdx(idx)
if err != nil {
return "", err
}
return string(bs), nil
}
// FieldData return column data mapped to schemapb.FieldData.
func (c *ColumnJSONBytes) FieldData() *schemapb.FieldData {
fd := &schemapb.FieldData{
Type: schemapb.DataType_JSON,
FieldName: c.name,
IsDynamic: c.isDynamic,
}
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_JsonData{
JsonData: &schemapb.JSONArray{
Data: c.values,
},
},
},
}
return fd
}
// ValueByIdx returns value of the provided index.
func (c *ColumnJSONBytes) ValueByIdx(idx int) ([]byte, error) {
if idx < 0 || idx >= c.Len() {
return nil, errors.New("index out of range")
}
return c.values[idx], nil
}
// AppendValue append value into column.
func (c *ColumnJSONBytes) AppendValue(i interface{}) error {
var v []byte
switch raw := i.(type) {
case []byte:
v = raw
default:
k := reflect.TypeOf(i).Kind()
if k == reflect.Ptr {
k = reflect.TypeOf(i).Elem().Kind()
}
switch k {
case reflect.Struct:
fallthrough
case reflect.Map:
bs, err := json.Marshal(raw)
if err != nil {
return err
}
v = bs
default:
return fmt.Errorf("expect json compatible type([]byte, struct[}, map], got %T)", i)
}
}
c.values = append(c.values, v)
return nil
}
// Data returns column data.
func (c *ColumnJSONBytes) Data() [][]byte {
return c.values
}
func (c *ColumnJSONBytes) WithIsDynamic(isDynamic bool) *ColumnJSONBytes {
c.isDynamic = isDynamic
return c
}
// NewColumnJSONBytes composes a Column with json bytes.
func NewColumnJSONBytes(name string, values [][]byte) *ColumnJSONBytes {
return &ColumnJSONBytes{
name: name,
values: values,
}
}

101
client/column/json_test.go Normal file
View File

@ -0,0 +1,101 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package column
import (
"fmt"
"math/rand"
"testing"
"time"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus/client/v2/entity"
)
type ColumnJSONBytesSuite struct {
suite.Suite
}
func (s *ColumnJSONBytesSuite) SetupSuite() {
rand.Seed(time.Now().UnixNano())
}
func (s *ColumnJSONBytesSuite) TestAttrMethods() {
columnName := fmt.Sprintf("column_jsonbs_%d", rand.Int())
columnLen := 8 + rand.Intn(10)
v := make([][]byte, columnLen)
column := NewColumnJSONBytes(columnName, v).WithIsDynamic(true)
s.Run("test_meta", func() {
ft := entity.FieldTypeJSON
s.Equal("JSON", ft.Name())
s.Equal("JSON", ft.String())
pbName, pbType := ft.PbFieldType()
s.Equal("JSON", pbName)
s.Equal("JSON", pbType)
})
s.Run("test_column_attribute", func() {
s.Equal(columnName, column.Name())
s.Equal(entity.FieldTypeJSON, column.Type())
s.Equal(columnLen, column.Len())
s.EqualValues(v, column.Data())
})
s.Run("test_column_field_data", func() {
fd := column.FieldData()
s.NotNil(fd)
s.Equal(fd.GetFieldName(), columnName)
})
s.Run("test_column_valuer_by_idx", func() {
_, err := column.ValueByIdx(-1)
s.Error(err)
_, err = column.ValueByIdx(columnLen)
s.Error(err)
for i := 0; i < columnLen; i++ {
v, err := column.ValueByIdx(i)
s.NoError(err)
s.Equal(column.values[i], v)
}
})
s.Run("test_append_value", func() {
item := make([]byte, 10)
err := column.AppendValue(item)
s.NoError(err)
s.Equal(columnLen+1, column.Len())
val, err := column.ValueByIdx(columnLen)
s.NoError(err)
s.Equal(item, val)
err = column.AppendValue(&struct{ Tag string }{Tag: "abc"})
s.NoError(err)
err = column.AppendValue(map[string]interface{}{"Value": 123})
s.NoError(err)
err = column.AppendValue(1)
s.Error(err)
})
}
func TestColumnJSONBytes(t *testing.T) {
suite.Run(t, new(ColumnJSONBytesSuite))
}

708
client/column/scalar_gen.go Normal file
View File

@ -0,0 +1,708 @@
// Code generated by go generate; DO NOT EDIT
// This file is generated by go generate
package column
import (
"errors"
"fmt"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/entity"
)
// ColumnBool generated columns type for Bool
type ColumnBool struct {
ColumnBase
name string
values []bool
}
// Name returns column name
func (c *ColumnBool) Name() string {
return c.name
}
// Type returns column entity.FieldType
func (c *ColumnBool) Type() entity.FieldType {
return entity.FieldTypeBool
}
// Len returns column values length
func (c *ColumnBool) Len() int {
return len(c.values)
}
// Get returns value at index as interface{}.
func (c *ColumnBool) Get(idx int) (interface{}, error) {
var r bool // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// FieldData return column data mapped to schemapb.FieldData
func (c *ColumnBool) FieldData() *schemapb.FieldData {
fd := &schemapb.FieldData{
Type: schemapb.DataType_Bool,
FieldName: c.name,
}
data := make([]bool, 0, c.Len())
for i := 0; i < c.Len(); i++ {
data = append(data, bool(c.values[i]))
}
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_BoolData{
BoolData: &schemapb.BoolArray{
Data: data,
},
},
},
}
return fd
}
// ValueByIdx returns value of the provided index
// error occurs when index out of range
func (c *ColumnBool) ValueByIdx(idx int) (bool, error) {
var r bool // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// AppendValue append value into column
func (c *ColumnBool) AppendValue(i interface{}) error {
v, ok := i.(bool)
if !ok {
return fmt.Errorf("invalid type, expected bool, got %T", i)
}
c.values = append(c.values, v)
return nil
}
// Data returns column data
func (c *ColumnBool) Data() []bool {
return c.values
}
// NewColumnBool auto generated constructor
func NewColumnBool(name string, values []bool) *ColumnBool {
return &ColumnBool{
name: name,
values: values,
}
}
// ColumnInt8 generated columns type for Int8
type ColumnInt8 struct {
ColumnBase
name string
values []int8
}
// Name returns column name
func (c *ColumnInt8) Name() string {
return c.name
}
// Type returns column entity.FieldType
func (c *ColumnInt8) Type() entity.FieldType {
return entity.FieldTypeInt8
}
// Len returns column values length
func (c *ColumnInt8) Len() int {
return len(c.values)
}
// Get returns value at index as interface{}.
func (c *ColumnInt8) Get(idx int) (interface{}, error) {
var r int8 // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// FieldData return column data mapped to schemapb.FieldData
func (c *ColumnInt8) FieldData() *schemapb.FieldData {
fd := &schemapb.FieldData{
Type: schemapb.DataType_Int8,
FieldName: c.name,
}
data := make([]int32, 0, c.Len())
for i := 0; i < c.Len(); i++ {
data = append(data, int32(c.values[i]))
}
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: data,
},
},
},
}
return fd
}
// ValueByIdx returns value of the provided index
// error occurs when index out of range
func (c *ColumnInt8) ValueByIdx(idx int) (int8, error) {
var r int8 // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// AppendValue append value into column
func (c *ColumnInt8) AppendValue(i interface{}) error {
v, ok := i.(int8)
if !ok {
return fmt.Errorf("invalid type, expected int8, got %T", i)
}
c.values = append(c.values, v)
return nil
}
// Data returns column data
func (c *ColumnInt8) Data() []int8 {
return c.values
}
// NewColumnInt8 auto generated constructor
func NewColumnInt8(name string, values []int8) *ColumnInt8 {
return &ColumnInt8{
name: name,
values: values,
}
}
// ColumnInt16 generated columns type for Int16
type ColumnInt16 struct {
ColumnBase
name string
values []int16
}
// Name returns column name
func (c *ColumnInt16) Name() string {
return c.name
}
// Type returns column entity.FieldType
func (c *ColumnInt16) Type() entity.FieldType {
return entity.FieldTypeInt16
}
// Len returns column values length
func (c *ColumnInt16) Len() int {
return len(c.values)
}
// Get returns value at index as interface{}.
func (c *ColumnInt16) Get(idx int) (interface{}, error) {
var r int16 // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// FieldData return column data mapped to schemapb.FieldData
func (c *ColumnInt16) FieldData() *schemapb.FieldData {
fd := &schemapb.FieldData{
Type: schemapb.DataType_Int16,
FieldName: c.name,
}
data := make([]int32, 0, c.Len())
for i := 0; i < c.Len(); i++ {
data = append(data, int32(c.values[i]))
}
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: data,
},
},
},
}
return fd
}
// ValueByIdx returns value of the provided index
// error occurs when index out of range
func (c *ColumnInt16) ValueByIdx(idx int) (int16, error) {
var r int16 // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// AppendValue append value into column
func (c *ColumnInt16) AppendValue(i interface{}) error {
v, ok := i.(int16)
if !ok {
return fmt.Errorf("invalid type, expected int16, got %T", i)
}
c.values = append(c.values, v)
return nil
}
// Data returns column data
func (c *ColumnInt16) Data() []int16 {
return c.values
}
// NewColumnInt16 auto generated constructor
func NewColumnInt16(name string, values []int16) *ColumnInt16 {
return &ColumnInt16{
name: name,
values: values,
}
}
// ColumnInt32 generated columns type for Int32
type ColumnInt32 struct {
ColumnBase
name string
values []int32
}
// Name returns column name
func (c *ColumnInt32) Name() string {
return c.name
}
// Type returns column entity.FieldType
func (c *ColumnInt32) Type() entity.FieldType {
return entity.FieldTypeInt32
}
// Len returns column values length
func (c *ColumnInt32) Len() int {
return len(c.values)
}
// Get returns value at index as interface{}.
func (c *ColumnInt32) Get(idx int) (interface{}, error) {
var r int32 // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// FieldData return column data mapped to schemapb.FieldData
func (c *ColumnInt32) FieldData() *schemapb.FieldData {
fd := &schemapb.FieldData{
Type: schemapb.DataType_Int32,
FieldName: c.name,
}
data := make([]int32, 0, c.Len())
for i := 0; i < c.Len(); i++ {
data = append(data, int32(c.values[i]))
}
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: data,
},
},
},
}
return fd
}
// ValueByIdx returns value of the provided index
// error occurs when index out of range
func (c *ColumnInt32) ValueByIdx(idx int) (int32, error) {
var r int32 // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// AppendValue append value into column
func (c *ColumnInt32) AppendValue(i interface{}) error {
v, ok := i.(int32)
if !ok {
return fmt.Errorf("invalid type, expected int32, got %T", i)
}
c.values = append(c.values, v)
return nil
}
// Data returns column data
func (c *ColumnInt32) Data() []int32 {
return c.values
}
// NewColumnInt32 auto generated constructor
func NewColumnInt32(name string, values []int32) *ColumnInt32 {
return &ColumnInt32{
name: name,
values: values,
}
}
// ColumnInt64 generated columns type for Int64
type ColumnInt64 struct {
ColumnBase
name string
values []int64
}
// Name returns column name
func (c *ColumnInt64) Name() string {
return c.name
}
// Type returns column entity.FieldType
func (c *ColumnInt64) Type() entity.FieldType {
return entity.FieldTypeInt64
}
// Len returns column values length
func (c *ColumnInt64) Len() int {
return len(c.values)
}
// Get returns value at index as interface{}.
func (c *ColumnInt64) Get(idx int) (interface{}, error) {
var r int64 // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// FieldData return column data mapped to schemapb.FieldData
func (c *ColumnInt64) FieldData() *schemapb.FieldData {
fd := &schemapb.FieldData{
Type: schemapb.DataType_Int64,
FieldName: c.name,
}
data := make([]int64, 0, c.Len())
for i := 0; i < c.Len(); i++ {
data = append(data, int64(c.values[i]))
}
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: data,
},
},
},
}
return fd
}
// ValueByIdx returns value of the provided index
// error occurs when index out of range
func (c *ColumnInt64) ValueByIdx(idx int) (int64, error) {
var r int64 // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// AppendValue append value into column
func (c *ColumnInt64) AppendValue(i interface{}) error {
v, ok := i.(int64)
if !ok {
return fmt.Errorf("invalid type, expected int64, got %T", i)
}
c.values = append(c.values, v)
return nil
}
// Data returns column data
func (c *ColumnInt64) Data() []int64 {
return c.values
}
// NewColumnInt64 auto generated constructor
func NewColumnInt64(name string, values []int64) *ColumnInt64 {
return &ColumnInt64{
name: name,
values: values,
}
}
// ColumnFloat generated columns type for Float
type ColumnFloat struct {
ColumnBase
name string
values []float32
}
// Name returns column name
func (c *ColumnFloat) Name() string {
return c.name
}
// Type returns column entity.FieldType
func (c *ColumnFloat) Type() entity.FieldType {
return entity.FieldTypeFloat
}
// Len returns column values length
func (c *ColumnFloat) Len() int {
return len(c.values)
}
// Get returns value at index as interface{}.
func (c *ColumnFloat) Get(idx int) (interface{}, error) {
var r float32 // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// FieldData return column data mapped to schemapb.FieldData
func (c *ColumnFloat) FieldData() *schemapb.FieldData {
fd := &schemapb.FieldData{
Type: schemapb.DataType_Float,
FieldName: c.name,
}
data := make([]float32, 0, c.Len())
for i := 0; i < c.Len(); i++ {
data = append(data, float32(c.values[i]))
}
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_FloatData{
FloatData: &schemapb.FloatArray{
Data: data,
},
},
},
}
return fd
}
// ValueByIdx returns value of the provided index
// error occurs when index out of range
func (c *ColumnFloat) ValueByIdx(idx int) (float32, error) {
var r float32 // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// AppendValue append value into column
func (c *ColumnFloat) AppendValue(i interface{}) error {
v, ok := i.(float32)
if !ok {
return fmt.Errorf("invalid type, expected float32, got %T", i)
}
c.values = append(c.values, v)
return nil
}
// Data returns column data
func (c *ColumnFloat) Data() []float32 {
return c.values
}
// NewColumnFloat auto generated constructor
func NewColumnFloat(name string, values []float32) *ColumnFloat {
return &ColumnFloat{
name: name,
values: values,
}
}
// ColumnDouble generated columns type for Double
type ColumnDouble struct {
ColumnBase
name string
values []float64
}
// Name returns column name
func (c *ColumnDouble) Name() string {
return c.name
}
// Type returns column entity.FieldType
func (c *ColumnDouble) Type() entity.FieldType {
return entity.FieldTypeDouble
}
// Len returns column values length
func (c *ColumnDouble) Len() int {
return len(c.values)
}
// Get returns value at index as interface{}.
func (c *ColumnDouble) Get(idx int) (interface{}, error) {
var r float64 // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// FieldData return column data mapped to schemapb.FieldData
func (c *ColumnDouble) FieldData() *schemapb.FieldData {
fd := &schemapb.FieldData{
Type: schemapb.DataType_Double,
FieldName: c.name,
}
data := make([]float64, 0, c.Len())
for i := 0; i < c.Len(); i++ {
data = append(data, float64(c.values[i]))
}
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_DoubleData{
DoubleData: &schemapb.DoubleArray{
Data: data,
},
},
},
}
return fd
}
// ValueByIdx returns value of the provided index
// error occurs when index out of range
func (c *ColumnDouble) ValueByIdx(idx int) (float64, error) {
var r float64 // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// AppendValue append value into column
func (c *ColumnDouble) AppendValue(i interface{}) error {
v, ok := i.(float64)
if !ok {
return fmt.Errorf("invalid type, expected float64, got %T", i)
}
c.values = append(c.values, v)
return nil
}
// Data returns column data
func (c *ColumnDouble) Data() []float64 {
return c.values
}
// NewColumnDouble auto generated constructor
func NewColumnDouble(name string, values []float64) *ColumnDouble {
return &ColumnDouble{
name: name,
values: values,
}
}
// ColumnString generated columns type for String
type ColumnString struct {
ColumnBase
name string
values []string
}
// Name returns column name
func (c *ColumnString) Name() string {
return c.name
}
// Type returns column entity.FieldType
func (c *ColumnString) Type() entity.FieldType {
return entity.FieldTypeString
}
// Len returns column values length
func (c *ColumnString) Len() int {
return len(c.values)
}
// Get returns value at index as interface{}.
func (c *ColumnString) Get(idx int) (interface{}, error) {
var r string // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// FieldData return column data mapped to schemapb.FieldData
func (c *ColumnString) FieldData() *schemapb.FieldData {
fd := &schemapb.FieldData{
Type: schemapb.DataType_String,
FieldName: c.name,
}
data := make([]string, 0, c.Len())
for i := 0; i < c.Len(); i++ {
data = append(data, string(c.values[i]))
}
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: data,
},
},
},
}
return fd
}
// ValueByIdx returns value of the provided index
// error occurs when index out of range
func (c *ColumnString) ValueByIdx(idx int) (string, error) {
var r string // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// AppendValue append value into column
func (c *ColumnString) AppendValue(i interface{}) error {
v, ok := i.(string)
if !ok {
return fmt.Errorf("invalid type, expected string, got %T", i)
}
c.values = append(c.values, v)
return nil
}
// Data returns column data
func (c *ColumnString) Data() []string {
return c.values
}
// NewColumnString auto generated constructor
func NewColumnString(name string, values []string) *ColumnString {
return &ColumnString{
name: name,
values: values,
}
}

View File

@ -0,0 +1,855 @@
// Code generated by go generate; DO NOT EDIT
// This file is generated by go generated
package column
import (
"fmt"
"math/rand"
"testing"
"time"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/stretchr/testify/assert"
)
func TestColumnBool(t *testing.T) {
rand.Seed(time.Now().UnixNano())
columnName := fmt.Sprintf("column_Bool_%d", rand.Int())
columnLen := 8 + rand.Intn(10)
v := make([]bool, columnLen)
column := NewColumnBool(columnName, v)
t.Run("test meta", func(t *testing.T) {
ft := entity.FieldTypeBool
assert.Equal(t, "Bool", ft.Name())
assert.Equal(t, "bool", ft.String())
pbName, pbType := ft.PbFieldType()
assert.Equal(t, "Bool", pbName)
assert.Equal(t, "bool", pbType)
})
t.Run("test column attribute", func(t *testing.T) {
assert.Equal(t, columnName, column.Name())
assert.Equal(t, entity.FieldTypeBool, column.Type())
assert.Equal(t, columnLen, column.Len())
assert.EqualValues(t, v, column.Data())
})
t.Run("test column field data", func(t *testing.T) {
fd := column.FieldData()
assert.NotNil(t, fd)
assert.Equal(t, fd.GetFieldName(), columnName)
})
t.Run("test column value by idx", func(t *testing.T) {
_, err := column.ValueByIdx(-1)
assert.NotNil(t, err)
_, err = column.ValueByIdx(columnLen)
assert.NotNil(t, err)
for i := 0; i < columnLen; i++ {
v, err := column.ValueByIdx(i)
assert.Nil(t, err)
assert.Equal(t, column.values[i], v)
}
})
}
func TestFieldDataBoolColumn(t *testing.T) {
len := rand.Intn(10) + 8
name := fmt.Sprintf("fd_Bool_%d", rand.Int())
fd := &schemapb.FieldData{
Type: schemapb.DataType_Bool,
FieldName: name,
}
t.Run("normal usage", func(t *testing.T) {
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_BoolData{
BoolData: &schemapb.BoolArray{
Data: make([]bool, len),
},
},
},
}
column, err := FieldDataColumn(fd, 0, len)
assert.Nil(t, err)
assert.NotNil(t, column)
assert.Equal(t, name, column.Name())
assert.Equal(t, len, column.Len())
assert.Equal(t, entity.FieldTypeBool, column.Type())
var ev bool
err = column.AppendValue(ev)
assert.Equal(t, len+1, column.Len())
assert.Nil(t, err)
err = column.AppendValue(struct{}{})
assert.Equal(t, len+1, column.Len())
assert.NotNil(t, err)
})
t.Run("nil data", func(t *testing.T) {
fd.Field = nil
_, err := FieldDataColumn(fd, 0, len)
assert.NotNil(t, err)
})
t.Run("get all data", func(t *testing.T) {
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_BoolData{
BoolData: &schemapb.BoolArray{
Data: make([]bool, len),
},
},
},
}
column, err := FieldDataColumn(fd, 0, -1)
assert.Nil(t, err)
assert.NotNil(t, column)
assert.Equal(t, name, column.Name())
assert.Equal(t, len, column.Len())
assert.Equal(t, entity.FieldTypeBool, column.Type())
})
}
func TestColumnInt8(t *testing.T) {
rand.Seed(time.Now().UnixNano())
columnName := fmt.Sprintf("column_Int8_%d", rand.Int())
columnLen := 8 + rand.Intn(10)
v := make([]int8, columnLen)
column := NewColumnInt8(columnName, v)
t.Run("test meta", func(t *testing.T) {
ft := entity.FieldTypeInt8
assert.Equal(t, "Int8", ft.Name())
assert.Equal(t, "int8", ft.String())
pbName, pbType := ft.PbFieldType()
assert.Equal(t, "Int", pbName)
assert.Equal(t, "int32", pbType)
})
t.Run("test column attribute", func(t *testing.T) {
assert.Equal(t, columnName, column.Name())
assert.Equal(t, entity.FieldTypeInt8, column.Type())
assert.Equal(t, columnLen, column.Len())
assert.EqualValues(t, v, column.Data())
})
t.Run("test column field data", func(t *testing.T) {
fd := column.FieldData()
assert.NotNil(t, fd)
assert.Equal(t, fd.GetFieldName(), columnName)
})
t.Run("test column value by idx", func(t *testing.T) {
_, err := column.ValueByIdx(-1)
assert.NotNil(t, err)
_, err = column.ValueByIdx(columnLen)
assert.NotNil(t, err)
for i := 0; i < columnLen; i++ {
v, err := column.ValueByIdx(i)
assert.Nil(t, err)
assert.Equal(t, column.values[i], v)
}
})
}
func TestFieldDataInt8Column(t *testing.T) {
len := rand.Intn(10) + 8
name := fmt.Sprintf("fd_Int8_%d", rand.Int())
fd := &schemapb.FieldData{
Type: schemapb.DataType_Int8,
FieldName: name,
}
t.Run("normal usage", func(t *testing.T) {
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: make([]int32, len),
},
},
},
}
column, err := FieldDataColumn(fd, 0, len)
assert.Nil(t, err)
assert.NotNil(t, column)
assert.Equal(t, name, column.Name())
assert.Equal(t, len, column.Len())
assert.Equal(t, entity.FieldTypeInt8, column.Type())
var ev int8
err = column.AppendValue(ev)
assert.Equal(t, len+1, column.Len())
assert.Nil(t, err)
err = column.AppendValue(struct{}{})
assert.Equal(t, len+1, column.Len())
assert.NotNil(t, err)
})
t.Run("nil data", func(t *testing.T) {
fd.Field = nil
_, err := FieldDataColumn(fd, 0, len)
assert.NotNil(t, err)
})
t.Run("get all data", func(t *testing.T) {
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: make([]int32, len),
},
},
},
}
column, err := FieldDataColumn(fd, 0, -1)
assert.Nil(t, err)
assert.NotNil(t, column)
assert.Equal(t, name, column.Name())
assert.Equal(t, len, column.Len())
assert.Equal(t, entity.FieldTypeInt8, column.Type())
})
}
func TestColumnInt16(t *testing.T) {
rand.Seed(time.Now().UnixNano())
columnName := fmt.Sprintf("column_Int16_%d", rand.Int())
columnLen := 8 + rand.Intn(10)
v := make([]int16, columnLen)
column := NewColumnInt16(columnName, v)
t.Run("test meta", func(t *testing.T) {
ft := entity.FieldTypeInt16
assert.Equal(t, "Int16", ft.Name())
assert.Equal(t, "int16", ft.String())
pbName, pbType := ft.PbFieldType()
assert.Equal(t, "Int", pbName)
assert.Equal(t, "int32", pbType)
})
t.Run("test column attribute", func(t *testing.T) {
assert.Equal(t, columnName, column.Name())
assert.Equal(t, entity.FieldTypeInt16, column.Type())
assert.Equal(t, columnLen, column.Len())
assert.EqualValues(t, v, column.Data())
})
t.Run("test column field data", func(t *testing.T) {
fd := column.FieldData()
assert.NotNil(t, fd)
assert.Equal(t, fd.GetFieldName(), columnName)
})
t.Run("test column value by idx", func(t *testing.T) {
_, err := column.ValueByIdx(-1)
assert.NotNil(t, err)
_, err = column.ValueByIdx(columnLen)
assert.NotNil(t, err)
for i := 0; i < columnLen; i++ {
v, err := column.ValueByIdx(i)
assert.Nil(t, err)
assert.Equal(t, column.values[i], v)
}
})
}
func TestFieldDataInt16Column(t *testing.T) {
len := rand.Intn(10) + 8
name := fmt.Sprintf("fd_Int16_%d", rand.Int())
fd := &schemapb.FieldData{
Type: schemapb.DataType_Int16,
FieldName: name,
}
t.Run("normal usage", func(t *testing.T) {
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: make([]int32, len),
},
},
},
}
column, err := FieldDataColumn(fd, 0, len)
assert.Nil(t, err)
assert.NotNil(t, column)
assert.Equal(t, name, column.Name())
assert.Equal(t, len, column.Len())
assert.Equal(t, entity.FieldTypeInt16, column.Type())
var ev int16
err = column.AppendValue(ev)
assert.Equal(t, len+1, column.Len())
assert.Nil(t, err)
err = column.AppendValue(struct{}{})
assert.Equal(t, len+1, column.Len())
assert.NotNil(t, err)
})
t.Run("nil data", func(t *testing.T) {
fd.Field = nil
_, err := FieldDataColumn(fd, 0, len)
assert.NotNil(t, err)
})
t.Run("get all data", func(t *testing.T) {
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: make([]int32, len),
},
},
},
}
column, err := FieldDataColumn(fd, 0, -1)
assert.Nil(t, err)
assert.NotNil(t, column)
assert.Equal(t, name, column.Name())
assert.Equal(t, len, column.Len())
assert.Equal(t, entity.FieldTypeInt16, column.Type())
})
}
func TestColumnInt32(t *testing.T) {
rand.Seed(time.Now().UnixNano())
columnName := fmt.Sprintf("column_Int32_%d", rand.Int())
columnLen := 8 + rand.Intn(10)
v := make([]int32, columnLen)
column := NewColumnInt32(columnName, v)
t.Run("test meta", func(t *testing.T) {
ft := entity.FieldTypeInt32
assert.Equal(t, "Int32", ft.Name())
assert.Equal(t, "int32", ft.String())
pbName, pbType := ft.PbFieldType()
assert.Equal(t, "Int", pbName)
assert.Equal(t, "int32", pbType)
})
t.Run("test column attribute", func(t *testing.T) {
assert.Equal(t, columnName, column.Name())
assert.Equal(t, entity.FieldTypeInt32, column.Type())
assert.Equal(t, columnLen, column.Len())
assert.EqualValues(t, v, column.Data())
})
t.Run("test column field data", func(t *testing.T) {
fd := column.FieldData()
assert.NotNil(t, fd)
assert.Equal(t, fd.GetFieldName(), columnName)
})
t.Run("test column value by idx", func(t *testing.T) {
_, err := column.ValueByIdx(-1)
assert.NotNil(t, err)
_, err = column.ValueByIdx(columnLen)
assert.NotNil(t, err)
for i := 0; i < columnLen; i++ {
v, err := column.ValueByIdx(i)
assert.Nil(t, err)
assert.Equal(t, column.values[i], v)
}
})
}
func TestFieldDataInt32Column(t *testing.T) {
len := rand.Intn(10) + 8
name := fmt.Sprintf("fd_Int32_%d", rand.Int())
fd := &schemapb.FieldData{
Type: schemapb.DataType_Int32,
FieldName: name,
}
t.Run("normal usage", func(t *testing.T) {
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: make([]int32, len),
},
},
},
}
column, err := FieldDataColumn(fd, 0, len)
assert.Nil(t, err)
assert.NotNil(t, column)
assert.Equal(t, name, column.Name())
assert.Equal(t, len, column.Len())
assert.Equal(t, entity.FieldTypeInt32, column.Type())
var ev int32
err = column.AppendValue(ev)
assert.Equal(t, len+1, column.Len())
assert.Nil(t, err)
err = column.AppendValue(struct{}{})
assert.Equal(t, len+1, column.Len())
assert.NotNil(t, err)
})
t.Run("nil data", func(t *testing.T) {
fd.Field = nil
_, err := FieldDataColumn(fd, 0, len)
assert.NotNil(t, err)
})
t.Run("get all data", func(t *testing.T) {
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: make([]int32, len),
},
},
},
}
column, err := FieldDataColumn(fd, 0, -1)
assert.Nil(t, err)
assert.NotNil(t, column)
assert.Equal(t, name, column.Name())
assert.Equal(t, len, column.Len())
assert.Equal(t, entity.FieldTypeInt32, column.Type())
})
}
func TestColumnInt64(t *testing.T) {
rand.Seed(time.Now().UnixNano())
columnName := fmt.Sprintf("column_Int64_%d", rand.Int())
columnLen := 8 + rand.Intn(10)
v := make([]int64, columnLen)
column := NewColumnInt64(columnName, v)
t.Run("test meta", func(t *testing.T) {
ft := entity.FieldTypeInt64
assert.Equal(t, "Int64", ft.Name())
assert.Equal(t, "int64", ft.String())
pbName, pbType := ft.PbFieldType()
assert.Equal(t, "Long", pbName)
assert.Equal(t, "int64", pbType)
})
t.Run("test column attribute", func(t *testing.T) {
assert.Equal(t, columnName, column.Name())
assert.Equal(t, entity.FieldTypeInt64, column.Type())
assert.Equal(t, columnLen, column.Len())
assert.EqualValues(t, v, column.Data())
})
t.Run("test column field data", func(t *testing.T) {
fd := column.FieldData()
assert.NotNil(t, fd)
assert.Equal(t, fd.GetFieldName(), columnName)
})
t.Run("test column value by idx", func(t *testing.T) {
_, err := column.ValueByIdx(-1)
assert.NotNil(t, err)
_, err = column.ValueByIdx(columnLen)
assert.NotNil(t, err)
for i := 0; i < columnLen; i++ {
v, err := column.ValueByIdx(i)
assert.Nil(t, err)
assert.Equal(t, column.values[i], v)
}
})
}
func TestFieldDataInt64Column(t *testing.T) {
len := rand.Intn(10) + 8
name := fmt.Sprintf("fd_Int64_%d", rand.Int())
fd := &schemapb.FieldData{
Type: schemapb.DataType_Int64,
FieldName: name,
}
t.Run("normal usage", func(t *testing.T) {
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: make([]int64, len),
},
},
},
}
column, err := FieldDataColumn(fd, 0, len)
assert.Nil(t, err)
assert.NotNil(t, column)
assert.Equal(t, name, column.Name())
assert.Equal(t, len, column.Len())
assert.Equal(t, entity.FieldTypeInt64, column.Type())
var ev int64
err = column.AppendValue(ev)
assert.Equal(t, len+1, column.Len())
assert.Nil(t, err)
err = column.AppendValue(struct{}{})
assert.Equal(t, len+1, column.Len())
assert.NotNil(t, err)
})
t.Run("nil data", func(t *testing.T) {
fd.Field = nil
_, err := FieldDataColumn(fd, 0, len)
assert.NotNil(t, err)
})
t.Run("get all data", func(t *testing.T) {
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: make([]int64, len),
},
},
},
}
column, err := FieldDataColumn(fd, 0, -1)
assert.Nil(t, err)
assert.NotNil(t, column)
assert.Equal(t, name, column.Name())
assert.Equal(t, len, column.Len())
assert.Equal(t, entity.FieldTypeInt64, column.Type())
})
}
func TestColumnFloat(t *testing.T) {
rand.Seed(time.Now().UnixNano())
columnName := fmt.Sprintf("column_Float_%d", rand.Int())
columnLen := 8 + rand.Intn(10)
v := make([]float32, columnLen)
column := NewColumnFloat(columnName, v)
t.Run("test meta", func(t *testing.T) {
ft := entity.FieldTypeFloat
assert.Equal(t, "Float", ft.Name())
assert.Equal(t, "float32", ft.String())
pbName, pbType := ft.PbFieldType()
assert.Equal(t, "Float", pbName)
assert.Equal(t, "float32", pbType)
})
t.Run("test column attribute", func(t *testing.T) {
assert.Equal(t, columnName, column.Name())
assert.Equal(t, entity.FieldTypeFloat, column.Type())
assert.Equal(t, columnLen, column.Len())
assert.EqualValues(t, v, column.Data())
})
t.Run("test column field data", func(t *testing.T) {
fd := column.FieldData()
assert.NotNil(t, fd)
assert.Equal(t, fd.GetFieldName(), columnName)
})
t.Run("test column value by idx", func(t *testing.T) {
_, err := column.ValueByIdx(-1)
assert.NotNil(t, err)
_, err = column.ValueByIdx(columnLen)
assert.NotNil(t, err)
for i := 0; i < columnLen; i++ {
v, err := column.ValueByIdx(i)
assert.Nil(t, err)
assert.Equal(t, column.values[i], v)
}
})
}
func TestFieldDataFloatColumn(t *testing.T) {
len := rand.Intn(10) + 8
name := fmt.Sprintf("fd_Float_%d", rand.Int())
fd := &schemapb.FieldData{
Type: schemapb.DataType_Float,
FieldName: name,
}
t.Run("normal usage", func(t *testing.T) {
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_FloatData{
FloatData: &schemapb.FloatArray{
Data: make([]float32, len),
},
},
},
}
column, err := FieldDataColumn(fd, 0, len)
assert.Nil(t, err)
assert.NotNil(t, column)
assert.Equal(t, name, column.Name())
assert.Equal(t, len, column.Len())
assert.Equal(t, entity.FieldTypeFloat, column.Type())
var ev float32
err = column.AppendValue(ev)
assert.Equal(t, len+1, column.Len())
assert.Nil(t, err)
err = column.AppendValue(struct{}{})
assert.Equal(t, len+1, column.Len())
assert.NotNil(t, err)
})
t.Run("nil data", func(t *testing.T) {
fd.Field = nil
_, err := FieldDataColumn(fd, 0, len)
assert.NotNil(t, err)
})
t.Run("get all data", func(t *testing.T) {
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_FloatData{
FloatData: &schemapb.FloatArray{
Data: make([]float32, len),
},
},
},
}
column, err := FieldDataColumn(fd, 0, -1)
assert.Nil(t, err)
assert.NotNil(t, column)
assert.Equal(t, name, column.Name())
assert.Equal(t, len, column.Len())
assert.Equal(t, entity.FieldTypeFloat, column.Type())
})
}
func TestColumnDouble(t *testing.T) {
rand.Seed(time.Now().UnixNano())
columnName := fmt.Sprintf("column_Double_%d", rand.Int())
columnLen := 8 + rand.Intn(10)
v := make([]float64, columnLen)
column := NewColumnDouble(columnName, v)
t.Run("test meta", func(t *testing.T) {
ft := entity.FieldTypeDouble
assert.Equal(t, "Double", ft.Name())
assert.Equal(t, "float64", ft.String())
pbName, pbType := ft.PbFieldType()
assert.Equal(t, "Double", pbName)
assert.Equal(t, "float64", pbType)
})
t.Run("test column attribute", func(t *testing.T) {
assert.Equal(t, columnName, column.Name())
assert.Equal(t, entity.FieldTypeDouble, column.Type())
assert.Equal(t, columnLen, column.Len())
assert.EqualValues(t, v, column.Data())
})
t.Run("test column field data", func(t *testing.T) {
fd := column.FieldData()
assert.NotNil(t, fd)
assert.Equal(t, fd.GetFieldName(), columnName)
})
t.Run("test column value by idx", func(t *testing.T) {
_, err := column.ValueByIdx(-1)
assert.NotNil(t, err)
_, err = column.ValueByIdx(columnLen)
assert.NotNil(t, err)
for i := 0; i < columnLen; i++ {
v, err := column.ValueByIdx(i)
assert.Nil(t, err)
assert.Equal(t, column.values[i], v)
}
})
}
func TestFieldDataDoubleColumn(t *testing.T) {
len := rand.Intn(10) + 8
name := fmt.Sprintf("fd_Double_%d", rand.Int())
fd := &schemapb.FieldData{
Type: schemapb.DataType_Double,
FieldName: name,
}
t.Run("normal usage", func(t *testing.T) {
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_DoubleData{
DoubleData: &schemapb.DoubleArray{
Data: make([]float64, len),
},
},
},
}
column, err := FieldDataColumn(fd, 0, len)
assert.Nil(t, err)
assert.NotNil(t, column)
assert.Equal(t, name, column.Name())
assert.Equal(t, len, column.Len())
assert.Equal(t, entity.FieldTypeDouble, column.Type())
var ev float64
err = column.AppendValue(ev)
assert.Equal(t, len+1, column.Len())
assert.Nil(t, err)
err = column.AppendValue(struct{}{})
assert.Equal(t, len+1, column.Len())
assert.NotNil(t, err)
})
t.Run("nil data", func(t *testing.T) {
fd.Field = nil
_, err := FieldDataColumn(fd, 0, len)
assert.NotNil(t, err)
})
t.Run("get all data", func(t *testing.T) {
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_DoubleData{
DoubleData: &schemapb.DoubleArray{
Data: make([]float64, len),
},
},
},
}
column, err := FieldDataColumn(fd, 0, -1)
assert.Nil(t, err)
assert.NotNil(t, column)
assert.Equal(t, name, column.Name())
assert.Equal(t, len, column.Len())
assert.Equal(t, entity.FieldTypeDouble, column.Type())
})
}
func TestColumnString(t *testing.T) {
rand.Seed(time.Now().UnixNano())
columnName := fmt.Sprintf("column_String_%d", rand.Int())
columnLen := 8 + rand.Intn(10)
v := make([]string, columnLen)
column := NewColumnString(columnName, v)
t.Run("test meta", func(t *testing.T) {
ft := entity.FieldTypeString
assert.Equal(t, "String", ft.Name())
assert.Equal(t, "string", ft.String())
pbName, pbType := ft.PbFieldType()
assert.Equal(t, "String", pbName)
assert.Equal(t, "string", pbType)
})
t.Run("test column attribute", func(t *testing.T) {
assert.Equal(t, columnName, column.Name())
assert.Equal(t, entity.FieldTypeString, column.Type())
assert.Equal(t, columnLen, column.Len())
assert.EqualValues(t, v, column.Data())
})
t.Run("test column field data", func(t *testing.T) {
fd := column.FieldData()
assert.NotNil(t, fd)
assert.Equal(t, fd.GetFieldName(), columnName)
})
t.Run("test column value by idx", func(t *testing.T) {
_, err := column.ValueByIdx(-1)
assert.NotNil(t, err)
_, err = column.ValueByIdx(columnLen)
assert.NotNil(t, err)
for i := 0; i < columnLen; i++ {
v, err := column.ValueByIdx(i)
assert.Nil(t, err)
assert.Equal(t, column.values[i], v)
}
})
}
func TestFieldDataStringColumn(t *testing.T) {
len := rand.Intn(10) + 8
name := fmt.Sprintf("fd_String_%d", rand.Int())
fd := &schemapb.FieldData{
Type: schemapb.DataType_String,
FieldName: name,
}
t.Run("normal usage", func(t *testing.T) {
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: make([]string, len),
},
},
},
}
column, err := FieldDataColumn(fd, 0, len)
assert.Nil(t, err)
assert.NotNil(t, column)
assert.Equal(t, name, column.Name())
assert.Equal(t, len, column.Len())
assert.Equal(t, entity.FieldTypeString, column.Type())
var ev string
err = column.AppendValue(ev)
assert.Equal(t, len+1, column.Len())
assert.Nil(t, err)
err = column.AppendValue(struct{}{})
assert.Equal(t, len+1, column.Len())
assert.NotNil(t, err)
})
t.Run("nil data", func(t *testing.T) {
fd.Field = nil
_, err := FieldDataColumn(fd, 0, len)
assert.NotNil(t, err)
})
t.Run("get all data", func(t *testing.T) {
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: make([]string, len),
},
},
},
}
column, err := FieldDataColumn(fd, 0, -1)
assert.Nil(t, err)
assert.NotNil(t, column)
assert.Equal(t, name, column.Name())
assert.Equal(t, len, column.Len())
assert.Equal(t, entity.FieldTypeString, column.Type())
})
}

125
client/column/sparse.go Normal file
View File

@ -0,0 +1,125 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package column
import (
"encoding/binary"
"fmt"
"math"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/entity"
)
var _ (Column) = (*ColumnSparseFloatVector)(nil)
type ColumnSparseFloatVector struct {
ColumnBase
vectors []entity.SparseEmbedding
name string
}
// Name returns column name.
func (c *ColumnSparseFloatVector) Name() string {
return c.name
}
// Type returns column FieldType.
func (c *ColumnSparseFloatVector) Type() entity.FieldType {
return entity.FieldTypeSparseVector
}
// Len returns column values length.
func (c *ColumnSparseFloatVector) Len() int {
return len(c.vectors)
}
// Get returns value at index as interface{}.
func (c *ColumnSparseFloatVector) Get(idx int) (interface{}, error) {
if idx < 0 || idx >= c.Len() {
return nil, errors.New("index out of range")
}
return c.vectors[idx], nil
}
// ValueByIdx returns value of the provided index
// error occurs when index out of range
func (c *ColumnSparseFloatVector) ValueByIdx(idx int) (entity.SparseEmbedding, error) {
var r entity.SparseEmbedding // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.vectors[idx], nil
}
func (c *ColumnSparseFloatVector) FieldData() *schemapb.FieldData {
fd := &schemapb.FieldData{
Type: schemapb.DataType_SparseFloatVector,
FieldName: c.name,
}
dim := int(0)
data := make([][]byte, 0, len(c.vectors))
for _, vector := range c.vectors {
row := make([]byte, 8*vector.Len())
for idx := 0; idx < vector.Len(); idx++ {
pos, value, _ := vector.Get(idx)
binary.LittleEndian.PutUint32(row[idx*8:], pos)
binary.LittleEndian.PutUint32(row[pos*8+4:], math.Float32bits(value))
}
data = append(data, row)
if vector.Dim() > dim {
dim = vector.Dim()
}
}
fd.Field = &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_SparseFloatVector{
SparseFloatVector: &schemapb.SparseFloatArray{
Dim: int64(dim),
Contents: data,
},
},
},
}
return fd
}
func (c *ColumnSparseFloatVector) AppendValue(i interface{}) error {
v, ok := i.(entity.SparseEmbedding)
if !ok {
return fmt.Errorf("invalid type, expect SparseEmbedding interface, got %T", i)
}
c.vectors = append(c.vectors, v)
return nil
}
func (c *ColumnSparseFloatVector) Data() []entity.SparseEmbedding {
return c.vectors
}
func NewColumnSparseVectors(name string, values []entity.SparseEmbedding) *ColumnSparseFloatVector {
return &ColumnSparseFloatVector{
name: name,
vectors: values,
}
}

View File

@ -0,0 +1,81 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package column
import (
"fmt"
"math/rand"
"testing"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestColumnSparseEmbedding(t *testing.T) {
columnName := fmt.Sprintf("column_sparse_embedding_%d", rand.Int())
columnLen := 8 + rand.Intn(10)
v := make([]entity.SparseEmbedding, 0, columnLen)
for i := 0; i < columnLen; i++ {
length := 1 + rand.Intn(5)
positions := make([]uint32, length)
values := make([]float32, length)
for j := 0; j < length; j++ {
positions[j] = uint32(j)
values[j] = rand.Float32()
}
se, err := entity.NewSliceSparseEmbedding(positions, values)
require.NoError(t, err)
v = append(v, se)
}
column := NewColumnSparseVectors(columnName, v)
t.Run("test column attribute", func(t *testing.T) {
assert.Equal(t, columnName, column.Name())
assert.Equal(t, entity.FieldTypeSparseVector, column.Type())
assert.Equal(t, columnLen, column.Len())
assert.EqualValues(t, v, column.Data())
})
t.Run("test column field data", func(t *testing.T) {
fd := column.FieldData()
assert.NotNil(t, fd)
assert.Equal(t, fd.GetFieldName(), columnName)
})
t.Run("test column value by idx", func(t *testing.T) {
_, err := column.ValueByIdx(-1)
assert.Error(t, err)
_, err = column.ValueByIdx(columnLen)
assert.Error(t, err)
_, err = column.Get(-1)
assert.Error(t, err)
_, err = column.Get(columnLen)
assert.Error(t, err)
for i := 0; i < columnLen; i++ {
v, err := column.ValueByIdx(i)
assert.NoError(t, err)
assert.Equal(t, column.vectors[i], v)
getV, err := column.Get(i)
assert.NoError(t, err)
assert.Equal(t, v, getV)
}
})
}

119
client/column/varchar.go Normal file
View File

@ -0,0 +1,119 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package column
import (
"errors"
"fmt"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/entity"
)
// ColumnVarChar generated columns type for VarChar
type ColumnVarChar struct {
ColumnBase
name string
values []string
}
// Name returns column name
func (c *ColumnVarChar) Name() string {
return c.name
}
// Type returns column entity.FieldType
func (c *ColumnVarChar) Type() entity.FieldType {
return entity.FieldTypeVarChar
}
// Len returns column values length
func (c *ColumnVarChar) Len() int {
return len(c.values)
}
// Get returns value at index as interface{}.
func (c *ColumnVarChar) Get(idx int) (interface{}, error) {
if idx < 0 || idx > c.Len() {
return "", errors.New("index out of range")
}
return c.values[idx], nil
}
// GetAsString returns value at idx.
func (c *ColumnVarChar) GetAsString(idx int) (string, error) {
if idx < 0 || idx > c.Len() {
return "", errors.New("index out of range")
}
return c.values[idx], nil
}
// FieldData return column data mapped to schemapb.FieldData
func (c *ColumnVarChar) FieldData() *schemapb.FieldData {
fd := &schemapb.FieldData{
Type: schemapb.DataType_VarChar,
FieldName: c.name,
}
data := make([]string, 0, c.Len())
for i := 0; i < c.Len(); i++ {
data = append(data, string(c.values[i]))
}
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: data,
},
},
},
}
return fd
}
// ValueByIdx returns value of the provided index
// error occurs when index out of range
func (c *ColumnVarChar) ValueByIdx(idx int) (string, error) {
var r string // use default value
if idx < 0 || idx >= c.Len() {
return r, errors.New("index out of range")
}
return c.values[idx], nil
}
// AppendValue append value into column
func (c *ColumnVarChar) AppendValue(i interface{}) error {
v, ok := i.(string)
if !ok {
return fmt.Errorf("invalid type, expected string, got %T", i)
}
c.values = append(c.values, v)
return nil
}
// Data returns column data
func (c *ColumnVarChar) Data() []string {
return c.values
}
// NewColumnVarChar auto generated constructor
func NewColumnVarChar(name string, values []string) *ColumnVarChar {
return &ColumnVarChar{
name: name,
values: values,
}
}

View File

@ -0,0 +1,134 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package column
import (
"fmt"
"math/rand"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/entity"
)
func TestColumnVarChar(t *testing.T) {
rand.Seed(time.Now().UnixNano())
columnName := fmt.Sprintf("column_VarChar_%d", rand.Int())
columnLen := 8 + rand.Intn(10)
v := make([]string, columnLen)
column := NewColumnVarChar(columnName, v)
t.Run("test meta", func(t *testing.T) {
ft := entity.FieldTypeVarChar
assert.Equal(t, "VarChar", ft.Name())
assert.Equal(t, "string", ft.String())
pbName, pbType := ft.PbFieldType()
assert.Equal(t, "VarChar", pbName)
assert.Equal(t, "string", pbType)
})
t.Run("test column attribute", func(t *testing.T) {
assert.Equal(t, columnName, column.Name())
assert.Equal(t, entity.FieldTypeVarChar, column.Type())
assert.Equal(t, columnLen, column.Len())
assert.EqualValues(t, v, column.Data())
})
t.Run("test column field data", func(t *testing.T) {
fd := column.FieldData()
assert.NotNil(t, fd)
assert.Equal(t, fd.GetFieldName(), columnName)
})
t.Run("test column value by idx", func(t *testing.T) {
_, err := column.ValueByIdx(-1)
assert.NotNil(t, err)
_, err = column.ValueByIdx(columnLen)
assert.NotNil(t, err)
for i := 0; i < columnLen; i++ {
v, err := column.ValueByIdx(i)
assert.Nil(t, err)
assert.Equal(t, column.values[i], v)
}
})
}
func TestFieldDataVarCharColumn(t *testing.T) {
colLen := rand.Intn(10) + 8
name := fmt.Sprintf("fd_VarChar_%d", rand.Int())
fd := &schemapb.FieldData{
Type: schemapb.DataType_VarChar,
FieldName: name,
}
t.Run("normal usage", func(t *testing.T) {
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: make([]string, colLen),
},
},
},
}
column, err := FieldDataColumn(fd, 0, colLen)
assert.Nil(t, err)
assert.NotNil(t, column)
assert.Equal(t, name, column.Name())
assert.Equal(t, colLen, column.Len())
assert.Equal(t, entity.FieldTypeVarChar, column.Type())
var ev string
err = column.AppendValue(ev)
assert.Equal(t, colLen+1, column.Len())
assert.Nil(t, err)
err = column.AppendValue(struct{}{})
assert.Equal(t, colLen+1, column.Len())
assert.NotNil(t, err)
})
t.Run("nil data", func(t *testing.T) {
fd.Field = nil
_, err := FieldDataColumn(fd, 0, colLen)
assert.NotNil(t, err)
})
t.Run("get all data", func(t *testing.T) {
fd.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: make([]string, colLen),
},
},
},
}
column, err := FieldDataColumn(fd, 0, -1)
assert.Nil(t, err)
assert.NotNil(t, column)
assert.Equal(t, name, column.Name())
assert.Equal(t, colLen, column.Len())
assert.Equal(t, entity.FieldTypeVarChar, column.Type())
})
}

358
client/column/vector_gen.go Normal file
View File

@ -0,0 +1,358 @@
// Code generated by go generate; DO NOT EDIT
// This file is generated by go generated
package column
import (
"fmt"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/cockroachdb/errors"
)
// ColumnBinaryVector generated columns type for BinaryVector
type ColumnBinaryVector struct {
ColumnBase
name string
dim int
values [][]byte
}
// Name returns column name
func (c *ColumnBinaryVector) Name() string {
return c.name
}
// Type returns column entity.FieldType
func (c *ColumnBinaryVector) Type() entity.FieldType {
return entity.FieldTypeBinaryVector
}
// Len returns column data length
func (c *ColumnBinaryVector) Len() int {
return len(c.values)
}
// Dim returns vector dimension
func (c *ColumnBinaryVector) Dim() int {
return c.dim
}
// Get returns values at index as interface{}.
func (c *ColumnBinaryVector) Get(idx int) (interface{}, error) {
if idx < 0 || idx >= c.Len() {
return nil, errors.New("index out of range")
}
return c.values[idx], nil
}
// AppendValue append value into column
func (c *ColumnBinaryVector) AppendValue(i interface{}) error {
v, ok := i.([]byte)
if !ok {
return fmt.Errorf("invalid type, expected []byte, got %T", i)
}
c.values = append(c.values, v)
return nil
}
// Data returns column data
func (c *ColumnBinaryVector) Data() [][]byte {
return c.values
}
// FieldData return column data mapped to schemapb.FieldData
func (c *ColumnBinaryVector) FieldData() *schemapb.FieldData {
fd := &schemapb.FieldData{
Type: schemapb.DataType_BinaryVector,
FieldName: c.name,
}
data := make([]byte, 0, len(c.values)*c.dim)
for _, vector := range c.values {
data = append(data, vector...)
}
fd.Field = &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: int64(c.dim),
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: data,
},
},
}
return fd
}
// NewColumnBinaryVector auto generated constructor
func NewColumnBinaryVector(name string, dim int, values [][]byte) *ColumnBinaryVector {
return &ColumnBinaryVector{
name: name,
dim: dim,
values: values,
}
}
// ColumnFloatVector generated columns type for FloatVector
type ColumnFloatVector struct {
ColumnBase
name string
dim int
values [][]float32
}
// Name returns column name
func (c *ColumnFloatVector) Name() string {
return c.name
}
// Type returns column entity.FieldType
func (c *ColumnFloatVector) Type() entity.FieldType {
return entity.FieldTypeFloatVector
}
// Len returns column data length
func (c *ColumnFloatVector) Len() int {
return len(c.values)
}
// Dim returns vector dimension
func (c *ColumnFloatVector) Dim() int {
return c.dim
}
// Get returns values at index as interface{}.
func (c *ColumnFloatVector) Get(idx int) (interface{}, error) {
if idx < 0 || idx >= c.Len() {
return nil, errors.New("index out of range")
}
return c.values[idx], nil
}
// AppendValue append value into column
func (c *ColumnFloatVector) AppendValue(i interface{}) error {
v, ok := i.([]float32)
if !ok {
return fmt.Errorf("invalid type, expected []float32, got %T", i)
}
c.values = append(c.values, v)
return nil
}
// Data returns column data
func (c *ColumnFloatVector) Data() [][]float32 {
return c.values
}
// FieldData return column data mapped to schemapb.FieldData
func (c *ColumnFloatVector) FieldData() *schemapb.FieldData {
fd := &schemapb.FieldData{
Type: schemapb.DataType_FloatVector,
FieldName: c.name,
}
data := make([]float32, 0, len(c.values)*c.dim)
for _, vector := range c.values {
data = append(data, vector...)
}
fd.Field = &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: int64(c.dim),
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: data,
},
},
},
}
return fd
}
// NewColumnFloatVector auto generated constructor
func NewColumnFloatVector(name string, dim int, values [][]float32) *ColumnFloatVector {
return &ColumnFloatVector{
name: name,
dim: dim,
values: values,
}
}
// ColumnFloat16Vector generated columns type for Float16Vector
type ColumnFloat16Vector struct {
ColumnBase
name string
dim int
values [][]byte
}
// Name returns column name
func (c *ColumnFloat16Vector) Name() string {
return c.name
}
// Type returns column entity.FieldType
func (c *ColumnFloat16Vector) Type() entity.FieldType {
return entity.FieldTypeFloat16Vector
}
// Len returns column data length
func (c *ColumnFloat16Vector) Len() int {
return len(c.values)
}
// Dim returns vector dimension
func (c *ColumnFloat16Vector) Dim() int {
return c.dim
}
// Get returns values at index as interface{}.
func (c *ColumnFloat16Vector) Get(idx int) (interface{}, error) {
if idx < 0 || idx >= c.Len() {
return nil, errors.New("index out of range")
}
return c.values[idx], nil
}
// AppendValue append value into column
func (c *ColumnFloat16Vector) AppendValue(i interface{}) error {
v, ok := i.([]byte)
if !ok {
return fmt.Errorf("invalid type, expected []byte, got %T", i)
}
c.values = append(c.values, v)
return nil
}
// Data returns column data
func (c *ColumnFloat16Vector) Data() [][]byte {
return c.values
}
// FieldData return column data mapped to schemapb.FieldData
func (c *ColumnFloat16Vector) FieldData() *schemapb.FieldData {
fd := &schemapb.FieldData{
Type: schemapb.DataType_Float16Vector,
FieldName: c.name,
}
data := make([]byte, 0, len(c.values)*c.dim)
for _, vector := range c.values {
data = append(data, vector...)
}
fd.Field = &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: int64(c.dim),
Data: &schemapb.VectorField_Float16Vector{
Float16Vector: data,
},
},
}
return fd
}
// NewColumnFloat16Vector auto generated constructor
func NewColumnFloat16Vector(name string, dim int, values [][]byte) *ColumnFloat16Vector {
return &ColumnFloat16Vector{
name: name,
dim: dim,
values: values,
}
}
// ColumnBFloat16Vector generated columns type for BFloat16Vector
type ColumnBFloat16Vector struct {
ColumnBase
name string
dim int
values [][]byte
}
// Name returns column name
func (c *ColumnBFloat16Vector) Name() string {
return c.name
}
// Type returns column entity.FieldType
func (c *ColumnBFloat16Vector) Type() entity.FieldType {
return entity.FieldTypeBFloat16Vector
}
// Len returns column data length
func (c *ColumnBFloat16Vector) Len() int {
return len(c.values)
}
// Dim returns vector dimension
func (c *ColumnBFloat16Vector) Dim() int {
return c.dim
}
// Get returns values at index as interface{}.
func (c *ColumnBFloat16Vector) Get(idx int) (interface{}, error) {
if idx < 0 || idx >= c.Len() {
return nil, errors.New("index out of range")
}
return c.values[idx], nil
}
// AppendValue append value into column
func (c *ColumnBFloat16Vector) AppendValue(i interface{}) error {
v, ok := i.([]byte)
if !ok {
return fmt.Errorf("invalid type, expected []byte, got %T", i)
}
c.values = append(c.values, v)
return nil
}
// Data returns column data
func (c *ColumnBFloat16Vector) Data() [][]byte {
return c.values
}
// FieldData return column data mapped to schemapb.FieldData
func (c *ColumnBFloat16Vector) FieldData() *schemapb.FieldData {
fd := &schemapb.FieldData{
Type: schemapb.DataType_BFloat16Vector,
FieldName: c.name,
}
data := make([]byte, 0, len(c.values)*c.dim)
for _, vector := range c.values {
data = append(data, vector...)
}
fd.Field = &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: int64(c.dim),
Data: &schemapb.VectorField_Bfloat16Vector{
Bfloat16Vector: data,
},
},
}
return fd
}
// NewColumnBFloat16Vector auto generated constructor
func NewColumnBFloat16Vector(name string, dim int, values [][]byte) *ColumnBFloat16Vector {
return &ColumnBFloat16Vector{
name: name,
dim: dim,
values: values,
}
}

View File

@ -0,0 +1,264 @@
// Code generated by go generate; DO NOT EDIT
// This file is generated by go generated
package column
import (
"fmt"
"math/rand"
"testing"
"time"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/stretchr/testify/assert"
)
func TestColumnBinaryVector(t *testing.T) {
rand.Seed(time.Now().UnixNano())
columnName := fmt.Sprintf("column_BinaryVector_%d", rand.Int())
columnLen := 12 + rand.Intn(10)
dim := ([]int{64, 128, 256, 512})[rand.Intn(4)]
v := make([][]byte, 0, columnLen)
dlen := dim
dlen /= 8
for i := 0; i < columnLen; i++ {
entry := make([]byte, dlen)
v = append(v, entry)
}
column := NewColumnBinaryVector(columnName, dim, v)
t.Run("test meta", func(t *testing.T) {
ft := entity.FieldTypeBinaryVector
assert.Equal(t, "BinaryVector", ft.Name())
assert.Equal(t, "[]byte", ft.String())
pbName, pbType := ft.PbFieldType()
assert.Equal(t, "[]byte", pbName)
assert.Equal(t, "", pbType)
})
t.Run("test column attribute", func(t *testing.T) {
assert.Equal(t, columnName, column.Name())
assert.Equal(t, entity.FieldTypeBinaryVector, column.Type())
assert.Equal(t, columnLen, column.Len())
assert.Equal(t, dim, column.Dim())
assert.Equal(t, v, column.Data())
var ev []byte
err := column.AppendValue(ev)
assert.Equal(t, columnLen+1, column.Len())
assert.Nil(t, err)
err = column.AppendValue(struct{}{})
assert.Equal(t, columnLen+1, column.Len())
assert.NotNil(t, err)
})
t.Run("test column field data", func(t *testing.T) {
fd := column.FieldData()
assert.NotNil(t, fd)
assert.Equal(t, fd.GetFieldName(), columnName)
c, err := FieldDataVector(fd)
assert.NotNil(t, c)
assert.NoError(t, err)
})
t.Run("test column field data error", func(t *testing.T) {
fd := &schemapb.FieldData{
Type: schemapb.DataType_BinaryVector,
FieldName: columnName,
}
_, err := FieldDataVector(fd)
assert.Error(t, err)
})
}
func TestColumnFloatVector(t *testing.T) {
rand.Seed(time.Now().UnixNano())
columnName := fmt.Sprintf("column_FloatVector_%d", rand.Int())
columnLen := 12 + rand.Intn(10)
dim := ([]int{64, 128, 256, 512})[rand.Intn(4)]
v := make([][]float32, 0, columnLen)
dlen := dim
for i := 0; i < columnLen; i++ {
entry := make([]float32, dlen)
v = append(v, entry)
}
column := NewColumnFloatVector(columnName, dim, v)
t.Run("test meta", func(t *testing.T) {
ft := entity.FieldTypeFloatVector
assert.Equal(t, "FloatVector", ft.Name())
assert.Equal(t, "[]float32", ft.String())
pbName, pbType := ft.PbFieldType()
assert.Equal(t, "[]float32", pbName)
assert.Equal(t, "", pbType)
})
t.Run("test column attribute", func(t *testing.T) {
assert.Equal(t, columnName, column.Name())
assert.Equal(t, entity.FieldTypeFloatVector, column.Type())
assert.Equal(t, columnLen, column.Len())
assert.Equal(t, dim, column.Dim())
assert.Equal(t, v, column.Data())
var ev []float32
err := column.AppendValue(ev)
assert.Equal(t, columnLen+1, column.Len())
assert.Nil(t, err)
err = column.AppendValue(struct{}{})
assert.Equal(t, columnLen+1, column.Len())
assert.NotNil(t, err)
})
t.Run("test column field data", func(t *testing.T) {
fd := column.FieldData()
assert.NotNil(t, fd)
assert.Equal(t, fd.GetFieldName(), columnName)
c, err := FieldDataVector(fd)
assert.NotNil(t, c)
assert.NoError(t, err)
})
t.Run("test column field data error", func(t *testing.T) {
fd := &schemapb.FieldData{
Type: schemapb.DataType_FloatVector,
FieldName: columnName,
}
_, err := FieldDataVector(fd)
assert.Error(t, err)
})
}
func TestColumnFloat16Vector(t *testing.T) {
rand.Seed(time.Now().UnixNano())
columnName := fmt.Sprintf("column_Float16Vector_%d", rand.Int())
columnLen := 12 + rand.Intn(10)
dim := ([]int{64, 128, 256, 512})[rand.Intn(4)]
v := make([][]byte, 0, columnLen)
dlen := dim
dlen *= 2
for i := 0; i < columnLen; i++ {
entry := make([]byte, dlen)
v = append(v, entry)
}
column := NewColumnFloat16Vector(columnName, dim, v)
t.Run("test meta", func(t *testing.T) {
ft := entity.FieldTypeFloat16Vector
assert.Equal(t, "Float16Vector", ft.Name())
assert.Equal(t, "[]byte", ft.String())
pbName, pbType := ft.PbFieldType()
assert.Equal(t, "[]byte", pbName)
assert.Equal(t, "", pbType)
})
t.Run("test column attribute", func(t *testing.T) {
assert.Equal(t, columnName, column.Name())
assert.Equal(t, entity.FieldTypeFloat16Vector, column.Type())
assert.Equal(t, columnLen, column.Len())
assert.Equal(t, dim, column.Dim())
assert.Equal(t, v, column.Data())
var ev []byte
err := column.AppendValue(ev)
assert.Equal(t, columnLen+1, column.Len())
assert.Nil(t, err)
err = column.AppendValue(struct{}{})
assert.Equal(t, columnLen+1, column.Len())
assert.NotNil(t, err)
})
t.Run("test column field data", func(t *testing.T) {
fd := column.FieldData()
assert.NotNil(t, fd)
assert.Equal(t, fd.GetFieldName(), columnName)
c, err := FieldDataVector(fd)
assert.NotNil(t, c)
assert.NoError(t, err)
})
t.Run("test column field data error", func(t *testing.T) {
fd := &schemapb.FieldData{
Type: schemapb.DataType_Float16Vector,
FieldName: columnName,
}
_, err := FieldDataVector(fd)
assert.Error(t, err)
})
}
func TestColumnBFloat16Vector(t *testing.T) {
rand.Seed(time.Now().UnixNano())
columnName := fmt.Sprintf("column_BFloat16Vector_%d", rand.Int())
columnLen := 12 + rand.Intn(10)
dim := ([]int{64, 128, 256, 512})[rand.Intn(4)]
v := make([][]byte, 0, columnLen)
dlen := dim
dlen *= 2
for i := 0; i < columnLen; i++ {
entry := make([]byte, dlen)
v = append(v, entry)
}
column := NewColumnBFloat16Vector(columnName, dim, v)
t.Run("test meta", func(t *testing.T) {
ft := entity.FieldTypeBFloat16Vector
assert.Equal(t, "BFloat16Vector", ft.Name())
assert.Equal(t, "[]byte", ft.String())
pbName, pbType := ft.PbFieldType()
assert.Equal(t, "[]byte", pbName)
assert.Equal(t, "", pbType)
})
t.Run("test column attribute", func(t *testing.T) {
assert.Equal(t, columnName, column.Name())
assert.Equal(t, entity.FieldTypeBFloat16Vector, column.Type())
assert.Equal(t, columnLen, column.Len())
assert.Equal(t, dim, column.Dim())
assert.Equal(t, v, column.Data())
var ev []byte
err := column.AppendValue(ev)
assert.Equal(t, columnLen+1, column.Len())
assert.Nil(t, err)
err = column.AppendValue(struct{}{})
assert.Equal(t, columnLen+1, column.Len())
assert.NotNil(t, err)
})
t.Run("test column field data", func(t *testing.T) {
fd := column.FieldData()
assert.NotNil(t, fd)
assert.Equal(t, fd.GetFieldName(), columnName)
c, err := FieldDataVector(fd)
assert.NotNil(t, c)
assert.NoError(t, err)
})
t.Run("test column field data error", func(t *testing.T) {
fd := &schemapb.FieldData{
Type: schemapb.DataType_BFloat16Vector,
FieldName: columnName,
}
_, err := FieldDataVector(fd)
assert.Error(t, err)
})
}

44
client/common.go Normal file
View File

@ -0,0 +1,44 @@
package client
import (
"context"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/pkg/util/conc"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
// CollectionCache stores the cached collection schema information.
type CollectionCache struct {
sf conc.Singleflight[*entity.Collection]
collections *typeutil.ConcurrentMap[string, *entity.Collection]
fetcher func(context.Context, string) (*entity.Collection, error)
}
func (c *CollectionCache) GetCollection(ctx context.Context, collName string) (*entity.Collection, error) {
coll, ok := c.collections.Get(collName)
if ok {
return coll, nil
}
coll, err, _ := c.sf.Do(collName, func() (*entity.Collection, error) {
coll, err := c.fetcher(ctx, collName)
if err != nil {
return nil, err
}
c.collections.Insert(collName, coll)
return coll, nil
})
return coll, err
}
func NewCollectionCache(fetcher func(context.Context, string) (*entity.Collection, error)) *CollectionCache {
return &CollectionCache{
collections: typeutil.NewConcurrentMap[string, *entity.Collection](),
fetcher: fetcher,
}
}
func (c *Client) getCollection(ctx context.Context, collName string) (*entity.Collection, error) {
return c.collCache.GetCollection(ctx, collName)
}

22
client/common/version.go Normal file
View File

@ -0,0 +1,22 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package common
const (
// SDKVersion const value for current version
SDKVersion = `2.4.0-dev`
)

View File

@ -0,0 +1,29 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package common
import (
"testing"
"github.com/blang/semver/v4"
"github.com/stretchr/testify/assert"
)
func TestVersion(t *testing.T) {
_, err := semver.Parse(SDKVersion)
assert.NoError(t, err)
}

60
client/database.go Normal file
View File

@ -0,0 +1,60 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"context"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/pkg/util/merr"
)
func (c *Client) ListDatabase(ctx context.Context, option ListDatabaseOption, callOptions ...grpc.CallOption) (databaseNames []string, err error) {
req := option.Request()
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.ListDatabases(ctx, req, callOptions...)
err = merr.CheckRPCCall(resp, err)
if err != nil {
return err
}
databaseNames = resp.GetDbNames()
return nil
})
return databaseNames, err
}
func (c *Client) CreateDatabase(ctx context.Context, option CreateDatabaseOption, callOptions ...grpc.CallOption) error {
req := option.Request()
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.CreateDatabase(ctx, req, callOptions...)
return merr.CheckRPCCall(resp, err)
})
}
func (c *Client) DropDatabase(ctx context.Context, option DropDatabaseOption, callOptions ...grpc.CallOption) error {
req := option.Request()
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.DropDatabase(ctx, req, callOptions...)
return merr.CheckRPCCall(resp, err)
})
}

View File

@ -0,0 +1,74 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
// ListDatabaseOption is a builder interface for ListDatabase request.
type ListDatabaseOption interface {
Request() *milvuspb.ListDatabasesRequest
}
type listDatabaseOption struct{}
func (opt *listDatabaseOption) Request() *milvuspb.ListDatabasesRequest {
return &milvuspb.ListDatabasesRequest{}
}
func NewListDatabaseOption() *listDatabaseOption {
return &listDatabaseOption{}
}
type CreateDatabaseOption interface {
Request() *milvuspb.CreateDatabaseRequest
}
type createDatabaseOption struct {
dbName string
}
func (opt *createDatabaseOption) Request() *milvuspb.CreateDatabaseRequest {
return &milvuspb.CreateDatabaseRequest{
DbName: opt.dbName,
}
}
func NewCreateDatabaseOption(dbName string) *createDatabaseOption {
return &createDatabaseOption{
dbName: dbName,
}
}
type DropDatabaseOption interface {
Request() *milvuspb.DropDatabaseRequest
}
type dropDatabaseOption struct {
dbName string
}
func (opt *dropDatabaseOption) Request() *milvuspb.DropDatabaseRequest {
return &milvuspb.DropDatabaseRequest{
DbName: opt.dbName,
}
}
func NewDropDatabaseOption(dbName string) *dropDatabaseOption {
return &dropDatabaseOption{
dbName: dbName,
}
}

92
client/database_test.go Normal file
View File

@ -0,0 +1,92 @@
package client
import (
"context"
"fmt"
"testing"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/pkg/util/merr"
mock "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
)
type DatabaseSuite struct {
MockSuiteBase
}
func (s *DatabaseSuite) TestListDatabases() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
s.mock.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
Status: merr.Success(),
DbNames: []string{"default", "db1"},
}, nil).Once()
names, err := s.client.ListDatabase(ctx, NewListDatabaseOption())
s.NoError(err)
s.ElementsMatch([]string{"default", "db1"}, names)
})
s.Run("failure", func() {
s.mock.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
_, err := s.client.ListDatabase(ctx, NewListDatabaseOption())
s.Error(err)
})
}
func (s *DatabaseSuite) TestCreateDatabase() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
dbName := fmt.Sprintf("dt_%s", s.randString(6))
s.mock.EXPECT().CreateDatabase(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cdr *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) {
s.Equal(dbName, cdr.GetDbName())
return merr.Success(), nil
}).Once()
err := s.client.CreateDatabase(ctx, NewCreateDatabaseOption(dbName))
s.NoError(err)
})
s.Run("failure", func() {
dbName := fmt.Sprintf("dt_%s", s.randString(6))
s.mock.EXPECT().CreateDatabase(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
err := s.client.CreateDatabase(ctx, NewCreateDatabaseOption(dbName))
s.Error(err)
})
}
func (s *DatabaseSuite) TestDropDatabase() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
dbName := fmt.Sprintf("dt_%s", s.randString(6))
s.mock.EXPECT().DropDatabase(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ddr *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) {
s.Equal(dbName, ddr.GetDbName())
return merr.Success(), nil
}).Once()
err := s.client.DropDatabase(ctx, NewDropDatabaseOption(dbName))
s.NoError(err)
})
s.Run("failure", func() {
dbName := fmt.Sprintf("dt_%s", s.randString(6))
s.mock.EXPECT().DropDatabase(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
err := s.client.DropDatabase(ctx, NewDropDatabaseOption(dbName))
s.Error(err)
})
}
func TestDatabase(t *testing.T) {
suite.Run(t, new(DatabaseSuite))
}

18
client/doc.go Normal file
View File

@ -0,0 +1,18 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package milvusclient implements the official Go Milvus client for v2.
package client

View File

@ -0,0 +1,56 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package entity
// DefaultShardNumber const value for using Milvus default shard number.
const DefaultShardNumber int32 = 0
// DefaultConsistencyLevel const value for using Milvus default consistency level setting.
const DefaultConsistencyLevel ConsistencyLevel = ClBounded
// Collection represent collection meta in Milvus
type Collection struct {
ID int64 // collection id
Name string // collection name
Schema *Schema // collection schema, with fields schema and primary key definition
PhysicalChannels []string
VirtualChannels []string
Loaded bool
ConsistencyLevel ConsistencyLevel
ShardNum int32
}
// Partition represent partition meta in Milvus
type Partition struct {
ID int64 // partition id
Name string // partition name
Loaded bool // partition loaded
}
// ReplicaGroup represents a replica group
type ReplicaGroup struct {
ReplicaID int64
NodeIDs []int64
ShardReplicas []*ShardReplica
}
// ShardReplica represents a shard in the ReplicaGroup
type ShardReplica struct {
LeaderID int64
NodesIDs []int64
DmChannelName string
}

View File

@ -0,0 +1,96 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package entity
import (
"strconv"
"github.com/cockroachdb/errors"
)
const (
// cakTTL const for collection attribute key TTL in seconds.
cakTTL = `collection.ttl.seconds`
// cakAutoCompaction const for collection attribute key autom compaction enabled.
cakAutoCompaction = `collection.autocompaction.enabled`
)
// CollectionAttribute is the interface for altering collection attributes.
type CollectionAttribute interface {
KeyValue() (string, string)
Valid() error
}
type collAttrBase struct {
key string
value string
}
// KeyValue implements CollectionAttribute.
func (ca collAttrBase) KeyValue() (string, string) {
return ca.key, ca.value
}
type ttlCollAttr struct {
collAttrBase
}
// Valid implements CollectionAttribute.
// checks ttl seconds is valid positive integer.
func (ca collAttrBase) Valid() error {
val, err := strconv.ParseInt(ca.value, 10, 64)
if err != nil {
return errors.Wrap(err, "ttl is not a valid positive integer")
}
if val < 0 {
return errors.New("ttl needs to be a positive integer")
}
return nil
}
// CollectionTTL returns collection attribute to set collection ttl in seconds.
func CollectionTTL(ttl int64) ttlCollAttr {
ca := ttlCollAttr{}
ca.key = cakTTL
ca.value = strconv.FormatInt(ttl, 10)
return ca
}
type autoCompactionCollAttr struct {
collAttrBase
}
// Valid implements CollectionAttribute.
// checks collection auto compaction is valid bool.
func (ca autoCompactionCollAttr) Valid() error {
_, err := strconv.ParseBool(ca.value)
if err != nil {
return errors.Wrap(err, "auto compaction setting is not valid boolean")
}
return nil
}
// CollectionAutoCompactionEnabled returns collection attribute to set collection auto compaction enabled.
func CollectionAutoCompactionEnabled(enabled bool) autoCompactionCollAttr {
ca := autoCompactionCollAttr{}
ca.key = cakAutoCompaction
ca.value = strconv.FormatBool(enabled)
return ca
}

View File

@ -0,0 +1,136 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package entity
import (
"fmt"
"strconv"
"testing"
"github.com/stretchr/testify/suite"
)
type CollectionTTLSuite struct {
suite.Suite
}
func (s *CollectionTTLSuite) TestValid() {
type testCase struct {
input string
expectErr bool
}
cases := []testCase{
{input: "a", expectErr: true},
{input: "1000", expectErr: false},
{input: "0", expectErr: false},
{input: "-10", expectErr: true},
}
for _, tc := range cases {
s.Run(tc.input, func() {
ca := ttlCollAttr{}
ca.value = tc.input
err := ca.Valid()
if tc.expectErr {
s.Error(err)
} else {
s.NoError(err)
}
})
}
}
func (s *CollectionTTLSuite) TestCollectionTTL() {
type testCase struct {
input int64
expectErr bool
}
cases := []testCase{
{input: 1000, expectErr: false},
{input: 0, expectErr: false},
{input: -10, expectErr: true},
}
for _, tc := range cases {
s.Run(fmt.Sprintf("%d", tc.input), func() {
ca := CollectionTTL(tc.input)
key, value := ca.KeyValue()
s.Equal(cakTTL, key)
s.Equal(strconv.FormatInt(tc.input, 10), value)
err := ca.Valid()
if tc.expectErr {
s.Error(err)
} else {
s.NoError(err)
}
})
}
}
func TestCollectionTTL(t *testing.T) {
suite.Run(t, new(CollectionTTLSuite))
}
type CollectionAutoCompactionSuite struct {
suite.Suite
}
func (s *CollectionAutoCompactionSuite) TestValid() {
type testCase struct {
input string
expectErr bool
}
cases := []testCase{
{input: "a", expectErr: true},
{input: "true", expectErr: false},
{input: "false", expectErr: false},
{input: "", expectErr: true},
}
for _, tc := range cases {
s.Run(tc.input, func() {
ca := autoCompactionCollAttr{}
ca.value = tc.input
err := ca.Valid()
if tc.expectErr {
s.Error(err)
} else {
s.NoError(err)
}
})
}
}
func (s *CollectionAutoCompactionSuite) TestCollectionAutoCompactionEnabled() {
cases := []bool{true, false}
for _, tc := range cases {
s.Run(fmt.Sprintf("%v", tc), func() {
ca := CollectionAutoCompactionEnabled(tc)
key, value := ca.KeyValue()
s.Equal(cakAutoCompaction, key)
s.Equal(strconv.FormatBool(tc), value)
})
}
}
func TestCollectionAutoCompaction(t *testing.T) {
suite.Run(t, new(CollectionAutoCompactionSuite))
}

32
client/entity/common.go Normal file
View File

@ -0,0 +1,32 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package entity
// MetricType metric type
type MetricType string
// Metric Constants
const (
L2 MetricType = "L2"
IP MetricType = "IP"
COSINE MetricType = "COSINE"
HAMMING MetricType = "HAMMING"
JACCARD MetricType = "JACCARD"
TANIMOTO MetricType = "TANIMOTO"
SUBSTRUCTURE MetricType = "SUBSTRUCTURE"
SUPERSTRUCTURE MetricType = "SUPERSTRUCTURE"
)

171
client/entity/field_type.go Normal file
View File

@ -0,0 +1,171 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package entity
// FieldType field data type alias type
// used in go:generate trick, DO NOT modify names & string
type FieldType int32
// Name returns field type name
func (t FieldType) Name() string {
switch t {
case FieldTypeBool:
return "Bool"
case FieldTypeInt8:
return "Int8"
case FieldTypeInt16:
return "Int16"
case FieldTypeInt32:
return "Int32"
case FieldTypeInt64:
return "Int64"
case FieldTypeFloat:
return "Float"
case FieldTypeDouble:
return "Double"
case FieldTypeString:
return "String"
case FieldTypeVarChar:
return "VarChar"
case FieldTypeArray:
return "Array"
case FieldTypeJSON:
return "JSON"
case FieldTypeBinaryVector:
return "BinaryVector"
case FieldTypeFloatVector:
return "FloatVector"
case FieldTypeFloat16Vector:
return "Float16Vector"
case FieldTypeBFloat16Vector:
return "BFloat16Vector"
default:
return "undefined"
}
}
// String returns field type
func (t FieldType) String() string {
switch t {
case FieldTypeBool:
return "bool"
case FieldTypeInt8:
return "int8"
case FieldTypeInt16:
return "int16"
case FieldTypeInt32:
return "int32"
case FieldTypeInt64:
return "int64"
case FieldTypeFloat:
return "float32"
case FieldTypeDouble:
return "float64"
case FieldTypeString:
return "string"
case FieldTypeVarChar:
return "string"
case FieldTypeArray:
return "Array"
case FieldTypeJSON:
return "JSON"
case FieldTypeBinaryVector:
return "[]byte"
case FieldTypeFloatVector:
return "[]float32"
case FieldTypeFloat16Vector:
return "[]byte"
case FieldTypeBFloat16Vector:
return "[]byte"
default:
return "undefined"
}
}
// PbFieldType represents FieldType corresponding schema pb type
func (t FieldType) PbFieldType() (string, string) {
switch t {
case FieldTypeBool:
return "Bool", "bool"
case FieldTypeInt8:
fallthrough
case FieldTypeInt16:
fallthrough
case FieldTypeInt32:
return "Int", "int32"
case FieldTypeInt64:
return "Long", "int64"
case FieldTypeFloat:
return "Float", "float32"
case FieldTypeDouble:
return "Double", "float64"
case FieldTypeString:
return "String", "string"
case FieldTypeVarChar:
return "VarChar", "string"
case FieldTypeJSON:
return "JSON", "JSON"
case FieldTypeBinaryVector:
return "[]byte", ""
case FieldTypeFloatVector:
return "[]float32", ""
case FieldTypeFloat16Vector:
return "[]byte", ""
case FieldTypeBFloat16Vector:
return "[]byte", ""
default:
return "undefined", ""
}
}
// Match schema definition
const (
// FieldTypeNone zero value place holder
FieldTypeNone FieldType = 0 // zero value place holder
// FieldTypeBool field type boolean
FieldTypeBool FieldType = 1
// FieldTypeInt8 field type int8
FieldTypeInt8 FieldType = 2
// FieldTypeInt16 field type int16
FieldTypeInt16 FieldType = 3
// FieldTypeInt32 field type int32
FieldTypeInt32 FieldType = 4
// FieldTypeInt64 field type int64
FieldTypeInt64 FieldType = 5
// FieldTypeFloat field type float
FieldTypeFloat FieldType = 10
// FieldTypeDouble field type double
FieldTypeDouble FieldType = 11
// FieldTypeString field type string
FieldTypeString FieldType = 20
// FieldTypeVarChar field type varchar
FieldTypeVarChar FieldType = 21 // variable-length strings with a specified maximum length
// FieldTypeArray field type Array
FieldTypeArray FieldType = 22
// FieldTypeJSON field type JSON
FieldTypeJSON FieldType = 23
// FieldTypeBinaryVector field type binary vector
FieldTypeBinaryVector FieldType = 100
// FieldTypeFloatVector field type float vector
FieldTypeFloatVector FieldType = 101
// FieldTypeBinaryVector field type float16 vector
FieldTypeFloat16Vector FieldType = 102
// FieldTypeBinaryVector field type bf16 vector
FieldTypeBFloat16Vector FieldType = 103
// FieldTypeBinaryVector field type sparse vector
FieldTypeSparseVector FieldType = 104
)

341
client/entity/schema.go Normal file
View File

@ -0,0 +1,341 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package entity
import (
"strconv"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
const (
// TypeParamDim is the const for field type param dimension
TypeParamDim = "dim"
// TypeParamMaxLength is the const for varchar type maximal length
TypeParamMaxLength = "max_length"
// TypeParamMaxCapacity is the const for array type max capacity
TypeParamMaxCapacity = `max_capacity`
// ClStrong strong consistency level
ClStrong ConsistencyLevel = ConsistencyLevel(commonpb.ConsistencyLevel_Strong)
// ClBounded bounded consistency level with default tolerance of 5 seconds
ClBounded ConsistencyLevel = ConsistencyLevel(commonpb.ConsistencyLevel_Bounded)
// ClSession session consistency level
ClSession ConsistencyLevel = ConsistencyLevel(commonpb.ConsistencyLevel_Session)
// ClEvenually eventually consistency level
ClEventually ConsistencyLevel = ConsistencyLevel(commonpb.ConsistencyLevel_Eventually)
// ClCustomized customized consistency level and users pass their own `guarantee_timestamp`.
ClCustomized ConsistencyLevel = ConsistencyLevel(commonpb.ConsistencyLevel_Customized)
)
// ConsistencyLevel enum type for collection Consistency Level
type ConsistencyLevel commonpb.ConsistencyLevel
// CommonConsistencyLevel returns corresponding commonpb.ConsistencyLevel
func (cl ConsistencyLevel) CommonConsistencyLevel() commonpb.ConsistencyLevel {
return commonpb.ConsistencyLevel(cl)
}
// Schema represents schema info of collection in milvus
type Schema struct {
CollectionName string
Description string
AutoID bool
Fields []*Field
EnableDynamicField bool
}
// NewSchema creates an empty schema object.
func NewSchema() *Schema {
return &Schema{}
}
// WithName sets the name value of schema, returns schema itself.
func (s *Schema) WithName(name string) *Schema {
s.CollectionName = name
return s
}
// WithDescription sets the description value of schema, returns schema itself.
func (s *Schema) WithDescription(desc string) *Schema {
s.Description = desc
return s
}
func (s *Schema) WithAutoID(autoID bool) *Schema {
s.AutoID = autoID
return s
}
func (s *Schema) WithDynamicFieldEnabled(dynamicEnabled bool) *Schema {
s.EnableDynamicField = dynamicEnabled
return s
}
// WithField adds a field into schema and returns schema itself.
func (s *Schema) WithField(f *Field) *Schema {
s.Fields = append(s.Fields, f)
return s
}
// ProtoMessage returns corresponding server.CollectionSchema
func (s *Schema) ProtoMessage() *schemapb.CollectionSchema {
r := &schemapb.CollectionSchema{
Name: s.CollectionName,
Description: s.Description,
AutoID: s.AutoID,
EnableDynamicField: s.EnableDynamicField,
}
r.Fields = make([]*schemapb.FieldSchema, 0, len(s.Fields))
for _, field := range s.Fields {
r.Fields = append(r.Fields, field.ProtoMessage())
}
return r
}
// ReadProto parses proto Collection Schema
func (s *Schema) ReadProto(p *schemapb.CollectionSchema) *Schema {
s.Description = p.GetDescription()
s.CollectionName = p.GetName()
s.Fields = make([]*Field, 0, len(p.GetFields()))
for _, fp := range p.GetFields() {
if fp.GetAutoID() {
s.AutoID = true
}
s.Fields = append(s.Fields, NewField().ReadProto(fp))
}
s.EnableDynamicField = p.GetEnableDynamicField()
return s
}
// PKFieldName returns pk field name for this schemapb.
func (s *Schema) PKFieldName() string {
for _, field := range s.Fields {
if field.PrimaryKey {
return field.Name
}
}
return ""
}
// Field represent field schema in milvus
type Field struct {
ID int64 // field id, generated when collection is created, input value is ignored
Name string // field name
PrimaryKey bool // is primary key
AutoID bool // is auto id
Description string
DataType FieldType
TypeParams map[string]string
IndexParams map[string]string
IsDynamic bool
IsPartitionKey bool
ElementType FieldType
}
// ProtoMessage generates corresponding FieldSchema
func (f *Field) ProtoMessage() *schemapb.FieldSchema {
return &schemapb.FieldSchema{
FieldID: f.ID,
Name: f.Name,
Description: f.Description,
IsPrimaryKey: f.PrimaryKey,
AutoID: f.AutoID,
DataType: schemapb.DataType(f.DataType),
TypeParams: MapKvPairs(f.TypeParams),
IndexParams: MapKvPairs(f.IndexParams),
IsDynamic: f.IsDynamic,
IsPartitionKey: f.IsPartitionKey,
ElementType: schemapb.DataType(f.ElementType),
}
}
// NewField creates a new Field with map initialized.
func NewField() *Field {
return &Field{
TypeParams: make(map[string]string),
IndexParams: make(map[string]string),
}
}
func (f *Field) WithName(name string) *Field {
f.Name = name
return f
}
func (f *Field) WithDescription(desc string) *Field {
f.Description = desc
return f
}
func (f *Field) WithDataType(dataType FieldType) *Field {
f.DataType = dataType
return f
}
func (f *Field) WithIsPrimaryKey(isPrimaryKey bool) *Field {
f.PrimaryKey = isPrimaryKey
return f
}
func (f *Field) WithIsAutoID(isAutoID bool) *Field {
f.AutoID = isAutoID
return f
}
func (f *Field) WithIsDynamic(isDynamic bool) *Field {
f.IsDynamic = isDynamic
return f
}
func (f *Field) WithIsPartitionKey(isPartitionKey bool) *Field {
f.IsPartitionKey = isPartitionKey
return f
}
/*
func (f *Field) WithDefaultValueBool(defaultValue bool) *Field {
f.DefaultValue = &schemapb.ValueField{
Data: &schemapb.ValueField_BoolData{
BoolData: defaultValue,
},
}
return f
}
func (f *Field) WithDefaultValueInt(defaultValue int32) *Field {
f.DefaultValue = &schemapb.ValueField{
Data: &schemapb.ValueField_IntData{
IntData: defaultValue,
},
}
return f
}
func (f *Field) WithDefaultValueLong(defaultValue int64) *Field {
f.DefaultValue = &schemapb.ValueField{
Data: &schemapb.ValueField_LongData{
LongData: defaultValue,
},
}
return f
}
func (f *Field) WithDefaultValueFloat(defaultValue float32) *Field {
f.DefaultValue = &schemapb.ValueField{
Data: &schemapb.ValueField_FloatData{
FloatData: defaultValue,
},
}
return f
}
func (f *Field) WithDefaultValueDouble(defaultValue float64) *Field {
f.DefaultValue = &schemapb.ValueField{
Data: &schemapb.ValueField_DoubleData{
DoubleData: defaultValue,
},
}
return f
}
func (f *Field) WithDefaultValueString(defaultValue string) *Field {
f.DefaultValue = &schemapb.ValueField{
Data: &schemapb.ValueField_StringData{
StringData: defaultValue,
},
}
return f
}*/
func (f *Field) WithTypeParams(key string, value string) *Field {
if f.TypeParams == nil {
f.TypeParams = make(map[string]string)
}
f.TypeParams[key] = value
return f
}
func (f *Field) WithDim(dim int64) *Field {
if f.TypeParams == nil {
f.TypeParams = make(map[string]string)
}
f.TypeParams[TypeParamDim] = strconv.FormatInt(dim, 10)
return f
}
func (f *Field) WithMaxLength(maxLen int64) *Field {
if f.TypeParams == nil {
f.TypeParams = make(map[string]string)
}
f.TypeParams[TypeParamMaxLength] = strconv.FormatInt(maxLen, 10)
return f
}
func (f *Field) WithElementType(eleType FieldType) *Field {
f.ElementType = eleType
return f
}
func (f *Field) WithMaxCapacity(maxCap int64) *Field {
if f.TypeParams == nil {
f.TypeParams = make(map[string]string)
}
f.TypeParams[TypeParamMaxCapacity] = strconv.FormatInt(maxCap, 10)
return f
}
// ReadProto parses FieldSchema
func (f *Field) ReadProto(p *schemapb.FieldSchema) *Field {
f.ID = p.GetFieldID()
f.Name = p.GetName()
f.PrimaryKey = p.GetIsPrimaryKey()
f.AutoID = p.GetAutoID()
f.Description = p.GetDescription()
f.DataType = FieldType(p.GetDataType())
f.TypeParams = KvPairsMap(p.GetTypeParams())
f.IndexParams = KvPairsMap(p.GetIndexParams())
f.IsDynamic = p.GetIsDynamic()
f.IsPartitionKey = p.GetIsPartitionKey()
f.ElementType = FieldType(p.GetElementType())
return f
}
// MapKvPairs converts map into commonpb.KeyValuePair slice
func MapKvPairs(m map[string]string) []*commonpb.KeyValuePair {
pairs := make([]*commonpb.KeyValuePair, 0, len(m))
for k, v := range m {
pairs = append(pairs, &commonpb.KeyValuePair{
Key: k,
Value: v,
})
}
return pairs
}
// KvPairsMap converts commonpb.KeyValuePair slices into map
func KvPairsMap(kvps []*commonpb.KeyValuePair) map[string]string {
m := make(map[string]string)
for _, kvp := range kvps {
m[kvp.Key] = kvp.Value
}
return m
}

View File

@ -0,0 +1,138 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package entity
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
)
func TestCL_CommonCL(t *testing.T) {
cls := []ConsistencyLevel{
ClStrong,
ClBounded,
ClSession,
ClEventually,
}
for _, cl := range cls {
assert.EqualValues(t, commonpb.ConsistencyLevel(cl), cl.CommonConsistencyLevel())
}
}
func TestFieldSchema(t *testing.T) {
fields := []*Field{
NewField().WithName("int_field").WithDataType(FieldTypeInt64).WithIsAutoID(true).WithIsPrimaryKey(true).WithDescription("int_field desc"),
NewField().WithName("string_field").WithDataType(FieldTypeString).WithIsAutoID(false).WithIsPrimaryKey(true).WithIsDynamic(false).WithTypeParams("max_len", "32").WithDescription("string_field desc"),
NewField().WithName("partition_key").WithDataType(FieldTypeInt32).WithIsPartitionKey(true),
NewField().WithName("array_field").WithDataType(FieldTypeArray).WithElementType(FieldTypeBool).WithMaxCapacity(128),
/*
NewField().WithName("default_value_bool").WithDataType(FieldTypeBool).WithDefaultValueBool(true),
NewField().WithName("default_value_int").WithDataType(FieldTypeInt32).WithDefaultValueInt(1),
NewField().WithName("default_value_long").WithDataType(FieldTypeInt64).WithDefaultValueLong(1),
NewField().WithName("default_value_float").WithDataType(FieldTypeFloat).WithDefaultValueFloat(1),
NewField().WithName("default_value_double").WithDataType(FieldTypeDouble).WithDefaultValueDouble(1),
NewField().WithName("default_value_string").WithDataType(FieldTypeString).WithDefaultValueString("a"),*/
}
for _, field := range fields {
fieldSchema := field.ProtoMessage()
assert.Equal(t, field.ID, fieldSchema.GetFieldID())
assert.Equal(t, field.Name, fieldSchema.GetName())
assert.EqualValues(t, field.DataType, fieldSchema.GetDataType())
assert.Equal(t, field.AutoID, fieldSchema.GetAutoID())
assert.Equal(t, field.PrimaryKey, fieldSchema.GetIsPrimaryKey())
assert.Equal(t, field.IsPartitionKey, fieldSchema.GetIsPartitionKey())
assert.Equal(t, field.IsDynamic, fieldSchema.GetIsDynamic())
assert.Equal(t, field.Description, fieldSchema.GetDescription())
assert.Equal(t, field.TypeParams, KvPairsMap(fieldSchema.GetTypeParams()))
assert.EqualValues(t, field.ElementType, fieldSchema.GetElementType())
// marshal & unmarshal, still equals
nf := &Field{}
nf = nf.ReadProto(fieldSchema)
assert.Equal(t, field.ID, nf.ID)
assert.Equal(t, field.Name, nf.Name)
assert.EqualValues(t, field.DataType, nf.DataType)
assert.Equal(t, field.AutoID, nf.AutoID)
assert.Equal(t, field.PrimaryKey, nf.PrimaryKey)
assert.Equal(t, field.Description, nf.Description)
assert.Equal(t, field.IsDynamic, nf.IsDynamic)
assert.Equal(t, field.IsPartitionKey, nf.IsPartitionKey)
assert.EqualValues(t, field.TypeParams, nf.TypeParams)
assert.EqualValues(t, field.ElementType, nf.ElementType)
}
assert.NotPanics(t, func() {
(&Field{}).WithTypeParams("a", "b")
})
}
type SchemaSuite struct {
suite.Suite
}
func (s *SchemaSuite) TestBasic() {
cases := []struct {
tag string
input *Schema
pkName string
}{
{
"test_collection",
NewSchema().WithName("test_collection_1").WithDescription("test_collection_1 desc").WithAutoID(false).
WithField(NewField().WithName("ID").WithDataType(FieldTypeInt64).WithIsPrimaryKey(true)).
WithField(NewField().WithName("vector").WithDataType(FieldTypeFloatVector).WithDim(128)),
"ID",
},
{
"dynamic_schema",
NewSchema().WithName("dynamic_schema").WithDescription("dynamic_schema desc").WithAutoID(true).WithDynamicFieldEnabled(true).
WithField(NewField().WithName("ID").WithDataType(FieldTypeVarChar).WithMaxLength(256)).
WithField(NewField().WithName("$meta").WithIsDynamic(true)),
"",
},
}
for _, c := range cases {
s.Run(c.tag, func() {
sch := c.input
p := sch.ProtoMessage()
s.Equal(sch.CollectionName, p.GetName())
s.Equal(sch.AutoID, p.GetAutoID())
s.Equal(sch.Description, p.GetDescription())
s.Equal(sch.EnableDynamicField, p.GetEnableDynamicField())
s.Equal(len(sch.Fields), len(p.GetFields()))
nsch := &Schema{}
nsch = nsch.ReadProto(p)
s.Equal(sch.CollectionName, nsch.CollectionName)
s.Equal(sch.Description, nsch.Description)
s.Equal(sch.EnableDynamicField, nsch.EnableDynamicField)
s.Equal(len(sch.Fields), len(nsch.Fields))
s.Equal(c.pkName, sch.PKFieldName())
s.Equal(c.pkName, nsch.PKFieldName())
})
}
}
func TestSchema(t *testing.T) {
suite.Run(t, new(SchemaSuite))
}

124
client/entity/sparse.go Normal file
View File

@ -0,0 +1,124 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package entity
import (
"encoding/binary"
"math"
"sort"
"github.com/cockroachdb/errors"
)
type SparseEmbedding interface {
Dim() int // the dimension
Len() int // the actual items in this vector
Get(idx int) (pos uint32, value float32, ok bool)
Serialize() []byte
}
var (
_ SparseEmbedding = sliceSparseEmbedding{}
_ Vector = sliceSparseEmbedding{}
)
type sliceSparseEmbedding struct {
positions []uint32
values []float32
dim int
len int
}
func (e sliceSparseEmbedding) Dim() int {
return e.dim
}
func (e sliceSparseEmbedding) Len() int {
return e.len
}
func (e sliceSparseEmbedding) FieldType() FieldType {
return FieldTypeSparseVector
}
func (e sliceSparseEmbedding) Get(idx int) (uint32, float32, bool) {
if idx < 0 || idx >= int(e.len) {
return 0, 0, false
}
return e.positions[idx], e.values[idx], true
}
func (e sliceSparseEmbedding) Serialize() []byte {
row := make([]byte, 8*e.Len())
for idx := 0; idx < e.Len(); idx++ {
pos, value, _ := e.Get(idx)
binary.LittleEndian.PutUint32(row[idx*8:], pos)
binary.LittleEndian.PutUint32(row[pos*8+4:], math.Float32bits(value))
}
return row
}
// Less implements sort.Interce
func (e sliceSparseEmbedding) Less(i, j int) bool {
return e.positions[i] < e.positions[j]
}
func (e sliceSparseEmbedding) Swap(i, j int) {
e.positions[i], e.positions[j] = e.positions[j], e.positions[i]
e.values[i], e.values[j] = e.values[j], e.values[i]
}
func deserializeSliceSparceEmbedding(bs []byte) (sliceSparseEmbedding, error) {
length := len(bs)
if length%8 != 0 {
return sliceSparseEmbedding{}, errors.New("not valid sparse embedding bytes")
}
length = length / 8
result := sliceSparseEmbedding{
positions: make([]uint32, length),
values: make([]float32, length),
len: length,
}
for i := 0; i < length; i++ {
result.positions[i] = binary.LittleEndian.Uint32(bs[i*8 : i*8+4])
result.values[i] = math.Float32frombits(binary.LittleEndian.Uint32(bs[i*8+4 : i*8+8]))
}
return result, nil
}
func NewSliceSparseEmbedding(positions []uint32, values []float32) (SparseEmbedding, error) {
if len(positions) != len(values) {
return nil, errors.New("invalid sparse embedding input, positions shall have same number of values")
}
se := sliceSparseEmbedding{
positions: positions,
values: values,
len: len(positions),
}
sort.Sort(se)
if se.len > 0 {
se.dim = int(se.positions[se.len-1]) + 1
}
return se, nil
}

View File

@ -0,0 +1,68 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package entity
import (
"math/rand"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSliceSparseEmbedding(t *testing.T) {
t.Run("normal_case", func(t *testing.T) {
length := 1 + rand.Intn(5)
positions := make([]uint32, length)
values := make([]float32, length)
for i := 0; i < length; i++ {
positions[i] = uint32(i)
values[i] = rand.Float32()
}
se, err := NewSliceSparseEmbedding(positions, values)
require.NoError(t, err)
assert.EqualValues(t, length, se.Dim())
assert.EqualValues(t, length, se.Len())
bs := se.Serialize()
nv, err := deserializeSliceSparceEmbedding(bs)
require.NoError(t, err)
for i := 0; i < length; i++ {
pos, val, ok := se.Get(i)
require.True(t, ok)
assert.Equal(t, positions[i], pos)
assert.Equal(t, values[i], val)
npos, nval, ok := nv.Get(i)
require.True(t, ok)
assert.Equal(t, positions[i], npos)
assert.Equal(t, values[i], nval)
}
_, _, ok := se.Get(-1)
assert.False(t, ok)
_, _, ok = se.Get(length)
assert.False(t, ok)
})
t.Run("position values not match", func(t *testing.T) {
_, err := NewSliceSparseEmbedding([]uint32{1}, []float32{})
assert.Error(t, err)
})
}

106
client/entity/vectors.go Normal file
View File

@ -0,0 +1,106 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package entity
import (
"encoding/binary"
"math"
)
// Vector interface vector used int search
type Vector interface {
Dim() int
Serialize() []byte
FieldType() FieldType
}
// FloatVector float32 vector wrapper.
type FloatVector []float32
// Dim returns vector dimension.
func (fv FloatVector) Dim() int {
return len(fv)
}
// entity.FieldType returns coresponding field type.
func (fv FloatVector) FieldType() FieldType {
return FieldTypeFloatVector
}
// Serialize serializes vector into byte slice, used in search placeholder
// LittleEndian is used for convention
func (fv FloatVector) Serialize() []byte {
data := make([]byte, 0, 4*len(fv)) // float32 occupies 4 bytes
buf := make([]byte, 4)
for _, f := range fv {
binary.LittleEndian.PutUint32(buf, math.Float32bits(f))
data = append(data, buf...)
}
return data
}
// FloatVector float32 vector wrapper.
type Float16Vector []byte
// Dim returns vector dimension.
func (fv Float16Vector) Dim() int {
return len(fv) / 2
}
// entity.FieldType returns coresponding field type.
func (fv Float16Vector) FieldType() FieldType {
return FieldTypeFloat16Vector
}
func (fv Float16Vector) Serialize() []byte {
return fv
}
// FloatVector float32 vector wrapper.
type BFloat16Vector []byte
// Dim returns vector dimension.
func (fv BFloat16Vector) Dim() int {
return len(fv) / 2
}
// entity.FieldType returns coresponding field type.
func (fv BFloat16Vector) FieldType() FieldType {
return FieldTypeBFloat16Vector
}
func (fv BFloat16Vector) Serialize() []byte {
return fv
}
// BinaryVector []byte vector wrapper
type BinaryVector []byte
// Dim return vector dimension, note that binary vector is bits count
func (bv BinaryVector) Dim() int {
return 8 * len(bv)
}
// Serialize just return bytes
func (bv BinaryVector) Serialize() []byte {
return bv
}
// entity.FieldType returns coresponding field type.
func (bv BinaryVector) FieldType() FieldType {
return FieldTypeBinaryVector
}

View File

@ -0,0 +1,51 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package entity
import (
"math/rand"
"testing"
"github.com/stretchr/testify/assert"
)
func TestVectors(t *testing.T) {
dim := rand.Intn(127) + 1
t.Run("test float vector", func(t *testing.T) {
raw := make([]float32, dim)
for i := 0; i < dim; i++ {
raw[i] = rand.Float32()
}
fv := FloatVector(raw)
assert.Equal(t, dim, fv.Dim())
assert.Equal(t, dim*4, len(fv.Serialize()))
})
t.Run("test binary vector", func(t *testing.T) {
raw := make([]byte, dim)
_, err := rand.Read(raw)
assert.Nil(t, err)
bv := BinaryVector(raw)
assert.Equal(t, dim*8, bv.Dim())
assert.ElementsMatch(t, raw, bv.Serialize())
})
}

125
client/go.mod Normal file
View File

@ -0,0 +1,125 @@
module github.com/milvus-io/milvus/client/v2
go 1.21.8
require (
github.com/blang/semver/v4 v4.0.0
github.com/cockroachdb/errors v1.9.1
github.com/gogo/status v1.1.0
github.com/golang/protobuf v1.5.3
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240228061649-a922b16f2a46
github.com/milvus-io/milvus/pkg v0.0.2-0.20240317152703-17b4938985f3
github.com/samber/lo v1.27.0
github.com/stretchr/testify v1.8.4
github.com/tidwall/gjson v1.17.1
go.uber.org/atomic v1.10.0
google.golang.org/grpc v1.54.0
)
require (
github.com/benbjohnson/clock v1.1.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cenkalti/backoff/v4 v4.2.0 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/cilium/ebpf v0.11.0 // indirect
github.com/cockroachdb/logtags v0.0.0-20211118104740-dabe8e521a4f // indirect
github.com/cockroachdb/redact v1.1.3 // indirect
github.com/containerd/cgroups/v3 v3.0.3 // indirect
github.com/coreos/go-semver v0.3.0 // indirect
github.com/coreos/go-systemd/v22 v22.3.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/docker/go-units v0.4.0 // indirect
github.com/dustin/go-humanize v1.0.0 // indirect
github.com/form3tech-oss/jwt-go v3.2.3+incompatible // indirect
github.com/fsnotify/fsnotify v1.4.9 // indirect
github.com/getsentry/sentry-go v0.12.0 // indirect
github.com/go-logr/logr v1.3.0 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/godbus/dbus/v5 v5.0.4 // indirect
github.com/gogo/googleapis v1.4.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/btree v1.1.2 // indirect
github.com/gorilla/websocket v1.4.2 // indirect
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect
github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/jonboulle/clockwork v0.2.2 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/magiconair/properties v1.8.5 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/mitchellh/mapstructure v1.4.1 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/opencontainers/runtime-spec v1.0.2 // indirect
github.com/panjf2000/ants/v2 v2.7.2 // indirect
github.com/pelletier/go-toml v1.9.3 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/prometheus/client_golang v1.14.0 // indirect
github.com/prometheus/client_model v0.3.0 // indirect
github.com/prometheus/common v0.42.0 // indirect
github.com/prometheus/procfs v0.9.0 // indirect
github.com/rogpeppe/go-internal v1.10.0 // indirect
github.com/shirou/gopsutil/v3 v3.22.9 // indirect
github.com/sirupsen/logrus v1.9.0 // indirect
github.com/soheilhy/cmux v0.1.5 // indirect
github.com/spaolacci/murmur3 v1.1.0 // indirect
github.com/spf13/afero v1.6.0 // indirect
github.com/spf13/cast v1.3.1 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/spf13/viper v1.8.1 // indirect
github.com/stretchr/objx v0.5.0 // indirect
github.com/subosito/gotenv v1.2.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tklauser/go-sysconf v0.3.10 // indirect
github.com/tklauser/numcpus v0.4.0 // indirect
github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802 // indirect
github.com/uber/jaeger-client-go v2.30.0+incompatible // indirect
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 // indirect
github.com/yusufpapurcu/wmi v1.2.2 // indirect
go.etcd.io/bbolt v1.3.6 // indirect
go.etcd.io/etcd/api/v3 v3.5.5 // indirect
go.etcd.io/etcd/client/pkg/v3 v3.5.5 // indirect
go.etcd.io/etcd/client/v2 v2.305.5 // indirect
go.etcd.io/etcd/client/v3 v3.5.5 // indirect
go.etcd.io/etcd/pkg/v3 v3.5.5 // indirect
go.etcd.io/etcd/raft/v3 v3.5.5 // indirect
go.etcd.io/etcd/server/v3 v3.5.5 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.38.0 // indirect
go.opentelemetry.io/otel v1.13.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/internal/retry v1.13.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.13.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.13.0 // indirect
go.opentelemetry.io/otel/metric v0.35.0 // indirect
go.opentelemetry.io/otel/sdk v1.13.0 // indirect
go.opentelemetry.io/otel/trace v1.13.0 // indirect
go.opentelemetry.io/proto/otlp v0.19.0 // indirect
go.uber.org/automaxprocs v1.5.2 // indirect
go.uber.org/multierr v1.7.0 // indirect
go.uber.org/zap v1.20.0 // indirect
golang.org/x/crypto v0.16.0 // indirect
golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2 // indirect
golang.org/x/net v0.19.0 // indirect
golang.org/x/sync v0.1.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/time v0.3.0 // indirect
google.golang.org/genproto v0.0.0-20230331144136-dcfb400f0633 // indirect
google.golang.org/protobuf v1.31.0 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/ini.v1 v1.62.0 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
k8s.io/apimachinery v0.28.6 // indirect
sigs.k8s.io/yaml v1.3.0 // indirect
)

1121
client/go.sum Normal file

File diff suppressed because it is too large Load Diff

159
client/index.go Normal file
View File

@ -0,0 +1,159 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"context"
"fmt"
"time"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/client/v2/index"
"github.com/milvus-io/milvus/pkg/util/merr"
"google.golang.org/grpc"
)
type CreateIndexTask struct {
client *Client
collectionName string
fieldName string
indexName string
interval time.Duration
}
func (t *CreateIndexTask) Await(ctx context.Context) error {
ticker := time.NewTicker(t.interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
finished := false
err := t.client.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.DescribeIndex(ctx, &milvuspb.DescribeIndexRequest{
CollectionName: t.collectionName,
FieldName: t.fieldName,
IndexName: t.indexName,
})
err = merr.CheckRPCCall(resp, err)
if err != nil {
return err
}
for _, info := range resp.GetIndexDescriptions() {
if (t.indexName == "" && info.GetFieldName() == t.fieldName) || t.indexName == info.GetIndexName() {
switch info.GetState() {
case commonpb.IndexState_Finished:
finished = true
return nil
case commonpb.IndexState_Failed:
return fmt.Errorf("create index failed, reason: %s", info.GetIndexStateFailReason())
}
}
}
return nil
})
if err != nil {
return err
}
if finished {
return nil
}
ticker.Reset(t.interval)
case <-ctx.Done():
return ctx.Err()
}
}
}
func (c *Client) CreateIndex(ctx context.Context, option CreateIndexOption, callOptions ...grpc.CallOption) (*CreateIndexTask, error) {
req := option.Request()
var task *CreateIndexTask
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.CreateIndex(ctx, req, callOptions...)
if err = merr.CheckRPCCall(resp, err); err != nil {
return err
}
task = &CreateIndexTask{
client: c,
collectionName: req.GetCollectionName(),
fieldName: req.GetFieldName(),
indexName: req.GetIndexName(),
interval: time.Millisecond * 100,
}
return nil
})
return task, err
}
func (c *Client) ListIndexes(ctx context.Context, opt ListIndexOption, callOptions ...grpc.CallOption) ([]string, error) {
req := opt.Request()
var indexes []string
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.DescribeIndex(ctx, req, callOptions...)
if err = merr.CheckRPCCall(resp, err); err != nil {
return err
}
for _, idxDef := range resp.GetIndexDescriptions() {
if opt.Matches(idxDef) {
indexes = append(indexes, idxDef.GetIndexName())
}
}
return nil
})
return indexes, err
}
func (c *Client) DescribeIndex(ctx context.Context, opt DescribeIndexOption, callOptions ...grpc.CallOption) (index.Index, error) {
req := opt.Request()
var idx index.Index
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.DescribeIndex(ctx, req, callOptions...)
if err = merr.CheckRPCCall(resp, err); err != nil {
return err
}
if len(resp.GetIndexDescriptions()) == 0 {
return merr.WrapErrIndexNotFound(req.GetIndexName())
}
for _, idxDef := range resp.GetIndexDescriptions() {
if idxDef.GetIndexName() == req.GetIndexName() {
idx = index.NewGenericIndex(idxDef.GetIndexName(), entity.KvPairsMap(idxDef.GetParams()))
}
}
return nil
})
return idx, err
}
func (c *Client) DropIndex(ctx context.Context, opt DropIndexOption, callOptions ...grpc.CallOption) error {
req := opt.Request()
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.DropIndex(ctx, req, callOptions...)
return merr.CheckRPCCall(resp, err)
})
}

61
client/index/common.go Normal file
View File

@ -0,0 +1,61 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package index
import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/client/v2/entity"
)
// index param field tag
const (
IndexTypeKey = `index_type`
MetricTypeKey = `metric_type`
ParamsKey = `params`
)
// IndexState export index state
type IndexState commonpb.IndexState
// IndexType index type
type IndexType string
// MetricType alias for `entity.MetricsType`.
type MetricType = entity.MetricType
// Index Constants
const (
Flat IndexType = "FLAT" // faiss
BinFlat IndexType = "BIN_FLAT"
IvfFlat IndexType = "IVF_FLAT" // faiss
BinIvfFlat IndexType = "BIN_IVF_FLAT"
IvfPQ IndexType = "IVF_PQ" // faiss
IvfSQ8 IndexType = "IVF_SQ8"
HNSW IndexType = "HNSW"
IvfHNSW IndexType = "IVF_HNSW"
AUTOINDEX IndexType = "AUTOINDEX"
DISKANN IndexType = "DISKANN"
SCANN IndexType = "SCANN"
GPUIvfFlat IndexType = "GPU_IVF_FLAT"
GPUIvfPQ IndexType = "GPU_IVF_PQ"
GPUCagra IndexType = "GPU_CAGRA"
GPUBruteForce IndexType = "GPU_BRUTE_FORCE"
Scalar IndexType = "SCALAR"
)

38
client/index/disk_ann.go Normal file
View File

@ -0,0 +1,38 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package index
var _ Index = diskANNIndex{}
type diskANNIndex struct {
baseIndex
}
func (idx diskANNIndex) Params() map[string]string {
return map[string]string{
MetricTypeKey: string(idx.metricType),
IndexTypeKey: string(DISKANN),
}
}
func NewDiskANNIndex(metricType MetricType) Index {
return &diskANNIndex{
baseIndex: baseIndex{
metricType: metricType,
},
}
}

38
client/index/flat.go Normal file
View File

@ -0,0 +1,38 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package index
var _ Index = flatIndex{}
type flatIndex struct {
baseIndex
}
func (idx flatIndex) Params() map[string]string {
return map[string]string{
MetricTypeKey: string(idx.metricType),
IndexTypeKey: string(Flat),
}
}
func NewFlatIndex(metricType MetricType) Index {
return flatIndex{
baseIndex: baseIndex{
metricType: metricType,
},
}
}

53
client/index/hnsw.go Normal file
View File

@ -0,0 +1,53 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package index
import "strconv"
const (
hnswMKey = `M`
hsnwEfConstruction = `efConstruction`
)
var _ Index = hnswIndex{}
type hnswIndex struct {
baseIndex
m int
efConstruction int // exploratory factor when building index
}
func (idx hnswIndex) Params() map[string]string {
return map[string]string{
MetricTypeKey: string(idx.metricType),
IndexTypeKey: string(HNSW),
hnswMKey: strconv.Itoa(idx.m),
hsnwEfConstruction: strconv.Itoa(idx.efConstruction),
}
}
func NewHNSWIndex(metricType MetricType, M int, efConstruction int) Index {
return hnswIndex{
baseIndex: baseIndex{
metricType: metricType,
indexType: HNSW,
},
m: M,
efConstruction: efConstruction,
}
}

79
client/index/index.go Normal file
View File

@ -0,0 +1,79 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package index
import "encoding/json"
// Index represent index definition in milvus.
type Index interface {
Name() string
IndexType() IndexType
Params() map[string]string
}
type baseIndex struct {
name string
metricType MetricType
indexType IndexType
params map[string]string
}
func (idx baseIndex) Name() string {
return idx.name
}
func (idx baseIndex) IndexType() IndexType {
return idx.indexType
}
func (idx baseIndex) Params() map[string]string {
return idx.params
}
func (idx baseIndex) getExtraParams(params map[string]any) string {
bs, _ := json.Marshal(params)
return string(bs)
}
var _ Index = GenericIndex{}
type GenericIndex struct {
baseIndex
params map[string]string
}
// Params implements Index
func (gi GenericIndex) Params() map[string]string {
m := make(map[string]string)
if gi.baseIndex.indexType != "" {
m[IndexTypeKey] = string(gi.IndexType())
}
for k, v := range gi.params {
m[k] = v
}
return m
}
// NewGenericIndex create generic index instance
func NewGenericIndex(name string, params map[string]string) Index {
return GenericIndex{
baseIndex: baseIndex{
name: name,
},
params: params,
}
}

View File

@ -0,0 +1,17 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package index

108
client/index/ivf.go Normal file
View File

@ -0,0 +1,108 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package index
import "strconv"
const (
ivfNlistKey = `nlist`
ivfPQMKey = `m`
ivfPQNbits = `nbits`
)
var _ Index = ivfFlatIndex{}
type ivfFlatIndex struct {
baseIndex
nlist int
}
func (idx ivfFlatIndex) Params() map[string]string {
return map[string]string{
MetricTypeKey: string(idx.metricType),
IndexTypeKey: string(IvfFlat),
ivfNlistKey: strconv.Itoa(idx.nlist),
}
}
func NewIvfFlatIndex(metricType MetricType, nlist int) Index {
return ivfFlatIndex{
baseIndex: baseIndex{
metricType: metricType,
indexType: IvfFlat,
},
nlist: nlist,
}
}
type ivfPQIndex struct {
baseIndex
nlist int
m int
nbits int
}
func (idx ivfPQIndex) Params() map[string]string {
return map[string]string{
MetricTypeKey: string(idx.metricType),
IndexTypeKey: string(IvfPQ),
ivfNlistKey: strconv.Itoa(idx.nlist),
ivfPQMKey: strconv.Itoa(idx.m),
ivfPQNbits: strconv.Itoa(idx.nbits),
}
}
func NewIvfPQIndex(metricType MetricType, nlist int, m int, nbits int) Index {
return ivfPQIndex{
baseIndex: baseIndex{
metricType: metricType,
indexType: IvfPQ,
},
nlist: nlist,
m: m,
nbits: nbits,
}
}
type ivfSQ8Index struct {
baseIndex
nlist int
}
func (idx ivfSQ8Index) Params() map[string]string {
return map[string]string{
MetricTypeKey: string(idx.metricType),
IndexTypeKey: string(IvfSQ8),
ivfNlistKey: strconv.Itoa(idx.nlist),
}
}
func NewIvfSQ8Index(metricType MetricType, nlist int) Index {
return ivfPQIndex{
baseIndex: baseIndex{
metricType: metricType,
indexType: IvfSQ8,
},
nlist: nlist,
}
}

50
client/index/scann.go Normal file
View File

@ -0,0 +1,50 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package index
import "strconv"
const (
scannNlistKey = `nlist`
scannWithRawDataKey = `with_raw_data`
)
type scannIndex struct {
baseIndex
nlist int
withRawData bool
}
func (idx scannIndex) Params() map[string]string {
return map[string]string{
MetricTypeKey: string(idx.metricType),
IndexTypeKey: string(IvfFlat),
ivfNlistKey: strconv.Itoa(idx.nlist),
}
}
func NewSCANNIndex(metricType MetricType, nlist int) Index {
return ivfFlatIndex{
baseIndex: baseIndex{
metricType: metricType,
indexType: IvfFlat,
},
nlist: nlist,
}
}

137
client/index_options.go Normal file
View File

@ -0,0 +1,137 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/client/v2/index"
)
type CreateIndexOption interface {
Request() *milvuspb.CreateIndexRequest
}
type createIndexOption struct {
collectionName string
fieldName string
indexName string
indexDef index.Index
}
func (opt *createIndexOption) Request() *milvuspb.CreateIndexRequest {
return &milvuspb.CreateIndexRequest{
CollectionName: opt.collectionName,
FieldName: opt.fieldName,
IndexName: opt.indexName,
ExtraParams: entity.MapKvPairs(opt.indexDef.Params()),
}
}
func (opt *createIndexOption) WithIndexName(indexName string) *createIndexOption {
opt.indexName = indexName
return opt
}
func NewCreateIndexOption(collectionName string, fieldName string, index index.Index) *createIndexOption {
return &createIndexOption{
collectionName: collectionName,
fieldName: fieldName,
indexDef: index,
}
}
type ListIndexOption interface {
Request() *milvuspb.DescribeIndexRequest
Matches(*milvuspb.IndexDescription) bool
}
var _ ListIndexOption = (*listIndexOption)(nil)
type listIndexOption struct {
collectionName string
fieldName string
}
func (opt *listIndexOption) WithFieldName(fieldName string) *listIndexOption {
opt.fieldName = fieldName
return opt
}
func (opt *listIndexOption) Matches(idxDef *milvuspb.IndexDescription) bool {
return opt.fieldName == "" || idxDef.GetFieldName() == opt.fieldName
}
func (opt *listIndexOption) Request() *milvuspb.DescribeIndexRequest {
return &milvuspb.DescribeIndexRequest{
CollectionName: opt.collectionName,
FieldName: opt.fieldName,
}
}
func NewListIndexOption(collectionName string) *listIndexOption {
return &listIndexOption{
collectionName: collectionName,
}
}
type DescribeIndexOption interface {
Request() *milvuspb.DescribeIndexRequest
}
type describeIndexOption struct {
collectionName string
fieldName string
indexName string
}
func (opt *describeIndexOption) Request() *milvuspb.DescribeIndexRequest {
return &milvuspb.DescribeIndexRequest{
CollectionName: opt.collectionName,
IndexName: opt.indexName,
}
}
func NewDescribeIndexOption(collectionName string, indexName string) *describeIndexOption {
return &describeIndexOption{
collectionName: collectionName,
indexName: indexName,
}
}
type DropIndexOption interface {
Request() *milvuspb.DropIndexRequest
}
type dropIndexOption struct {
collectionName string
indexName string
}
func (opt *dropIndexOption) Request() *milvuspb.DropIndexRequest {
return &milvuspb.DropIndexRequest{
CollectionName: opt.collectionName,
IndexName: opt.indexName,
}
}
func NewDropIndexOption(collectionName string, indexName string) *dropIndexOption {
return &dropIndexOption{
collectionName: collectionName,
indexName: indexName,
}
}

221
client/index_test.go Normal file
View File

@ -0,0 +1,221 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"context"
"fmt"
"testing"
"time"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/client/v2/index"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"go.uber.org/atomic"
)
type IndexSuite struct {
MockSuiteBase
}
func (s *IndexSuite) TestCreateIndex() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
fieldName := fmt.Sprintf("field_%s", s.randString(4))
indexName := fmt.Sprintf("idx_%s", s.randString(6))
done := atomic.NewBool(false)
s.mock.EXPECT().CreateIndex(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cir *milvuspb.CreateIndexRequest) (*commonpb.Status, error) {
s.Equal(collectionName, cir.GetCollectionName())
s.Equal(fieldName, cir.GetFieldName())
s.Equal(indexName, cir.GetIndexName())
return merr.Success(), nil
}).Once()
s.mock.EXPECT().DescribeIndex(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, dir *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) {
state := commonpb.IndexState_InProgress
if done.Load() {
state = commonpb.IndexState_Finished
}
return &milvuspb.DescribeIndexResponse{
Status: merr.Success(),
IndexDescriptions: []*milvuspb.IndexDescription{
{
FieldName: fieldName,
IndexName: indexName,
State: state,
},
},
}, nil
})
defer s.mock.EXPECT().DescribeIndex(mock.Anything, mock.Anything).Unset()
task, err := s.client.CreateIndex(ctx, NewCreateIndexOption(collectionName, fieldName, index.NewHNSWIndex(entity.L2, 32, 128)).WithIndexName(indexName))
s.NoError(err)
ch := make(chan struct{})
go func() {
defer close(ch)
err := task.Await(ctx)
s.NoError(err)
}()
select {
case <-ch:
s.FailNow("task done before index state set to finish")
case <-time.After(time.Second):
}
done.Store(true)
select {
case <-ch:
case <-time.After(time.Second):
s.FailNow("task not done after index set finished")
}
})
s.Run("failure", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
fieldName := fmt.Sprintf("field_%s", s.randString(4))
indexName := fmt.Sprintf("idx_%s", s.randString(6))
s.mock.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
_, err := s.client.CreateIndex(ctx, NewCreateIndexOption(collectionName, fieldName, index.NewHNSWIndex(entity.L2, 32, 128)).WithIndexName(indexName))
s.Error(err)
})
}
func (s *IndexSuite) TestListIndexes() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
s.mock.EXPECT().DescribeIndex(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, dir *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) {
s.Equal(collectionName, dir.GetCollectionName())
return &milvuspb.DescribeIndexResponse{
Status: merr.Success(),
IndexDescriptions: []*milvuspb.IndexDescription{
{IndexName: "test_idx"},
},
}, nil
}).Once()
names, err := s.client.ListIndexes(ctx, NewListIndexOption(collectionName))
s.NoError(err)
s.ElementsMatch([]string{"test_idx"}, names)
})
s.Run("failure", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
s.mock.EXPECT().DescribeIndex(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
_, err := s.client.ListIndexes(ctx, NewListIndexOption(collectionName))
s.Error(err)
})
}
func (s *IndexSuite) TestDescribeIndex() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
indexName := fmt.Sprintf("idx_%s", s.randString(6))
s.mock.EXPECT().DescribeIndex(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, dir *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) {
s.Equal(collectionName, dir.GetCollectionName())
s.Equal(indexName, dir.GetIndexName())
return &milvuspb.DescribeIndexResponse{
Status: merr.Success(),
IndexDescriptions: []*milvuspb.IndexDescription{
{IndexName: indexName, Params: []*commonpb.KeyValuePair{
{Key: index.IndexTypeKey, Value: string(index.HNSW)},
}},
},
}, nil
}).Once()
index, err := s.client.DescribeIndex(ctx, NewDescribeIndexOption(collectionName, indexName))
s.NoError(err)
s.Equal(indexName, index.Name())
})
s.Run("no_index_found", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
indexName := fmt.Sprintf("idx_%s", s.randString(6))
s.mock.EXPECT().DescribeIndex(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, dir *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) {
s.Equal(collectionName, dir.GetCollectionName())
s.Equal(indexName, dir.GetIndexName())
return &milvuspb.DescribeIndexResponse{
Status: merr.Success(),
IndexDescriptions: []*milvuspb.IndexDescription{},
}, nil
}).Once()
_, err := s.client.DescribeIndex(ctx, NewDescribeIndexOption(collectionName, indexName))
s.Error(err)
s.ErrorIs(err, merr.ErrIndexNotFound)
})
s.Run("failure", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
indexName := fmt.Sprintf("idx_%s", s.randString(6))
s.mock.EXPECT().DescribeIndex(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
_, err := s.client.DescribeIndex(ctx, NewDescribeIndexOption(collectionName, indexName))
s.Error(err)
})
}
func (s *IndexSuite) TestDropIndexOption() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
indexName := fmt.Sprintf("idx_%s", s.randString(6))
opt := NewDropIndexOption(collectionName, indexName)
req := opt.Request()
s.Equal(collectionName, req.GetCollectionName())
s.Equal(indexName, req.GetIndexName())
}
func (s *IndexSuite) TestDropIndex() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
s.mock.EXPECT().DropIndex(mock.Anything, mock.Anything).Return(merr.Success(), nil).Once()
err := s.client.DropIndex(ctx, NewDropIndexOption("testCollection", "testIndex"))
s.NoError(err)
})
s.Run("failure", func() {
s.mock.EXPECT().DropIndex(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
err := s.client.DropIndex(ctx, NewDropIndexOption("testCollection", "testIndex"))
s.Error(err)
})
}
func TestIndex(t *testing.T) {
suite.Run(t, new(IndexSuite))
}

171
client/maintenance.go Normal file
View File

@ -0,0 +1,171 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"context"
"time"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/pkg/util/merr"
)
type LoadTask struct {
client *Client
collectionName string
partitionNames []string
interval time.Duration
}
func (t *LoadTask) Await(ctx context.Context) error {
ticker := time.NewTicker(t.interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
loaded := false
t.client.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
CollectionName: t.collectionName,
PartitionNames: t.partitionNames,
})
if err = merr.CheckRPCCall(resp, err); err != nil {
return err
}
loaded = resp.GetProgress() == 100
return nil
})
if loaded {
return nil
}
ticker.Reset(t.interval)
case <-ctx.Done():
return ctx.Err()
}
}
}
func (c *Client) LoadCollection(ctx context.Context, option LoadCollectionOption, callOptions ...grpc.CallOption) (LoadTask, error) {
req := option.Request()
var task LoadTask
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.LoadCollection(ctx, req, callOptions...)
if err = merr.CheckRPCCall(resp, err); err != nil {
return err
}
task = LoadTask{
client: c,
collectionName: req.GetCollectionName(),
interval: option.CheckInterval(),
}
return nil
})
return task, err
}
func (c *Client) LoadPartitions(ctx context.Context, option LoadPartitionsOption, callOptions ...grpc.CallOption) (LoadTask, error) {
req := option.Request()
var task LoadTask
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.LoadPartitions(ctx, req, callOptions...)
if err = merr.CheckRPCCall(resp, err); err != nil {
return err
}
task = LoadTask{
client: c,
collectionName: req.GetCollectionName(),
partitionNames: req.GetPartitionNames(),
interval: option.CheckInterval(),
}
return nil
})
return task, err
}
type FlushTask struct {
client *Client
collectionName string
segmentIDs []int64
flushTs uint64
interval time.Duration
}
func (t *FlushTask) Await(ctx context.Context) error {
ticker := time.NewTicker(t.interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
flushed := false
t.client.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{
CollectionName: t.collectionName,
SegmentIDs: t.segmentIDs,
FlushTs: t.flushTs,
})
err = merr.CheckRPCCall(resp, err)
if err != nil {
return err
}
flushed = resp.GetFlushed()
return nil
})
if flushed {
return nil
}
ticker.Reset(t.interval)
case <-ctx.Done():
return ctx.Err()
}
}
}
func (c *Client) Flush(ctx context.Context, option FlushOption, callOptions ...grpc.CallOption) (*FlushTask, error) {
req := option.Request()
collectionName := option.CollectionName()
var task *FlushTask
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.Flush(ctx, req, callOptions...)
err = merr.CheckRPCCall(resp, err)
if err != nil {
return err
}
task = &FlushTask{
client: c,
collectionName: collectionName,
segmentIDs: resp.GetCollSegIDs()[collectionName].GetData(),
flushTs: resp.GetCollFlushTs()[collectionName],
interval: option.CheckInterval(),
}
return nil
})
return task, err
}

View File

@ -0,0 +1,125 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"time"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
)
type LoadCollectionOption interface {
Request() *milvuspb.LoadCollectionRequest
CheckInterval() time.Duration
}
type loadCollectionOption struct {
collectionName string
interval time.Duration
replicaNum int
}
func (opt *loadCollectionOption) Request() *milvuspb.LoadCollectionRequest {
return &milvuspb.LoadCollectionRequest{
CollectionName: opt.collectionName,
ReplicaNumber: int32(opt.replicaNum),
}
}
func (opt *loadCollectionOption) CheckInterval() time.Duration {
return opt.interval
}
func (opt *loadCollectionOption) WithReplica(num int) *loadCollectionOption {
opt.replicaNum = num
return opt
}
func NewLoadCollectionOption(collectionName string) *loadCollectionOption {
return &loadCollectionOption{
collectionName: collectionName,
replicaNum: 1,
interval: time.Millisecond * 200,
}
}
type LoadPartitionsOption interface {
Request() *milvuspb.LoadPartitionsRequest
CheckInterval() time.Duration
}
var _ LoadPartitionsOption = (*loadPartitionsOption)(nil)
type loadPartitionsOption struct {
collectionName string
partitionNames []string
interval time.Duration
replicaNum int
}
func (opt *loadPartitionsOption) Request() *milvuspb.LoadPartitionsRequest {
return &milvuspb.LoadPartitionsRequest{
CollectionName: opt.collectionName,
PartitionNames: opt.partitionNames,
ReplicaNumber: int32(opt.replicaNum),
}
}
func (opt *loadPartitionsOption) CheckInterval() time.Duration {
return opt.interval
}
func NewLoadPartitionsOption(collectionName string, partitionsNames []string) *loadPartitionsOption {
return &loadPartitionsOption{
collectionName: collectionName,
partitionNames: partitionsNames,
replicaNum: 1,
interval: time.Millisecond * 200,
}
}
type FlushOption interface {
Request() *milvuspb.FlushRequest
CollectionName() string
CheckInterval() time.Duration
}
type flushOption struct {
collectionName string
interval time.Duration
}
func (opt *flushOption) Request() *milvuspb.FlushRequest {
return &milvuspb.FlushRequest{
CollectionNames: []string{opt.collectionName},
}
}
func (opt *flushOption) CollectionName() string {
return opt.collectionName
}
func (opt *flushOption) CheckInterval() time.Duration {
return opt.interval
}
func NewFlushOption(collName string) *flushOption {
return &flushOption{
collectionName: collName,
interval: time.Millisecond * 200,
}
}

229
client/maintenance_test.go Normal file
View File

@ -0,0 +1,229 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"context"
"fmt"
"testing"
"time"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"go.uber.org/atomic"
)
type MaintenanceSuite struct {
MockSuiteBase
}
func (s *MaintenanceSuite) TestLoadCollection() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
done := atomic.NewBool(false)
s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, lcr *milvuspb.LoadCollectionRequest) (*commonpb.Status, error) {
s.Equal(collectionName, lcr.GetCollectionName())
return merr.Success(), nil
}).Once()
s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, glpr *milvuspb.GetLoadingProgressRequest) (*milvuspb.GetLoadingProgressResponse, error) {
s.Equal(collectionName, glpr.GetCollectionName())
progress := int64(50)
if done.Load() {
progress = 100
}
return &milvuspb.GetLoadingProgressResponse{
Status: merr.Success(),
Progress: progress,
}, nil
})
defer s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).Unset()
task, err := s.client.LoadCollection(ctx, NewLoadCollectionOption(collectionName))
s.NoError(err)
ch := make(chan struct{})
go func() {
defer close(ch)
err := task.Await(ctx)
s.NoError(err)
}()
select {
case <-ch:
s.FailNow("task done before index state set to finish")
case <-time.After(time.Second):
}
done.Store(true)
select {
case <-ch:
case <-time.After(time.Second):
s.FailNow("task not done after index set finished")
}
})
s.Run("failure", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
_, err := s.client.LoadCollection(ctx, NewLoadCollectionOption(collectionName))
s.Error(err)
})
}
func (s *MaintenanceSuite) TestLoadPartitions() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
partitionName := fmt.Sprintf("part_%s", s.randString(6))
done := atomic.NewBool(false)
s.mock.EXPECT().LoadPartitions(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, lpr *milvuspb.LoadPartitionsRequest) (*commonpb.Status, error) {
s.Equal(collectionName, lpr.GetCollectionName())
s.ElementsMatch([]string{partitionName}, lpr.GetPartitionNames())
return merr.Success(), nil
}).Once()
s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, glpr *milvuspb.GetLoadingProgressRequest) (*milvuspb.GetLoadingProgressResponse, error) {
s.Equal(collectionName, glpr.GetCollectionName())
s.ElementsMatch([]string{partitionName}, glpr.GetPartitionNames())
progress := int64(50)
if done.Load() {
progress = 100
}
return &milvuspb.GetLoadingProgressResponse{
Status: merr.Success(),
Progress: progress,
}, nil
})
defer s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).Unset()
task, err := s.client.LoadPartitions(ctx, NewLoadPartitionsOption(collectionName, []string{partitionName}))
s.NoError(err)
ch := make(chan struct{})
go func() {
defer close(ch)
err := task.Await(ctx)
s.NoError(err)
}()
select {
case <-ch:
s.FailNow("task done before index state set to finish")
case <-time.After(time.Second):
}
done.Store(true)
select {
case <-ch:
case <-time.After(time.Second):
s.FailNow("task not done after index set finished")
}
})
s.Run("failure", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
partitionName := fmt.Sprintf("part_%s", s.randString(6))
s.mock.EXPECT().LoadPartitions(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
_, err := s.client.LoadPartitions(ctx, NewLoadPartitionsOption(collectionName, []string{partitionName}))
s.Error(err)
})
}
func (s *MaintenanceSuite) TestFlush() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
done := atomic.NewBool(false)
s.mock.EXPECT().Flush(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, fr *milvuspb.FlushRequest) (*milvuspb.FlushResponse, error) {
s.ElementsMatch([]string{collectionName}, fr.GetCollectionNames())
return &milvuspb.FlushResponse{
Status: merr.Success(),
CollSegIDs: map[string]*schemapb.LongArray{
collectionName: {Data: []int64{1, 2, 3}},
},
CollFlushTs: map[string]uint64{collectionName: 321},
}, nil
}).Once()
s.mock.EXPECT().GetFlushState(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, gfsr *milvuspb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error) {
s.Equal(collectionName, gfsr.GetCollectionName())
s.ElementsMatch([]int64{1, 2, 3}, gfsr.GetSegmentIDs())
s.EqualValues(321, gfsr.GetFlushTs())
return &milvuspb.GetFlushStateResponse{
Status: merr.Success(),
Flushed: done.Load(),
}, nil
})
defer s.mock.EXPECT().GetFlushState(mock.Anything, mock.Anything).Unset()
task, err := s.client.Flush(ctx, NewFlushOption(collectionName))
s.NoError(err)
ch := make(chan struct{})
go func() {
defer close(ch)
err := task.Await(ctx)
s.NoError(err)
}()
select {
case <-ch:
s.FailNow("task done before index state set to finish")
case <-time.After(time.Second):
}
done.Store(true)
select {
case <-ch:
case <-time.After(time.Second):
s.FailNow("task not done after index set finished")
}
})
s.Run("failure", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
s.mock.EXPECT().Flush(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
_, err := s.client.Flush(ctx, NewFlushOption(collectionName))
s.Error(err)
})
}
func TestMaintenance(t *testing.T) {
suite.Run(t, new(MaintenanceSuite))
}

File diff suppressed because it is too large Load Diff

77
client/partition.go Normal file
View File

@ -0,0 +1,77 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"context"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/pkg/util/merr"
"google.golang.org/grpc"
)
// CreatePartition is the API for creating a partition for a collection.
func (c *Client) CreatePartition(ctx context.Context, opt CreatePartitionOption, callOptions ...grpc.CallOption) error {
req := opt.Request()
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.CreatePartition(ctx, req, callOptions...)
return merr.CheckRPCCall(resp, err)
})
return err
}
func (c *Client) DropPartition(ctx context.Context, opt DropPartitionOption, callOptions ...grpc.CallOption) error {
req := opt.Request()
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.DropPartition(ctx, req, callOptions...)
return merr.CheckRPCCall(resp, err)
})
return err
}
func (c *Client) HasPartition(ctx context.Context, opt HasPartitionOption, callOptions ...grpc.CallOption) (has bool, err error) {
req := opt.Request()
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.HasPartition(ctx, req, callOptions...)
err = merr.CheckRPCCall(resp, err)
if err != nil {
return err
}
has = resp.GetValue()
return nil
})
return has, err
}
func (c *Client) ListPartitions(ctx context.Context, opt ListPartitionsOption, callOptions ...grpc.CallOption) (partitionNames []string, err error) {
req := opt.Request()
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.ShowPartitions(ctx, req, callOptions...)
err = merr.CheckRPCCall(resp, err)
if err != nil {
return err
}
partitionNames = resp.GetPartitionNames()
return nil
})
return partitionNames, err
}

119
client/partition_options.go Normal file
View File

@ -0,0 +1,119 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
// CreatePartitionOption is the interface builds Create Partition request.
type CreatePartitionOption interface {
// Request is the method returns the composed request.
Request() *milvuspb.CreatePartitionRequest
}
type createPartitionOpt struct {
collectionName string
partitionName string
}
func (opt *createPartitionOpt) Request() *milvuspb.CreatePartitionRequest {
return &milvuspb.CreatePartitionRequest{
CollectionName: opt.collectionName,
PartitionName: opt.partitionName,
}
}
func NewCreatePartitionOption(collectionName string, partitionName string) *createPartitionOpt {
return &createPartitionOpt{
collectionName: collectionName,
partitionName: partitionName,
}
}
// DropPartitionOption is the interface that builds Drop Partition request.
type DropPartitionOption interface {
// Request is the method returns the composed request.
Request() *milvuspb.DropPartitionRequest
}
type dropPartitionOpt struct {
collectionName string
partitionName string
}
func (opt *dropPartitionOpt) Request() *milvuspb.DropPartitionRequest {
return &milvuspb.DropPartitionRequest{
CollectionName: opt.collectionName,
PartitionName: opt.partitionName,
}
}
func NewDropPartitionOption(collectionName string, partitionName string) *dropPartitionOpt {
return &dropPartitionOpt{
collectionName: collectionName,
partitionName: partitionName,
}
}
// HasPartitionOption is the interface builds HasPartition request.
type HasPartitionOption interface {
// Request is the method returns the composed request.
Request() *milvuspb.HasPartitionRequest
}
var _ HasPartitionOption = (*hasPartitionOpt)(nil)
type hasPartitionOpt struct {
collectionName string
partitionName string
}
func (opt *hasPartitionOpt) Request() *milvuspb.HasPartitionRequest {
return &milvuspb.HasPartitionRequest{
CollectionName: opt.collectionName,
PartitionName: opt.partitionName,
}
}
func NewHasPartitionOption(collectionName string, partitionName string) *hasPartitionOpt {
return &hasPartitionOpt{
collectionName: collectionName,
partitionName: partitionName,
}
}
// ListPartitionsOption is the interface builds List Partition request.
type ListPartitionsOption interface {
// Request is the method returns the composed request.
Request() *milvuspb.ShowPartitionsRequest
}
type listPartitionsOpt struct {
collectionName string
}
func (opt *listPartitionsOpt) Request() *milvuspb.ShowPartitionsRequest {
return &milvuspb.ShowPartitionsRequest{
CollectionName: opt.collectionName,
Type: milvuspb.ShowType_All,
}
}
func NewListPartitionOption(collectionName string) *listPartitionsOpt {
return &listPartitionsOpt{
collectionName: collectionName,
}
}

166
client/partition_test.go Normal file
View File

@ -0,0 +1,166 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"context"
"fmt"
"testing"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
)
type PartitionSuite struct {
MockSuiteBase
}
func (s *PartitionSuite) TestListPartitions() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
s.mock.EXPECT().ShowPartitions(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, spr *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) {
s.Equal(collectionName, spr.GetCollectionName())
return &milvuspb.ShowPartitionsResponse{
Status: merr.Success(),
PartitionNames: []string{"_default", "part_1"},
PartitionIDs: []int64{100, 101},
}, nil
}).Once()
names, err := s.client.ListPartitions(ctx, NewListPartitionOption(collectionName))
s.NoError(err)
s.ElementsMatch([]string{"_default", "part_1"}, names)
})
s.Run("failure", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
s.mock.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
_, err := s.client.ListPartitions(ctx, NewListPartitionOption(collectionName))
s.Error(err)
})
}
func (s *PartitionSuite) TestCreatePartition() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
partitionName := fmt.Sprintf("part_%s", s.randString(6))
s.mock.EXPECT().CreatePartition(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cpr *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) {
s.Equal(collectionName, cpr.GetCollectionName())
s.Equal(partitionName, cpr.GetPartitionName())
return merr.Success(), nil
}).Once()
err := s.client.CreatePartition(ctx, NewCreatePartitionOption(collectionName, partitionName))
s.NoError(err)
})
s.Run("success", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
partitionName := fmt.Sprintf("part_%s", s.randString(6))
s.mock.EXPECT().CreatePartition(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
err := s.client.CreatePartition(ctx, NewCreatePartitionOption(collectionName, partitionName))
s.Error(err)
})
}
func (s *PartitionSuite) TestHasPartition() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
partitionName := fmt.Sprintf("part_%s", s.randString(6))
s.mock.EXPECT().HasPartition(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, hpr *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) {
s.Equal(collectionName, hpr.GetCollectionName())
s.Equal(partitionName, hpr.GetPartitionName())
return &milvuspb.BoolResponse{Status: merr.Success()}, nil
}).Once()
has, err := s.client.HasPartition(ctx, NewHasPartitionOption(collectionName, partitionName))
s.NoError(err)
s.False(has)
s.mock.EXPECT().HasPartition(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, hpr *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) {
s.Equal(collectionName, hpr.GetCollectionName())
s.Equal(partitionName, hpr.GetPartitionName())
return &milvuspb.BoolResponse{
Status: merr.Success(),
Value: true,
}, nil
}).Once()
has, err = s.client.HasPartition(ctx, NewHasPartitionOption(collectionName, partitionName))
s.NoError(err)
s.True(has)
})
s.Run("failure", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
partitionName := fmt.Sprintf("part_%s", s.randString(6))
s.mock.EXPECT().HasPartition(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
_, err := s.client.HasPartition(ctx, NewHasPartitionOption(collectionName, partitionName))
s.Error(err)
})
}
func (s *PartitionSuite) TestDropPartition() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
partitionName := fmt.Sprintf("part_%s", s.randString(6))
s.mock.EXPECT().DropPartition(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, dpr *milvuspb.DropPartitionRequest) (*commonpb.Status, error) {
s.Equal(collectionName, dpr.GetCollectionName())
s.Equal(partitionName, dpr.GetPartitionName())
return merr.Success(), nil
}).Once()
err := s.client.DropPartition(ctx, NewDropPartitionOption(collectionName, partitionName))
s.NoError(err)
})
s.Run("failure", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
partitionName := fmt.Sprintf("part_%s", s.randString(6))
s.mock.EXPECT().DropPartition(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
err := s.client.DropPartition(ctx, NewDropPartitionOption(collectionName, partitionName))
s.Error(err)
})
}
func TestPartition(t *testing.T) {
suite.Run(t, new(PartitionSuite))
}

220
client/read.go Normal file
View File

@ -0,0 +1,220 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"context"
"google.golang.org/grpc"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/column"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/pkg/util/merr"
)
type ResultSets struct{}
type ResultSet struct {
ResultCount int // the returning entry count
GroupByValue any
IDs column.Column // auto generated id, can be mapped to the columns from `Insert` API
Fields DataSet // output field data
Scores []float32 // distance to the target vector
Err error // search error if any
}
// DataSet is an alias type for column slice.
type DataSet []column.Column
func (c *Client) Search(ctx context.Context, option SearchOption, callOptions ...grpc.CallOption) ([]ResultSet, error) {
req := option.Request()
collection, err := c.getCollection(ctx, req.GetCollectionName())
if err != nil {
return nil, err
}
var resultSets []ResultSet
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.Search(ctx, req, callOptions...)
err = merr.CheckRPCCall(resp, err)
if err != nil {
return err
}
resultSets, err = c.handleSearchResult(collection.Schema, req.GetOutputFields(), int(req.GetNq()), resp)
return err
})
return resultSets, err
}
func (c *Client) handleSearchResult(schema *entity.Schema, outputFields []string, nq int, resp *milvuspb.SearchResults) ([]ResultSet, error) {
var err error
sr := make([]ResultSet, 0, nq)
results := resp.GetResults()
offset := 0
fieldDataList := results.GetFieldsData()
gb := results.GetGroupByFieldValue()
var gbc column.Column
if gb != nil {
gbc, err = column.FieldDataColumn(gb, 0, -1)
if err != nil {
return nil, err
}
}
for i := 0; i < int(results.GetNumQueries()); i++ {
rc := int(results.GetTopks()[i]) // result entry count for current query
entry := ResultSet{
ResultCount: rc,
Scores: results.GetScores()[offset : offset+rc],
}
if gbc != nil {
entry.GroupByValue, _ = gbc.Get(i)
}
// parse result set if current nq is not empty
if rc > 0 {
entry.IDs, entry.Err = column.IDColumns(results.GetIds(), offset, offset+rc)
if entry.Err != nil {
offset += rc
continue
}
entry.Fields, entry.Err = c.parseSearchResult(schema, outputFields, fieldDataList, i, offset, offset+rc)
sr = append(sr, entry)
}
offset += rc
}
return sr, nil
}
func (c *Client) parseSearchResult(sch *entity.Schema, outputFields []string, fieldDataList []*schemapb.FieldData, _, from, to int) ([]column.Column, error) {
var wildcard bool
outputFields, wildcard = expandWildcard(sch, outputFields)
// duplicated name will have only one column now
outputSet := make(map[string]struct{})
for _, output := range outputFields {
outputSet[output] = struct{}{}
}
// fields := make(map[string]*schemapb.FieldData)
columns := make([]column.Column, 0, len(outputFields))
var dynamicColumn *column.ColumnJSONBytes
for _, fieldData := range fieldDataList {
col, err := column.FieldDataColumn(fieldData, from, to)
if err != nil {
return nil, err
}
if fieldData.GetIsDynamic() {
var ok bool
dynamicColumn, ok = col.(*column.ColumnJSONBytes)
if !ok {
return nil, errors.New("dynamic field not json")
}
// return json column only explicitly specified in output fields and not in wildcard mode
if _, ok := outputSet[fieldData.GetFieldName()]; !ok && !wildcard {
continue
}
}
// remove processed field
delete(outputSet, fieldData.GetFieldName())
columns = append(columns, col)
}
if len(outputSet) > 0 && dynamicColumn == nil {
var extraFields []string
for output := range outputSet {
extraFields = append(extraFields, output)
}
return nil, errors.Newf("extra output fields %v found and result does not dynamic field", extraFields)
}
// add dynamic column for extra fields
for outputField := range outputSet {
column := column.NewColumnDynamic(dynamicColumn, outputField)
columns = append(columns, column)
}
return columns, nil
}
func (c *Client) Query(ctx context.Context, option QueryOption, callOptions ...grpc.CallOption) (ResultSet, error) {
req := option.Request()
var resultSet ResultSet
collection, err := c.getCollection(ctx, req.GetCollectionName())
if err != nil {
return resultSet, err
}
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.Query(ctx, req, callOptions...)
err = merr.CheckRPCCall(resp, err)
if err != nil {
return err
}
columns, err := c.parseSearchResult(collection.Schema, resp.GetOutputFields(), resp.GetFieldsData(), 0, 0, -1)
if err != nil {
return err
}
resultSet = ResultSet{
Fields: columns,
}
if len(columns) > 0 {
resultSet.ResultCount = columns[0].Len()
}
return nil
})
return resultSet, err
}
func expandWildcard(schema *entity.Schema, outputFields []string) ([]string, bool) {
wildcard := false
for _, outputField := range outputFields {
if outputField == "*" {
wildcard = true
}
}
if !wildcard {
return outputFields, false
}
set := make(map[string]struct{})
result := make([]string, 0, len(schema.Fields))
for _, field := range schema.Fields {
result = append(result, field.Name)
set[field.Name] = struct{}{}
}
// add dynamic fields output
for _, output := range outputFields {
if output == "*" {
continue
}
_, ok := set[output]
if !ok {
result = append(result, output)
}
}
return result, true
}

250
client/read_options.go Normal file
View File

@ -0,0 +1,250 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"encoding/json"
"strconv"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/client/v2/entity"
)
const (
spAnnsField = `anns_field`
spTopK = `topk`
spOffset = `offset`
spParams = `params`
spMetricsType = `metric_type`
spRoundDecimal = `round_decimal`
spIgnoreGrowing = `ignore_growing`
spGroupBy = `group_by_field`
)
type SearchOption interface {
Request() *milvuspb.SearchRequest
}
var _ SearchOption = (*searchOption)(nil)
type searchOption struct {
collectionName string
partitionNames []string
topK int
offset int
outputFields []string
consistencyLevel entity.ConsistencyLevel
useDefaultConsistencyLevel bool
ignoreGrowing bool
expr string
// normal search request
request *annRequest
// TODO add sub request when support hybrid search
}
type annRequest struct {
vectors []entity.Vector
annField string
metricsType entity.MetricType
searchParam map[string]string
groupByField string
}
func (opt *searchOption) Request() *milvuspb.SearchRequest {
// TODO check whether search is hybrid after logic merged
return opt.prepareSearchRequest(opt.request)
}
func (opt *searchOption) prepareSearchRequest(annRequest *annRequest) *milvuspb.SearchRequest {
request := &milvuspb.SearchRequest{
CollectionName: opt.collectionName,
PartitionNames: opt.partitionNames,
Dsl: opt.expr,
DslType: commonpb.DslType_BoolExprV1,
ConsistencyLevel: commonpb.ConsistencyLevel(opt.consistencyLevel),
}
if annRequest != nil {
// nq
request.Nq = int64(len(annRequest.vectors))
// search param
bs, _ := json.Marshal(annRequest.searchParam)
request.SearchParams = entity.MapKvPairs(map[string]string{
spAnnsField: annRequest.annField,
spTopK: strconv.Itoa(opt.topK),
spOffset: strconv.Itoa(opt.offset),
spParams: string(bs),
spMetricsType: string(annRequest.metricsType),
spRoundDecimal: "-1",
spIgnoreGrowing: strconv.FormatBool(opt.ignoreGrowing),
spGroupBy: annRequest.groupByField,
})
// placeholder group
request.PlaceholderGroup = vector2PlaceholderGroupBytes(annRequest.vectors)
}
return request
}
func (opt *searchOption) WithFilter(expr string) *searchOption {
opt.expr = expr
return opt
}
func (opt *searchOption) WithOffset(offset int) *searchOption {
opt.offset = offset
return opt
}
func (opt *searchOption) WithOutputFields(fieldNames []string) *searchOption {
opt.outputFields = fieldNames
return opt
}
func (opt *searchOption) WithConsistencyLevel(consistencyLevel entity.ConsistencyLevel) *searchOption {
opt.consistencyLevel = consistencyLevel
opt.useDefaultConsistencyLevel = false
return opt
}
func (opt *searchOption) WithANNSField(annsField string) *searchOption {
opt.request.annField = annsField
return opt
}
func (opt *searchOption) WithPartitions(partitionNames []string) *searchOption {
opt.partitionNames = partitionNames
return opt
}
func NewSearchOption(collectionName string, limit int, vectors []entity.Vector) *searchOption {
return &searchOption{
collectionName: collectionName,
topK: limit,
request: &annRequest{
vectors: vectors,
},
useDefaultConsistencyLevel: true,
consistencyLevel: entity.ClBounded,
}
}
func vector2PlaceholderGroupBytes(vectors []entity.Vector) []byte {
phg := &commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{
vector2Placeholder(vectors),
},
}
bs, _ := proto.Marshal(phg)
return bs
}
func vector2Placeholder(vectors []entity.Vector) *commonpb.PlaceholderValue {
var placeHolderType commonpb.PlaceholderType
ph := &commonpb.PlaceholderValue{
Tag: "$0",
Values: make([][]byte, 0, len(vectors)),
}
if len(vectors) == 0 {
return ph
}
switch vectors[0].(type) {
case entity.FloatVector:
placeHolderType = commonpb.PlaceholderType_FloatVector
case entity.BinaryVector:
placeHolderType = commonpb.PlaceholderType_BinaryVector
case entity.BFloat16Vector:
placeHolderType = commonpb.PlaceholderType_BFloat16Vector
case entity.Float16Vector:
placeHolderType = commonpb.PlaceholderType_Float16Vector
case entity.SparseEmbedding:
placeHolderType = commonpb.PlaceholderType_SparseFloatVector
}
ph.Type = placeHolderType
for _, vector := range vectors {
ph.Values = append(ph.Values, vector.Serialize())
}
return ph
}
type QueryOption interface {
Request() *milvuspb.QueryRequest
}
type queryOption struct {
collectionName string
partitionNames []string
limit int
offset int
outputFields []string
consistencyLevel entity.ConsistencyLevel
useDefaultConsistencyLevel bool
ignoreGrowing bool
expr string
}
func (opt *queryOption) Request() *milvuspb.QueryRequest {
return &milvuspb.QueryRequest{
CollectionName: opt.collectionName,
PartitionNames: opt.partitionNames,
OutputFields: opt.outputFields,
Expr: opt.expr,
ConsistencyLevel: opt.consistencyLevel.CommonConsistencyLevel(),
}
}
func (opt *queryOption) WithFilter(expr string) *queryOption {
opt.expr = expr
return opt
}
func (opt *queryOption) WithOffset(offset int) *queryOption {
opt.offset = offset
return opt
}
func (opt *queryOption) WithOutputFields(fieldNames []string) *queryOption {
opt.outputFields = fieldNames
return opt
}
func (opt *queryOption) WithConsistencyLevel(consistencyLevel entity.ConsistencyLevel) *queryOption {
opt.consistencyLevel = consistencyLevel
opt.useDefaultConsistencyLevel = false
return opt
}
func (opt *queryOption) WithPartitions(partitionNames []string) *queryOption {
opt.partitionNames = partitionNames
return opt
}
func NewQueryOption(collectionName string) *queryOption {
return &queryOption{
collectionName: collectionName,
useDefaultConsistencyLevel: true,
consistencyLevel: entity.ClBounded,
}
}

154
client/read_test.go Normal file
View File

@ -0,0 +1,154 @@
package client
import (
"context"
"fmt"
"math/rand"
"testing"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/samber/lo"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
)
type ReadSuite struct {
MockSuiteBase
schema *entity.Schema
schemaDyn *entity.Schema
}
func (s *ReadSuite) SetupSuite() {
s.MockSuiteBase.SetupSuite()
s.schema = entity.NewSchema().
WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
WithField(entity.NewField().WithName("Vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128))
s.schemaDyn = entity.NewSchema().WithDynamicFieldEnabled(true).
WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128))
}
func (s *ReadSuite) TestSearch() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
partitionName := fmt.Sprintf("part_%s", s.randString(6))
s.setupCache(collectionName, s.schema)
s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
s.Equal(collectionName, sr.GetCollectionName())
s.ElementsMatch([]string{partitionName}, sr.GetPartitionNames())
return &milvuspb.SearchResults{
Status: merr.Success(),
Results: &schemapb.SearchResultData{
NumQueries: 1,
TopK: 10,
FieldsData: []*schemapb.FieldData{
s.getInt64FieldData("ID", []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
},
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
},
},
},
Scores: make([]float32, 10),
Topks: []int64{10},
},
}, nil
}).Once()
_, err := s.client.Search(ctx, NewSearchOption(collectionName, 10, []entity.Vector{
entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
return rand.Float32()
})),
}).WithPartitions([]string{partitionName}))
s.NoError(err)
})
s.Run("dynamic_schema", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
partitionName := fmt.Sprintf("part_%s", s.randString(6))
s.setupCache(collectionName, s.schemaDyn)
s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
return &milvuspb.SearchResults{
Status: merr.Success(),
Results: &schemapb.SearchResultData{
NumQueries: 1,
TopK: 2,
FieldsData: []*schemapb.FieldData{
s.getInt64FieldData("ID", []int64{1, 2}),
s.getJSONBytesFieldData("$meta", [][]byte{
[]byte(`{"A": 123, "B": "456"}`),
[]byte(`{"B": "abc", "A": 456}`),
}, true),
},
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1, 2},
},
},
},
Scores: make([]float32, 2),
Topks: []int64{2},
},
}, nil
}).Once()
_, err := s.client.Search(ctx, NewSearchOption(collectionName, 10, []entity.Vector{
entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
return rand.Float32()
})),
}).WithPartitions([]string{partitionName}))
s.NoError(err)
})
s.Run("failure", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
s.setupCache(collectionName, s.schemaDyn)
s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
return nil, merr.WrapErrServiceInternal("mocked")
}).Once()
_, err := s.client.Search(ctx, NewSearchOption(collectionName, 10, []entity.Vector{
entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
return rand.Float32()
})),
}))
s.Error(err)
})
}
func (s *ReadSuite) TestQuery() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
partitionName := fmt.Sprintf("part_%s", s.randString(6))
s.setupCache(collectionName, s.schema)
s.mock.EXPECT().Query(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, qr *milvuspb.QueryRequest) (*milvuspb.QueryResults, error) {
s.Equal(collectionName, qr.GetCollectionName())
return &milvuspb.QueryResults{}, nil
}).Once()
_, err := s.client.Query(ctx, NewQueryOption(collectionName).WithPartitions([]string{partitionName}))
s.NoError(err)
})
}
func TestRead(t *testing.T) {
suite.Run(t, new(ReadSuite))
}

74
client/write.go Normal file
View File

@ -0,0 +1,74 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"context"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/pkg/util/merr"
)
func (c *Client) Insert(ctx context.Context, option InsertOption, callOptions ...grpc.CallOption) error {
collection, err := c.getCollection(ctx, option.CollectionName())
if err != nil {
return err
}
req, err := option.InsertRequest(collection)
if err != nil {
return err
}
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.Insert(ctx, req, callOptions...)
return merr.CheckRPCCall(resp, err)
})
return err
}
func (c *Client) Delete(ctx context.Context, option DeleteOption, callOptions ...grpc.CallOption) error {
req := option.Request()
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.Delete(ctx, req, callOptions...)
if err = merr.CheckRPCCall(resp, err); err != nil {
return err
}
return nil
})
}
func (c *Client) Upsert(ctx context.Context, option UpsertOption, callOptions ...grpc.CallOption) error {
collection, err := c.getCollection(ctx, option.CollectionName())
if err != nil {
return err
}
req, err := option.UpsertRequest(collection)
if err != nil {
return err
}
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
resp, err := milvusService.Upsert(ctx, req, callOptions...)
if err = merr.CheckRPCCall(resp, err); err != nil {
return err
}
return nil
})
}

View File

@ -0,0 +1,39 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"fmt"
"testing"
"github.com/stretchr/testify/suite"
)
type DeleteOptionSuite struct {
MockSuiteBase
}
func (s *DeleteOptionSuite) TestBasic() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
opt := NewDeleteOption(collectionName)
s.Equal(collectionName, opt.Request().GetCollectionName())
}
func TestDeleteOption(t *testing.T) {
suite.Run(t, new(DeleteOptionSuite))
}

290
client/write_options.go Normal file
View File

@ -0,0 +1,290 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"encoding/json"
"fmt"
"strings"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/client/v2/column"
"github.com/milvus-io/milvus/client/v2/entity"
)
type InsertOption interface {
InsertRequest(coll *entity.Collection) (*milvuspb.InsertRequest, error)
CollectionName() string
}
type UpsertOption interface {
UpsertRequest(coll *entity.Collection) (*milvuspb.UpsertRequest, error)
CollectionName() string
}
var (
_ UpsertOption = (*columnBasedDataOption)(nil)
_ InsertOption = (*columnBasedDataOption)(nil)
)
type columnBasedDataOption struct {
collName string
partitionName string
columns []column.Column
}
func (opt *columnBasedDataOption) processInsertColumns(colSchema *entity.Schema, columns ...column.Column) ([]*schemapb.FieldData, int, error) {
// setup dynamic related var
isDynamic := colSchema.EnableDynamicField
// check columns and field matches
var rowSize int
mNameField := make(map[string]*entity.Field)
for _, field := range colSchema.Fields {
mNameField[field.Name] = field
}
mNameColumn := make(map[string]column.Column)
var dynamicColumns []column.Column
for _, col := range columns {
_, dup := mNameColumn[col.Name()]
if dup {
return nil, 0, fmt.Errorf("duplicated column %s found", col.Name())
}
l := col.Len()
if rowSize == 0 {
rowSize = l
} else {
if rowSize != l {
return nil, 0, errors.New("column size not match")
}
}
field, has := mNameField[col.Name()]
if !has {
if !isDynamic {
return nil, 0, fmt.Errorf("field %s does not exist in collection %s", col.Name(), colSchema.CollectionName)
}
// add to dynamic column list for further processing
dynamicColumns = append(dynamicColumns, col)
continue
}
mNameColumn[col.Name()] = col
if col.Type() != field.DataType {
return nil, 0, fmt.Errorf("param column %s has type %v but collection field definition is %v", col.Name(), col.FieldData(), field.DataType)
}
if field.DataType == entity.FieldTypeFloatVector || field.DataType == entity.FieldTypeBinaryVector {
dim := 0
switch column := col.(type) {
case *column.ColumnFloatVector:
dim = column.Dim()
case *column.ColumnBinaryVector:
dim = column.Dim()
}
if fmt.Sprintf("%d", dim) != field.TypeParams[entity.TypeParamDim] {
return nil, 0, fmt.Errorf("params column %s vector dim %d not match collection definition, which has dim of %s", field.Name, dim, field.TypeParams[entity.TypeParamDim])
}
}
}
// check all fixed field pass value
for _, field := range colSchema.Fields {
_, has := mNameColumn[field.Name]
if !has &&
!field.AutoID && !field.IsDynamic {
return nil, 0, fmt.Errorf("field %s not passed", field.Name)
}
}
fieldsData := make([]*schemapb.FieldData, 0, len(mNameColumn)+1)
for _, fixedColumn := range mNameColumn {
fieldsData = append(fieldsData, fixedColumn.FieldData())
}
if len(dynamicColumns) > 0 {
// use empty column name here
col, err := opt.mergeDynamicColumns("", rowSize, dynamicColumns)
if err != nil {
return nil, 0, err
}
fieldsData = append(fieldsData, col)
}
return fieldsData, rowSize, nil
}
func (opt *columnBasedDataOption) mergeDynamicColumns(dynamicName string, rowSize int, columns []column.Column) (*schemapb.FieldData, error) {
values := make([][]byte, 0, rowSize)
for i := 0; i < rowSize; i++ {
m := make(map[string]interface{})
for _, column := range columns {
// range guaranteed
m[column.Name()], _ = column.Get(i)
}
bs, err := json.Marshal(m)
if err != nil {
return nil, err
}
values = append(values, bs)
}
return &schemapb.FieldData{
Type: schemapb.DataType_JSON,
FieldName: dynamicName,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_JsonData{
JsonData: &schemapb.JSONArray{
Data: values,
},
},
},
},
IsDynamic: true,
}, nil
}
func (opt *columnBasedDataOption) WithColumns(columns ...column.Column) *columnBasedDataOption {
opt.columns = append(opt.columns, columns...)
return opt
}
func (opt *columnBasedDataOption) WithBoolColumn(colName string, data []bool) *columnBasedDataOption {
column := column.NewColumnBool(colName, data)
return opt.WithColumns(column)
}
func (opt *columnBasedDataOption) WithInt8Column(colName string, data []int8) *columnBasedDataOption {
column := column.NewColumnInt8(colName, data)
return opt.WithColumns(column)
}
func (opt *columnBasedDataOption) WithInt16Column(colName string, data []int16) *columnBasedDataOption {
column := column.NewColumnInt16(colName, data)
return opt.WithColumns(column)
}
func (opt *columnBasedDataOption) WithInt32Column(colName string, data []int32) *columnBasedDataOption {
column := column.NewColumnInt32(colName, data)
return opt.WithColumns(column)
}
func (opt *columnBasedDataOption) WithInt64Column(colName string, data []int64) *columnBasedDataOption {
column := column.NewColumnInt64(colName, data)
return opt.WithColumns(column)
}
func (opt *columnBasedDataOption) WithVarcharColumn(colName string, data []string) *columnBasedDataOption {
column := column.NewColumnVarChar(colName, data)
return opt.WithColumns(column)
}
func (opt *columnBasedDataOption) WithFloatVectorColumn(colName string, dim int, data [][]float32) *columnBasedDataOption {
column := column.NewColumnFloatVector(colName, dim, data)
return opt.WithColumns(column)
}
func (opt *columnBasedDataOption) WithBinaryVectorColumn(colName string, dim int, data [][]byte) *columnBasedDataOption {
column := column.NewColumnBinaryVector(colName, dim, data)
return opt.WithColumns(column)
}
func (opt *columnBasedDataOption) WithPartition(partitionName string) *columnBasedDataOption {
opt.partitionName = partitionName
return opt
}
func (opt *columnBasedDataOption) CollectionName() string {
return opt.collName
}
func (opt *columnBasedDataOption) InsertRequest(coll *entity.Collection) (*milvuspb.InsertRequest, error) {
fieldsData, rowNum, err := opt.processInsertColumns(coll.Schema, opt.columns...)
if err != nil {
return nil, err
}
return &milvuspb.InsertRequest{
CollectionName: opt.collName,
PartitionName: opt.partitionName,
FieldsData: fieldsData,
NumRows: uint32(rowNum),
}, nil
}
func (opt *columnBasedDataOption) UpsertRequest(coll *entity.Collection) (*milvuspb.UpsertRequest, error) {
fieldsData, rowNum, err := opt.processInsertColumns(coll.Schema, opt.columns...)
if err != nil {
return nil, err
}
return &milvuspb.UpsertRequest{
CollectionName: opt.collName,
PartitionName: opt.partitionName,
FieldsData: fieldsData,
NumRows: uint32(rowNum),
}, nil
}
func NewColumnBasedInsertOption(collName string, columns ...column.Column) *columnBasedDataOption {
return &columnBasedDataOption{
columns: columns,
collName: collName,
// leave partition name empty, using default partition
}
}
type DeleteOption interface {
Request() *milvuspb.DeleteRequest
}
type deleteOption struct {
collectionName string
partitionName string
expr string
}
func (opt *deleteOption) Request() *milvuspb.DeleteRequest {
return &milvuspb.DeleteRequest{
CollectionName: opt.collectionName,
PartitionName: opt.partitionName,
Expr: opt.expr,
}
}
func (opt *deleteOption) WithExpr(expr string) *deleteOption {
opt.expr = expr
return opt
}
func (opt *deleteOption) WithInt64IDs(fieldName string, ids []int64) *deleteOption {
opt.expr = fmt.Sprintf("%s in %s", fieldName, strings.Join(strings.Fields(fmt.Sprint(ids)), ","))
return opt
}
func (opt *deleteOption) WithStringIDs(fieldName string, ids []string) *deleteOption {
opt.expr = fmt.Sprintf("%s in [%s]", fieldName, strings.Join(lo.Map(ids, func(id string, _ int) string { return fmt.Sprintf("\"%s\"", id) }), ","))
return opt
}
func (opt *deleteOption) WithPartition(partitionName string) *deleteOption {
opt.partitionName = partitionName
return opt
}
func NewDeleteOption(collectionName string) *deleteOption {
return &deleteOption{collectionName: collectionName}
}

330
client/write_test.go Normal file
View File

@ -0,0 +1,330 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client
import (
"context"
"fmt"
"math/rand"
"testing"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/samber/lo"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
)
type WriteSuite struct {
MockSuiteBase
schema *entity.Schema
schemaDyn *entity.Schema
}
func (s *WriteSuite) SetupSuite() {
s.MockSuiteBase.SetupSuite()
s.schema = entity.NewSchema().
WithField(entity.NewField().WithName("id").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128))
s.schemaDyn = entity.NewSchema().WithDynamicFieldEnabled(true).
WithField(entity.NewField().WithName("id").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128))
}
func (s *WriteSuite) TestInsert() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
collName := fmt.Sprintf("coll_%s", s.randString(6))
partName := fmt.Sprintf("part_%s", s.randString(6))
s.setupCache(collName, s.schema)
s.mock.EXPECT().Insert(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ir *milvuspb.InsertRequest) (*milvuspb.MutationResult, error) {
s.Equal(collName, ir.GetCollectionName())
s.Equal(partName, ir.GetPartitionName())
s.Require().Len(ir.GetFieldsData(), 2)
s.EqualValues(3, ir.GetNumRows())
return &milvuspb.MutationResult{
Status: merr.Success(),
}, nil
}).Once()
err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
})).
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
s.NoError(err)
})
s.Run("dynamic_schema", func() {
collName := fmt.Sprintf("coll_%s", s.randString(6))
partName := fmt.Sprintf("part_%s", s.randString(6))
s.setupCache(collName, s.schemaDyn)
s.mock.EXPECT().Insert(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ir *milvuspb.InsertRequest) (*milvuspb.MutationResult, error) {
s.Equal(collName, ir.GetCollectionName())
s.Equal(partName, ir.GetPartitionName())
s.Require().Len(ir.GetFieldsData(), 3)
s.EqualValues(3, ir.GetNumRows())
return &milvuspb.MutationResult{
Status: merr.Success(),
}, nil
}).Once()
err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
})).
WithVarcharColumn("extra", []string{"a", "b", "c"}).
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
s.NoError(err)
})
s.Run("bad_input", func() {
collName := fmt.Sprintf("coll_%s", s.randString(6))
s.setupCache(collName, s.schema)
type badCase struct {
tag string
input InsertOption
}
cases := []badCase{
{
tag: "missing_column",
input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}),
},
{
tag: "row_count_not_match",
input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}).
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
})),
},
{
tag: "duplicated_columns",
input: NewColumnBasedInsertOption(collName).
WithInt64Column("id", []int64{1}).
WithInt64Column("id", []int64{2}).
WithFloatVectorColumn("vector", 128, lo.RepeatBy(1, func(i int) []float32 {
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
})),
},
{
tag: "different_data_type",
input: NewColumnBasedInsertOption(collName).
WithVarcharColumn("id", []string{"1"}).
WithFloatVectorColumn("vector", 128, lo.RepeatBy(1, func(i int) []float32 {
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
})),
},
}
for _, tc := range cases {
s.Run(tc.tag, func() {
err := s.client.Insert(ctx, tc.input)
s.Error(err)
})
}
})
s.Run("failure", func() {
collName := fmt.Sprintf("coll_%s", s.randString(6))
s.setupCache(collName, s.schema)
s.mock.EXPECT().Insert(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
})).
WithInt64Column("id", []int64{1, 2, 3}))
s.Error(err)
})
}
func (s *WriteSuite) TestUpsert() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
collName := fmt.Sprintf("coll_%s", s.randString(6))
partName := fmt.Sprintf("part_%s", s.randString(6))
s.setupCache(collName, s.schema)
s.mock.EXPECT().Upsert(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ur *milvuspb.UpsertRequest) (*milvuspb.MutationResult, error) {
s.Equal(collName, ur.GetCollectionName())
s.Equal(partName, ur.GetPartitionName())
s.Require().Len(ur.GetFieldsData(), 2)
s.EqualValues(3, ur.GetNumRows())
return &milvuspb.MutationResult{
Status: merr.Success(),
}, nil
}).Once()
err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
})).
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
s.NoError(err)
})
s.Run("dynamic_schema", func() {
collName := fmt.Sprintf("coll_%s", s.randString(6))
partName := fmt.Sprintf("part_%s", s.randString(6))
s.setupCache(collName, s.schemaDyn)
s.mock.EXPECT().Upsert(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ur *milvuspb.UpsertRequest) (*milvuspb.MutationResult, error) {
s.Equal(collName, ur.GetCollectionName())
s.Equal(partName, ur.GetPartitionName())
s.Require().Len(ur.GetFieldsData(), 3)
s.EqualValues(3, ur.GetNumRows())
return &milvuspb.MutationResult{
Status: merr.Success(),
}, nil
}).Once()
err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
})).
WithVarcharColumn("extra", []string{"a", "b", "c"}).
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
s.NoError(err)
})
s.Run("bad_input", func() {
collName := fmt.Sprintf("coll_%s", s.randString(6))
s.setupCache(collName, s.schema)
type badCase struct {
tag string
input UpsertOption
}
cases := []badCase{
{
tag: "missing_column",
input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}),
},
{
tag: "row_count_not_match",
input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}).
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
})),
},
{
tag: "duplicated_columns",
input: NewColumnBasedInsertOption(collName).
WithInt64Column("id", []int64{1}).
WithInt64Column("id", []int64{2}).
WithFloatVectorColumn("vector", 128, lo.RepeatBy(1, func(i int) []float32 {
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
})),
},
{
tag: "different_data_type",
input: NewColumnBasedInsertOption(collName).
WithVarcharColumn("id", []string{"1"}).
WithFloatVectorColumn("vector", 128, lo.RepeatBy(1, func(i int) []float32 {
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
})),
},
}
for _, tc := range cases {
s.Run(tc.tag, func() {
err := s.client.Upsert(ctx, tc.input)
s.Error(err)
})
}
})
s.Run("failure", func() {
collName := fmt.Sprintf("coll_%s", s.randString(6))
s.setupCache(collName, s.schema)
s.mock.EXPECT().Upsert(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
})).
WithInt64Column("id", []int64{1, 2, 3}))
s.Error(err)
})
}
func (s *WriteSuite) TestDelete() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("success", func() {
collName := fmt.Sprintf("coll_%s", s.randString(6))
partName := fmt.Sprintf("part_%s", s.randString(6))
type testCase struct {
tag string
input DeleteOption
expectExpr string
}
cases := []testCase{
{
tag: "raw_expr",
input: NewDeleteOption(collName).WithPartition(partName).WithExpr("id > 100"),
expectExpr: "id > 100",
},
{
tag: "int_ids",
input: NewDeleteOption(collName).WithPartition(partName).WithInt64IDs("id", []int64{1, 2, 3}),
expectExpr: "id in [1,2,3]",
},
{
tag: "str_ids",
input: NewDeleteOption(collName).WithPartition(partName).WithStringIDs("id", []string{"a", "b", "c"}),
expectExpr: `id in ["a","b","c"]`,
},
}
for _, tc := range cases {
s.Run(tc.tag, func() {
s.mock.EXPECT().Delete(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, dr *milvuspb.DeleteRequest) (*milvuspb.MutationResult, error) {
s.Equal(collName, dr.GetCollectionName())
s.Equal(partName, dr.GetPartitionName())
s.Equal(tc.expectExpr, dr.GetExpr())
return &milvuspb.MutationResult{
Status: merr.Success(),
}, nil
}).Once()
err := s.client.Delete(ctx, tc.input)
s.NoError(err)
})
}
})
}
func TestWrite(t *testing.T) {
suite.Run(t, new(WriteSuite))
}

View File

@ -58,6 +58,16 @@ for d in $(go list ./... | grep -v -e vendor -e kafka -e planparserv2/generated
fi
done
popd
# milvusclient
pushd client
for d in $(go list ./... | grep -v -e vendor -e kafka -e planparserv2/generated -e mocks); do
$TEST_CMD -race -tags dynamic -v -coverpkg=./... -coverprofile=profile.out -covermode=atomic "$d"
if [ -f profile.out ]; then
grep -v kafka profile.out | grep -v planparserv2/generated | grep -v mocks | sed '1d' >> ../${FILE_COVERAGE_INFO}
rm profile.out
fi
done
popd
endTime=`date +%s`
echo "Total time for go unittest:" $(($endTime-$beginTime)) "s"