mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
feat: Add milvusclient package and migrate GoSDK (#32907)
Related to #31293 Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
parent
855192eb3d
commit
244d2c04f6
2
.github/workflows/main.yaml
vendored
2
.github/workflows/main.yaml
vendored
@ -7,6 +7,7 @@ on:
|
||||
paths:
|
||||
- 'scripts/**'
|
||||
- 'internal/**'
|
||||
- 'client/**'
|
||||
- 'pkg/**'
|
||||
- 'cmd/**'
|
||||
- 'build/**'
|
||||
@ -24,6 +25,7 @@ on:
|
||||
- 'scripts/**'
|
||||
- 'internal/**'
|
||||
- 'pkg/**'
|
||||
- 'client/**'
|
||||
- 'cmd/**'
|
||||
- 'build/**'
|
||||
- 'tests/integration/**' # run integration test
|
||||
|
||||
33
client/Makefile
Normal file
33
client/Makefile
Normal file
@ -0,0 +1,33 @@
|
||||
# Licensed to the LF AI & Data foundation under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
GO ?= go
|
||||
PWD := $(shell pwd)
|
||||
GOPATH := $(shell $(GO) env GOPATH)
|
||||
SHELL := /bin/bash
|
||||
OBJPREFIX := "github.com/milvus-io/milvus/cmd/milvus/v2"
|
||||
|
||||
# TODO pass golangci-lint path
|
||||
lint:
|
||||
@echo "Running lint checks..."
|
||||
|
||||
unittest:
|
||||
@echo "Running unittests..."
|
||||
@(env bash $(PWD)/scripts/run_unittest.sh)
|
||||
|
||||
generate-mockery:
|
||||
@echo "Generating mockery Milvus service server"
|
||||
@../bin/mockery --srcpkg=github.com/milvus-io/milvus-proto/go-api/v2/milvuspb --name=MilvusServiceServer --filename=mock_milvus_server_test.go --output=. --outpkg=client --with-expecter
|
||||
6
client/OWNERS
Normal file
6
client/OWNERS
Normal file
@ -0,0 +1,6 @@
|
||||
reviewers:
|
||||
- congqixia
|
||||
|
||||
approvers:
|
||||
- maintainers
|
||||
|
||||
157
client/client.go
Normal file
157
client/client.go
Normal file
@ -0,0 +1,157 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gogo/status"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/client/v2/common"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
conn *grpc.ClientConn
|
||||
service milvuspb.MilvusServiceClient
|
||||
config *ClientConfig
|
||||
|
||||
collCache *CollectionCache
|
||||
}
|
||||
|
||||
func New(ctx context.Context, config *ClientConfig) (*Client, error) {
|
||||
if err := config.parse(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := &Client{
|
||||
config: config,
|
||||
}
|
||||
|
||||
// Parse remote address.
|
||||
addr := c.config.getParsedAddress()
|
||||
|
||||
// Parse grpc options
|
||||
options := c.config.getDialOption()
|
||||
|
||||
// Connect the grpc server.
|
||||
if err := c.connect(ctx, addr, options...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.collCache = NewCollectionCache(func(ctx context.Context, collName string) (*entity.Collection, error) {
|
||||
return c.DescribeCollection(ctx, NewDescribeCollectionOption(collName))
|
||||
})
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Client) Close(ctx context.Context) error {
|
||||
if c.conn == nil {
|
||||
return nil
|
||||
}
|
||||
err := c.conn.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.conn = nil
|
||||
c.service = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) connect(ctx context.Context, addr string, options ...grpc.DialOption) error {
|
||||
if addr == "" {
|
||||
return fmt.Errorf("address is empty")
|
||||
}
|
||||
conn, err := grpc.DialContext(ctx, addr, options...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.conn = conn
|
||||
c.service = milvuspb.NewMilvusServiceClient(c.conn)
|
||||
|
||||
if !c.config.DisableConn {
|
||||
err = c.connectInternal(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) connectInternal(ctx context.Context) error {
|
||||
hostName, err := os.Hostname()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req := &milvuspb.ConnectRequest{
|
||||
ClientInfo: &commonpb.ClientInfo{
|
||||
SdkType: "Golang",
|
||||
SdkVersion: common.SDKVersion,
|
||||
LocalTime: time.Now().String(),
|
||||
User: c.config.Username,
|
||||
Host: hostName,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := c.service.Connect(ctx, req)
|
||||
if err != nil {
|
||||
status, ok := status.FromError(err)
|
||||
if ok {
|
||||
if status.Code() == codes.Unimplemented {
|
||||
// disable unsupported feature
|
||||
c.config.addFlags(
|
||||
disableDatabase |
|
||||
disableJSON |
|
||||
disableParitionKey |
|
||||
disableDynamicSchema)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if !merr.Ok(resp.GetStatus()) {
|
||||
return merr.Error(resp.GetStatus())
|
||||
}
|
||||
|
||||
c.config.setServerInfo(resp.GetServerInfo().GetBuildTags())
|
||||
c.config.setIdentifier(strconv.FormatInt(resp.GetIdentifier(), 10))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) callService(fn func(milvusService milvuspb.MilvusServiceClient) error) error {
|
||||
service := c.service
|
||||
if service == nil {
|
||||
return merr.WrapErrServiceNotReady("SDK", 0, "not connected")
|
||||
}
|
||||
|
||||
return fn(c.service)
|
||||
}
|
||||
182
client/client_config.go
Normal file
182
client/client_config.go
Normal file
@ -0,0 +1,182 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/backoff"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
)
|
||||
|
||||
const (
|
||||
disableDatabase uint64 = 1 << iota
|
||||
disableJSON
|
||||
disableDynamicSchema
|
||||
disableParitionKey
|
||||
)
|
||||
|
||||
var regexValidScheme = regexp.MustCompile(`^https?:\/\/`)
|
||||
|
||||
// DefaultGrpcOpts is GRPC options for milvus client.
|
||||
var DefaultGrpcOpts = []grpc.DialOption{
|
||||
grpc.WithBlock(),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 5 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
PermitWithoutStream: true,
|
||||
}),
|
||||
grpc.WithConnectParams(grpc.ConnectParams{
|
||||
Backoff: backoff.Config{
|
||||
BaseDelay: 100 * time.Millisecond,
|
||||
Multiplier: 1.6,
|
||||
Jitter: 0.2,
|
||||
MaxDelay: 3 * time.Second,
|
||||
},
|
||||
MinConnectTimeout: 3 * time.Second,
|
||||
}),
|
||||
}
|
||||
|
||||
// ClientConfig for milvus client.
|
||||
type ClientConfig struct {
|
||||
Address string // Remote address, "localhost:19530".
|
||||
Username string // Username for auth.
|
||||
Password string // Password for auth.
|
||||
DBName string // DBName for this client.
|
||||
|
||||
EnableTLSAuth bool // Enable TLS Auth for transport security.
|
||||
APIKey string // API key
|
||||
|
||||
DialOptions []grpc.DialOption // Dial options for GRPC.
|
||||
|
||||
// RetryRateLimit *RetryRateLimitOption // option for retry on rate limit inteceptor
|
||||
|
||||
DisableConn bool
|
||||
|
||||
identifier string // Identifier for this connection
|
||||
ServerVersion string // ServerVersion
|
||||
parsedAddress *url.URL
|
||||
flags uint64 // internal flags
|
||||
}
|
||||
|
||||
func (cfg *ClientConfig) parse() error {
|
||||
// Prepend default fake tcp:// scheme for remote address.
|
||||
address := cfg.Address
|
||||
if !regexValidScheme.MatchString(address) {
|
||||
address = fmt.Sprintf("tcp://%s", address)
|
||||
}
|
||||
|
||||
remoteURL, err := url.Parse(address)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "milvus address parse fail")
|
||||
}
|
||||
// Remote Host should never be empty.
|
||||
if remoteURL.Host == "" {
|
||||
return errors.New("empty remote host of milvus address")
|
||||
}
|
||||
// Use DBName in remote url path.
|
||||
if cfg.DBName == "" {
|
||||
cfg.DBName = strings.TrimLeft(remoteURL.Path, "/")
|
||||
}
|
||||
// Always enable tls auth for https remote url.
|
||||
if remoteURL.Scheme == "https" {
|
||||
cfg.EnableTLSAuth = true
|
||||
}
|
||||
if remoteURL.Port() == "" && cfg.EnableTLSAuth {
|
||||
remoteURL.Host += ":443"
|
||||
}
|
||||
cfg.parsedAddress = remoteURL
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get parsed remote milvus address, should be called after parse was called.
|
||||
func (c *ClientConfig) getParsedAddress() string {
|
||||
return c.parsedAddress.Host
|
||||
}
|
||||
|
||||
// useDatabase change the inner db name.
|
||||
func (c *ClientConfig) useDatabase(dbName string) {
|
||||
c.DBName = dbName
|
||||
}
|
||||
|
||||
// useDatabase change the inner db name.
|
||||
func (c *ClientConfig) setIdentifier(identifier string) {
|
||||
c.identifier = identifier
|
||||
}
|
||||
|
||||
func (c *ClientConfig) setServerInfo(serverInfo string) {
|
||||
c.ServerVersion = serverInfo
|
||||
}
|
||||
|
||||
// Get parsed grpc dial options, should be called after parse was called.
|
||||
func (c *ClientConfig) getDialOption() []grpc.DialOption {
|
||||
options := c.DialOptions
|
||||
if c.DialOptions == nil {
|
||||
// Add default connection options.
|
||||
options = make([]grpc.DialOption, len(DefaultGrpcOpts))
|
||||
copy(options, DefaultGrpcOpts)
|
||||
}
|
||||
|
||||
// Construct dial option.
|
||||
if c.EnableTLSAuth {
|
||||
options = append(options, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})))
|
||||
} else {
|
||||
options = append(options, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
}
|
||||
|
||||
options = append(options,
|
||||
grpc.WithChainUnaryInterceptor(grpc_retry.UnaryClientInterceptor(
|
||||
grpc_retry.WithMax(6),
|
||||
grpc_retry.WithBackoff(func(attempt uint) time.Duration {
|
||||
return 60 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt)))
|
||||
}),
|
||||
grpc_retry.WithCodes(codes.Unavailable, codes.ResourceExhausted)),
|
||||
// c.getRetryOnRateLimitInterceptor(),
|
||||
))
|
||||
|
||||
// options = append(options, grpc.WithChainUnaryInterceptor(
|
||||
// createMetaDataUnaryInterceptor(c),
|
||||
// ))
|
||||
return options
|
||||
}
|
||||
|
||||
// func (c *ClientConfig) getRetryOnRateLimitInterceptor() grpc.UnaryClientInterceptor {
|
||||
// if c.RetryRateLimit == nil {
|
||||
// c.RetryRateLimit = c.defaultRetryRateLimitOption()
|
||||
// }
|
||||
|
||||
// return RetryOnRateLimitInterceptor(c.RetryRateLimit.MaxRetry, c.RetryRateLimit.MaxBackoff, func(ctx context.Context, attempt uint) time.Duration {
|
||||
// return 10 * time.Millisecond * time.Duration(math.Pow(3, float64(attempt)))
|
||||
// })
|
||||
// }
|
||||
|
||||
// func (c *ClientConfig) defaultRetryRateLimitOption() *RetryRateLimitOption {
|
||||
// return &RetryRateLimitOption{
|
||||
// MaxRetry: 75,
|
||||
// MaxBackoff: 3 * time.Second,
|
||||
// }
|
||||
// }
|
||||
|
||||
// addFlags set internal flags
|
||||
func (c *ClientConfig) addFlags(flags uint64) {
|
||||
c.flags |= flags
|
||||
}
|
||||
|
||||
// hasFlags check flags is set
|
||||
func (c *ClientConfig) hasFlags(flags uint64) bool {
|
||||
return (c.flags & flags) > 0
|
||||
}
|
||||
|
||||
func (c *ClientConfig) resetFlags(flags uint64) {
|
||||
c.flags &= ^flags
|
||||
}
|
||||
251
client/client_suite_test.go
Normal file
251
client/client_suite_test.go
Normal file
@ -0,0 +1,251 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/test/bufconn"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
)
|
||||
|
||||
const (
|
||||
bufSize = 1024 * 1024
|
||||
)
|
||||
|
||||
type MockSuiteBase struct {
|
||||
suite.Suite
|
||||
|
||||
lis *bufconn.Listener
|
||||
svr *grpc.Server
|
||||
mock *MilvusServiceServer
|
||||
|
||||
client *Client
|
||||
}
|
||||
|
||||
func (s *MockSuiteBase) SetupSuite() {
|
||||
s.lis = bufconn.Listen(bufSize)
|
||||
s.svr = grpc.NewServer()
|
||||
|
||||
s.mock = &MilvusServiceServer{}
|
||||
|
||||
milvuspb.RegisterMilvusServiceServer(s.svr, s.mock)
|
||||
|
||||
go func() {
|
||||
s.T().Log("start mock server")
|
||||
if err := s.svr.Serve(s.lis); err != nil {
|
||||
s.Fail("failed to start mock server", err.Error())
|
||||
}
|
||||
}()
|
||||
s.setupConnect()
|
||||
}
|
||||
|
||||
func (s *MockSuiteBase) TearDownSuite() {
|
||||
s.svr.Stop()
|
||||
s.lis.Close()
|
||||
}
|
||||
|
||||
func (s *MockSuiteBase) mockDialer(context.Context, string) (net.Conn, error) {
|
||||
return s.lis.Dial()
|
||||
}
|
||||
|
||||
func (s *MockSuiteBase) SetupTest() {
|
||||
c, err := New(context.Background(), &ClientConfig{
|
||||
Address: "bufnet",
|
||||
DialOptions: []grpc.DialOption{
|
||||
grpc.WithBlock(),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithContextDialer(s.mockDialer),
|
||||
},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.setupConnect()
|
||||
|
||||
s.client = c
|
||||
}
|
||||
|
||||
func (s *MockSuiteBase) TearDownTest() {
|
||||
s.client.Close(context.Background())
|
||||
s.client = nil
|
||||
}
|
||||
|
||||
func (s *MockSuiteBase) resetMock() {
|
||||
// MetaCache.reset()
|
||||
if s.mock != nil {
|
||||
s.mock.Calls = nil
|
||||
s.mock.ExpectedCalls = nil
|
||||
s.setupConnect()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MockSuiteBase) setupConnect() {
|
||||
s.mock.EXPECT().Connect(mock.Anything, mock.AnythingOfType("*milvuspb.ConnectRequest")).
|
||||
Return(&milvuspb.ConnectResponse{
|
||||
Status: &commonpb.Status{},
|
||||
Identifier: 1,
|
||||
}, nil).Maybe()
|
||||
}
|
||||
|
||||
func (s *MockSuiteBase) setupCache(collName string, schema *entity.Schema) {
|
||||
s.client.collCache.collections.Insert(collName, &entity.Collection{
|
||||
Name: collName,
|
||||
Schema: schema,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *MockSuiteBase) setupHasCollection(collNames ...string) {
|
||||
s.mock.EXPECT().HasCollection(mock.Anything, mock.AnythingOfType("*milvuspb.HasCollectionRequest")).
|
||||
Call.Return(func(ctx context.Context, req *milvuspb.HasCollectionRequest) *milvuspb.BoolResponse {
|
||||
resp := &milvuspb.BoolResponse{Status: &commonpb.Status{}}
|
||||
for _, collName := range collNames {
|
||||
if req.GetCollectionName() == collName {
|
||||
resp.Value = true
|
||||
break
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}, nil)
|
||||
}
|
||||
|
||||
func (s *MockSuiteBase) setupHasCollectionError(errorCode commonpb.ErrorCode, err error) {
|
||||
s.mock.EXPECT().HasCollection(mock.Anything, mock.AnythingOfType("*milvuspb.HasCollectionRequest")).
|
||||
Return(&milvuspb.BoolResponse{
|
||||
Status: &commonpb.Status{ErrorCode: errorCode},
|
||||
}, err)
|
||||
}
|
||||
|
||||
func (s *MockSuiteBase) setupHasPartition(collName string, partNames ...string) {
|
||||
s.mock.EXPECT().HasPartition(mock.Anything, mock.AnythingOfType("*milvuspb.HasPartitionRequest")).
|
||||
Call.Return(func(ctx context.Context, req *milvuspb.HasPartitionRequest) *milvuspb.BoolResponse {
|
||||
resp := &milvuspb.BoolResponse{Status: &commonpb.Status{}}
|
||||
if req.GetCollectionName() == collName {
|
||||
for _, partName := range partNames {
|
||||
if req.GetPartitionName() == partName {
|
||||
resp.Value = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}, nil)
|
||||
}
|
||||
|
||||
func (s *MockSuiteBase) setupHasPartitionError(errorCode commonpb.ErrorCode, err error) {
|
||||
s.mock.EXPECT().HasPartition(mock.Anything, mock.AnythingOfType("*milvuspb.HasPartitionRequest")).
|
||||
Return(&milvuspb.BoolResponse{
|
||||
Status: &commonpb.Status{ErrorCode: errorCode},
|
||||
}, err)
|
||||
}
|
||||
|
||||
func (s *MockSuiteBase) setupDescribeCollection(_ string, schema *entity.Schema) {
|
||||
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.AnythingOfType("*milvuspb.DescribeCollectionRequest")).
|
||||
Call.Return(func(ctx context.Context, req *milvuspb.DescribeCollectionRequest) *milvuspb.DescribeCollectionResponse {
|
||||
return &milvuspb.DescribeCollectionResponse{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
|
||||
Schema: schema.ProtoMessage(),
|
||||
}
|
||||
}, nil)
|
||||
}
|
||||
|
||||
func (s *MockSuiteBase) setupDescribeCollectionError(errorCode commonpb.ErrorCode, err error) {
|
||||
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.AnythingOfType("*milvuspb.DescribeCollectionRequest")).
|
||||
Return(&milvuspb.DescribeCollectionResponse{
|
||||
Status: &commonpb.Status{ErrorCode: errorCode},
|
||||
}, err)
|
||||
}
|
||||
|
||||
func (s *MockSuiteBase) getInt64FieldData(name string, data []int64) *schemapb.FieldData {
|
||||
return &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: name,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: data,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MockSuiteBase) getVarcharFieldData(name string, data []string) *schemapb.FieldData {
|
||||
return &schemapb.FieldData{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
FieldName: name,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: data,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MockSuiteBase) getJSONBytesFieldData(name string, data [][]byte, isDynamic bool) *schemapb.FieldData {
|
||||
return &schemapb.FieldData{
|
||||
Type: schemapb.DataType_JSON,
|
||||
FieldName: name,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_JsonData{
|
||||
JsonData: &schemapb.JSONArray{
|
||||
Data: data,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
IsDynamic: isDynamic,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MockSuiteBase) getFloatVectorFieldData(name string, dim int64, data []float32) *schemapb.FieldData {
|
||||
return &schemapb.FieldData{
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
FieldName: name,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: dim,
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: data,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MockSuiteBase) getSuccessStatus() *commonpb.Status {
|
||||
return s.getStatus(commonpb.ErrorCode_Success, "")
|
||||
}
|
||||
|
||||
func (s *MockSuiteBase) getStatus(code commonpb.ErrorCode, reason string) *commonpb.Status {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: code,
|
||||
Reason: reason,
|
||||
}
|
||||
}
|
||||
|
||||
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||
|
||||
func (s *MockSuiteBase) randString(l int) string {
|
||||
builder := strings.Builder{}
|
||||
for i := 0; i < l; i++ {
|
||||
builder.WriteRune(letters[rand.Intn(len(letters))])
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
43
client/client_test.go
Normal file
43
client/client_test.go
Normal file
@ -0,0 +1,43 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
|
||||
type ClientSuite struct {
|
||||
MockSuiteBase
|
||||
}
|
||||
|
||||
func (s *ClientSuite) TestNewClient() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("Use bufconn dailer, testing case", func() {
|
||||
c, err := New(ctx,
|
||||
&ClientConfig{
|
||||
Address: "bufnet",
|
||||
DialOptions: []grpc.DialOption{
|
||||
grpc.WithBlock(),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithContextDialer(s.mockDialer),
|
||||
},
|
||||
})
|
||||
s.NoError(err)
|
||||
s.NotNil(c)
|
||||
})
|
||||
|
||||
s.Run("emtpy_addr", func() {
|
||||
_, err := New(ctx, &ClientConfig{})
|
||||
s.Error(err)
|
||||
s.T().Log(err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClient(t *testing.T) {
|
||||
suite.Run(t, new(ClientSuite))
|
||||
}
|
||||
134
client/collection.go
Normal file
134
client/collection.go
Normal file
@ -0,0 +1,134 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
// CreateCollection is the API for create a collection in Milvus.
|
||||
func (c *Client) CreateCollection(ctx context.Context, option CreateCollectionOption, callOptions ...grpc.CallOption) error {
|
||||
req := option.Request()
|
||||
|
||||
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.CreateCollection(ctx, req, callOptions...)
|
||||
return merr.CheckRPCCall(resp, err)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
indexes := option.Indexes()
|
||||
for _, indexOption := range indexes {
|
||||
task, err := c.CreateIndex(ctx, indexOption, callOptions...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = task.Await(ctx)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if option.IsFast() {
|
||||
task, err := c.LoadCollection(ctx, NewLoadCollectionOption(req.GetCollectionName()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return task.Await(ctx)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type ListCollectionOption interface {
|
||||
Request() *milvuspb.ShowCollectionsRequest
|
||||
}
|
||||
|
||||
func (c *Client) ListCollections(ctx context.Context, option ListCollectionOption, callOptions ...grpc.CallOption) (collectionNames []string, err error) {
|
||||
req := option.Request()
|
||||
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.ShowCollections(ctx, req, callOptions...)
|
||||
err = merr.CheckRPCCall(resp, err)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
collectionNames = resp.GetCollectionNames()
|
||||
return nil
|
||||
})
|
||||
|
||||
return collectionNames, err
|
||||
}
|
||||
|
||||
func (c *Client) DescribeCollection(ctx context.Context, option *describeCollectionOption, callOptions ...grpc.CallOption) (collection *entity.Collection, err error) {
|
||||
req := option.Request()
|
||||
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.DescribeCollection(ctx, req, callOptions...)
|
||||
err = merr.CheckRPCCall(resp, err)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
collection = &entity.Collection{
|
||||
ID: resp.GetCollectionID(),
|
||||
Schema: entity.NewSchema().ReadProto(resp.GetSchema()),
|
||||
PhysicalChannels: resp.GetPhysicalChannelNames(),
|
||||
VirtualChannels: resp.GetVirtualChannelNames(),
|
||||
ConsistencyLevel: entity.ConsistencyLevel(resp.ConsistencyLevel),
|
||||
ShardNum: resp.GetShardsNum(),
|
||||
}
|
||||
collection.Name = collection.Schema.CollectionName
|
||||
return nil
|
||||
})
|
||||
|
||||
return collection, err
|
||||
}
|
||||
|
||||
func (c *Client) HasCollection(ctx context.Context, option HasCollectionOption, callOptions ...grpc.CallOption) (has bool, err error) {
|
||||
req := option.Request()
|
||||
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.DescribeCollection(ctx, req, callOptions...)
|
||||
err = merr.CheckRPCCall(resp, err)
|
||||
if err != nil {
|
||||
// ErrCollectionNotFound for collection not exist
|
||||
if errors.Is(err, merr.ErrCollectionNotFound) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
has = true
|
||||
return nil
|
||||
})
|
||||
return has, err
|
||||
}
|
||||
|
||||
func (c *Client) DropCollection(ctx context.Context, option DropCollectionOption, callOptions ...grpc.CallOption) error {
|
||||
req := option.Request()
|
||||
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.DropCollection(ctx, req, callOptions...)
|
||||
return merr.CheckRPCCall(resp, err)
|
||||
})
|
||||
return err
|
||||
}
|
||||
232
client/collection_options.go
Normal file
232
client/collection_options.go
Normal file
@ -0,0 +1,232 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
"github.com/milvus-io/milvus/client/v2/index"
|
||||
)
|
||||
|
||||
// CreateCollectionOption is the interface builds CreateCollectionRequest.
|
||||
type CreateCollectionOption interface {
|
||||
// Request is the method returns the composed request.
|
||||
Request() *milvuspb.CreateCollectionRequest
|
||||
// Indexes is the method returns IndexOption to create
|
||||
Indexes() []CreateIndexOption
|
||||
IsFast() bool
|
||||
}
|
||||
|
||||
// createCollectionOption contains all the parameters to create collection.
|
||||
type createCollectionOption struct {
|
||||
name string
|
||||
shardNum int32
|
||||
|
||||
// fast create collection params
|
||||
varcharPK bool
|
||||
varcharPKMaxLength int
|
||||
pkFieldName string
|
||||
vectorFieldName string
|
||||
dim int64
|
||||
autoID bool
|
||||
enabledDynamicSchema bool
|
||||
|
||||
// advanced create collection params
|
||||
schema *entity.Schema
|
||||
consistencyLevel entity.ConsistencyLevel
|
||||
properties map[string]string
|
||||
|
||||
// partition key
|
||||
numPartitions int64
|
||||
|
||||
// is fast create collection
|
||||
isFast bool
|
||||
// fast creation with index
|
||||
metricType entity.MetricType
|
||||
}
|
||||
|
||||
func (opt *createCollectionOption) WithAutoID(autoID bool) *createCollectionOption {
|
||||
opt.autoID = autoID
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *createCollectionOption) WithShardNum(shardNum int32) *createCollectionOption {
|
||||
opt.shardNum = shardNum
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *createCollectionOption) WithDynamicSchema(dynamicSchema bool) *createCollectionOption {
|
||||
opt.enabledDynamicSchema = dynamicSchema
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *createCollectionOption) WithVarcharPK(varcharPK bool, maxLen int) *createCollectionOption {
|
||||
opt.varcharPK = varcharPK
|
||||
opt.varcharPKMaxLength = maxLen
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *createCollectionOption) Request() *milvuspb.CreateCollectionRequest {
|
||||
// fast create collection
|
||||
if opt.isFast || opt.schema == nil {
|
||||
var pkField *entity.Field
|
||||
if opt.varcharPK {
|
||||
pkField = entity.NewField().WithDataType(entity.FieldTypeVarChar).WithMaxLength(int64(opt.varcharPKMaxLength))
|
||||
} else {
|
||||
pkField = entity.NewField().WithDataType(entity.FieldTypeInt64)
|
||||
}
|
||||
pkField = pkField.WithName(opt.pkFieldName).WithIsPrimaryKey(true).WithIsAutoID(opt.autoID)
|
||||
opt.schema = entity.NewSchema().
|
||||
WithName(opt.name).
|
||||
WithAutoID(opt.autoID).
|
||||
WithDynamicFieldEnabled(opt.enabledDynamicSchema).
|
||||
WithField(pkField).
|
||||
WithField(entity.NewField().WithName(opt.vectorFieldName).WithDataType(entity.FieldTypeFloatVector).WithDim(opt.dim))
|
||||
}
|
||||
|
||||
schemaProto := opt.schema.ProtoMessage()
|
||||
schemaBytes, _ := proto.Marshal(schemaProto)
|
||||
|
||||
return &milvuspb.CreateCollectionRequest{
|
||||
DbName: "", // reserved fields, not used for now
|
||||
CollectionName: opt.name,
|
||||
Schema: schemaBytes,
|
||||
ShardsNum: opt.shardNum,
|
||||
ConsistencyLevel: commonpb.ConsistencyLevel(opt.consistencyLevel),
|
||||
NumPartitions: opt.numPartitions,
|
||||
Properties: entity.MapKvPairs(opt.properties),
|
||||
}
|
||||
}
|
||||
|
||||
func (opt *createCollectionOption) Indexes() []CreateIndexOption {
|
||||
// fast create
|
||||
if opt.isFast {
|
||||
return []CreateIndexOption{
|
||||
NewCreateIndexOption(opt.name, opt.vectorFieldName, index.NewGenericIndex("", map[string]string{})),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (opt *createCollectionOption) IsFast() bool {
|
||||
return opt.isFast
|
||||
}
|
||||
|
||||
// SimpleCreateCollectionOptions returns a CreateCollectionOption with default fast collection options.
|
||||
func SimpleCreateCollectionOptions(name string, dim int64) *createCollectionOption {
|
||||
return &createCollectionOption{
|
||||
name: name,
|
||||
shardNum: 1,
|
||||
|
||||
pkFieldName: "id",
|
||||
vectorFieldName: "vector",
|
||||
autoID: true,
|
||||
dim: dim,
|
||||
enabledDynamicSchema: true,
|
||||
|
||||
isFast: true,
|
||||
metricType: entity.COSINE,
|
||||
}
|
||||
}
|
||||
|
||||
// NewCreateCollectionOption returns a CreateCollectionOption with customized collection schema
|
||||
func NewCreateCollectionOption(name string, collectionSchema *entity.Schema) *createCollectionOption {
|
||||
return &createCollectionOption{
|
||||
name: name,
|
||||
shardNum: 1,
|
||||
schema: collectionSchema,
|
||||
|
||||
metricType: entity.COSINE,
|
||||
}
|
||||
}
|
||||
|
||||
type listCollectionOption struct{}
|
||||
|
||||
func (opt *listCollectionOption) Request() *milvuspb.ShowCollectionsRequest {
|
||||
return &milvuspb.ShowCollectionsRequest{}
|
||||
}
|
||||
|
||||
func NewListCollectionOption() *listCollectionOption {
|
||||
return &listCollectionOption{}
|
||||
}
|
||||
|
||||
// DescribeCollectionOption is the interface builds DescribeCollection request.
|
||||
type DescribeCollectionOption interface {
|
||||
// Request is the method returns the composed request.
|
||||
Request() *milvuspb.DescribeCollectionRequest
|
||||
}
|
||||
|
||||
type describeCollectionOption struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (opt *describeCollectionOption) Request() *milvuspb.DescribeCollectionRequest {
|
||||
return &milvuspb.DescribeCollectionRequest{
|
||||
CollectionName: opt.name,
|
||||
}
|
||||
}
|
||||
|
||||
// NewDescribeCollectionOption composes a describeCollectionOption with provided collection name.
|
||||
func NewDescribeCollectionOption(name string) *describeCollectionOption {
|
||||
return &describeCollectionOption{
|
||||
name: name,
|
||||
}
|
||||
}
|
||||
|
||||
// HasCollectionOption is the interface to build DescribeCollectionRequest.
|
||||
type HasCollectionOption interface {
|
||||
Request() *milvuspb.DescribeCollectionRequest
|
||||
}
|
||||
|
||||
type hasCollectionOpt struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (opt *hasCollectionOpt) Request() *milvuspb.DescribeCollectionRequest {
|
||||
return &milvuspb.DescribeCollectionRequest{
|
||||
CollectionName: opt.name,
|
||||
}
|
||||
}
|
||||
|
||||
func NewHasCollectionOption(name string) HasCollectionOption {
|
||||
return &hasCollectionOpt{
|
||||
name: name,
|
||||
}
|
||||
}
|
||||
|
||||
// The DropCollectionOption interface builds DropCollectionRequest.
|
||||
type DropCollectionOption interface {
|
||||
Request() *milvuspb.DropCollectionRequest
|
||||
}
|
||||
|
||||
type dropCollectionOption struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (opt *dropCollectionOption) Request() *milvuspb.DropCollectionRequest {
|
||||
return &milvuspb.DropCollectionRequest{
|
||||
CollectionName: opt.name,
|
||||
}
|
||||
}
|
||||
|
||||
func NewDropCollectionOption(name string) *dropCollectionOption {
|
||||
return &dropCollectionOption{
|
||||
name: name,
|
||||
}
|
||||
}
|
||||
253
client/collection_test.go
Normal file
253
client/collection_test.go
Normal file
@ -0,0 +1,253 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/samber/lo"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
type CollectionSuite struct {
|
||||
MockSuiteBase
|
||||
}
|
||||
|
||||
func (s *CollectionSuite) TestListCollection() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s.Run("success", func() {
|
||||
s.mock.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&milvuspb.ShowCollectionsResponse{
|
||||
CollectionNames: []string{"test1", "test2", "test3"},
|
||||
}, nil).Once()
|
||||
|
||||
names, err := s.client.ListCollections(ctx, NewListCollectionOption())
|
||||
s.NoError(err)
|
||||
s.ElementsMatch([]string{"test1", "test2", "test3"}, names)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
s.mock.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
_, err := s.client.ListCollections(ctx, NewListCollectionOption())
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *CollectionSuite) TestCreateCollection() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
s.mock.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(merr.Success(), nil).Once()
|
||||
s.mock.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(merr.Success(), nil).Once()
|
||||
s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(merr.Success(), nil).Once()
|
||||
s.mock.EXPECT().DescribeIndex(mock.Anything, mock.Anything).Return(&milvuspb.DescribeIndexResponse{
|
||||
Status: merr.Success(),
|
||||
IndexDescriptions: []*milvuspb.IndexDescription{
|
||||
{FieldName: "vector", State: commonpb.IndexState_Finished},
|
||||
},
|
||||
}, nil).Once()
|
||||
s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).Return(&milvuspb.GetLoadingProgressResponse{
|
||||
Status: merr.Success(),
|
||||
Progress: 100,
|
||||
}, nil).Once()
|
||||
|
||||
err := s.client.CreateCollection(ctx, SimpleCreateCollectionOptions("test_collection", 128))
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
s.mock.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
err := s.client.CreateCollection(ctx, SimpleCreateCollectionOptions("test_collection", 128))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *CollectionSuite) TestCreateCollectionOptions() {
|
||||
collectionName := fmt.Sprintf("test_collection_%s", s.randString(6))
|
||||
opt := SimpleCreateCollectionOptions(collectionName, 128)
|
||||
req := opt.Request()
|
||||
s.Equal(collectionName, req.GetCollectionName())
|
||||
s.EqualValues(1, req.GetShardsNum())
|
||||
|
||||
collSchema := &schemapb.CollectionSchema{}
|
||||
err := proto.Unmarshal(req.GetSchema(), collSchema)
|
||||
s.Require().NoError(err)
|
||||
s.True(collSchema.GetEnableDynamicField())
|
||||
|
||||
collectionName = fmt.Sprintf("test_collection_%s", s.randString(6))
|
||||
opt = SimpleCreateCollectionOptions(collectionName, 128).WithVarcharPK(true, 64).WithAutoID(false).WithDynamicSchema(false)
|
||||
req = opt.Request()
|
||||
s.Equal(collectionName, req.GetCollectionName())
|
||||
s.EqualValues(1, req.GetShardsNum())
|
||||
|
||||
collSchema = &schemapb.CollectionSchema{}
|
||||
err = proto.Unmarshal(req.GetSchema(), collSchema)
|
||||
s.Require().NoError(err)
|
||||
s.False(collSchema.GetEnableDynamicField())
|
||||
|
||||
collectionName = fmt.Sprintf("test_collection_%s", s.randString(6))
|
||||
schema := entity.NewSchema().
|
||||
WithField(entity.NewField().WithName("int64").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
|
||||
WithField(entity.NewField().WithName("vector").WithDim(128).WithDataType(entity.FieldTypeFloatVector))
|
||||
|
||||
opt = NewCreateCollectionOption(collectionName, schema).WithShardNum(2)
|
||||
|
||||
req = opt.Request()
|
||||
s.Equal(collectionName, req.GetCollectionName())
|
||||
s.EqualValues(2, req.GetShardsNum())
|
||||
|
||||
collSchema = &schemapb.CollectionSchema{}
|
||||
err = proto.Unmarshal(req.GetSchema(), collSchema)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *CollectionSuite) TestDescribeCollection() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
|
||||
Status: merr.Success(),
|
||||
Schema: &schemapb.CollectionSchema{
|
||||
Name: "test_collection",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, DataType: schemapb.DataType_Int64, AutoID: true, Name: "ID"},
|
||||
{
|
||||
FieldID: 101, DataType: schemapb.DataType_FloatVector, Name: "vector",
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "128"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
CollectionID: 1000,
|
||||
CollectionName: "test_collection",
|
||||
}, nil).Once()
|
||||
|
||||
coll, err := s.client.DescribeCollection(ctx, NewDescribeCollectionOption("test_collection"))
|
||||
s.NoError(err)
|
||||
|
||||
s.EqualValues(1000, coll.ID)
|
||||
s.Equal("test_collection", coll.Name)
|
||||
s.Len(coll.Schema.Fields, 2)
|
||||
idField, ok := lo.Find(coll.Schema.Fields, func(field *entity.Field) bool {
|
||||
return field.ID == 100
|
||||
})
|
||||
s.Require().True(ok)
|
||||
s.Equal("ID", idField.Name)
|
||||
s.Equal(entity.FieldTypeInt64, idField.DataType)
|
||||
s.True(idField.AutoID)
|
||||
|
||||
vectorField, ok := lo.Find(coll.Schema.Fields, func(field *entity.Field) bool {
|
||||
return field.ID == 101
|
||||
})
|
||||
s.Require().True(ok)
|
||||
s.Equal("vector", vectorField.Name)
|
||||
s.Equal(entity.FieldTypeFloatVector, vectorField.DataType)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
_, err := s.client.DescribeCollection(ctx, NewDescribeCollectionOption("test_collection"))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *CollectionSuite) TestHasCollection() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
|
||||
Status: merr.Success(),
|
||||
Schema: &schemapb.CollectionSchema{
|
||||
Name: "test_collection",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, DataType: schemapb.DataType_Int64, AutoID: true, Name: "ID"},
|
||||
{
|
||||
FieldID: 101, DataType: schemapb.DataType_FloatVector, Name: "vector",
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "128"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
CollectionID: 1000,
|
||||
CollectionName: "test_collection",
|
||||
}, nil).Once()
|
||||
|
||||
has, err := s.client.HasCollection(ctx, NewHasCollectionOption("test_collection"))
|
||||
s.NoError(err)
|
||||
|
||||
s.True(has)
|
||||
})
|
||||
|
||||
s.Run("collection_not_exist", func() {
|
||||
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
|
||||
Status: merr.Status(merr.WrapErrCollectionNotFound("test_collection")),
|
||||
}, nil).Once()
|
||||
|
||||
has, err := s.client.HasCollection(ctx, NewHasCollectionOption("test_collection"))
|
||||
s.NoError(err)
|
||||
|
||||
s.False(has)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
_, err := s.client.HasCollection(ctx, NewHasCollectionOption("test_collection"))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *CollectionSuite) TestDropCollection() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
s.mock.EXPECT().DropCollection(mock.Anything, mock.Anything).Return(merr.Success(), nil).Once()
|
||||
|
||||
err := s.client.DropCollection(ctx, NewDropCollectionOption("test_collection"))
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
s.mock.EXPECT().DropCollection(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
err := s.client.DropCollection(ctx, NewDropCollectionOption("test_collection"))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCollection(t *testing.T) {
|
||||
suite.Run(t, new(CollectionSuite))
|
||||
}
|
||||
125
client/column/array.go
Normal file
125
client/column/array.go
Normal file
@ -0,0 +1,125 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package column
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
)
|
||||
|
||||
// ColumnVarCharArray generated columns type for VarChar
|
||||
type ColumnVarCharArray struct {
|
||||
ColumnBase
|
||||
name string
|
||||
values [][][]byte
|
||||
}
|
||||
|
||||
// Name returns column name
|
||||
func (c *ColumnVarCharArray) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// Type returns column entity.FieldType
|
||||
func (c *ColumnVarCharArray) Type() entity.FieldType {
|
||||
return entity.FieldTypeArray
|
||||
}
|
||||
|
||||
// Len returns column values length
|
||||
func (c *ColumnVarCharArray) Len() int {
|
||||
return len(c.values)
|
||||
}
|
||||
|
||||
// Get returns value at index as interface{}.
|
||||
func (c *ColumnVarCharArray) Get(idx int) (interface{}, error) {
|
||||
var r []string // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// FieldData return column data mapped to schemapb.FieldData
|
||||
func (c *ColumnVarCharArray) FieldData() *schemapb.FieldData {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Array,
|
||||
FieldName: c.name,
|
||||
}
|
||||
|
||||
data := make([]*schemapb.ScalarField, 0, c.Len())
|
||||
for _, arr := range c.values {
|
||||
converted := make([]string, 0, c.Len())
|
||||
for i := 0; i < len(arr); i++ {
|
||||
converted = append(converted, string(arr[i]))
|
||||
}
|
||||
data = append(data, &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: converted,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_ArrayData{
|
||||
ArrayData: &schemapb.ArrayArray{
|
||||
Data: data,
|
||||
ElementType: schemapb.DataType_VarChar,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return fd
|
||||
}
|
||||
|
||||
// ValueByIdx returns value of the provided index
|
||||
// error occurs when index out of range
|
||||
func (c *ColumnVarCharArray) ValueByIdx(idx int) ([][]byte, error) {
|
||||
var r [][]byte // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// AppendValue append value into column
|
||||
func (c *ColumnVarCharArray) AppendValue(i interface{}) error {
|
||||
v, ok := i.([][]byte)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid type, expected []string, got %T", i)
|
||||
}
|
||||
c.values = append(c.values, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data returns column data
|
||||
func (c *ColumnVarCharArray) Data() [][][]byte {
|
||||
return c.values
|
||||
}
|
||||
|
||||
// NewColumnVarChar auto generated constructor
|
||||
func NewColumnVarCharArray(name string, values [][][]byte) *ColumnVarCharArray {
|
||||
return &ColumnVarCharArray{
|
||||
name: name,
|
||||
values: values,
|
||||
}
|
||||
}
|
||||
705
client/column/array_gen.go
Normal file
705
client/column/array_gen.go
Normal file
@ -0,0 +1,705 @@
|
||||
// Code generated by go generate; DO NOT EDIT
|
||||
// This file is generated by go generate
|
||||
|
||||
package column
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
)
|
||||
|
||||
// ColumnBoolArray generated columns type for Bool
|
||||
type ColumnBoolArray struct {
|
||||
ColumnBase
|
||||
name string
|
||||
values [][]bool
|
||||
}
|
||||
|
||||
// Name returns column name
|
||||
func (c *ColumnBoolArray) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// Type returns column entity.FieldType
|
||||
func (c *ColumnBoolArray) Type() entity.FieldType {
|
||||
return entity.FieldTypeArray
|
||||
}
|
||||
|
||||
// Len returns column values length
|
||||
func (c *ColumnBoolArray) Len() int {
|
||||
return len(c.values)
|
||||
}
|
||||
|
||||
// Get returns value at index as interface{}.
|
||||
func (c *ColumnBoolArray) Get(idx int) (interface{}, error) {
|
||||
var r []bool // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// FieldData return column data mapped to schemapb.FieldData
|
||||
func (c *ColumnBoolArray) FieldData() *schemapb.FieldData {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Array,
|
||||
FieldName: c.name,
|
||||
}
|
||||
|
||||
data := make([]*schemapb.ScalarField, 0, c.Len())
|
||||
for _, arr := range c.values {
|
||||
converted := make([]bool, 0, c.Len())
|
||||
for i := 0; i < len(arr); i++ {
|
||||
converted = append(converted, bool(arr[i]))
|
||||
}
|
||||
data = append(data, &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_BoolData{
|
||||
BoolData: &schemapb.BoolArray{
|
||||
Data: converted,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_ArrayData{
|
||||
ArrayData: &schemapb.ArrayArray{
|
||||
Data: data,
|
||||
ElementType: schemapb.DataType_Bool,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return fd
|
||||
}
|
||||
|
||||
// ValueByIdx returns value of the provided index
|
||||
// error occurs when index out of range
|
||||
func (c *ColumnBoolArray) ValueByIdx(idx int) ([]bool, error) {
|
||||
var r []bool // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// AppendValue append value into column
|
||||
func (c *ColumnBoolArray) AppendValue(i interface{}) error {
|
||||
v, ok := i.([]bool)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid type, expected []bool, got %T", i)
|
||||
}
|
||||
c.values = append(c.values, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data returns column data
|
||||
func (c *ColumnBoolArray) Data() [][]bool {
|
||||
return c.values
|
||||
}
|
||||
|
||||
// NewColumnBool auto generated constructor
|
||||
func NewColumnBoolArray(name string, values [][]bool) *ColumnBoolArray {
|
||||
return &ColumnBoolArray{
|
||||
name: name,
|
||||
values: values,
|
||||
}
|
||||
}
|
||||
|
||||
// ColumnInt8Array generated columns type for Int8
|
||||
type ColumnInt8Array struct {
|
||||
ColumnBase
|
||||
name string
|
||||
values [][]int8
|
||||
}
|
||||
|
||||
// Name returns column name
|
||||
func (c *ColumnInt8Array) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// Type returns column entity.FieldType
|
||||
func (c *ColumnInt8Array) Type() entity.FieldType {
|
||||
return entity.FieldTypeArray
|
||||
}
|
||||
|
||||
// Len returns column values length
|
||||
func (c *ColumnInt8Array) Len() int {
|
||||
return len(c.values)
|
||||
}
|
||||
|
||||
// Get returns value at index as interface{}.
|
||||
func (c *ColumnInt8Array) Get(idx int) (interface{}, error) {
|
||||
var r []int8 // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// FieldData return column data mapped to schemapb.FieldData
|
||||
func (c *ColumnInt8Array) FieldData() *schemapb.FieldData {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Array,
|
||||
FieldName: c.name,
|
||||
}
|
||||
|
||||
data := make([]*schemapb.ScalarField, 0, c.Len())
|
||||
for _, arr := range c.values {
|
||||
converted := make([]int32, 0, c.Len())
|
||||
for i := 0; i < len(arr); i++ {
|
||||
converted = append(converted, int32(arr[i]))
|
||||
}
|
||||
data = append(data, &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: converted,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_ArrayData{
|
||||
ArrayData: &schemapb.ArrayArray{
|
||||
Data: data,
|
||||
ElementType: schemapb.DataType_Int8,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return fd
|
||||
}
|
||||
|
||||
// ValueByIdx returns value of the provided index
|
||||
// error occurs when index out of range
|
||||
func (c *ColumnInt8Array) ValueByIdx(idx int) ([]int8, error) {
|
||||
var r []int8 // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// AppendValue append value into column
|
||||
func (c *ColumnInt8Array) AppendValue(i interface{}) error {
|
||||
v, ok := i.([]int8)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid type, expected []int8, got %T", i)
|
||||
}
|
||||
c.values = append(c.values, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data returns column data
|
||||
func (c *ColumnInt8Array) Data() [][]int8 {
|
||||
return c.values
|
||||
}
|
||||
|
||||
// NewColumnInt8 auto generated constructor
|
||||
func NewColumnInt8Array(name string, values [][]int8) *ColumnInt8Array {
|
||||
return &ColumnInt8Array{
|
||||
name: name,
|
||||
values: values,
|
||||
}
|
||||
}
|
||||
|
||||
// ColumnInt16Array generated columns type for Int16
|
||||
type ColumnInt16Array struct {
|
||||
ColumnBase
|
||||
name string
|
||||
values [][]int16
|
||||
}
|
||||
|
||||
// Name returns column name
|
||||
func (c *ColumnInt16Array) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// Type returns column entity.FieldType
|
||||
func (c *ColumnInt16Array) Type() entity.FieldType {
|
||||
return entity.FieldTypeArray
|
||||
}
|
||||
|
||||
// Len returns column values length
|
||||
func (c *ColumnInt16Array) Len() int {
|
||||
return len(c.values)
|
||||
}
|
||||
|
||||
// Get returns value at index as interface{}.
|
||||
func (c *ColumnInt16Array) Get(idx int) (interface{}, error) {
|
||||
var r []int16 // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// FieldData return column data mapped to schemapb.FieldData
|
||||
func (c *ColumnInt16Array) FieldData() *schemapb.FieldData {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Array,
|
||||
FieldName: c.name,
|
||||
}
|
||||
|
||||
data := make([]*schemapb.ScalarField, 0, c.Len())
|
||||
for _, arr := range c.values {
|
||||
converted := make([]int32, 0, c.Len())
|
||||
for i := 0; i < len(arr); i++ {
|
||||
converted = append(converted, int32(arr[i]))
|
||||
}
|
||||
data = append(data, &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: converted,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_ArrayData{
|
||||
ArrayData: &schemapb.ArrayArray{
|
||||
Data: data,
|
||||
ElementType: schemapb.DataType_Int16,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return fd
|
||||
}
|
||||
|
||||
// ValueByIdx returns value of the provided index
|
||||
// error occurs when index out of range
|
||||
func (c *ColumnInt16Array) ValueByIdx(idx int) ([]int16, error) {
|
||||
var r []int16 // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// AppendValue append value into column
|
||||
func (c *ColumnInt16Array) AppendValue(i interface{}) error {
|
||||
v, ok := i.([]int16)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid type, expected []int16, got %T", i)
|
||||
}
|
||||
c.values = append(c.values, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data returns column data
|
||||
func (c *ColumnInt16Array) Data() [][]int16 {
|
||||
return c.values
|
||||
}
|
||||
|
||||
// NewColumnInt16 auto generated constructor
|
||||
func NewColumnInt16Array(name string, values [][]int16) *ColumnInt16Array {
|
||||
return &ColumnInt16Array{
|
||||
name: name,
|
||||
values: values,
|
||||
}
|
||||
}
|
||||
|
||||
// ColumnInt32Array generated columns type for Int32
|
||||
type ColumnInt32Array struct {
|
||||
ColumnBase
|
||||
name string
|
||||
values [][]int32
|
||||
}
|
||||
|
||||
// Name returns column name
|
||||
func (c *ColumnInt32Array) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// Type returns column entity.FieldType
|
||||
func (c *ColumnInt32Array) Type() entity.FieldType {
|
||||
return entity.FieldTypeArray
|
||||
}
|
||||
|
||||
// Len returns column values length
|
||||
func (c *ColumnInt32Array) Len() int {
|
||||
return len(c.values)
|
||||
}
|
||||
|
||||
// Get returns value at index as interface{}.
|
||||
func (c *ColumnInt32Array) Get(idx int) (interface{}, error) {
|
||||
var r []int32 // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// FieldData return column data mapped to schemapb.FieldData
|
||||
func (c *ColumnInt32Array) FieldData() *schemapb.FieldData {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Array,
|
||||
FieldName: c.name,
|
||||
}
|
||||
|
||||
data := make([]*schemapb.ScalarField, 0, c.Len())
|
||||
for _, arr := range c.values {
|
||||
converted := make([]int32, 0, c.Len())
|
||||
for i := 0; i < len(arr); i++ {
|
||||
converted = append(converted, int32(arr[i]))
|
||||
}
|
||||
data = append(data, &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: converted,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_ArrayData{
|
||||
ArrayData: &schemapb.ArrayArray{
|
||||
Data: data,
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return fd
|
||||
}
|
||||
|
||||
// ValueByIdx returns value of the provided index
|
||||
// error occurs when index out of range
|
||||
func (c *ColumnInt32Array) ValueByIdx(idx int) ([]int32, error) {
|
||||
var r []int32 // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// AppendValue append value into column
|
||||
func (c *ColumnInt32Array) AppendValue(i interface{}) error {
|
||||
v, ok := i.([]int32)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid type, expected []int32, got %T", i)
|
||||
}
|
||||
c.values = append(c.values, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data returns column data
|
||||
func (c *ColumnInt32Array) Data() [][]int32 {
|
||||
return c.values
|
||||
}
|
||||
|
||||
// NewColumnInt32 auto generated constructor
|
||||
func NewColumnInt32Array(name string, values [][]int32) *ColumnInt32Array {
|
||||
return &ColumnInt32Array{
|
||||
name: name,
|
||||
values: values,
|
||||
}
|
||||
}
|
||||
|
||||
// ColumnInt64Array generated columns type for Int64
|
||||
type ColumnInt64Array struct {
|
||||
ColumnBase
|
||||
name string
|
||||
values [][]int64
|
||||
}
|
||||
|
||||
// Name returns column name
|
||||
func (c *ColumnInt64Array) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// Type returns column entity.FieldType
|
||||
func (c *ColumnInt64Array) Type() entity.FieldType {
|
||||
return entity.FieldTypeArray
|
||||
}
|
||||
|
||||
// Len returns column values length
|
||||
func (c *ColumnInt64Array) Len() int {
|
||||
return len(c.values)
|
||||
}
|
||||
|
||||
// Get returns value at index as interface{}.
|
||||
func (c *ColumnInt64Array) Get(idx int) (interface{}, error) {
|
||||
var r []int64 // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// FieldData return column data mapped to schemapb.FieldData
|
||||
func (c *ColumnInt64Array) FieldData() *schemapb.FieldData {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Array,
|
||||
FieldName: c.name,
|
||||
}
|
||||
|
||||
data := make([]*schemapb.ScalarField, 0, c.Len())
|
||||
for _, arr := range c.values {
|
||||
converted := make([]int64, 0, c.Len())
|
||||
for i := 0; i < len(arr); i++ {
|
||||
converted = append(converted, int64(arr[i]))
|
||||
}
|
||||
data = append(data, &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: converted,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_ArrayData{
|
||||
ArrayData: &schemapb.ArrayArray{
|
||||
Data: data,
|
||||
ElementType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return fd
|
||||
}
|
||||
|
||||
// ValueByIdx returns value of the provided index
|
||||
// error occurs when index out of range
|
||||
func (c *ColumnInt64Array) ValueByIdx(idx int) ([]int64, error) {
|
||||
var r []int64 // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// AppendValue append value into column
|
||||
func (c *ColumnInt64Array) AppendValue(i interface{}) error {
|
||||
v, ok := i.([]int64)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid type, expected []int64, got %T", i)
|
||||
}
|
||||
c.values = append(c.values, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data returns column data
|
||||
func (c *ColumnInt64Array) Data() [][]int64 {
|
||||
return c.values
|
||||
}
|
||||
|
||||
// NewColumnInt64 auto generated constructor
|
||||
func NewColumnInt64Array(name string, values [][]int64) *ColumnInt64Array {
|
||||
return &ColumnInt64Array{
|
||||
name: name,
|
||||
values: values,
|
||||
}
|
||||
}
|
||||
|
||||
// ColumnFloatArray generated columns type for Float
|
||||
type ColumnFloatArray struct {
|
||||
ColumnBase
|
||||
name string
|
||||
values [][]float32
|
||||
}
|
||||
|
||||
// Name returns column name
|
||||
func (c *ColumnFloatArray) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// Type returns column entity.FieldType
|
||||
func (c *ColumnFloatArray) Type() entity.FieldType {
|
||||
return entity.FieldTypeArray
|
||||
}
|
||||
|
||||
// Len returns column values length
|
||||
func (c *ColumnFloatArray) Len() int {
|
||||
return len(c.values)
|
||||
}
|
||||
|
||||
// Get returns value at index as interface{}.
|
||||
func (c *ColumnFloatArray) Get(idx int) (interface{}, error) {
|
||||
var r []float32 // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// FieldData return column data mapped to schemapb.FieldData
|
||||
func (c *ColumnFloatArray) FieldData() *schemapb.FieldData {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Array,
|
||||
FieldName: c.name,
|
||||
}
|
||||
|
||||
data := make([]*schemapb.ScalarField, 0, c.Len())
|
||||
for _, arr := range c.values {
|
||||
converted := make([]float32, 0, c.Len())
|
||||
for i := 0; i < len(arr); i++ {
|
||||
converted = append(converted, float32(arr[i]))
|
||||
}
|
||||
data = append(data, &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_FloatData{
|
||||
FloatData: &schemapb.FloatArray{
|
||||
Data: converted,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_ArrayData{
|
||||
ArrayData: &schemapb.ArrayArray{
|
||||
Data: data,
|
||||
ElementType: schemapb.DataType_Float,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return fd
|
||||
}
|
||||
|
||||
// ValueByIdx returns value of the provided index
|
||||
// error occurs when index out of range
|
||||
func (c *ColumnFloatArray) ValueByIdx(idx int) ([]float32, error) {
|
||||
var r []float32 // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// AppendValue append value into column
|
||||
func (c *ColumnFloatArray) AppendValue(i interface{}) error {
|
||||
v, ok := i.([]float32)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid type, expected []float32, got %T", i)
|
||||
}
|
||||
c.values = append(c.values, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data returns column data
|
||||
func (c *ColumnFloatArray) Data() [][]float32 {
|
||||
return c.values
|
||||
}
|
||||
|
||||
// NewColumnFloat auto generated constructor
|
||||
func NewColumnFloatArray(name string, values [][]float32) *ColumnFloatArray {
|
||||
return &ColumnFloatArray{
|
||||
name: name,
|
||||
values: values,
|
||||
}
|
||||
}
|
||||
|
||||
// ColumnDoubleArray generated columns type for Double
|
||||
type ColumnDoubleArray struct {
|
||||
ColumnBase
|
||||
name string
|
||||
values [][]float64
|
||||
}
|
||||
|
||||
// Name returns column name
|
||||
func (c *ColumnDoubleArray) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// Type returns column entity.FieldType
|
||||
func (c *ColumnDoubleArray) Type() entity.FieldType {
|
||||
return entity.FieldTypeArray
|
||||
}
|
||||
|
||||
// Len returns column values length
|
||||
func (c *ColumnDoubleArray) Len() int {
|
||||
return len(c.values)
|
||||
}
|
||||
|
||||
// Get returns value at index as interface{}.
|
||||
func (c *ColumnDoubleArray) Get(idx int) (interface{}, error) {
|
||||
var r []float64 // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// FieldData return column data mapped to schemapb.FieldData
|
||||
func (c *ColumnDoubleArray) FieldData() *schemapb.FieldData {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Array,
|
||||
FieldName: c.name,
|
||||
}
|
||||
|
||||
data := make([]*schemapb.ScalarField, 0, c.Len())
|
||||
for _, arr := range c.values {
|
||||
converted := make([]float64, 0, c.Len())
|
||||
for i := 0; i < len(arr); i++ {
|
||||
converted = append(converted, float64(arr[i]))
|
||||
}
|
||||
data = append(data, &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_DoubleData{
|
||||
DoubleData: &schemapb.DoubleArray{
|
||||
Data: converted,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_ArrayData{
|
||||
ArrayData: &schemapb.ArrayArray{
|
||||
Data: data,
|
||||
ElementType: schemapb.DataType_Double,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return fd
|
||||
}
|
||||
|
||||
// ValueByIdx returns value of the provided index
|
||||
// error occurs when index out of range
|
||||
func (c *ColumnDoubleArray) ValueByIdx(idx int) ([]float64, error) {
|
||||
var r []float64 // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// AppendValue append value into column
|
||||
func (c *ColumnDoubleArray) AppendValue(i interface{}) error {
|
||||
v, ok := i.([]float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid type, expected []float64, got %T", i)
|
||||
}
|
||||
c.values = append(c.values, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data returns column data
|
||||
func (c *ColumnDoubleArray) Data() [][]float64 {
|
||||
return c.values
|
||||
}
|
||||
|
||||
// NewColumnDouble auto generated constructor
|
||||
func NewColumnDoubleArray(name string, values [][]float64) *ColumnDoubleArray {
|
||||
return &ColumnDoubleArray{
|
||||
name: name,
|
||||
values: values,
|
||||
}
|
||||
}
|
||||
502
client/column/columns.go
Normal file
502
client/column/columns.go
Normal file
@ -0,0 +1,502 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package column
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
)
|
||||
|
||||
//go:generate go run gen/gen.go
|
||||
|
||||
// Column interface field type for column-based data frame
|
||||
type Column interface {
|
||||
Name() string
|
||||
Type() entity.FieldType
|
||||
Len() int
|
||||
FieldData() *schemapb.FieldData
|
||||
AppendValue(interface{}) error
|
||||
Get(int) (interface{}, error)
|
||||
GetAsInt64(int) (int64, error)
|
||||
GetAsString(int) (string, error)
|
||||
GetAsDouble(int) (float64, error)
|
||||
GetAsBool(int) (bool, error)
|
||||
}
|
||||
|
||||
// ColumnBase adds conversion methods support for fixed-type columns.
|
||||
type ColumnBase struct{}
|
||||
|
||||
func (b ColumnBase) GetAsInt64(_ int) (int64, error) {
|
||||
return 0, errors.New("conversion between fixed-type column not support")
|
||||
}
|
||||
|
||||
func (b ColumnBase) GetAsString(_ int) (string, error) {
|
||||
return "", errors.New("conversion between fixed-type column not support")
|
||||
}
|
||||
|
||||
func (b ColumnBase) GetAsDouble(_ int) (float64, error) {
|
||||
return 0, errors.New("conversion between fixed-type column not support")
|
||||
}
|
||||
|
||||
func (b ColumnBase) GetAsBool(_ int) (bool, error) {
|
||||
return false, errors.New("conversion between fixed-type column not support")
|
||||
}
|
||||
|
||||
var errFieldDataTypeNotMatch = errors.New("FieldData type not matched")
|
||||
|
||||
// IDColumns converts schemapb.IDs to corresponding column
|
||||
// currently Int64 / string may be in IDs
|
||||
func IDColumns(idField *schemapb.IDs, begin, end int) (Column, error) {
|
||||
var idColumn Column
|
||||
if idField == nil {
|
||||
return nil, errors.New("nil Ids from response")
|
||||
}
|
||||
switch field := idField.GetIdField().(type) {
|
||||
case *schemapb.IDs_IntId:
|
||||
if end >= 0 {
|
||||
idColumn = NewColumnInt64("", field.IntId.GetData()[begin:end])
|
||||
} else {
|
||||
idColumn = NewColumnInt64("", field.IntId.GetData()[begin:])
|
||||
}
|
||||
case *schemapb.IDs_StrId:
|
||||
if end >= 0 {
|
||||
idColumn = NewColumnVarChar("", field.StrId.GetData()[begin:end])
|
||||
} else {
|
||||
idColumn = NewColumnVarChar("", field.StrId.GetData()[begin:])
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported id type %v", field)
|
||||
}
|
||||
return idColumn, nil
|
||||
}
|
||||
|
||||
// FieldDataColumn converts schemapb.FieldData to Column, used int search result conversion logic
|
||||
// begin, end specifies the start and end positions
|
||||
func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) {
|
||||
switch fd.GetType() {
|
||||
case schemapb.DataType_Bool:
|
||||
data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_BoolData)
|
||||
if !ok {
|
||||
return nil, errFieldDataTypeNotMatch
|
||||
}
|
||||
if end < 0 {
|
||||
return NewColumnBool(fd.GetFieldName(), data.BoolData.GetData()[begin:]), nil
|
||||
}
|
||||
return NewColumnBool(fd.GetFieldName(), data.BoolData.GetData()[begin:end]), nil
|
||||
|
||||
case schemapb.DataType_Int8:
|
||||
data, ok := getIntData(fd)
|
||||
if !ok {
|
||||
return nil, errFieldDataTypeNotMatch
|
||||
}
|
||||
values := make([]int8, 0, len(data.IntData.GetData()))
|
||||
for _, v := range data.IntData.GetData() {
|
||||
values = append(values, int8(v))
|
||||
}
|
||||
|
||||
if end < 0 {
|
||||
return NewColumnInt8(fd.GetFieldName(), values[begin:]), nil
|
||||
}
|
||||
|
||||
return NewColumnInt8(fd.GetFieldName(), values[begin:end]), nil
|
||||
|
||||
case schemapb.DataType_Int16:
|
||||
data, ok := getIntData(fd)
|
||||
if !ok {
|
||||
return nil, errFieldDataTypeNotMatch
|
||||
}
|
||||
values := make([]int16, 0, len(data.IntData.GetData()))
|
||||
for _, v := range data.IntData.GetData() {
|
||||
values = append(values, int16(v))
|
||||
}
|
||||
if end < 0 {
|
||||
return NewColumnInt16(fd.GetFieldName(), values[begin:]), nil
|
||||
}
|
||||
|
||||
return NewColumnInt16(fd.GetFieldName(), values[begin:end]), nil
|
||||
|
||||
case schemapb.DataType_Int32:
|
||||
data, ok := getIntData(fd)
|
||||
if !ok {
|
||||
return nil, errFieldDataTypeNotMatch
|
||||
}
|
||||
if end < 0 {
|
||||
return NewColumnInt32(fd.GetFieldName(), data.IntData.GetData()[begin:]), nil
|
||||
}
|
||||
return NewColumnInt32(fd.GetFieldName(), data.IntData.GetData()[begin:end]), nil
|
||||
|
||||
case schemapb.DataType_Int64:
|
||||
data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_LongData)
|
||||
if !ok {
|
||||
return nil, errFieldDataTypeNotMatch
|
||||
}
|
||||
if end < 0 {
|
||||
return NewColumnInt64(fd.GetFieldName(), data.LongData.GetData()[begin:]), nil
|
||||
}
|
||||
return NewColumnInt64(fd.GetFieldName(), data.LongData.GetData()[begin:end]), nil
|
||||
|
||||
case schemapb.DataType_Float:
|
||||
data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_FloatData)
|
||||
if !ok {
|
||||
return nil, errFieldDataTypeNotMatch
|
||||
}
|
||||
if end < 0 {
|
||||
return NewColumnFloat(fd.GetFieldName(), data.FloatData.GetData()[begin:]), nil
|
||||
}
|
||||
return NewColumnFloat(fd.GetFieldName(), data.FloatData.GetData()[begin:end]), nil
|
||||
|
||||
case schemapb.DataType_Double:
|
||||
data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_DoubleData)
|
||||
if !ok {
|
||||
return nil, errFieldDataTypeNotMatch
|
||||
}
|
||||
if end < 0 {
|
||||
return NewColumnDouble(fd.GetFieldName(), data.DoubleData.GetData()[begin:]), nil
|
||||
}
|
||||
return NewColumnDouble(fd.GetFieldName(), data.DoubleData.GetData()[begin:end]), nil
|
||||
|
||||
case schemapb.DataType_String:
|
||||
data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_StringData)
|
||||
if !ok {
|
||||
return nil, errFieldDataTypeNotMatch
|
||||
}
|
||||
if end < 0 {
|
||||
return NewColumnString(fd.GetFieldName(), data.StringData.GetData()[begin:]), nil
|
||||
}
|
||||
return NewColumnString(fd.GetFieldName(), data.StringData.GetData()[begin:end]), nil
|
||||
|
||||
case schemapb.DataType_VarChar:
|
||||
data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_StringData)
|
||||
if !ok {
|
||||
return nil, errFieldDataTypeNotMatch
|
||||
}
|
||||
if end < 0 {
|
||||
return NewColumnVarChar(fd.GetFieldName(), data.StringData.GetData()[begin:]), nil
|
||||
}
|
||||
return NewColumnVarChar(fd.GetFieldName(), data.StringData.GetData()[begin:end]), nil
|
||||
|
||||
case schemapb.DataType_Array:
|
||||
data := fd.GetScalars().GetArrayData()
|
||||
if data == nil {
|
||||
return nil, errFieldDataTypeNotMatch
|
||||
}
|
||||
var arrayData []*schemapb.ScalarField
|
||||
if end < 0 {
|
||||
arrayData = data.GetData()[begin:]
|
||||
} else {
|
||||
arrayData = data.GetData()[begin:end]
|
||||
}
|
||||
|
||||
return parseArrayData(fd.GetFieldName(), data.GetElementType(), arrayData)
|
||||
|
||||
case schemapb.DataType_JSON:
|
||||
data, ok := fd.GetScalars().GetData().(*schemapb.ScalarField_JsonData)
|
||||
isDynamic := fd.GetIsDynamic()
|
||||
if !ok {
|
||||
return nil, errFieldDataTypeNotMatch
|
||||
}
|
||||
if end < 0 {
|
||||
return NewColumnJSONBytes(fd.GetFieldName(), data.JsonData.GetData()[begin:]).WithIsDynamic(isDynamic), nil
|
||||
}
|
||||
return NewColumnJSONBytes(fd.GetFieldName(), data.JsonData.GetData()[begin:end]).WithIsDynamic(isDynamic), nil
|
||||
|
||||
case schemapb.DataType_FloatVector:
|
||||
vectors := fd.GetVectors()
|
||||
x, ok := vectors.GetData().(*schemapb.VectorField_FloatVector)
|
||||
if !ok {
|
||||
return nil, errFieldDataTypeNotMatch
|
||||
}
|
||||
data := x.FloatVector.GetData()
|
||||
dim := int(vectors.GetDim())
|
||||
if end < 0 {
|
||||
end = int(len(data) / dim)
|
||||
}
|
||||
vector := make([][]float32, 0, end-begin) // shall not have remanunt
|
||||
for i := begin; i < end; i++ {
|
||||
v := make([]float32, dim)
|
||||
copy(v, data[i*dim:(i+1)*dim])
|
||||
vector = append(vector, v)
|
||||
}
|
||||
return NewColumnFloatVector(fd.GetFieldName(), dim, vector), nil
|
||||
|
||||
case schemapb.DataType_BinaryVector:
|
||||
vectors := fd.GetVectors()
|
||||
x, ok := vectors.GetData().(*schemapb.VectorField_BinaryVector)
|
||||
if !ok {
|
||||
return nil, errFieldDataTypeNotMatch
|
||||
}
|
||||
data := x.BinaryVector
|
||||
if data == nil {
|
||||
return nil, errFieldDataTypeNotMatch
|
||||
}
|
||||
dim := int(vectors.GetDim())
|
||||
blen := dim / 8
|
||||
if end < 0 {
|
||||
end = int(len(data) / blen)
|
||||
}
|
||||
vector := make([][]byte, 0, end-begin)
|
||||
for i := begin; i < end; i++ {
|
||||
v := make([]byte, blen)
|
||||
copy(v, data[i*blen:(i+1)*blen])
|
||||
vector = append(vector, v)
|
||||
}
|
||||
return NewColumnBinaryVector(fd.GetFieldName(), dim, vector), nil
|
||||
|
||||
case schemapb.DataType_Float16Vector:
|
||||
vectors := fd.GetVectors()
|
||||
x, ok := vectors.GetData().(*schemapb.VectorField_Float16Vector)
|
||||
if !ok {
|
||||
return nil, errFieldDataTypeNotMatch
|
||||
}
|
||||
data := x.Float16Vector
|
||||
dim := int(vectors.GetDim())
|
||||
if end < 0 {
|
||||
end = int(len(data) / dim)
|
||||
}
|
||||
vector := make([][]byte, 0, end-begin)
|
||||
for i := begin; i < end; i++ {
|
||||
v := make([]byte, dim)
|
||||
copy(v, data[i*dim:(i+1)*dim])
|
||||
vector = append(vector, v)
|
||||
}
|
||||
return NewColumnFloat16Vector(fd.GetFieldName(), dim, vector), nil
|
||||
|
||||
case schemapb.DataType_BFloat16Vector:
|
||||
vectors := fd.GetVectors()
|
||||
x, ok := vectors.GetData().(*schemapb.VectorField_Bfloat16Vector)
|
||||
if !ok {
|
||||
return nil, errFieldDataTypeNotMatch
|
||||
}
|
||||
data := x.Bfloat16Vector
|
||||
dim := int(vectors.GetDim())
|
||||
if end < 0 {
|
||||
end = int(len(data) / dim)
|
||||
}
|
||||
vector := make([][]byte, 0, end-begin) // shall not have remanunt
|
||||
for i := begin; i < end; i++ {
|
||||
v := make([]byte, dim)
|
||||
copy(v, data[i*dim:(i+1)*dim])
|
||||
vector = append(vector, v)
|
||||
}
|
||||
return NewColumnBFloat16Vector(fd.GetFieldName(), dim, vector), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported data type %s", fd.GetType())
|
||||
}
|
||||
}
|
||||
|
||||
func parseArrayData(fieldName string, elementType schemapb.DataType, fieldDataList []*schemapb.ScalarField) (Column, error) {
|
||||
switch elementType {
|
||||
case schemapb.DataType_Bool:
|
||||
var data [][]bool
|
||||
for _, fd := range fieldDataList {
|
||||
data = append(data, fd.GetBoolData().GetData())
|
||||
}
|
||||
return NewColumnBoolArray(fieldName, data), nil
|
||||
|
||||
case schemapb.DataType_Int8:
|
||||
var data [][]int8
|
||||
for _, fd := range fieldDataList {
|
||||
raw := fd.GetIntData().GetData()
|
||||
row := make([]int8, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
row = append(row, int8(item))
|
||||
}
|
||||
data = append(data, row)
|
||||
}
|
||||
return NewColumnInt8Array(fieldName, data), nil
|
||||
|
||||
case schemapb.DataType_Int16:
|
||||
var data [][]int16
|
||||
for _, fd := range fieldDataList {
|
||||
raw := fd.GetIntData().GetData()
|
||||
row := make([]int16, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
row = append(row, int16(item))
|
||||
}
|
||||
data = append(data, row)
|
||||
}
|
||||
return NewColumnInt16Array(fieldName, data), nil
|
||||
|
||||
case schemapb.DataType_Int32:
|
||||
var data [][]int32
|
||||
for _, fd := range fieldDataList {
|
||||
data = append(data, fd.GetIntData().GetData())
|
||||
}
|
||||
return NewColumnInt32Array(fieldName, data), nil
|
||||
|
||||
case schemapb.DataType_Int64:
|
||||
var data [][]int64
|
||||
for _, fd := range fieldDataList {
|
||||
data = append(data, fd.GetLongData().GetData())
|
||||
}
|
||||
return NewColumnInt64Array(fieldName, data), nil
|
||||
|
||||
case schemapb.DataType_Float:
|
||||
var data [][]float32
|
||||
for _, fd := range fieldDataList {
|
||||
data = append(data, fd.GetFloatData().GetData())
|
||||
}
|
||||
return NewColumnFloatArray(fieldName, data), nil
|
||||
|
||||
case schemapb.DataType_Double:
|
||||
var data [][]float64
|
||||
for _, fd := range fieldDataList {
|
||||
data = append(data, fd.GetDoubleData().GetData())
|
||||
}
|
||||
return NewColumnDoubleArray(fieldName, data), nil
|
||||
|
||||
case schemapb.DataType_VarChar, schemapb.DataType_String:
|
||||
var data [][][]byte
|
||||
for _, fd := range fieldDataList {
|
||||
strs := fd.GetStringData().GetData()
|
||||
bytesData := make([][]byte, 0, len(strs))
|
||||
for _, str := range strs {
|
||||
bytesData = append(bytesData, []byte(str))
|
||||
}
|
||||
data = append(data, bytesData)
|
||||
}
|
||||
|
||||
return NewColumnVarCharArray(fieldName, data), nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported element type %s", elementType)
|
||||
}
|
||||
}
|
||||
|
||||
// getIntData get int32 slice from result field data
|
||||
// also handles LongData bug (see also https://github.com/milvus-io/milvus/issues/23850)
|
||||
func getIntData(fd *schemapb.FieldData) (*schemapb.ScalarField_IntData, bool) {
|
||||
switch data := fd.GetScalars().GetData().(type) {
|
||||
case *schemapb.ScalarField_IntData:
|
||||
return data, true
|
||||
case *schemapb.ScalarField_LongData:
|
||||
// only alway empty LongData for backward compatibility
|
||||
if len(data.LongData.GetData()) == 0 {
|
||||
return &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{},
|
||||
}, true
|
||||
}
|
||||
return nil, false
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
// FieldDataColumn converts schemapb.FieldData to vector Column
|
||||
func FieldDataVector(fd *schemapb.FieldData) (Column, error) {
|
||||
switch fd.GetType() {
|
||||
case schemapb.DataType_FloatVector:
|
||||
vectors := fd.GetVectors()
|
||||
x, ok := vectors.GetData().(*schemapb.VectorField_FloatVector)
|
||||
if !ok {
|
||||
return nil, errFieldDataTypeNotMatch
|
||||
}
|
||||
data := x.FloatVector.GetData()
|
||||
dim := int(vectors.GetDim())
|
||||
vector := make([][]float32, 0, len(data)/dim) // shall not have remanunt
|
||||
for i := 0; i < len(data)/dim; i++ {
|
||||
v := make([]float32, dim)
|
||||
copy(v, data[i*dim:(i+1)*dim])
|
||||
vector = append(vector, v)
|
||||
}
|
||||
return NewColumnFloatVector(fd.GetFieldName(), dim, vector), nil
|
||||
case schemapb.DataType_BinaryVector:
|
||||
vectors := fd.GetVectors()
|
||||
x, ok := vectors.GetData().(*schemapb.VectorField_BinaryVector)
|
||||
if !ok {
|
||||
return nil, errFieldDataTypeNotMatch
|
||||
}
|
||||
data := x.BinaryVector
|
||||
if data == nil {
|
||||
return nil, errFieldDataTypeNotMatch
|
||||
}
|
||||
dim := int(vectors.GetDim())
|
||||
blen := dim / 8
|
||||
vector := make([][]byte, 0, len(data)/blen)
|
||||
for i := 0; i < len(data)/blen; i++ {
|
||||
v := make([]byte, blen)
|
||||
copy(v, data[i*blen:(i+1)*blen])
|
||||
vector = append(vector, v)
|
||||
}
|
||||
return NewColumnBinaryVector(fd.GetFieldName(), dim, vector), nil
|
||||
case schemapb.DataType_Float16Vector:
|
||||
vectors := fd.GetVectors()
|
||||
x, ok := vectors.GetData().(*schemapb.VectorField_Float16Vector)
|
||||
if !ok {
|
||||
return nil, errFieldDataTypeNotMatch
|
||||
}
|
||||
data := x.Float16Vector
|
||||
dim := int(vectors.GetDim())
|
||||
vector := make([][]byte, 0, len(data)/dim) // shall not have remanunt
|
||||
for i := 0; i < len(data)/dim; i++ {
|
||||
v := make([]byte, dim)
|
||||
copy(v, data[i*dim:(i+1)*dim])
|
||||
vector = append(vector, v)
|
||||
}
|
||||
return NewColumnFloat16Vector(fd.GetFieldName(), dim, vector), nil
|
||||
case schemapb.DataType_BFloat16Vector:
|
||||
vectors := fd.GetVectors()
|
||||
x, ok := vectors.GetData().(*schemapb.VectorField_Bfloat16Vector)
|
||||
if !ok {
|
||||
return nil, errFieldDataTypeNotMatch
|
||||
}
|
||||
data := x.Bfloat16Vector
|
||||
dim := int(vectors.GetDim())
|
||||
vector := make([][]byte, 0, len(data)/dim) // shall not have remanunt
|
||||
for i := 0; i < len(data)/dim; i++ {
|
||||
v := make([]byte, dim)
|
||||
copy(v, data[i*dim:(i+1)*dim])
|
||||
vector = append(vector, v)
|
||||
}
|
||||
return NewColumnBFloat16Vector(fd.GetFieldName(), dim, vector), nil
|
||||
default:
|
||||
return nil, errors.New("unsupported data type")
|
||||
}
|
||||
}
|
||||
|
||||
// defaultValueColumn will return the empty scalars column which will be fill with default value
|
||||
func DefaultValueColumn(name string, dataType entity.FieldType) (Column, error) {
|
||||
switch dataType {
|
||||
case entity.FieldTypeBool:
|
||||
return NewColumnBool(name, nil), nil
|
||||
case entity.FieldTypeInt8:
|
||||
return NewColumnInt8(name, nil), nil
|
||||
case entity.FieldTypeInt16:
|
||||
return NewColumnInt16(name, nil), nil
|
||||
case entity.FieldTypeInt32:
|
||||
return NewColumnInt32(name, nil), nil
|
||||
case entity.FieldTypeInt64:
|
||||
return NewColumnInt64(name, nil), nil
|
||||
case entity.FieldTypeFloat:
|
||||
return NewColumnFloat(name, nil), nil
|
||||
case entity.FieldTypeDouble:
|
||||
return NewColumnDouble(name, nil), nil
|
||||
case entity.FieldTypeString:
|
||||
return NewColumnString(name, nil), nil
|
||||
case entity.FieldTypeVarChar:
|
||||
return NewColumnVarChar(name, nil), nil
|
||||
case entity.FieldTypeJSON:
|
||||
return NewColumnJSONBytes(name, nil), nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("default value unsupported data type %s", dataType)
|
||||
}
|
||||
}
|
||||
160
client/column/columns_test.go
Normal file
160
client/column/columns_test.go
Normal file
@ -0,0 +1,160 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package column
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
func TestIDColumns(t *testing.T) {
|
||||
dataLen := rand.Intn(100) + 1
|
||||
base := rand.Intn(5000) // id start point
|
||||
|
||||
t.Run("nil id", func(t *testing.T) {
|
||||
_, err := IDColumns(nil, 0, -1)
|
||||
assert.NotNil(t, err)
|
||||
idField := &schemapb.IDs{}
|
||||
_, err = IDColumns(idField, 0, -1)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("int ids", func(t *testing.T) {
|
||||
ids := make([]int64, 0, dataLen)
|
||||
for i := 0; i < dataLen; i++ {
|
||||
ids = append(ids, int64(i+base))
|
||||
}
|
||||
idField := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: ids,
|
||||
},
|
||||
},
|
||||
}
|
||||
column, err := IDColumns(idField, 0, dataLen)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, column)
|
||||
assert.Equal(t, dataLen, column.Len())
|
||||
|
||||
column, err = IDColumns(idField, 0, -1) // test -1 method
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, column)
|
||||
assert.Equal(t, dataLen, column.Len())
|
||||
})
|
||||
t.Run("string ids", func(t *testing.T) {
|
||||
ids := make([]string, 0, dataLen)
|
||||
for i := 0; i < dataLen; i++ {
|
||||
ids = append(ids, strconv.FormatInt(int64(i+base), 10))
|
||||
}
|
||||
idField := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: ids,
|
||||
},
|
||||
},
|
||||
}
|
||||
column, err := IDColumns(idField, 0, dataLen)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, column)
|
||||
assert.Equal(t, dataLen, column.Len())
|
||||
|
||||
column, err = IDColumns(idField, 0, -1) // test -1 method
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, column)
|
||||
assert.Equal(t, dataLen, column.Len())
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetIntData(t *testing.T) {
|
||||
type testCase struct {
|
||||
tag string
|
||||
fd *schemapb.FieldData
|
||||
expectOK bool
|
||||
}
|
||||
|
||||
cases := []testCase{
|
||||
{
|
||||
tag: "normal_IntData",
|
||||
fd: &schemapb.FieldData{
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{Data: []int32{1, 2, 3}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectOK: true,
|
||||
},
|
||||
{
|
||||
tag: "empty_LongData",
|
||||
fd: &schemapb.FieldData{
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{Data: nil},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectOK: true,
|
||||
},
|
||||
{
|
||||
tag: "nonempty_LongData",
|
||||
fd: &schemapb.FieldData{
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectOK: false,
|
||||
},
|
||||
{
|
||||
tag: "other_data",
|
||||
fd: &schemapb.FieldData{
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_BoolData{},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectOK: false,
|
||||
},
|
||||
{
|
||||
tag: "vector_data",
|
||||
fd: &schemapb.FieldData{
|
||||
Field: &schemapb.FieldData_Vectors{},
|
||||
},
|
||||
expectOK: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.tag, func(t *testing.T) {
|
||||
_, ok := getIntData(tc.fd)
|
||||
assert.Equal(t, tc.expectOK, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
53
client/column/conversion.go
Normal file
53
client/column/conversion.go
Normal file
@ -0,0 +1,53 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package column
|
||||
|
||||
func (c *ColumnInt8) GetAsInt64(idx int) (int64, error) {
|
||||
v, err := c.ValueByIdx(idx)
|
||||
return int64(v), err
|
||||
}
|
||||
|
||||
func (c *ColumnInt16) GetAsInt64(idx int) (int64, error) {
|
||||
v, err := c.ValueByIdx(idx)
|
||||
return int64(v), err
|
||||
}
|
||||
|
||||
func (c *ColumnInt32) GetAsInt64(idx int) (int64, error) {
|
||||
v, err := c.ValueByIdx(idx)
|
||||
return int64(v), err
|
||||
}
|
||||
|
||||
func (c *ColumnInt64) GetAsInt64(idx int) (int64, error) {
|
||||
return c.ValueByIdx(idx)
|
||||
}
|
||||
|
||||
func (c *ColumnString) GetAsString(idx int) (string, error) {
|
||||
return c.ValueByIdx(idx)
|
||||
}
|
||||
|
||||
func (c *ColumnFloat) GetAsDouble(idx int) (float64, error) {
|
||||
v, err := c.ValueByIdx(idx)
|
||||
return float64(v), err
|
||||
}
|
||||
|
||||
func (c *ColumnDouble) GetAsDouble(idx int) (float64, error) {
|
||||
return c.ValueByIdx(idx)
|
||||
}
|
||||
|
||||
func (c *ColumnBool) GetAsBool(idx int) (bool, error) {
|
||||
return c.ValueByIdx(idx)
|
||||
}
|
||||
113
client/column/dynamic.go
Normal file
113
client/column/dynamic.go
Normal file
@ -0,0 +1,113 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package column
|
||||
|
||||
import (
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// ColumnDynamic is a logically wrapper for dynamic json field with provided output field.
|
||||
type ColumnDynamic struct {
|
||||
*ColumnJSONBytes
|
||||
outputField string
|
||||
}
|
||||
|
||||
func NewColumnDynamic(column *ColumnJSONBytes, outputField string) *ColumnDynamic {
|
||||
return &ColumnDynamic{
|
||||
ColumnJSONBytes: column,
|
||||
outputField: outputField,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ColumnDynamic) Name() string {
|
||||
return c.outputField
|
||||
}
|
||||
|
||||
// Get returns element at idx as interface{}.
|
||||
// Overrides internal json column behavior, returns raw json data.
|
||||
func (c *ColumnDynamic) Get(idx int) (interface{}, error) {
|
||||
bs, err := c.ColumnJSONBytes.ValueByIdx(idx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
r := gjson.GetBytes(bs, c.outputField)
|
||||
if !r.Exists() {
|
||||
return 0, errors.New("column not has value")
|
||||
}
|
||||
return r.Raw, nil
|
||||
}
|
||||
|
||||
func (c *ColumnDynamic) GetAsInt64(idx int) (int64, error) {
|
||||
bs, err := c.ColumnJSONBytes.ValueByIdx(idx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
r := gjson.GetBytes(bs, c.outputField)
|
||||
if !r.Exists() {
|
||||
return 0, errors.New("column not has value")
|
||||
}
|
||||
if r.Type != gjson.Number {
|
||||
return 0, errors.New("column not int")
|
||||
}
|
||||
return r.Int(), nil
|
||||
}
|
||||
|
||||
func (c *ColumnDynamic) GetAsString(idx int) (string, error) {
|
||||
bs, err := c.ColumnJSONBytes.ValueByIdx(idx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
r := gjson.GetBytes(bs, c.outputField)
|
||||
if !r.Exists() {
|
||||
return "", errors.New("column not has value")
|
||||
}
|
||||
if r.Type != gjson.String {
|
||||
return "", errors.New("column not string")
|
||||
}
|
||||
return r.String(), nil
|
||||
}
|
||||
|
||||
func (c *ColumnDynamic) GetAsBool(idx int) (bool, error) {
|
||||
bs, err := c.ColumnJSONBytes.ValueByIdx(idx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
r := gjson.GetBytes(bs, c.outputField)
|
||||
if !r.Exists() {
|
||||
return false, errors.New("column not has value")
|
||||
}
|
||||
if !r.IsBool() {
|
||||
return false, errors.New("column not string")
|
||||
}
|
||||
return r.Bool(), nil
|
||||
}
|
||||
|
||||
func (c *ColumnDynamic) GetAsDouble(idx int) (float64, error) {
|
||||
bs, err := c.ColumnJSONBytes.ValueByIdx(idx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
r := gjson.GetBytes(bs, c.outputField)
|
||||
if !r.Exists() {
|
||||
return 0, errors.New("column not has value")
|
||||
}
|
||||
if r.Type != gjson.Number {
|
||||
return 0, errors.New("column not string")
|
||||
}
|
||||
return r.Float(), nil
|
||||
}
|
||||
162
client/column/dynamic_test.go
Normal file
162
client/column/dynamic_test.go
Normal file
@ -0,0 +1,162 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package column
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ColumnDynamicSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func (s *ColumnDynamicSuite) TestGetInt() {
|
||||
cases := []struct {
|
||||
input string
|
||||
expectErr bool
|
||||
expectValue int64
|
||||
}{
|
||||
{`{"field": 1000000000000000001}`, false, 1000000000000000001},
|
||||
{`{"field": 4418489049307132905}`, false, 4418489049307132905},
|
||||
{`{"other_field": 4418489049307132905}`, true, 0},
|
||||
{`{"field": "string"}`, true, 0},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
s.Run(c.input, func() {
|
||||
column := NewColumnDynamic(&ColumnJSONBytes{
|
||||
values: [][]byte{[]byte(c.input)},
|
||||
}, "field")
|
||||
v, err := column.GetAsInt64(0)
|
||||
if c.expectErr {
|
||||
s.Error(err)
|
||||
return
|
||||
}
|
||||
s.NoError(err)
|
||||
s.Equal(c.expectValue, v)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ColumnDynamicSuite) TestGetString() {
|
||||
cases := []struct {
|
||||
input string
|
||||
expectErr bool
|
||||
expectValue string
|
||||
}{
|
||||
{`{"field": "abc"}`, false, "abc"},
|
||||
{`{"field": "test"}`, false, "test"},
|
||||
{`{"other_field": "string"}`, true, ""},
|
||||
{`{"field": 123}`, true, ""},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
s.Run(c.input, func() {
|
||||
column := NewColumnDynamic(&ColumnJSONBytes{
|
||||
values: [][]byte{[]byte(c.input)},
|
||||
}, "field")
|
||||
v, err := column.GetAsString(0)
|
||||
if c.expectErr {
|
||||
s.Error(err)
|
||||
return
|
||||
}
|
||||
s.NoError(err)
|
||||
s.Equal(c.expectValue, v)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ColumnDynamicSuite) TestGetBool() {
|
||||
cases := []struct {
|
||||
input string
|
||||
expectErr bool
|
||||
expectValue bool
|
||||
}{
|
||||
{`{"field": true}`, false, true},
|
||||
{`{"field": false}`, false, false},
|
||||
{`{"other_field": true}`, true, false},
|
||||
{`{"field": "test"}`, true, false},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
s.Run(c.input, func() {
|
||||
column := NewColumnDynamic(&ColumnJSONBytes{
|
||||
values: [][]byte{[]byte(c.input)},
|
||||
}, "field")
|
||||
v, err := column.GetAsBool(0)
|
||||
if c.expectErr {
|
||||
s.Error(err)
|
||||
return
|
||||
}
|
||||
s.NoError(err)
|
||||
s.Equal(c.expectValue, v)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ColumnDynamicSuite) TestGetDouble() {
|
||||
cases := []struct {
|
||||
input string
|
||||
expectErr bool
|
||||
expectValue float64
|
||||
}{
|
||||
{`{"field": 1}`, false, 1.0},
|
||||
{`{"field": 6231.123}`, false, 6231.123},
|
||||
{`{"other_field": 1.0}`, true, 0},
|
||||
{`{"field": "string"}`, true, 0},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
s.Run(c.input, func() {
|
||||
column := NewColumnDynamic(&ColumnJSONBytes{
|
||||
values: [][]byte{[]byte(c.input)},
|
||||
}, "field")
|
||||
v, err := column.GetAsDouble(0)
|
||||
if c.expectErr {
|
||||
s.Error(err)
|
||||
return
|
||||
}
|
||||
s.NoError(err)
|
||||
s.Less(v-c.expectValue, 1e-10)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ColumnDynamicSuite) TestIndexOutOfRange() {
|
||||
var err error
|
||||
column := NewColumnDynamic(&ColumnJSONBytes{}, "field")
|
||||
|
||||
s.Equal("field", column.Name())
|
||||
|
||||
_, err = column.GetAsInt64(0)
|
||||
s.Error(err)
|
||||
|
||||
_, err = column.GetAsString(0)
|
||||
s.Error(err)
|
||||
|
||||
_, err = column.GetAsBool(0)
|
||||
s.Error(err)
|
||||
|
||||
_, err = column.GetAsDouble(0)
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
func TestColumnDynamic(t *testing.T) {
|
||||
suite.Run(t, new(ColumnDynamicSuite))
|
||||
}
|
||||
146
client/column/json.go
Normal file
146
client/column/json.go
Normal file
@ -0,0 +1,146 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package column
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
)
|
||||
|
||||
var _ (Column) = (*ColumnJSONBytes)(nil)
|
||||
|
||||
// ColumnJSONBytes column type for JSON.
|
||||
// all items are marshaled json bytes.
|
||||
type ColumnJSONBytes struct {
|
||||
ColumnBase
|
||||
name string
|
||||
values [][]byte
|
||||
isDynamic bool
|
||||
}
|
||||
|
||||
// Name returns column name.
|
||||
func (c *ColumnJSONBytes) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// Type returns column entity.FieldType.
|
||||
func (c *ColumnJSONBytes) Type() entity.FieldType {
|
||||
return entity.FieldTypeJSON
|
||||
}
|
||||
|
||||
// Len returns column values length.
|
||||
func (c *ColumnJSONBytes) Len() int {
|
||||
return len(c.values)
|
||||
}
|
||||
|
||||
// Get returns value at index as interface{}.
|
||||
func (c *ColumnJSONBytes) Get(idx int) (interface{}, error) {
|
||||
if idx < 0 || idx > c.Len() {
|
||||
return nil, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
func (c *ColumnJSONBytes) GetAsString(idx int) (string, error) {
|
||||
bs, err := c.ValueByIdx(idx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(bs), nil
|
||||
}
|
||||
|
||||
// FieldData return column data mapped to schemapb.FieldData.
|
||||
func (c *ColumnJSONBytes) FieldData() *schemapb.FieldData {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_JSON,
|
||||
FieldName: c.name,
|
||||
IsDynamic: c.isDynamic,
|
||||
}
|
||||
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_JsonData{
|
||||
JsonData: &schemapb.JSONArray{
|
||||
Data: c.values,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return fd
|
||||
}
|
||||
|
||||
// ValueByIdx returns value of the provided index.
|
||||
func (c *ColumnJSONBytes) ValueByIdx(idx int) ([]byte, error) {
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return nil, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// AppendValue append value into column.
|
||||
func (c *ColumnJSONBytes) AppendValue(i interface{}) error {
|
||||
var v []byte
|
||||
switch raw := i.(type) {
|
||||
case []byte:
|
||||
v = raw
|
||||
default:
|
||||
k := reflect.TypeOf(i).Kind()
|
||||
if k == reflect.Ptr {
|
||||
k = reflect.TypeOf(i).Elem().Kind()
|
||||
}
|
||||
switch k {
|
||||
case reflect.Struct:
|
||||
fallthrough
|
||||
case reflect.Map:
|
||||
bs, err := json.Marshal(raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v = bs
|
||||
default:
|
||||
return fmt.Errorf("expect json compatible type([]byte, struct[}, map], got %T)", i)
|
||||
}
|
||||
}
|
||||
c.values = append(c.values, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data returns column data.
|
||||
func (c *ColumnJSONBytes) Data() [][]byte {
|
||||
return c.values
|
||||
}
|
||||
|
||||
func (c *ColumnJSONBytes) WithIsDynamic(isDynamic bool) *ColumnJSONBytes {
|
||||
c.isDynamic = isDynamic
|
||||
return c
|
||||
}
|
||||
|
||||
// NewColumnJSONBytes composes a Column with json bytes.
|
||||
func NewColumnJSONBytes(name string, values [][]byte) *ColumnJSONBytes {
|
||||
return &ColumnJSONBytes{
|
||||
name: name,
|
||||
values: values,
|
||||
}
|
||||
}
|
||||
101
client/column/json_test.go
Normal file
101
client/column/json_test.go
Normal file
@ -0,0 +1,101 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package column
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
)
|
||||
|
||||
type ColumnJSONBytesSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func (s *ColumnJSONBytesSuite) SetupSuite() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func (s *ColumnJSONBytesSuite) TestAttrMethods() {
|
||||
columnName := fmt.Sprintf("column_jsonbs_%d", rand.Int())
|
||||
columnLen := 8 + rand.Intn(10)
|
||||
|
||||
v := make([][]byte, columnLen)
|
||||
column := NewColumnJSONBytes(columnName, v).WithIsDynamic(true)
|
||||
|
||||
s.Run("test_meta", func() {
|
||||
ft := entity.FieldTypeJSON
|
||||
s.Equal("JSON", ft.Name())
|
||||
s.Equal("JSON", ft.String())
|
||||
pbName, pbType := ft.PbFieldType()
|
||||
s.Equal("JSON", pbName)
|
||||
s.Equal("JSON", pbType)
|
||||
})
|
||||
|
||||
s.Run("test_column_attribute", func() {
|
||||
s.Equal(columnName, column.Name())
|
||||
s.Equal(entity.FieldTypeJSON, column.Type())
|
||||
s.Equal(columnLen, column.Len())
|
||||
s.EqualValues(v, column.Data())
|
||||
})
|
||||
|
||||
s.Run("test_column_field_data", func() {
|
||||
fd := column.FieldData()
|
||||
s.NotNil(fd)
|
||||
s.Equal(fd.GetFieldName(), columnName)
|
||||
})
|
||||
|
||||
s.Run("test_column_valuer_by_idx", func() {
|
||||
_, err := column.ValueByIdx(-1)
|
||||
s.Error(err)
|
||||
_, err = column.ValueByIdx(columnLen)
|
||||
s.Error(err)
|
||||
for i := 0; i < columnLen; i++ {
|
||||
v, err := column.ValueByIdx(i)
|
||||
s.NoError(err)
|
||||
s.Equal(column.values[i], v)
|
||||
}
|
||||
})
|
||||
|
||||
s.Run("test_append_value", func() {
|
||||
item := make([]byte, 10)
|
||||
err := column.AppendValue(item)
|
||||
s.NoError(err)
|
||||
s.Equal(columnLen+1, column.Len())
|
||||
val, err := column.ValueByIdx(columnLen)
|
||||
s.NoError(err)
|
||||
s.Equal(item, val)
|
||||
|
||||
err = column.AppendValue(&struct{ Tag string }{Tag: "abc"})
|
||||
s.NoError(err)
|
||||
|
||||
err = column.AppendValue(map[string]interface{}{"Value": 123})
|
||||
s.NoError(err)
|
||||
|
||||
err = column.AppendValue(1)
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestColumnJSONBytes(t *testing.T) {
|
||||
suite.Run(t, new(ColumnJSONBytesSuite))
|
||||
}
|
||||
708
client/column/scalar_gen.go
Normal file
708
client/column/scalar_gen.go
Normal file
@ -0,0 +1,708 @@
|
||||
// Code generated by go generate; DO NOT EDIT
|
||||
// This file is generated by go generate
|
||||
|
||||
package column
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
)
|
||||
|
||||
// ColumnBool generated columns type for Bool
|
||||
type ColumnBool struct {
|
||||
ColumnBase
|
||||
name string
|
||||
values []bool
|
||||
}
|
||||
|
||||
// Name returns column name
|
||||
func (c *ColumnBool) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// Type returns column entity.FieldType
|
||||
func (c *ColumnBool) Type() entity.FieldType {
|
||||
return entity.FieldTypeBool
|
||||
}
|
||||
|
||||
// Len returns column values length
|
||||
func (c *ColumnBool) Len() int {
|
||||
return len(c.values)
|
||||
}
|
||||
|
||||
// Get returns value at index as interface{}.
|
||||
func (c *ColumnBool) Get(idx int) (interface{}, error) {
|
||||
var r bool // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// FieldData return column data mapped to schemapb.FieldData
|
||||
func (c *ColumnBool) FieldData() *schemapb.FieldData {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Bool,
|
||||
FieldName: c.name,
|
||||
}
|
||||
data := make([]bool, 0, c.Len())
|
||||
for i := 0; i < c.Len(); i++ {
|
||||
data = append(data, bool(c.values[i]))
|
||||
}
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_BoolData{
|
||||
BoolData: &schemapb.BoolArray{
|
||||
Data: data,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return fd
|
||||
}
|
||||
|
||||
// ValueByIdx returns value of the provided index
|
||||
// error occurs when index out of range
|
||||
func (c *ColumnBool) ValueByIdx(idx int) (bool, error) {
|
||||
var r bool // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// AppendValue append value into column
|
||||
func (c *ColumnBool) AppendValue(i interface{}) error {
|
||||
v, ok := i.(bool)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid type, expected bool, got %T", i)
|
||||
}
|
||||
c.values = append(c.values, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data returns column data
|
||||
func (c *ColumnBool) Data() []bool {
|
||||
return c.values
|
||||
}
|
||||
|
||||
// NewColumnBool auto generated constructor
|
||||
func NewColumnBool(name string, values []bool) *ColumnBool {
|
||||
return &ColumnBool{
|
||||
name: name,
|
||||
values: values,
|
||||
}
|
||||
}
|
||||
|
||||
// ColumnInt8 generated columns type for Int8
|
||||
type ColumnInt8 struct {
|
||||
ColumnBase
|
||||
name string
|
||||
values []int8
|
||||
}
|
||||
|
||||
// Name returns column name
|
||||
func (c *ColumnInt8) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// Type returns column entity.FieldType
|
||||
func (c *ColumnInt8) Type() entity.FieldType {
|
||||
return entity.FieldTypeInt8
|
||||
}
|
||||
|
||||
// Len returns column values length
|
||||
func (c *ColumnInt8) Len() int {
|
||||
return len(c.values)
|
||||
}
|
||||
|
||||
// Get returns value at index as interface{}.
|
||||
func (c *ColumnInt8) Get(idx int) (interface{}, error) {
|
||||
var r int8 // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// FieldData return column data mapped to schemapb.FieldData
|
||||
func (c *ColumnInt8) FieldData() *schemapb.FieldData {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int8,
|
||||
FieldName: c.name,
|
||||
}
|
||||
data := make([]int32, 0, c.Len())
|
||||
for i := 0; i < c.Len(); i++ {
|
||||
data = append(data, int32(c.values[i]))
|
||||
}
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: data,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return fd
|
||||
}
|
||||
|
||||
// ValueByIdx returns value of the provided index
|
||||
// error occurs when index out of range
|
||||
func (c *ColumnInt8) ValueByIdx(idx int) (int8, error) {
|
||||
var r int8 // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// AppendValue append value into column
|
||||
func (c *ColumnInt8) AppendValue(i interface{}) error {
|
||||
v, ok := i.(int8)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid type, expected int8, got %T", i)
|
||||
}
|
||||
c.values = append(c.values, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data returns column data
|
||||
func (c *ColumnInt8) Data() []int8 {
|
||||
return c.values
|
||||
}
|
||||
|
||||
// NewColumnInt8 auto generated constructor
|
||||
func NewColumnInt8(name string, values []int8) *ColumnInt8 {
|
||||
return &ColumnInt8{
|
||||
name: name,
|
||||
values: values,
|
||||
}
|
||||
}
|
||||
|
||||
// ColumnInt16 generated columns type for Int16
|
||||
type ColumnInt16 struct {
|
||||
ColumnBase
|
||||
name string
|
||||
values []int16
|
||||
}
|
||||
|
||||
// Name returns column name
|
||||
func (c *ColumnInt16) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// Type returns column entity.FieldType
|
||||
func (c *ColumnInt16) Type() entity.FieldType {
|
||||
return entity.FieldTypeInt16
|
||||
}
|
||||
|
||||
// Len returns column values length
|
||||
func (c *ColumnInt16) Len() int {
|
||||
return len(c.values)
|
||||
}
|
||||
|
||||
// Get returns value at index as interface{}.
|
||||
func (c *ColumnInt16) Get(idx int) (interface{}, error) {
|
||||
var r int16 // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// FieldData return column data mapped to schemapb.FieldData
|
||||
func (c *ColumnInt16) FieldData() *schemapb.FieldData {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int16,
|
||||
FieldName: c.name,
|
||||
}
|
||||
data := make([]int32, 0, c.Len())
|
||||
for i := 0; i < c.Len(); i++ {
|
||||
data = append(data, int32(c.values[i]))
|
||||
}
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: data,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return fd
|
||||
}
|
||||
|
||||
// ValueByIdx returns value of the provided index
|
||||
// error occurs when index out of range
|
||||
func (c *ColumnInt16) ValueByIdx(idx int) (int16, error) {
|
||||
var r int16 // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// AppendValue append value into column
|
||||
func (c *ColumnInt16) AppendValue(i interface{}) error {
|
||||
v, ok := i.(int16)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid type, expected int16, got %T", i)
|
||||
}
|
||||
c.values = append(c.values, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data returns column data
|
||||
func (c *ColumnInt16) Data() []int16 {
|
||||
return c.values
|
||||
}
|
||||
|
||||
// NewColumnInt16 auto generated constructor
|
||||
func NewColumnInt16(name string, values []int16) *ColumnInt16 {
|
||||
return &ColumnInt16{
|
||||
name: name,
|
||||
values: values,
|
||||
}
|
||||
}
|
||||
|
||||
// ColumnInt32 generated columns type for Int32
|
||||
type ColumnInt32 struct {
|
||||
ColumnBase
|
||||
name string
|
||||
values []int32
|
||||
}
|
||||
|
||||
// Name returns column name
|
||||
func (c *ColumnInt32) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// Type returns column entity.FieldType
|
||||
func (c *ColumnInt32) Type() entity.FieldType {
|
||||
return entity.FieldTypeInt32
|
||||
}
|
||||
|
||||
// Len returns column values length
|
||||
func (c *ColumnInt32) Len() int {
|
||||
return len(c.values)
|
||||
}
|
||||
|
||||
// Get returns value at index as interface{}.
|
||||
func (c *ColumnInt32) Get(idx int) (interface{}, error) {
|
||||
var r int32 // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// FieldData return column data mapped to schemapb.FieldData
|
||||
func (c *ColumnInt32) FieldData() *schemapb.FieldData {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int32,
|
||||
FieldName: c.name,
|
||||
}
|
||||
data := make([]int32, 0, c.Len())
|
||||
for i := 0; i < c.Len(); i++ {
|
||||
data = append(data, int32(c.values[i]))
|
||||
}
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: data,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return fd
|
||||
}
|
||||
|
||||
// ValueByIdx returns value of the provided index
|
||||
// error occurs when index out of range
|
||||
func (c *ColumnInt32) ValueByIdx(idx int) (int32, error) {
|
||||
var r int32 // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// AppendValue append value into column
|
||||
func (c *ColumnInt32) AppendValue(i interface{}) error {
|
||||
v, ok := i.(int32)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid type, expected int32, got %T", i)
|
||||
}
|
||||
c.values = append(c.values, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data returns column data
|
||||
func (c *ColumnInt32) Data() []int32 {
|
||||
return c.values
|
||||
}
|
||||
|
||||
// NewColumnInt32 auto generated constructor
|
||||
func NewColumnInt32(name string, values []int32) *ColumnInt32 {
|
||||
return &ColumnInt32{
|
||||
name: name,
|
||||
values: values,
|
||||
}
|
||||
}
|
||||
|
||||
// ColumnInt64 generated columns type for Int64
|
||||
type ColumnInt64 struct {
|
||||
ColumnBase
|
||||
name string
|
||||
values []int64
|
||||
}
|
||||
|
||||
// Name returns column name
|
||||
func (c *ColumnInt64) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// Type returns column entity.FieldType
|
||||
func (c *ColumnInt64) Type() entity.FieldType {
|
||||
return entity.FieldTypeInt64
|
||||
}
|
||||
|
||||
// Len returns column values length
|
||||
func (c *ColumnInt64) Len() int {
|
||||
return len(c.values)
|
||||
}
|
||||
|
||||
// Get returns value at index as interface{}.
|
||||
func (c *ColumnInt64) Get(idx int) (interface{}, error) {
|
||||
var r int64 // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// FieldData return column data mapped to schemapb.FieldData
|
||||
func (c *ColumnInt64) FieldData() *schemapb.FieldData {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: c.name,
|
||||
}
|
||||
data := make([]int64, 0, c.Len())
|
||||
for i := 0; i < c.Len(); i++ {
|
||||
data = append(data, int64(c.values[i]))
|
||||
}
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: data,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return fd
|
||||
}
|
||||
|
||||
// ValueByIdx returns value of the provided index
|
||||
// error occurs when index out of range
|
||||
func (c *ColumnInt64) ValueByIdx(idx int) (int64, error) {
|
||||
var r int64 // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// AppendValue append value into column
|
||||
func (c *ColumnInt64) AppendValue(i interface{}) error {
|
||||
v, ok := i.(int64)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid type, expected int64, got %T", i)
|
||||
}
|
||||
c.values = append(c.values, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data returns column data
|
||||
func (c *ColumnInt64) Data() []int64 {
|
||||
return c.values
|
||||
}
|
||||
|
||||
// NewColumnInt64 auto generated constructor
|
||||
func NewColumnInt64(name string, values []int64) *ColumnInt64 {
|
||||
return &ColumnInt64{
|
||||
name: name,
|
||||
values: values,
|
||||
}
|
||||
}
|
||||
|
||||
// ColumnFloat generated columns type for Float
|
||||
type ColumnFloat struct {
|
||||
ColumnBase
|
||||
name string
|
||||
values []float32
|
||||
}
|
||||
|
||||
// Name returns column name
|
||||
func (c *ColumnFloat) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// Type returns column entity.FieldType
|
||||
func (c *ColumnFloat) Type() entity.FieldType {
|
||||
return entity.FieldTypeFloat
|
||||
}
|
||||
|
||||
// Len returns column values length
|
||||
func (c *ColumnFloat) Len() int {
|
||||
return len(c.values)
|
||||
}
|
||||
|
||||
// Get returns value at index as interface{}.
|
||||
func (c *ColumnFloat) Get(idx int) (interface{}, error) {
|
||||
var r float32 // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// FieldData return column data mapped to schemapb.FieldData
|
||||
func (c *ColumnFloat) FieldData() *schemapb.FieldData {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Float,
|
||||
FieldName: c.name,
|
||||
}
|
||||
data := make([]float32, 0, c.Len())
|
||||
for i := 0; i < c.Len(); i++ {
|
||||
data = append(data, float32(c.values[i]))
|
||||
}
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_FloatData{
|
||||
FloatData: &schemapb.FloatArray{
|
||||
Data: data,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return fd
|
||||
}
|
||||
|
||||
// ValueByIdx returns value of the provided index
|
||||
// error occurs when index out of range
|
||||
func (c *ColumnFloat) ValueByIdx(idx int) (float32, error) {
|
||||
var r float32 // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// AppendValue append value into column
|
||||
func (c *ColumnFloat) AppendValue(i interface{}) error {
|
||||
v, ok := i.(float32)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid type, expected float32, got %T", i)
|
||||
}
|
||||
c.values = append(c.values, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data returns column data
|
||||
func (c *ColumnFloat) Data() []float32 {
|
||||
return c.values
|
||||
}
|
||||
|
||||
// NewColumnFloat auto generated constructor
|
||||
func NewColumnFloat(name string, values []float32) *ColumnFloat {
|
||||
return &ColumnFloat{
|
||||
name: name,
|
||||
values: values,
|
||||
}
|
||||
}
|
||||
|
||||
// ColumnDouble generated columns type for Double
|
||||
type ColumnDouble struct {
|
||||
ColumnBase
|
||||
name string
|
||||
values []float64
|
||||
}
|
||||
|
||||
// Name returns column name
|
||||
func (c *ColumnDouble) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// Type returns column entity.FieldType
|
||||
func (c *ColumnDouble) Type() entity.FieldType {
|
||||
return entity.FieldTypeDouble
|
||||
}
|
||||
|
||||
// Len returns column values length
|
||||
func (c *ColumnDouble) Len() int {
|
||||
return len(c.values)
|
||||
}
|
||||
|
||||
// Get returns value at index as interface{}.
|
||||
func (c *ColumnDouble) Get(idx int) (interface{}, error) {
|
||||
var r float64 // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// FieldData return column data mapped to schemapb.FieldData
|
||||
func (c *ColumnDouble) FieldData() *schemapb.FieldData {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Double,
|
||||
FieldName: c.name,
|
||||
}
|
||||
data := make([]float64, 0, c.Len())
|
||||
for i := 0; i < c.Len(); i++ {
|
||||
data = append(data, float64(c.values[i]))
|
||||
}
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_DoubleData{
|
||||
DoubleData: &schemapb.DoubleArray{
|
||||
Data: data,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return fd
|
||||
}
|
||||
|
||||
// ValueByIdx returns value of the provided index
|
||||
// error occurs when index out of range
|
||||
func (c *ColumnDouble) ValueByIdx(idx int) (float64, error) {
|
||||
var r float64 // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// AppendValue append value into column
|
||||
func (c *ColumnDouble) AppendValue(i interface{}) error {
|
||||
v, ok := i.(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid type, expected float64, got %T", i)
|
||||
}
|
||||
c.values = append(c.values, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data returns column data
|
||||
func (c *ColumnDouble) Data() []float64 {
|
||||
return c.values
|
||||
}
|
||||
|
||||
// NewColumnDouble auto generated constructor
|
||||
func NewColumnDouble(name string, values []float64) *ColumnDouble {
|
||||
return &ColumnDouble{
|
||||
name: name,
|
||||
values: values,
|
||||
}
|
||||
}
|
||||
|
||||
// ColumnString generated columns type for String
|
||||
type ColumnString struct {
|
||||
ColumnBase
|
||||
name string
|
||||
values []string
|
||||
}
|
||||
|
||||
// Name returns column name
|
||||
func (c *ColumnString) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// Type returns column entity.FieldType
|
||||
func (c *ColumnString) Type() entity.FieldType {
|
||||
return entity.FieldTypeString
|
||||
}
|
||||
|
||||
// Len returns column values length
|
||||
func (c *ColumnString) Len() int {
|
||||
return len(c.values)
|
||||
}
|
||||
|
||||
// Get returns value at index as interface{}.
|
||||
func (c *ColumnString) Get(idx int) (interface{}, error) {
|
||||
var r string // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// FieldData return column data mapped to schemapb.FieldData
|
||||
func (c *ColumnString) FieldData() *schemapb.FieldData {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_String,
|
||||
FieldName: c.name,
|
||||
}
|
||||
data := make([]string, 0, c.Len())
|
||||
for i := 0; i < c.Len(); i++ {
|
||||
data = append(data, string(c.values[i]))
|
||||
}
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: data,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return fd
|
||||
}
|
||||
|
||||
// ValueByIdx returns value of the provided index
|
||||
// error occurs when index out of range
|
||||
func (c *ColumnString) ValueByIdx(idx int) (string, error) {
|
||||
var r string // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// AppendValue append value into column
|
||||
func (c *ColumnString) AppendValue(i interface{}) error {
|
||||
v, ok := i.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid type, expected string, got %T", i)
|
||||
}
|
||||
c.values = append(c.values, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data returns column data
|
||||
func (c *ColumnString) Data() []string {
|
||||
return c.values
|
||||
}
|
||||
|
||||
// NewColumnString auto generated constructor
|
||||
func NewColumnString(name string, values []string) *ColumnString {
|
||||
return &ColumnString{
|
||||
name: name,
|
||||
values: values,
|
||||
}
|
||||
}
|
||||
855
client/column/scalar_gen_test.go
Normal file
855
client/column/scalar_gen_test.go
Normal file
@ -0,0 +1,855 @@
|
||||
// Code generated by go generate; DO NOT EDIT
|
||||
// This file is generated by go generated
|
||||
|
||||
package column
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestColumnBool(t *testing.T) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
columnName := fmt.Sprintf("column_Bool_%d", rand.Int())
|
||||
columnLen := 8 + rand.Intn(10)
|
||||
|
||||
v := make([]bool, columnLen)
|
||||
column := NewColumnBool(columnName, v)
|
||||
|
||||
t.Run("test meta", func(t *testing.T) {
|
||||
ft := entity.FieldTypeBool
|
||||
assert.Equal(t, "Bool", ft.Name())
|
||||
assert.Equal(t, "bool", ft.String())
|
||||
pbName, pbType := ft.PbFieldType()
|
||||
assert.Equal(t, "Bool", pbName)
|
||||
assert.Equal(t, "bool", pbType)
|
||||
})
|
||||
|
||||
t.Run("test column attribute", func(t *testing.T) {
|
||||
assert.Equal(t, columnName, column.Name())
|
||||
assert.Equal(t, entity.FieldTypeBool, column.Type())
|
||||
assert.Equal(t, columnLen, column.Len())
|
||||
assert.EqualValues(t, v, column.Data())
|
||||
})
|
||||
|
||||
t.Run("test column field data", func(t *testing.T) {
|
||||
fd := column.FieldData()
|
||||
assert.NotNil(t, fd)
|
||||
assert.Equal(t, fd.GetFieldName(), columnName)
|
||||
})
|
||||
|
||||
t.Run("test column value by idx", func(t *testing.T) {
|
||||
_, err := column.ValueByIdx(-1)
|
||||
assert.NotNil(t, err)
|
||||
_, err = column.ValueByIdx(columnLen)
|
||||
assert.NotNil(t, err)
|
||||
for i := 0; i < columnLen; i++ {
|
||||
v, err := column.ValueByIdx(i)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, column.values[i], v)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFieldDataBoolColumn(t *testing.T) {
|
||||
len := rand.Intn(10) + 8
|
||||
name := fmt.Sprintf("fd_Bool_%d", rand.Int())
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Bool,
|
||||
FieldName: name,
|
||||
}
|
||||
|
||||
t.Run("normal usage", func(t *testing.T) {
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_BoolData{
|
||||
BoolData: &schemapb.BoolArray{
|
||||
Data: make([]bool, len),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
column, err := FieldDataColumn(fd, 0, len)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, column)
|
||||
|
||||
assert.Equal(t, name, column.Name())
|
||||
assert.Equal(t, len, column.Len())
|
||||
assert.Equal(t, entity.FieldTypeBool, column.Type())
|
||||
|
||||
var ev bool
|
||||
err = column.AppendValue(ev)
|
||||
assert.Equal(t, len+1, column.Len())
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = column.AppendValue(struct{}{})
|
||||
assert.Equal(t, len+1, column.Len())
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("nil data", func(t *testing.T) {
|
||||
fd.Field = nil
|
||||
_, err := FieldDataColumn(fd, 0, len)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("get all data", func(t *testing.T) {
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_BoolData{
|
||||
BoolData: &schemapb.BoolArray{
|
||||
Data: make([]bool, len),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
column, err := FieldDataColumn(fd, 0, -1)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, column)
|
||||
|
||||
assert.Equal(t, name, column.Name())
|
||||
assert.Equal(t, len, column.Len())
|
||||
assert.Equal(t, entity.FieldTypeBool, column.Type())
|
||||
})
|
||||
}
|
||||
|
||||
func TestColumnInt8(t *testing.T) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
columnName := fmt.Sprintf("column_Int8_%d", rand.Int())
|
||||
columnLen := 8 + rand.Intn(10)
|
||||
|
||||
v := make([]int8, columnLen)
|
||||
column := NewColumnInt8(columnName, v)
|
||||
|
||||
t.Run("test meta", func(t *testing.T) {
|
||||
ft := entity.FieldTypeInt8
|
||||
assert.Equal(t, "Int8", ft.Name())
|
||||
assert.Equal(t, "int8", ft.String())
|
||||
pbName, pbType := ft.PbFieldType()
|
||||
assert.Equal(t, "Int", pbName)
|
||||
assert.Equal(t, "int32", pbType)
|
||||
})
|
||||
|
||||
t.Run("test column attribute", func(t *testing.T) {
|
||||
assert.Equal(t, columnName, column.Name())
|
||||
assert.Equal(t, entity.FieldTypeInt8, column.Type())
|
||||
assert.Equal(t, columnLen, column.Len())
|
||||
assert.EqualValues(t, v, column.Data())
|
||||
})
|
||||
|
||||
t.Run("test column field data", func(t *testing.T) {
|
||||
fd := column.FieldData()
|
||||
assert.NotNil(t, fd)
|
||||
assert.Equal(t, fd.GetFieldName(), columnName)
|
||||
})
|
||||
|
||||
t.Run("test column value by idx", func(t *testing.T) {
|
||||
_, err := column.ValueByIdx(-1)
|
||||
assert.NotNil(t, err)
|
||||
_, err = column.ValueByIdx(columnLen)
|
||||
assert.NotNil(t, err)
|
||||
for i := 0; i < columnLen; i++ {
|
||||
v, err := column.ValueByIdx(i)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, column.values[i], v)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFieldDataInt8Column(t *testing.T) {
|
||||
len := rand.Intn(10) + 8
|
||||
name := fmt.Sprintf("fd_Int8_%d", rand.Int())
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int8,
|
||||
FieldName: name,
|
||||
}
|
||||
|
||||
t.Run("normal usage", func(t *testing.T) {
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: make([]int32, len),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
column, err := FieldDataColumn(fd, 0, len)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, column)
|
||||
|
||||
assert.Equal(t, name, column.Name())
|
||||
assert.Equal(t, len, column.Len())
|
||||
assert.Equal(t, entity.FieldTypeInt8, column.Type())
|
||||
|
||||
var ev int8
|
||||
err = column.AppendValue(ev)
|
||||
assert.Equal(t, len+1, column.Len())
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = column.AppendValue(struct{}{})
|
||||
assert.Equal(t, len+1, column.Len())
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("nil data", func(t *testing.T) {
|
||||
fd.Field = nil
|
||||
_, err := FieldDataColumn(fd, 0, len)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("get all data", func(t *testing.T) {
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: make([]int32, len),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
column, err := FieldDataColumn(fd, 0, -1)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, column)
|
||||
|
||||
assert.Equal(t, name, column.Name())
|
||||
assert.Equal(t, len, column.Len())
|
||||
assert.Equal(t, entity.FieldTypeInt8, column.Type())
|
||||
})
|
||||
}
|
||||
|
||||
func TestColumnInt16(t *testing.T) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
columnName := fmt.Sprintf("column_Int16_%d", rand.Int())
|
||||
columnLen := 8 + rand.Intn(10)
|
||||
|
||||
v := make([]int16, columnLen)
|
||||
column := NewColumnInt16(columnName, v)
|
||||
|
||||
t.Run("test meta", func(t *testing.T) {
|
||||
ft := entity.FieldTypeInt16
|
||||
assert.Equal(t, "Int16", ft.Name())
|
||||
assert.Equal(t, "int16", ft.String())
|
||||
pbName, pbType := ft.PbFieldType()
|
||||
assert.Equal(t, "Int", pbName)
|
||||
assert.Equal(t, "int32", pbType)
|
||||
})
|
||||
|
||||
t.Run("test column attribute", func(t *testing.T) {
|
||||
assert.Equal(t, columnName, column.Name())
|
||||
assert.Equal(t, entity.FieldTypeInt16, column.Type())
|
||||
assert.Equal(t, columnLen, column.Len())
|
||||
assert.EqualValues(t, v, column.Data())
|
||||
})
|
||||
|
||||
t.Run("test column field data", func(t *testing.T) {
|
||||
fd := column.FieldData()
|
||||
assert.NotNil(t, fd)
|
||||
assert.Equal(t, fd.GetFieldName(), columnName)
|
||||
})
|
||||
|
||||
t.Run("test column value by idx", func(t *testing.T) {
|
||||
_, err := column.ValueByIdx(-1)
|
||||
assert.NotNil(t, err)
|
||||
_, err = column.ValueByIdx(columnLen)
|
||||
assert.NotNil(t, err)
|
||||
for i := 0; i < columnLen; i++ {
|
||||
v, err := column.ValueByIdx(i)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, column.values[i], v)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFieldDataInt16Column(t *testing.T) {
|
||||
len := rand.Intn(10) + 8
|
||||
name := fmt.Sprintf("fd_Int16_%d", rand.Int())
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int16,
|
||||
FieldName: name,
|
||||
}
|
||||
|
||||
t.Run("normal usage", func(t *testing.T) {
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: make([]int32, len),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
column, err := FieldDataColumn(fd, 0, len)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, column)
|
||||
|
||||
assert.Equal(t, name, column.Name())
|
||||
assert.Equal(t, len, column.Len())
|
||||
assert.Equal(t, entity.FieldTypeInt16, column.Type())
|
||||
|
||||
var ev int16
|
||||
err = column.AppendValue(ev)
|
||||
assert.Equal(t, len+1, column.Len())
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = column.AppendValue(struct{}{})
|
||||
assert.Equal(t, len+1, column.Len())
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("nil data", func(t *testing.T) {
|
||||
fd.Field = nil
|
||||
_, err := FieldDataColumn(fd, 0, len)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("get all data", func(t *testing.T) {
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: make([]int32, len),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
column, err := FieldDataColumn(fd, 0, -1)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, column)
|
||||
|
||||
assert.Equal(t, name, column.Name())
|
||||
assert.Equal(t, len, column.Len())
|
||||
assert.Equal(t, entity.FieldTypeInt16, column.Type())
|
||||
})
|
||||
}
|
||||
|
||||
func TestColumnInt32(t *testing.T) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
columnName := fmt.Sprintf("column_Int32_%d", rand.Int())
|
||||
columnLen := 8 + rand.Intn(10)
|
||||
|
||||
v := make([]int32, columnLen)
|
||||
column := NewColumnInt32(columnName, v)
|
||||
|
||||
t.Run("test meta", func(t *testing.T) {
|
||||
ft := entity.FieldTypeInt32
|
||||
assert.Equal(t, "Int32", ft.Name())
|
||||
assert.Equal(t, "int32", ft.String())
|
||||
pbName, pbType := ft.PbFieldType()
|
||||
assert.Equal(t, "Int", pbName)
|
||||
assert.Equal(t, "int32", pbType)
|
||||
})
|
||||
|
||||
t.Run("test column attribute", func(t *testing.T) {
|
||||
assert.Equal(t, columnName, column.Name())
|
||||
assert.Equal(t, entity.FieldTypeInt32, column.Type())
|
||||
assert.Equal(t, columnLen, column.Len())
|
||||
assert.EqualValues(t, v, column.Data())
|
||||
})
|
||||
|
||||
t.Run("test column field data", func(t *testing.T) {
|
||||
fd := column.FieldData()
|
||||
assert.NotNil(t, fd)
|
||||
assert.Equal(t, fd.GetFieldName(), columnName)
|
||||
})
|
||||
|
||||
t.Run("test column value by idx", func(t *testing.T) {
|
||||
_, err := column.ValueByIdx(-1)
|
||||
assert.NotNil(t, err)
|
||||
_, err = column.ValueByIdx(columnLen)
|
||||
assert.NotNil(t, err)
|
||||
for i := 0; i < columnLen; i++ {
|
||||
v, err := column.ValueByIdx(i)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, column.values[i], v)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFieldDataInt32Column(t *testing.T) {
|
||||
len := rand.Intn(10) + 8
|
||||
name := fmt.Sprintf("fd_Int32_%d", rand.Int())
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int32,
|
||||
FieldName: name,
|
||||
}
|
||||
|
||||
t.Run("normal usage", func(t *testing.T) {
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: make([]int32, len),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
column, err := FieldDataColumn(fd, 0, len)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, column)
|
||||
|
||||
assert.Equal(t, name, column.Name())
|
||||
assert.Equal(t, len, column.Len())
|
||||
assert.Equal(t, entity.FieldTypeInt32, column.Type())
|
||||
|
||||
var ev int32
|
||||
err = column.AppendValue(ev)
|
||||
assert.Equal(t, len+1, column.Len())
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = column.AppendValue(struct{}{})
|
||||
assert.Equal(t, len+1, column.Len())
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("nil data", func(t *testing.T) {
|
||||
fd.Field = nil
|
||||
_, err := FieldDataColumn(fd, 0, len)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("get all data", func(t *testing.T) {
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: make([]int32, len),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
column, err := FieldDataColumn(fd, 0, -1)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, column)
|
||||
|
||||
assert.Equal(t, name, column.Name())
|
||||
assert.Equal(t, len, column.Len())
|
||||
assert.Equal(t, entity.FieldTypeInt32, column.Type())
|
||||
})
|
||||
}
|
||||
|
||||
func TestColumnInt64(t *testing.T) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
columnName := fmt.Sprintf("column_Int64_%d", rand.Int())
|
||||
columnLen := 8 + rand.Intn(10)
|
||||
|
||||
v := make([]int64, columnLen)
|
||||
column := NewColumnInt64(columnName, v)
|
||||
|
||||
t.Run("test meta", func(t *testing.T) {
|
||||
ft := entity.FieldTypeInt64
|
||||
assert.Equal(t, "Int64", ft.Name())
|
||||
assert.Equal(t, "int64", ft.String())
|
||||
pbName, pbType := ft.PbFieldType()
|
||||
assert.Equal(t, "Long", pbName)
|
||||
assert.Equal(t, "int64", pbType)
|
||||
})
|
||||
|
||||
t.Run("test column attribute", func(t *testing.T) {
|
||||
assert.Equal(t, columnName, column.Name())
|
||||
assert.Equal(t, entity.FieldTypeInt64, column.Type())
|
||||
assert.Equal(t, columnLen, column.Len())
|
||||
assert.EqualValues(t, v, column.Data())
|
||||
})
|
||||
|
||||
t.Run("test column field data", func(t *testing.T) {
|
||||
fd := column.FieldData()
|
||||
assert.NotNil(t, fd)
|
||||
assert.Equal(t, fd.GetFieldName(), columnName)
|
||||
})
|
||||
|
||||
t.Run("test column value by idx", func(t *testing.T) {
|
||||
_, err := column.ValueByIdx(-1)
|
||||
assert.NotNil(t, err)
|
||||
_, err = column.ValueByIdx(columnLen)
|
||||
assert.NotNil(t, err)
|
||||
for i := 0; i < columnLen; i++ {
|
||||
v, err := column.ValueByIdx(i)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, column.values[i], v)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFieldDataInt64Column(t *testing.T) {
|
||||
len := rand.Intn(10) + 8
|
||||
name := fmt.Sprintf("fd_Int64_%d", rand.Int())
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: name,
|
||||
}
|
||||
|
||||
t.Run("normal usage", func(t *testing.T) {
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: make([]int64, len),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
column, err := FieldDataColumn(fd, 0, len)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, column)
|
||||
|
||||
assert.Equal(t, name, column.Name())
|
||||
assert.Equal(t, len, column.Len())
|
||||
assert.Equal(t, entity.FieldTypeInt64, column.Type())
|
||||
|
||||
var ev int64
|
||||
err = column.AppendValue(ev)
|
||||
assert.Equal(t, len+1, column.Len())
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = column.AppendValue(struct{}{})
|
||||
assert.Equal(t, len+1, column.Len())
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("nil data", func(t *testing.T) {
|
||||
fd.Field = nil
|
||||
_, err := FieldDataColumn(fd, 0, len)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("get all data", func(t *testing.T) {
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: make([]int64, len),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
column, err := FieldDataColumn(fd, 0, -1)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, column)
|
||||
|
||||
assert.Equal(t, name, column.Name())
|
||||
assert.Equal(t, len, column.Len())
|
||||
assert.Equal(t, entity.FieldTypeInt64, column.Type())
|
||||
})
|
||||
}
|
||||
|
||||
func TestColumnFloat(t *testing.T) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
columnName := fmt.Sprintf("column_Float_%d", rand.Int())
|
||||
columnLen := 8 + rand.Intn(10)
|
||||
|
||||
v := make([]float32, columnLen)
|
||||
column := NewColumnFloat(columnName, v)
|
||||
|
||||
t.Run("test meta", func(t *testing.T) {
|
||||
ft := entity.FieldTypeFloat
|
||||
assert.Equal(t, "Float", ft.Name())
|
||||
assert.Equal(t, "float32", ft.String())
|
||||
pbName, pbType := ft.PbFieldType()
|
||||
assert.Equal(t, "Float", pbName)
|
||||
assert.Equal(t, "float32", pbType)
|
||||
})
|
||||
|
||||
t.Run("test column attribute", func(t *testing.T) {
|
||||
assert.Equal(t, columnName, column.Name())
|
||||
assert.Equal(t, entity.FieldTypeFloat, column.Type())
|
||||
assert.Equal(t, columnLen, column.Len())
|
||||
assert.EqualValues(t, v, column.Data())
|
||||
})
|
||||
|
||||
t.Run("test column field data", func(t *testing.T) {
|
||||
fd := column.FieldData()
|
||||
assert.NotNil(t, fd)
|
||||
assert.Equal(t, fd.GetFieldName(), columnName)
|
||||
})
|
||||
|
||||
t.Run("test column value by idx", func(t *testing.T) {
|
||||
_, err := column.ValueByIdx(-1)
|
||||
assert.NotNil(t, err)
|
||||
_, err = column.ValueByIdx(columnLen)
|
||||
assert.NotNil(t, err)
|
||||
for i := 0; i < columnLen; i++ {
|
||||
v, err := column.ValueByIdx(i)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, column.values[i], v)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFieldDataFloatColumn(t *testing.T) {
|
||||
len := rand.Intn(10) + 8
|
||||
name := fmt.Sprintf("fd_Float_%d", rand.Int())
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Float,
|
||||
FieldName: name,
|
||||
}
|
||||
|
||||
t.Run("normal usage", func(t *testing.T) {
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_FloatData{
|
||||
FloatData: &schemapb.FloatArray{
|
||||
Data: make([]float32, len),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
column, err := FieldDataColumn(fd, 0, len)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, column)
|
||||
|
||||
assert.Equal(t, name, column.Name())
|
||||
assert.Equal(t, len, column.Len())
|
||||
assert.Equal(t, entity.FieldTypeFloat, column.Type())
|
||||
|
||||
var ev float32
|
||||
err = column.AppendValue(ev)
|
||||
assert.Equal(t, len+1, column.Len())
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = column.AppendValue(struct{}{})
|
||||
assert.Equal(t, len+1, column.Len())
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("nil data", func(t *testing.T) {
|
||||
fd.Field = nil
|
||||
_, err := FieldDataColumn(fd, 0, len)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("get all data", func(t *testing.T) {
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_FloatData{
|
||||
FloatData: &schemapb.FloatArray{
|
||||
Data: make([]float32, len),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
column, err := FieldDataColumn(fd, 0, -1)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, column)
|
||||
|
||||
assert.Equal(t, name, column.Name())
|
||||
assert.Equal(t, len, column.Len())
|
||||
assert.Equal(t, entity.FieldTypeFloat, column.Type())
|
||||
})
|
||||
}
|
||||
|
||||
func TestColumnDouble(t *testing.T) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
columnName := fmt.Sprintf("column_Double_%d", rand.Int())
|
||||
columnLen := 8 + rand.Intn(10)
|
||||
|
||||
v := make([]float64, columnLen)
|
||||
column := NewColumnDouble(columnName, v)
|
||||
|
||||
t.Run("test meta", func(t *testing.T) {
|
||||
ft := entity.FieldTypeDouble
|
||||
assert.Equal(t, "Double", ft.Name())
|
||||
assert.Equal(t, "float64", ft.String())
|
||||
pbName, pbType := ft.PbFieldType()
|
||||
assert.Equal(t, "Double", pbName)
|
||||
assert.Equal(t, "float64", pbType)
|
||||
})
|
||||
|
||||
t.Run("test column attribute", func(t *testing.T) {
|
||||
assert.Equal(t, columnName, column.Name())
|
||||
assert.Equal(t, entity.FieldTypeDouble, column.Type())
|
||||
assert.Equal(t, columnLen, column.Len())
|
||||
assert.EqualValues(t, v, column.Data())
|
||||
})
|
||||
|
||||
t.Run("test column field data", func(t *testing.T) {
|
||||
fd := column.FieldData()
|
||||
assert.NotNil(t, fd)
|
||||
assert.Equal(t, fd.GetFieldName(), columnName)
|
||||
})
|
||||
|
||||
t.Run("test column value by idx", func(t *testing.T) {
|
||||
_, err := column.ValueByIdx(-1)
|
||||
assert.NotNil(t, err)
|
||||
_, err = column.ValueByIdx(columnLen)
|
||||
assert.NotNil(t, err)
|
||||
for i := 0; i < columnLen; i++ {
|
||||
v, err := column.ValueByIdx(i)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, column.values[i], v)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFieldDataDoubleColumn(t *testing.T) {
|
||||
len := rand.Intn(10) + 8
|
||||
name := fmt.Sprintf("fd_Double_%d", rand.Int())
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Double,
|
||||
FieldName: name,
|
||||
}
|
||||
|
||||
t.Run("normal usage", func(t *testing.T) {
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_DoubleData{
|
||||
DoubleData: &schemapb.DoubleArray{
|
||||
Data: make([]float64, len),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
column, err := FieldDataColumn(fd, 0, len)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, column)
|
||||
|
||||
assert.Equal(t, name, column.Name())
|
||||
assert.Equal(t, len, column.Len())
|
||||
assert.Equal(t, entity.FieldTypeDouble, column.Type())
|
||||
|
||||
var ev float64
|
||||
err = column.AppendValue(ev)
|
||||
assert.Equal(t, len+1, column.Len())
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = column.AppendValue(struct{}{})
|
||||
assert.Equal(t, len+1, column.Len())
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("nil data", func(t *testing.T) {
|
||||
fd.Field = nil
|
||||
_, err := FieldDataColumn(fd, 0, len)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("get all data", func(t *testing.T) {
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_DoubleData{
|
||||
DoubleData: &schemapb.DoubleArray{
|
||||
Data: make([]float64, len),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
column, err := FieldDataColumn(fd, 0, -1)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, column)
|
||||
|
||||
assert.Equal(t, name, column.Name())
|
||||
assert.Equal(t, len, column.Len())
|
||||
assert.Equal(t, entity.FieldTypeDouble, column.Type())
|
||||
})
|
||||
}
|
||||
|
||||
func TestColumnString(t *testing.T) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
columnName := fmt.Sprintf("column_String_%d", rand.Int())
|
||||
columnLen := 8 + rand.Intn(10)
|
||||
|
||||
v := make([]string, columnLen)
|
||||
column := NewColumnString(columnName, v)
|
||||
|
||||
t.Run("test meta", func(t *testing.T) {
|
||||
ft := entity.FieldTypeString
|
||||
assert.Equal(t, "String", ft.Name())
|
||||
assert.Equal(t, "string", ft.String())
|
||||
pbName, pbType := ft.PbFieldType()
|
||||
assert.Equal(t, "String", pbName)
|
||||
assert.Equal(t, "string", pbType)
|
||||
})
|
||||
|
||||
t.Run("test column attribute", func(t *testing.T) {
|
||||
assert.Equal(t, columnName, column.Name())
|
||||
assert.Equal(t, entity.FieldTypeString, column.Type())
|
||||
assert.Equal(t, columnLen, column.Len())
|
||||
assert.EqualValues(t, v, column.Data())
|
||||
})
|
||||
|
||||
t.Run("test column field data", func(t *testing.T) {
|
||||
fd := column.FieldData()
|
||||
assert.NotNil(t, fd)
|
||||
assert.Equal(t, fd.GetFieldName(), columnName)
|
||||
})
|
||||
|
||||
t.Run("test column value by idx", func(t *testing.T) {
|
||||
_, err := column.ValueByIdx(-1)
|
||||
assert.NotNil(t, err)
|
||||
_, err = column.ValueByIdx(columnLen)
|
||||
assert.NotNil(t, err)
|
||||
for i := 0; i < columnLen; i++ {
|
||||
v, err := column.ValueByIdx(i)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, column.values[i], v)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFieldDataStringColumn(t *testing.T) {
|
||||
len := rand.Intn(10) + 8
|
||||
name := fmt.Sprintf("fd_String_%d", rand.Int())
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_String,
|
||||
FieldName: name,
|
||||
}
|
||||
|
||||
t.Run("normal usage", func(t *testing.T) {
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: make([]string, len),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
column, err := FieldDataColumn(fd, 0, len)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, column)
|
||||
|
||||
assert.Equal(t, name, column.Name())
|
||||
assert.Equal(t, len, column.Len())
|
||||
assert.Equal(t, entity.FieldTypeString, column.Type())
|
||||
|
||||
var ev string
|
||||
err = column.AppendValue(ev)
|
||||
assert.Equal(t, len+1, column.Len())
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = column.AppendValue(struct{}{})
|
||||
assert.Equal(t, len+1, column.Len())
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("nil data", func(t *testing.T) {
|
||||
fd.Field = nil
|
||||
_, err := FieldDataColumn(fd, 0, len)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("get all data", func(t *testing.T) {
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: make([]string, len),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
column, err := FieldDataColumn(fd, 0, -1)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, column)
|
||||
|
||||
assert.Equal(t, name, column.Name())
|
||||
assert.Equal(t, len, column.Len())
|
||||
assert.Equal(t, entity.FieldTypeString, column.Type())
|
||||
})
|
||||
}
|
||||
125
client/column/sparse.go
Normal file
125
client/column/sparse.go
Normal file
@ -0,0 +1,125 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package column
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
)
|
||||
|
||||
var _ (Column) = (*ColumnSparseFloatVector)(nil)
|
||||
|
||||
type ColumnSparseFloatVector struct {
|
||||
ColumnBase
|
||||
|
||||
vectors []entity.SparseEmbedding
|
||||
name string
|
||||
}
|
||||
|
||||
// Name returns column name.
|
||||
func (c *ColumnSparseFloatVector) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// Type returns column FieldType.
|
||||
func (c *ColumnSparseFloatVector) Type() entity.FieldType {
|
||||
return entity.FieldTypeSparseVector
|
||||
}
|
||||
|
||||
// Len returns column values length.
|
||||
func (c *ColumnSparseFloatVector) Len() int {
|
||||
return len(c.vectors)
|
||||
}
|
||||
|
||||
// Get returns value at index as interface{}.
|
||||
func (c *ColumnSparseFloatVector) Get(idx int) (interface{}, error) {
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return nil, errors.New("index out of range")
|
||||
}
|
||||
return c.vectors[idx], nil
|
||||
}
|
||||
|
||||
// ValueByIdx returns value of the provided index
|
||||
// error occurs when index out of range
|
||||
func (c *ColumnSparseFloatVector) ValueByIdx(idx int) (entity.SparseEmbedding, error) {
|
||||
var r entity.SparseEmbedding // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.vectors[idx], nil
|
||||
}
|
||||
|
||||
func (c *ColumnSparseFloatVector) FieldData() *schemapb.FieldData {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_SparseFloatVector,
|
||||
FieldName: c.name,
|
||||
}
|
||||
|
||||
dim := int(0)
|
||||
data := make([][]byte, 0, len(c.vectors))
|
||||
for _, vector := range c.vectors {
|
||||
row := make([]byte, 8*vector.Len())
|
||||
for idx := 0; idx < vector.Len(); idx++ {
|
||||
pos, value, _ := vector.Get(idx)
|
||||
binary.LittleEndian.PutUint32(row[idx*8:], pos)
|
||||
binary.LittleEndian.PutUint32(row[pos*8+4:], math.Float32bits(value))
|
||||
}
|
||||
data = append(data, row)
|
||||
if vector.Dim() > dim {
|
||||
dim = vector.Dim()
|
||||
}
|
||||
}
|
||||
|
||||
fd.Field = &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_SparseFloatVector{
|
||||
SparseFloatVector: &schemapb.SparseFloatArray{
|
||||
Dim: int64(dim),
|
||||
Contents: data,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return fd
|
||||
}
|
||||
|
||||
func (c *ColumnSparseFloatVector) AppendValue(i interface{}) error {
|
||||
v, ok := i.(entity.SparseEmbedding)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid type, expect SparseEmbedding interface, got %T", i)
|
||||
}
|
||||
c.vectors = append(c.vectors, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ColumnSparseFloatVector) Data() []entity.SparseEmbedding {
|
||||
return c.vectors
|
||||
}
|
||||
|
||||
func NewColumnSparseVectors(name string, values []entity.SparseEmbedding) *ColumnSparseFloatVector {
|
||||
return &ColumnSparseFloatVector{
|
||||
name: name,
|
||||
vectors: values,
|
||||
}
|
||||
}
|
||||
81
client/column/sparse_test.go
Normal file
81
client/column/sparse_test.go
Normal file
@ -0,0 +1,81 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package column
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestColumnSparseEmbedding(t *testing.T) {
|
||||
columnName := fmt.Sprintf("column_sparse_embedding_%d", rand.Int())
|
||||
columnLen := 8 + rand.Intn(10)
|
||||
|
||||
v := make([]entity.SparseEmbedding, 0, columnLen)
|
||||
for i := 0; i < columnLen; i++ {
|
||||
length := 1 + rand.Intn(5)
|
||||
positions := make([]uint32, length)
|
||||
values := make([]float32, length)
|
||||
for j := 0; j < length; j++ {
|
||||
positions[j] = uint32(j)
|
||||
values[j] = rand.Float32()
|
||||
}
|
||||
se, err := entity.NewSliceSparseEmbedding(positions, values)
|
||||
require.NoError(t, err)
|
||||
v = append(v, se)
|
||||
}
|
||||
column := NewColumnSparseVectors(columnName, v)
|
||||
|
||||
t.Run("test column attribute", func(t *testing.T) {
|
||||
assert.Equal(t, columnName, column.Name())
|
||||
assert.Equal(t, entity.FieldTypeSparseVector, column.Type())
|
||||
assert.Equal(t, columnLen, column.Len())
|
||||
assert.EqualValues(t, v, column.Data())
|
||||
})
|
||||
|
||||
t.Run("test column field data", func(t *testing.T) {
|
||||
fd := column.FieldData()
|
||||
assert.NotNil(t, fd)
|
||||
assert.Equal(t, fd.GetFieldName(), columnName)
|
||||
})
|
||||
|
||||
t.Run("test column value by idx", func(t *testing.T) {
|
||||
_, err := column.ValueByIdx(-1)
|
||||
assert.Error(t, err)
|
||||
_, err = column.ValueByIdx(columnLen)
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = column.Get(-1)
|
||||
assert.Error(t, err)
|
||||
_, err = column.Get(columnLen)
|
||||
assert.Error(t, err)
|
||||
|
||||
for i := 0; i < columnLen; i++ {
|
||||
v, err := column.ValueByIdx(i)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, column.vectors[i], v)
|
||||
getV, err := column.Get(i)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, v, getV)
|
||||
}
|
||||
})
|
||||
}
|
||||
119
client/column/varchar.go
Normal file
119
client/column/varchar.go
Normal file
@ -0,0 +1,119 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package column
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
)
|
||||
|
||||
// ColumnVarChar generated columns type for VarChar
|
||||
type ColumnVarChar struct {
|
||||
ColumnBase
|
||||
name string
|
||||
values []string
|
||||
}
|
||||
|
||||
// Name returns column name
|
||||
func (c *ColumnVarChar) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// Type returns column entity.FieldType
|
||||
func (c *ColumnVarChar) Type() entity.FieldType {
|
||||
return entity.FieldTypeVarChar
|
||||
}
|
||||
|
||||
// Len returns column values length
|
||||
func (c *ColumnVarChar) Len() int {
|
||||
return len(c.values)
|
||||
}
|
||||
|
||||
// Get returns value at index as interface{}.
|
||||
func (c *ColumnVarChar) Get(idx int) (interface{}, error) {
|
||||
if idx < 0 || idx > c.Len() {
|
||||
return "", errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// GetAsString returns value at idx.
|
||||
func (c *ColumnVarChar) GetAsString(idx int) (string, error) {
|
||||
if idx < 0 || idx > c.Len() {
|
||||
return "", errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// FieldData return column data mapped to schemapb.FieldData
|
||||
func (c *ColumnVarChar) FieldData() *schemapb.FieldData {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
FieldName: c.name,
|
||||
}
|
||||
data := make([]string, 0, c.Len())
|
||||
for i := 0; i < c.Len(); i++ {
|
||||
data = append(data, string(c.values[i]))
|
||||
}
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: data,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return fd
|
||||
}
|
||||
|
||||
// ValueByIdx returns value of the provided index
|
||||
// error occurs when index out of range
|
||||
func (c *ColumnVarChar) ValueByIdx(idx int) (string, error) {
|
||||
var r string // use default value
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return r, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// AppendValue append value into column
|
||||
func (c *ColumnVarChar) AppendValue(i interface{}) error {
|
||||
v, ok := i.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid type, expected string, got %T", i)
|
||||
}
|
||||
c.values = append(c.values, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data returns column data
|
||||
func (c *ColumnVarChar) Data() []string {
|
||||
return c.values
|
||||
}
|
||||
|
||||
// NewColumnVarChar auto generated constructor
|
||||
func NewColumnVarChar(name string, values []string) *ColumnVarChar {
|
||||
return &ColumnVarChar{
|
||||
name: name,
|
||||
values: values,
|
||||
}
|
||||
}
|
||||
134
client/column/varchar_test.go
Normal file
134
client/column/varchar_test.go
Normal file
@ -0,0 +1,134 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package column
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
)
|
||||
|
||||
func TestColumnVarChar(t *testing.T) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
columnName := fmt.Sprintf("column_VarChar_%d", rand.Int())
|
||||
columnLen := 8 + rand.Intn(10)
|
||||
|
||||
v := make([]string, columnLen)
|
||||
column := NewColumnVarChar(columnName, v)
|
||||
|
||||
t.Run("test meta", func(t *testing.T) {
|
||||
ft := entity.FieldTypeVarChar
|
||||
assert.Equal(t, "VarChar", ft.Name())
|
||||
assert.Equal(t, "string", ft.String())
|
||||
pbName, pbType := ft.PbFieldType()
|
||||
assert.Equal(t, "VarChar", pbName)
|
||||
assert.Equal(t, "string", pbType)
|
||||
})
|
||||
|
||||
t.Run("test column attribute", func(t *testing.T) {
|
||||
assert.Equal(t, columnName, column.Name())
|
||||
assert.Equal(t, entity.FieldTypeVarChar, column.Type())
|
||||
assert.Equal(t, columnLen, column.Len())
|
||||
assert.EqualValues(t, v, column.Data())
|
||||
})
|
||||
|
||||
t.Run("test column field data", func(t *testing.T) {
|
||||
fd := column.FieldData()
|
||||
assert.NotNil(t, fd)
|
||||
assert.Equal(t, fd.GetFieldName(), columnName)
|
||||
})
|
||||
|
||||
t.Run("test column value by idx", func(t *testing.T) {
|
||||
_, err := column.ValueByIdx(-1)
|
||||
assert.NotNil(t, err)
|
||||
_, err = column.ValueByIdx(columnLen)
|
||||
assert.NotNil(t, err)
|
||||
for i := 0; i < columnLen; i++ {
|
||||
v, err := column.ValueByIdx(i)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, column.values[i], v)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFieldDataVarCharColumn(t *testing.T) {
|
||||
colLen := rand.Intn(10) + 8
|
||||
name := fmt.Sprintf("fd_VarChar_%d", rand.Int())
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
FieldName: name,
|
||||
}
|
||||
|
||||
t.Run("normal usage", func(t *testing.T) {
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: make([]string, colLen),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
column, err := FieldDataColumn(fd, 0, colLen)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, column)
|
||||
|
||||
assert.Equal(t, name, column.Name())
|
||||
assert.Equal(t, colLen, column.Len())
|
||||
assert.Equal(t, entity.FieldTypeVarChar, column.Type())
|
||||
|
||||
var ev string
|
||||
err = column.AppendValue(ev)
|
||||
assert.Equal(t, colLen+1, column.Len())
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = column.AppendValue(struct{}{})
|
||||
assert.Equal(t, colLen+1, column.Len())
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("nil data", func(t *testing.T) {
|
||||
fd.Field = nil
|
||||
_, err := FieldDataColumn(fd, 0, colLen)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("get all data", func(t *testing.T) {
|
||||
fd.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: make([]string, colLen),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
column, err := FieldDataColumn(fd, 0, -1)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, column)
|
||||
|
||||
assert.Equal(t, name, column.Name())
|
||||
assert.Equal(t, colLen, column.Len())
|
||||
assert.Equal(t, entity.FieldTypeVarChar, column.Type())
|
||||
})
|
||||
}
|
||||
358
client/column/vector_gen.go
Normal file
358
client/column/vector_gen.go
Normal file
@ -0,0 +1,358 @@
|
||||
// Code generated by go generate; DO NOT EDIT
|
||||
// This file is generated by go generated
|
||||
package column
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
)
|
||||
|
||||
// ColumnBinaryVector generated columns type for BinaryVector
|
||||
type ColumnBinaryVector struct {
|
||||
ColumnBase
|
||||
name string
|
||||
dim int
|
||||
values [][]byte
|
||||
}
|
||||
|
||||
// Name returns column name
|
||||
func (c *ColumnBinaryVector) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// Type returns column entity.FieldType
|
||||
func (c *ColumnBinaryVector) Type() entity.FieldType {
|
||||
return entity.FieldTypeBinaryVector
|
||||
}
|
||||
|
||||
// Len returns column data length
|
||||
func (c *ColumnBinaryVector) Len() int {
|
||||
return len(c.values)
|
||||
}
|
||||
|
||||
// Dim returns vector dimension
|
||||
func (c *ColumnBinaryVector) Dim() int {
|
||||
return c.dim
|
||||
}
|
||||
|
||||
// Get returns values at index as interface{}.
|
||||
func (c *ColumnBinaryVector) Get(idx int) (interface{}, error) {
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return nil, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// AppendValue append value into column
|
||||
func (c *ColumnBinaryVector) AppendValue(i interface{}) error {
|
||||
v, ok := i.([]byte)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid type, expected []byte, got %T", i)
|
||||
}
|
||||
c.values = append(c.values, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data returns column data
|
||||
func (c *ColumnBinaryVector) Data() [][]byte {
|
||||
return c.values
|
||||
}
|
||||
|
||||
// FieldData return column data mapped to schemapb.FieldData
|
||||
func (c *ColumnBinaryVector) FieldData() *schemapb.FieldData {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_BinaryVector,
|
||||
FieldName: c.name,
|
||||
}
|
||||
|
||||
data := make([]byte, 0, len(c.values)*c.dim)
|
||||
|
||||
for _, vector := range c.values {
|
||||
data = append(data, vector...)
|
||||
}
|
||||
|
||||
fd.Field = &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: int64(c.dim),
|
||||
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: data,
|
||||
},
|
||||
},
|
||||
}
|
||||
return fd
|
||||
}
|
||||
|
||||
// NewColumnBinaryVector auto generated constructor
|
||||
func NewColumnBinaryVector(name string, dim int, values [][]byte) *ColumnBinaryVector {
|
||||
return &ColumnBinaryVector{
|
||||
name: name,
|
||||
dim: dim,
|
||||
values: values,
|
||||
}
|
||||
}
|
||||
|
||||
// ColumnFloatVector generated columns type for FloatVector
|
||||
type ColumnFloatVector struct {
|
||||
ColumnBase
|
||||
name string
|
||||
dim int
|
||||
values [][]float32
|
||||
}
|
||||
|
||||
// Name returns column name
|
||||
func (c *ColumnFloatVector) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// Type returns column entity.FieldType
|
||||
func (c *ColumnFloatVector) Type() entity.FieldType {
|
||||
return entity.FieldTypeFloatVector
|
||||
}
|
||||
|
||||
// Len returns column data length
|
||||
func (c *ColumnFloatVector) Len() int {
|
||||
return len(c.values)
|
||||
}
|
||||
|
||||
// Dim returns vector dimension
|
||||
func (c *ColumnFloatVector) Dim() int {
|
||||
return c.dim
|
||||
}
|
||||
|
||||
// Get returns values at index as interface{}.
|
||||
func (c *ColumnFloatVector) Get(idx int) (interface{}, error) {
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return nil, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// AppendValue append value into column
|
||||
func (c *ColumnFloatVector) AppendValue(i interface{}) error {
|
||||
v, ok := i.([]float32)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid type, expected []float32, got %T", i)
|
||||
}
|
||||
c.values = append(c.values, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data returns column data
|
||||
func (c *ColumnFloatVector) Data() [][]float32 {
|
||||
return c.values
|
||||
}
|
||||
|
||||
// FieldData return column data mapped to schemapb.FieldData
|
||||
func (c *ColumnFloatVector) FieldData() *schemapb.FieldData {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
FieldName: c.name,
|
||||
}
|
||||
|
||||
data := make([]float32, 0, len(c.values)*c.dim)
|
||||
|
||||
for _, vector := range c.values {
|
||||
data = append(data, vector...)
|
||||
}
|
||||
|
||||
fd.Field = &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: int64(c.dim),
|
||||
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: data,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return fd
|
||||
}
|
||||
|
||||
// NewColumnFloatVector auto generated constructor
|
||||
func NewColumnFloatVector(name string, dim int, values [][]float32) *ColumnFloatVector {
|
||||
return &ColumnFloatVector{
|
||||
name: name,
|
||||
dim: dim,
|
||||
values: values,
|
||||
}
|
||||
}
|
||||
|
||||
// ColumnFloat16Vector generated columns type for Float16Vector
|
||||
type ColumnFloat16Vector struct {
|
||||
ColumnBase
|
||||
name string
|
||||
dim int
|
||||
values [][]byte
|
||||
}
|
||||
|
||||
// Name returns column name
|
||||
func (c *ColumnFloat16Vector) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// Type returns column entity.FieldType
|
||||
func (c *ColumnFloat16Vector) Type() entity.FieldType {
|
||||
return entity.FieldTypeFloat16Vector
|
||||
}
|
||||
|
||||
// Len returns column data length
|
||||
func (c *ColumnFloat16Vector) Len() int {
|
||||
return len(c.values)
|
||||
}
|
||||
|
||||
// Dim returns vector dimension
|
||||
func (c *ColumnFloat16Vector) Dim() int {
|
||||
return c.dim
|
||||
}
|
||||
|
||||
// Get returns values at index as interface{}.
|
||||
func (c *ColumnFloat16Vector) Get(idx int) (interface{}, error) {
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return nil, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// AppendValue append value into column
|
||||
func (c *ColumnFloat16Vector) AppendValue(i interface{}) error {
|
||||
v, ok := i.([]byte)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid type, expected []byte, got %T", i)
|
||||
}
|
||||
c.values = append(c.values, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data returns column data
|
||||
func (c *ColumnFloat16Vector) Data() [][]byte {
|
||||
return c.values
|
||||
}
|
||||
|
||||
// FieldData return column data mapped to schemapb.FieldData
|
||||
func (c *ColumnFloat16Vector) FieldData() *schemapb.FieldData {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Float16Vector,
|
||||
FieldName: c.name,
|
||||
}
|
||||
|
||||
data := make([]byte, 0, len(c.values)*c.dim)
|
||||
|
||||
for _, vector := range c.values {
|
||||
data = append(data, vector...)
|
||||
}
|
||||
|
||||
fd.Field = &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: int64(c.dim),
|
||||
|
||||
Data: &schemapb.VectorField_Float16Vector{
|
||||
Float16Vector: data,
|
||||
},
|
||||
},
|
||||
}
|
||||
return fd
|
||||
}
|
||||
|
||||
// NewColumnFloat16Vector auto generated constructor
|
||||
func NewColumnFloat16Vector(name string, dim int, values [][]byte) *ColumnFloat16Vector {
|
||||
return &ColumnFloat16Vector{
|
||||
name: name,
|
||||
dim: dim,
|
||||
values: values,
|
||||
}
|
||||
}
|
||||
|
||||
// ColumnBFloat16Vector generated columns type for BFloat16Vector
|
||||
type ColumnBFloat16Vector struct {
|
||||
ColumnBase
|
||||
name string
|
||||
dim int
|
||||
values [][]byte
|
||||
}
|
||||
|
||||
// Name returns column name
|
||||
func (c *ColumnBFloat16Vector) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
// Type returns column entity.FieldType
|
||||
func (c *ColumnBFloat16Vector) Type() entity.FieldType {
|
||||
return entity.FieldTypeBFloat16Vector
|
||||
}
|
||||
|
||||
// Len returns column data length
|
||||
func (c *ColumnBFloat16Vector) Len() int {
|
||||
return len(c.values)
|
||||
}
|
||||
|
||||
// Dim returns vector dimension
|
||||
func (c *ColumnBFloat16Vector) Dim() int {
|
||||
return c.dim
|
||||
}
|
||||
|
||||
// Get returns values at index as interface{}.
|
||||
func (c *ColumnBFloat16Vector) Get(idx int) (interface{}, error) {
|
||||
if idx < 0 || idx >= c.Len() {
|
||||
return nil, errors.New("index out of range")
|
||||
}
|
||||
return c.values[idx], nil
|
||||
}
|
||||
|
||||
// AppendValue append value into column
|
||||
func (c *ColumnBFloat16Vector) AppendValue(i interface{}) error {
|
||||
v, ok := i.([]byte)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid type, expected []byte, got %T", i)
|
||||
}
|
||||
c.values = append(c.values, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data returns column data
|
||||
func (c *ColumnBFloat16Vector) Data() [][]byte {
|
||||
return c.values
|
||||
}
|
||||
|
||||
// FieldData return column data mapped to schemapb.FieldData
|
||||
func (c *ColumnBFloat16Vector) FieldData() *schemapb.FieldData {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_BFloat16Vector,
|
||||
FieldName: c.name,
|
||||
}
|
||||
|
||||
data := make([]byte, 0, len(c.values)*c.dim)
|
||||
|
||||
for _, vector := range c.values {
|
||||
data = append(data, vector...)
|
||||
}
|
||||
|
||||
fd.Field = &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: int64(c.dim),
|
||||
|
||||
Data: &schemapb.VectorField_Bfloat16Vector{
|
||||
Bfloat16Vector: data,
|
||||
},
|
||||
},
|
||||
}
|
||||
return fd
|
||||
}
|
||||
|
||||
// NewColumnBFloat16Vector auto generated constructor
|
||||
func NewColumnBFloat16Vector(name string, dim int, values [][]byte) *ColumnBFloat16Vector {
|
||||
return &ColumnBFloat16Vector{
|
||||
name: name,
|
||||
dim: dim,
|
||||
values: values,
|
||||
}
|
||||
}
|
||||
264
client/column/vector_gen_test.go
Normal file
264
client/column/vector_gen_test.go
Normal file
@ -0,0 +1,264 @@
|
||||
// Code generated by go generate; DO NOT EDIT
|
||||
// This file is generated by go generated
|
||||
|
||||
package column
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestColumnBinaryVector(t *testing.T) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
columnName := fmt.Sprintf("column_BinaryVector_%d", rand.Int())
|
||||
columnLen := 12 + rand.Intn(10)
|
||||
dim := ([]int{64, 128, 256, 512})[rand.Intn(4)]
|
||||
|
||||
v := make([][]byte, 0, columnLen)
|
||||
dlen := dim
|
||||
dlen /= 8
|
||||
|
||||
for i := 0; i < columnLen; i++ {
|
||||
entry := make([]byte, dlen)
|
||||
v = append(v, entry)
|
||||
}
|
||||
column := NewColumnBinaryVector(columnName, dim, v)
|
||||
|
||||
t.Run("test meta", func(t *testing.T) {
|
||||
ft := entity.FieldTypeBinaryVector
|
||||
assert.Equal(t, "BinaryVector", ft.Name())
|
||||
assert.Equal(t, "[]byte", ft.String())
|
||||
pbName, pbType := ft.PbFieldType()
|
||||
assert.Equal(t, "[]byte", pbName)
|
||||
assert.Equal(t, "", pbType)
|
||||
})
|
||||
|
||||
t.Run("test column attribute", func(t *testing.T) {
|
||||
assert.Equal(t, columnName, column.Name())
|
||||
assert.Equal(t, entity.FieldTypeBinaryVector, column.Type())
|
||||
assert.Equal(t, columnLen, column.Len())
|
||||
assert.Equal(t, dim, column.Dim())
|
||||
assert.Equal(t, v, column.Data())
|
||||
|
||||
var ev []byte
|
||||
err := column.AppendValue(ev)
|
||||
assert.Equal(t, columnLen+1, column.Len())
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = column.AppendValue(struct{}{})
|
||||
assert.Equal(t, columnLen+1, column.Len())
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("test column field data", func(t *testing.T) {
|
||||
fd := column.FieldData()
|
||||
assert.NotNil(t, fd)
|
||||
assert.Equal(t, fd.GetFieldName(), columnName)
|
||||
|
||||
c, err := FieldDataVector(fd)
|
||||
assert.NotNil(t, c)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("test column field data error", func(t *testing.T) {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_BinaryVector,
|
||||
FieldName: columnName,
|
||||
}
|
||||
_, err := FieldDataVector(fd)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestColumnFloatVector(t *testing.T) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
columnName := fmt.Sprintf("column_FloatVector_%d", rand.Int())
|
||||
columnLen := 12 + rand.Intn(10)
|
||||
dim := ([]int{64, 128, 256, 512})[rand.Intn(4)]
|
||||
|
||||
v := make([][]float32, 0, columnLen)
|
||||
dlen := dim
|
||||
|
||||
for i := 0; i < columnLen; i++ {
|
||||
entry := make([]float32, dlen)
|
||||
v = append(v, entry)
|
||||
}
|
||||
column := NewColumnFloatVector(columnName, dim, v)
|
||||
|
||||
t.Run("test meta", func(t *testing.T) {
|
||||
ft := entity.FieldTypeFloatVector
|
||||
assert.Equal(t, "FloatVector", ft.Name())
|
||||
assert.Equal(t, "[]float32", ft.String())
|
||||
pbName, pbType := ft.PbFieldType()
|
||||
assert.Equal(t, "[]float32", pbName)
|
||||
assert.Equal(t, "", pbType)
|
||||
})
|
||||
|
||||
t.Run("test column attribute", func(t *testing.T) {
|
||||
assert.Equal(t, columnName, column.Name())
|
||||
assert.Equal(t, entity.FieldTypeFloatVector, column.Type())
|
||||
assert.Equal(t, columnLen, column.Len())
|
||||
assert.Equal(t, dim, column.Dim())
|
||||
assert.Equal(t, v, column.Data())
|
||||
|
||||
var ev []float32
|
||||
err := column.AppendValue(ev)
|
||||
assert.Equal(t, columnLen+1, column.Len())
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = column.AppendValue(struct{}{})
|
||||
assert.Equal(t, columnLen+1, column.Len())
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("test column field data", func(t *testing.T) {
|
||||
fd := column.FieldData()
|
||||
assert.NotNil(t, fd)
|
||||
assert.Equal(t, fd.GetFieldName(), columnName)
|
||||
|
||||
c, err := FieldDataVector(fd)
|
||||
assert.NotNil(t, c)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("test column field data error", func(t *testing.T) {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
FieldName: columnName,
|
||||
}
|
||||
_, err := FieldDataVector(fd)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestColumnFloat16Vector(t *testing.T) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
columnName := fmt.Sprintf("column_Float16Vector_%d", rand.Int())
|
||||
columnLen := 12 + rand.Intn(10)
|
||||
dim := ([]int{64, 128, 256, 512})[rand.Intn(4)]
|
||||
|
||||
v := make([][]byte, 0, columnLen)
|
||||
dlen := dim
|
||||
|
||||
dlen *= 2
|
||||
|
||||
for i := 0; i < columnLen; i++ {
|
||||
entry := make([]byte, dlen)
|
||||
v = append(v, entry)
|
||||
}
|
||||
column := NewColumnFloat16Vector(columnName, dim, v)
|
||||
|
||||
t.Run("test meta", func(t *testing.T) {
|
||||
ft := entity.FieldTypeFloat16Vector
|
||||
assert.Equal(t, "Float16Vector", ft.Name())
|
||||
assert.Equal(t, "[]byte", ft.String())
|
||||
pbName, pbType := ft.PbFieldType()
|
||||
assert.Equal(t, "[]byte", pbName)
|
||||
assert.Equal(t, "", pbType)
|
||||
})
|
||||
|
||||
t.Run("test column attribute", func(t *testing.T) {
|
||||
assert.Equal(t, columnName, column.Name())
|
||||
assert.Equal(t, entity.FieldTypeFloat16Vector, column.Type())
|
||||
assert.Equal(t, columnLen, column.Len())
|
||||
assert.Equal(t, dim, column.Dim())
|
||||
assert.Equal(t, v, column.Data())
|
||||
|
||||
var ev []byte
|
||||
err := column.AppendValue(ev)
|
||||
assert.Equal(t, columnLen+1, column.Len())
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = column.AppendValue(struct{}{})
|
||||
assert.Equal(t, columnLen+1, column.Len())
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("test column field data", func(t *testing.T) {
|
||||
fd := column.FieldData()
|
||||
assert.NotNil(t, fd)
|
||||
assert.Equal(t, fd.GetFieldName(), columnName)
|
||||
|
||||
c, err := FieldDataVector(fd)
|
||||
assert.NotNil(t, c)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("test column field data error", func(t *testing.T) {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Float16Vector,
|
||||
FieldName: columnName,
|
||||
}
|
||||
_, err := FieldDataVector(fd)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestColumnBFloat16Vector(t *testing.T) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
columnName := fmt.Sprintf("column_BFloat16Vector_%d", rand.Int())
|
||||
columnLen := 12 + rand.Intn(10)
|
||||
dim := ([]int{64, 128, 256, 512})[rand.Intn(4)]
|
||||
|
||||
v := make([][]byte, 0, columnLen)
|
||||
dlen := dim
|
||||
|
||||
dlen *= 2
|
||||
|
||||
for i := 0; i < columnLen; i++ {
|
||||
entry := make([]byte, dlen)
|
||||
v = append(v, entry)
|
||||
}
|
||||
column := NewColumnBFloat16Vector(columnName, dim, v)
|
||||
|
||||
t.Run("test meta", func(t *testing.T) {
|
||||
ft := entity.FieldTypeBFloat16Vector
|
||||
assert.Equal(t, "BFloat16Vector", ft.Name())
|
||||
assert.Equal(t, "[]byte", ft.String())
|
||||
pbName, pbType := ft.PbFieldType()
|
||||
assert.Equal(t, "[]byte", pbName)
|
||||
assert.Equal(t, "", pbType)
|
||||
})
|
||||
|
||||
t.Run("test column attribute", func(t *testing.T) {
|
||||
assert.Equal(t, columnName, column.Name())
|
||||
assert.Equal(t, entity.FieldTypeBFloat16Vector, column.Type())
|
||||
assert.Equal(t, columnLen, column.Len())
|
||||
assert.Equal(t, dim, column.Dim())
|
||||
assert.Equal(t, v, column.Data())
|
||||
|
||||
var ev []byte
|
||||
err := column.AppendValue(ev)
|
||||
assert.Equal(t, columnLen+1, column.Len())
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = column.AppendValue(struct{}{})
|
||||
assert.Equal(t, columnLen+1, column.Len())
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("test column field data", func(t *testing.T) {
|
||||
fd := column.FieldData()
|
||||
assert.NotNil(t, fd)
|
||||
assert.Equal(t, fd.GetFieldName(), columnName)
|
||||
|
||||
c, err := FieldDataVector(fd)
|
||||
assert.NotNil(t, c)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("test column field data error", func(t *testing.T) {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_BFloat16Vector,
|
||||
FieldName: columnName,
|
||||
}
|
||||
_, err := FieldDataVector(fd)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
44
client/common.go
Normal file
44
client/common.go
Normal file
@ -0,0 +1,44 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
"github.com/milvus-io/milvus/pkg/util/conc"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
// CollectionCache stores the cached collection schema information.
|
||||
type CollectionCache struct {
|
||||
sf conc.Singleflight[*entity.Collection]
|
||||
collections *typeutil.ConcurrentMap[string, *entity.Collection]
|
||||
fetcher func(context.Context, string) (*entity.Collection, error)
|
||||
}
|
||||
|
||||
func (c *CollectionCache) GetCollection(ctx context.Context, collName string) (*entity.Collection, error) {
|
||||
coll, ok := c.collections.Get(collName)
|
||||
if ok {
|
||||
return coll, nil
|
||||
}
|
||||
|
||||
coll, err, _ := c.sf.Do(collName, func() (*entity.Collection, error) {
|
||||
coll, err := c.fetcher(ctx, collName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.collections.Insert(collName, coll)
|
||||
return coll, nil
|
||||
})
|
||||
return coll, err
|
||||
}
|
||||
|
||||
func NewCollectionCache(fetcher func(context.Context, string) (*entity.Collection, error)) *CollectionCache {
|
||||
return &CollectionCache{
|
||||
collections: typeutil.NewConcurrentMap[string, *entity.Collection](),
|
||||
fetcher: fetcher,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) getCollection(ctx context.Context, collName string) (*entity.Collection, error) {
|
||||
return c.collCache.GetCollection(ctx, collName)
|
||||
}
|
||||
22
client/common/version.go
Normal file
22
client/common/version.go
Normal file
@ -0,0 +1,22 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package common
|
||||
|
||||
const (
|
||||
// SDKVersion const value for current version
|
||||
SDKVersion = `2.4.0-dev`
|
||||
)
|
||||
29
client/common/version_test.go
Normal file
29
client/common/version_test.go
Normal file
@ -0,0 +1,29 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package common
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/blang/semver/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestVersion(t *testing.T) {
|
||||
_, err := semver.Parse(SDKVersion)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
60
client/database.go
Normal file
60
client/database.go
Normal file
@ -0,0 +1,60 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
func (c *Client) ListDatabase(ctx context.Context, option ListDatabaseOption, callOptions ...grpc.CallOption) (databaseNames []string, err error) {
|
||||
req := option.Request()
|
||||
|
||||
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.ListDatabases(ctx, req, callOptions...)
|
||||
err = merr.CheckRPCCall(resp, err)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
databaseNames = resp.GetDbNames()
|
||||
return nil
|
||||
})
|
||||
|
||||
return databaseNames, err
|
||||
}
|
||||
|
||||
func (c *Client) CreateDatabase(ctx context.Context, option CreateDatabaseOption, callOptions ...grpc.CallOption) error {
|
||||
req := option.Request()
|
||||
|
||||
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.CreateDatabase(ctx, req, callOptions...)
|
||||
return merr.CheckRPCCall(resp, err)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) DropDatabase(ctx context.Context, option DropDatabaseOption, callOptions ...grpc.CallOption) error {
|
||||
req := option.Request()
|
||||
|
||||
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.DropDatabase(ctx, req, callOptions...)
|
||||
return merr.CheckRPCCall(resp, err)
|
||||
})
|
||||
}
|
||||
74
client/database_options.go
Normal file
74
client/database_options.go
Normal file
@ -0,0 +1,74 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
|
||||
// ListDatabaseOption is a builder interface for ListDatabase request.
|
||||
type ListDatabaseOption interface {
|
||||
Request() *milvuspb.ListDatabasesRequest
|
||||
}
|
||||
|
||||
type listDatabaseOption struct{}
|
||||
|
||||
func (opt *listDatabaseOption) Request() *milvuspb.ListDatabasesRequest {
|
||||
return &milvuspb.ListDatabasesRequest{}
|
||||
}
|
||||
|
||||
func NewListDatabaseOption() *listDatabaseOption {
|
||||
return &listDatabaseOption{}
|
||||
}
|
||||
|
||||
type CreateDatabaseOption interface {
|
||||
Request() *milvuspb.CreateDatabaseRequest
|
||||
}
|
||||
|
||||
type createDatabaseOption struct {
|
||||
dbName string
|
||||
}
|
||||
|
||||
func (opt *createDatabaseOption) Request() *milvuspb.CreateDatabaseRequest {
|
||||
return &milvuspb.CreateDatabaseRequest{
|
||||
DbName: opt.dbName,
|
||||
}
|
||||
}
|
||||
|
||||
func NewCreateDatabaseOption(dbName string) *createDatabaseOption {
|
||||
return &createDatabaseOption{
|
||||
dbName: dbName,
|
||||
}
|
||||
}
|
||||
|
||||
type DropDatabaseOption interface {
|
||||
Request() *milvuspb.DropDatabaseRequest
|
||||
}
|
||||
|
||||
type dropDatabaseOption struct {
|
||||
dbName string
|
||||
}
|
||||
|
||||
func (opt *dropDatabaseOption) Request() *milvuspb.DropDatabaseRequest {
|
||||
return &milvuspb.DropDatabaseRequest{
|
||||
DbName: opt.dbName,
|
||||
}
|
||||
}
|
||||
|
||||
func NewDropDatabaseOption(dbName string) *dropDatabaseOption {
|
||||
return &dropDatabaseOption{
|
||||
dbName: dbName,
|
||||
}
|
||||
}
|
||||
92
client/database_test.go
Normal file
92
client/database_test.go
Normal file
@ -0,0 +1,92 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type DatabaseSuite struct {
|
||||
MockSuiteBase
|
||||
}
|
||||
|
||||
func (s *DatabaseSuite) TestListDatabases() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
s.mock.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{
|
||||
Status: merr.Success(),
|
||||
DbNames: []string{"default", "db1"},
|
||||
}, nil).Once()
|
||||
|
||||
names, err := s.client.ListDatabase(ctx, NewListDatabaseOption())
|
||||
s.NoError(err)
|
||||
s.ElementsMatch([]string{"default", "db1"}, names)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
s.mock.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
_, err := s.client.ListDatabase(ctx, NewListDatabaseOption())
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *DatabaseSuite) TestCreateDatabase() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
dbName := fmt.Sprintf("dt_%s", s.randString(6))
|
||||
s.mock.EXPECT().CreateDatabase(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cdr *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) {
|
||||
s.Equal(dbName, cdr.GetDbName())
|
||||
return merr.Success(), nil
|
||||
}).Once()
|
||||
|
||||
err := s.client.CreateDatabase(ctx, NewCreateDatabaseOption(dbName))
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
dbName := fmt.Sprintf("dt_%s", s.randString(6))
|
||||
s.mock.EXPECT().CreateDatabase(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
err := s.client.CreateDatabase(ctx, NewCreateDatabaseOption(dbName))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *DatabaseSuite) TestDropDatabase() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
dbName := fmt.Sprintf("dt_%s", s.randString(6))
|
||||
s.mock.EXPECT().DropDatabase(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ddr *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) {
|
||||
s.Equal(dbName, ddr.GetDbName())
|
||||
return merr.Success(), nil
|
||||
}).Once()
|
||||
|
||||
err := s.client.DropDatabase(ctx, NewDropDatabaseOption(dbName))
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
dbName := fmt.Sprintf("dt_%s", s.randString(6))
|
||||
s.mock.EXPECT().DropDatabase(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
err := s.client.DropDatabase(ctx, NewDropDatabaseOption(dbName))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDatabase(t *testing.T) {
|
||||
suite.Run(t, new(DatabaseSuite))
|
||||
}
|
||||
18
client/doc.go
Normal file
18
client/doc.go
Normal file
@ -0,0 +1,18 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package milvusclient implements the official Go Milvus client for v2.
|
||||
package client
|
||||
56
client/entity/collection.go
Normal file
56
client/entity/collection.go
Normal file
@ -0,0 +1,56 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package entity
|
||||
|
||||
// DefaultShardNumber const value for using Milvus default shard number.
|
||||
const DefaultShardNumber int32 = 0
|
||||
|
||||
// DefaultConsistencyLevel const value for using Milvus default consistency level setting.
|
||||
const DefaultConsistencyLevel ConsistencyLevel = ClBounded
|
||||
|
||||
// Collection represent collection meta in Milvus
|
||||
type Collection struct {
|
||||
ID int64 // collection id
|
||||
Name string // collection name
|
||||
Schema *Schema // collection schema, with fields schema and primary key definition
|
||||
PhysicalChannels []string
|
||||
VirtualChannels []string
|
||||
Loaded bool
|
||||
ConsistencyLevel ConsistencyLevel
|
||||
ShardNum int32
|
||||
}
|
||||
|
||||
// Partition represent partition meta in Milvus
|
||||
type Partition struct {
|
||||
ID int64 // partition id
|
||||
Name string // partition name
|
||||
Loaded bool // partition loaded
|
||||
}
|
||||
|
||||
// ReplicaGroup represents a replica group
|
||||
type ReplicaGroup struct {
|
||||
ReplicaID int64
|
||||
NodeIDs []int64
|
||||
ShardReplicas []*ShardReplica
|
||||
}
|
||||
|
||||
// ShardReplica represents a shard in the ReplicaGroup
|
||||
type ShardReplica struct {
|
||||
LeaderID int64
|
||||
NodesIDs []int64
|
||||
DmChannelName string
|
||||
}
|
||||
96
client/entity/collection_attr.go
Normal file
96
client/entity/collection_attr.go
Normal file
@ -0,0 +1,96 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package entity
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
// cakTTL const for collection attribute key TTL in seconds.
|
||||
cakTTL = `collection.ttl.seconds`
|
||||
// cakAutoCompaction const for collection attribute key autom compaction enabled.
|
||||
cakAutoCompaction = `collection.autocompaction.enabled`
|
||||
)
|
||||
|
||||
// CollectionAttribute is the interface for altering collection attributes.
|
||||
type CollectionAttribute interface {
|
||||
KeyValue() (string, string)
|
||||
Valid() error
|
||||
}
|
||||
|
||||
type collAttrBase struct {
|
||||
key string
|
||||
value string
|
||||
}
|
||||
|
||||
// KeyValue implements CollectionAttribute.
|
||||
func (ca collAttrBase) KeyValue() (string, string) {
|
||||
return ca.key, ca.value
|
||||
}
|
||||
|
||||
type ttlCollAttr struct {
|
||||
collAttrBase
|
||||
}
|
||||
|
||||
// Valid implements CollectionAttribute.
|
||||
// checks ttl seconds is valid positive integer.
|
||||
func (ca collAttrBase) Valid() error {
|
||||
val, err := strconv.ParseInt(ca.value, 10, 64)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "ttl is not a valid positive integer")
|
||||
}
|
||||
|
||||
if val < 0 {
|
||||
return errors.New("ttl needs to be a positive integer")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CollectionTTL returns collection attribute to set collection ttl in seconds.
|
||||
func CollectionTTL(ttl int64) ttlCollAttr {
|
||||
ca := ttlCollAttr{}
|
||||
ca.key = cakTTL
|
||||
ca.value = strconv.FormatInt(ttl, 10)
|
||||
return ca
|
||||
}
|
||||
|
||||
type autoCompactionCollAttr struct {
|
||||
collAttrBase
|
||||
}
|
||||
|
||||
// Valid implements CollectionAttribute.
|
||||
// checks collection auto compaction is valid bool.
|
||||
func (ca autoCompactionCollAttr) Valid() error {
|
||||
_, err := strconv.ParseBool(ca.value)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "auto compaction setting is not valid boolean")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CollectionAutoCompactionEnabled returns collection attribute to set collection auto compaction enabled.
|
||||
func CollectionAutoCompactionEnabled(enabled bool) autoCompactionCollAttr {
|
||||
ca := autoCompactionCollAttr{}
|
||||
ca.key = cakAutoCompaction
|
||||
ca.value = strconv.FormatBool(enabled)
|
||||
return ca
|
||||
}
|
||||
136
client/entity/collection_attr_test.go
Normal file
136
client/entity/collection_attr_test.go
Normal file
@ -0,0 +1,136 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package entity
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type CollectionTTLSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func (s *CollectionTTLSuite) TestValid() {
|
||||
type testCase struct {
|
||||
input string
|
||||
expectErr bool
|
||||
}
|
||||
|
||||
cases := []testCase{
|
||||
{input: "a", expectErr: true},
|
||||
{input: "1000", expectErr: false},
|
||||
{input: "0", expectErr: false},
|
||||
{input: "-10", expectErr: true},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
s.Run(tc.input, func() {
|
||||
ca := ttlCollAttr{}
|
||||
ca.value = tc.input
|
||||
err := ca.Valid()
|
||||
if tc.expectErr {
|
||||
s.Error(err)
|
||||
} else {
|
||||
s.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *CollectionTTLSuite) TestCollectionTTL() {
|
||||
type testCase struct {
|
||||
input int64
|
||||
expectErr bool
|
||||
}
|
||||
|
||||
cases := []testCase{
|
||||
{input: 1000, expectErr: false},
|
||||
{input: 0, expectErr: false},
|
||||
{input: -10, expectErr: true},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
s.Run(fmt.Sprintf("%d", tc.input), func() {
|
||||
ca := CollectionTTL(tc.input)
|
||||
key, value := ca.KeyValue()
|
||||
s.Equal(cakTTL, key)
|
||||
s.Equal(strconv.FormatInt(tc.input, 10), value)
|
||||
err := ca.Valid()
|
||||
if tc.expectErr {
|
||||
s.Error(err)
|
||||
} else {
|
||||
s.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectionTTL(t *testing.T) {
|
||||
suite.Run(t, new(CollectionTTLSuite))
|
||||
}
|
||||
|
||||
type CollectionAutoCompactionSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func (s *CollectionAutoCompactionSuite) TestValid() {
|
||||
type testCase struct {
|
||||
input string
|
||||
expectErr bool
|
||||
}
|
||||
|
||||
cases := []testCase{
|
||||
{input: "a", expectErr: true},
|
||||
{input: "true", expectErr: false},
|
||||
{input: "false", expectErr: false},
|
||||
{input: "", expectErr: true},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
s.Run(tc.input, func() {
|
||||
ca := autoCompactionCollAttr{}
|
||||
ca.value = tc.input
|
||||
err := ca.Valid()
|
||||
if tc.expectErr {
|
||||
s.Error(err)
|
||||
} else {
|
||||
s.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *CollectionAutoCompactionSuite) TestCollectionAutoCompactionEnabled() {
|
||||
cases := []bool{true, false}
|
||||
|
||||
for _, tc := range cases {
|
||||
s.Run(fmt.Sprintf("%v", tc), func() {
|
||||
ca := CollectionAutoCompactionEnabled(tc)
|
||||
key, value := ca.KeyValue()
|
||||
s.Equal(cakAutoCompaction, key)
|
||||
s.Equal(strconv.FormatBool(tc), value)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollectionAutoCompaction(t *testing.T) {
|
||||
suite.Run(t, new(CollectionAutoCompactionSuite))
|
||||
}
|
||||
32
client/entity/common.go
Normal file
32
client/entity/common.go
Normal file
@ -0,0 +1,32 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package entity
|
||||
|
||||
// MetricType metric type
|
||||
type MetricType string
|
||||
|
||||
// Metric Constants
|
||||
const (
|
||||
L2 MetricType = "L2"
|
||||
IP MetricType = "IP"
|
||||
COSINE MetricType = "COSINE"
|
||||
HAMMING MetricType = "HAMMING"
|
||||
JACCARD MetricType = "JACCARD"
|
||||
TANIMOTO MetricType = "TANIMOTO"
|
||||
SUBSTRUCTURE MetricType = "SUBSTRUCTURE"
|
||||
SUPERSTRUCTURE MetricType = "SUPERSTRUCTURE"
|
||||
)
|
||||
171
client/entity/field_type.go
Normal file
171
client/entity/field_type.go
Normal file
@ -0,0 +1,171 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package entity
|
||||
|
||||
// FieldType field data type alias type
|
||||
// used in go:generate trick, DO NOT modify names & string
|
||||
type FieldType int32
|
||||
|
||||
// Name returns field type name
|
||||
func (t FieldType) Name() string {
|
||||
switch t {
|
||||
case FieldTypeBool:
|
||||
return "Bool"
|
||||
case FieldTypeInt8:
|
||||
return "Int8"
|
||||
case FieldTypeInt16:
|
||||
return "Int16"
|
||||
case FieldTypeInt32:
|
||||
return "Int32"
|
||||
case FieldTypeInt64:
|
||||
return "Int64"
|
||||
case FieldTypeFloat:
|
||||
return "Float"
|
||||
case FieldTypeDouble:
|
||||
return "Double"
|
||||
case FieldTypeString:
|
||||
return "String"
|
||||
case FieldTypeVarChar:
|
||||
return "VarChar"
|
||||
case FieldTypeArray:
|
||||
return "Array"
|
||||
case FieldTypeJSON:
|
||||
return "JSON"
|
||||
case FieldTypeBinaryVector:
|
||||
return "BinaryVector"
|
||||
case FieldTypeFloatVector:
|
||||
return "FloatVector"
|
||||
case FieldTypeFloat16Vector:
|
||||
return "Float16Vector"
|
||||
case FieldTypeBFloat16Vector:
|
||||
return "BFloat16Vector"
|
||||
default:
|
||||
return "undefined"
|
||||
}
|
||||
}
|
||||
|
||||
// String returns field type
|
||||
func (t FieldType) String() string {
|
||||
switch t {
|
||||
case FieldTypeBool:
|
||||
return "bool"
|
||||
case FieldTypeInt8:
|
||||
return "int8"
|
||||
case FieldTypeInt16:
|
||||
return "int16"
|
||||
case FieldTypeInt32:
|
||||
return "int32"
|
||||
case FieldTypeInt64:
|
||||
return "int64"
|
||||
case FieldTypeFloat:
|
||||
return "float32"
|
||||
case FieldTypeDouble:
|
||||
return "float64"
|
||||
case FieldTypeString:
|
||||
return "string"
|
||||
case FieldTypeVarChar:
|
||||
return "string"
|
||||
case FieldTypeArray:
|
||||
return "Array"
|
||||
case FieldTypeJSON:
|
||||
return "JSON"
|
||||
case FieldTypeBinaryVector:
|
||||
return "[]byte"
|
||||
case FieldTypeFloatVector:
|
||||
return "[]float32"
|
||||
case FieldTypeFloat16Vector:
|
||||
return "[]byte"
|
||||
case FieldTypeBFloat16Vector:
|
||||
return "[]byte"
|
||||
default:
|
||||
return "undefined"
|
||||
}
|
||||
}
|
||||
|
||||
// PbFieldType represents FieldType corresponding schema pb type
|
||||
func (t FieldType) PbFieldType() (string, string) {
|
||||
switch t {
|
||||
case FieldTypeBool:
|
||||
return "Bool", "bool"
|
||||
case FieldTypeInt8:
|
||||
fallthrough
|
||||
case FieldTypeInt16:
|
||||
fallthrough
|
||||
case FieldTypeInt32:
|
||||
return "Int", "int32"
|
||||
case FieldTypeInt64:
|
||||
return "Long", "int64"
|
||||
case FieldTypeFloat:
|
||||
return "Float", "float32"
|
||||
case FieldTypeDouble:
|
||||
return "Double", "float64"
|
||||
case FieldTypeString:
|
||||
return "String", "string"
|
||||
case FieldTypeVarChar:
|
||||
return "VarChar", "string"
|
||||
case FieldTypeJSON:
|
||||
return "JSON", "JSON"
|
||||
case FieldTypeBinaryVector:
|
||||
return "[]byte", ""
|
||||
case FieldTypeFloatVector:
|
||||
return "[]float32", ""
|
||||
case FieldTypeFloat16Vector:
|
||||
return "[]byte", ""
|
||||
case FieldTypeBFloat16Vector:
|
||||
return "[]byte", ""
|
||||
default:
|
||||
return "undefined", ""
|
||||
}
|
||||
}
|
||||
|
||||
// Match schema definition
|
||||
const (
|
||||
// FieldTypeNone zero value place holder
|
||||
FieldTypeNone FieldType = 0 // zero value place holder
|
||||
// FieldTypeBool field type boolean
|
||||
FieldTypeBool FieldType = 1
|
||||
// FieldTypeInt8 field type int8
|
||||
FieldTypeInt8 FieldType = 2
|
||||
// FieldTypeInt16 field type int16
|
||||
FieldTypeInt16 FieldType = 3
|
||||
// FieldTypeInt32 field type int32
|
||||
FieldTypeInt32 FieldType = 4
|
||||
// FieldTypeInt64 field type int64
|
||||
FieldTypeInt64 FieldType = 5
|
||||
// FieldTypeFloat field type float
|
||||
FieldTypeFloat FieldType = 10
|
||||
// FieldTypeDouble field type double
|
||||
FieldTypeDouble FieldType = 11
|
||||
// FieldTypeString field type string
|
||||
FieldTypeString FieldType = 20
|
||||
// FieldTypeVarChar field type varchar
|
||||
FieldTypeVarChar FieldType = 21 // variable-length strings with a specified maximum length
|
||||
// FieldTypeArray field type Array
|
||||
FieldTypeArray FieldType = 22
|
||||
// FieldTypeJSON field type JSON
|
||||
FieldTypeJSON FieldType = 23
|
||||
// FieldTypeBinaryVector field type binary vector
|
||||
FieldTypeBinaryVector FieldType = 100
|
||||
// FieldTypeFloatVector field type float vector
|
||||
FieldTypeFloatVector FieldType = 101
|
||||
// FieldTypeBinaryVector field type float16 vector
|
||||
FieldTypeFloat16Vector FieldType = 102
|
||||
// FieldTypeBinaryVector field type bf16 vector
|
||||
FieldTypeBFloat16Vector FieldType = 103
|
||||
// FieldTypeBinaryVector field type sparse vector
|
||||
FieldTypeSparseVector FieldType = 104
|
||||
)
|
||||
341
client/entity/schema.go
Normal file
341
client/entity/schema.go
Normal file
@ -0,0 +1,341 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package entity
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
const (
|
||||
// TypeParamDim is the const for field type param dimension
|
||||
TypeParamDim = "dim"
|
||||
|
||||
// TypeParamMaxLength is the const for varchar type maximal length
|
||||
TypeParamMaxLength = "max_length"
|
||||
|
||||
// TypeParamMaxCapacity is the const for array type max capacity
|
||||
TypeParamMaxCapacity = `max_capacity`
|
||||
|
||||
// ClStrong strong consistency level
|
||||
ClStrong ConsistencyLevel = ConsistencyLevel(commonpb.ConsistencyLevel_Strong)
|
||||
// ClBounded bounded consistency level with default tolerance of 5 seconds
|
||||
ClBounded ConsistencyLevel = ConsistencyLevel(commonpb.ConsistencyLevel_Bounded)
|
||||
// ClSession session consistency level
|
||||
ClSession ConsistencyLevel = ConsistencyLevel(commonpb.ConsistencyLevel_Session)
|
||||
// ClEvenually eventually consistency level
|
||||
ClEventually ConsistencyLevel = ConsistencyLevel(commonpb.ConsistencyLevel_Eventually)
|
||||
// ClCustomized customized consistency level and users pass their own `guarantee_timestamp`.
|
||||
ClCustomized ConsistencyLevel = ConsistencyLevel(commonpb.ConsistencyLevel_Customized)
|
||||
)
|
||||
|
||||
// ConsistencyLevel enum type for collection Consistency Level
|
||||
type ConsistencyLevel commonpb.ConsistencyLevel
|
||||
|
||||
// CommonConsistencyLevel returns corresponding commonpb.ConsistencyLevel
|
||||
func (cl ConsistencyLevel) CommonConsistencyLevel() commonpb.ConsistencyLevel {
|
||||
return commonpb.ConsistencyLevel(cl)
|
||||
}
|
||||
|
||||
// Schema represents schema info of collection in milvus
|
||||
type Schema struct {
|
||||
CollectionName string
|
||||
Description string
|
||||
AutoID bool
|
||||
Fields []*Field
|
||||
EnableDynamicField bool
|
||||
}
|
||||
|
||||
// NewSchema creates an empty schema object.
|
||||
func NewSchema() *Schema {
|
||||
return &Schema{}
|
||||
}
|
||||
|
||||
// WithName sets the name value of schema, returns schema itself.
|
||||
func (s *Schema) WithName(name string) *Schema {
|
||||
s.CollectionName = name
|
||||
return s
|
||||
}
|
||||
|
||||
// WithDescription sets the description value of schema, returns schema itself.
|
||||
func (s *Schema) WithDescription(desc string) *Schema {
|
||||
s.Description = desc
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Schema) WithAutoID(autoID bool) *Schema {
|
||||
s.AutoID = autoID
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Schema) WithDynamicFieldEnabled(dynamicEnabled bool) *Schema {
|
||||
s.EnableDynamicField = dynamicEnabled
|
||||
return s
|
||||
}
|
||||
|
||||
// WithField adds a field into schema and returns schema itself.
|
||||
func (s *Schema) WithField(f *Field) *Schema {
|
||||
s.Fields = append(s.Fields, f)
|
||||
return s
|
||||
}
|
||||
|
||||
// ProtoMessage returns corresponding server.CollectionSchema
|
||||
func (s *Schema) ProtoMessage() *schemapb.CollectionSchema {
|
||||
r := &schemapb.CollectionSchema{
|
||||
Name: s.CollectionName,
|
||||
Description: s.Description,
|
||||
AutoID: s.AutoID,
|
||||
EnableDynamicField: s.EnableDynamicField,
|
||||
}
|
||||
r.Fields = make([]*schemapb.FieldSchema, 0, len(s.Fields))
|
||||
for _, field := range s.Fields {
|
||||
r.Fields = append(r.Fields, field.ProtoMessage())
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// ReadProto parses proto Collection Schema
|
||||
func (s *Schema) ReadProto(p *schemapb.CollectionSchema) *Schema {
|
||||
s.Description = p.GetDescription()
|
||||
s.CollectionName = p.GetName()
|
||||
s.Fields = make([]*Field, 0, len(p.GetFields()))
|
||||
for _, fp := range p.GetFields() {
|
||||
if fp.GetAutoID() {
|
||||
s.AutoID = true
|
||||
}
|
||||
s.Fields = append(s.Fields, NewField().ReadProto(fp))
|
||||
}
|
||||
s.EnableDynamicField = p.GetEnableDynamicField()
|
||||
return s
|
||||
}
|
||||
|
||||
// PKFieldName returns pk field name for this schemapb.
|
||||
func (s *Schema) PKFieldName() string {
|
||||
for _, field := range s.Fields {
|
||||
if field.PrimaryKey {
|
||||
return field.Name
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Field represent field schema in milvus
|
||||
type Field struct {
|
||||
ID int64 // field id, generated when collection is created, input value is ignored
|
||||
Name string // field name
|
||||
PrimaryKey bool // is primary key
|
||||
AutoID bool // is auto id
|
||||
Description string
|
||||
DataType FieldType
|
||||
TypeParams map[string]string
|
||||
IndexParams map[string]string
|
||||
IsDynamic bool
|
||||
IsPartitionKey bool
|
||||
ElementType FieldType
|
||||
}
|
||||
|
||||
// ProtoMessage generates corresponding FieldSchema
|
||||
func (f *Field) ProtoMessage() *schemapb.FieldSchema {
|
||||
return &schemapb.FieldSchema{
|
||||
FieldID: f.ID,
|
||||
Name: f.Name,
|
||||
Description: f.Description,
|
||||
IsPrimaryKey: f.PrimaryKey,
|
||||
AutoID: f.AutoID,
|
||||
DataType: schemapb.DataType(f.DataType),
|
||||
TypeParams: MapKvPairs(f.TypeParams),
|
||||
IndexParams: MapKvPairs(f.IndexParams),
|
||||
IsDynamic: f.IsDynamic,
|
||||
IsPartitionKey: f.IsPartitionKey,
|
||||
ElementType: schemapb.DataType(f.ElementType),
|
||||
}
|
||||
}
|
||||
|
||||
// NewField creates a new Field with map initialized.
|
||||
func NewField() *Field {
|
||||
return &Field{
|
||||
TypeParams: make(map[string]string),
|
||||
IndexParams: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Field) WithName(name string) *Field {
|
||||
f.Name = name
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithDescription(desc string) *Field {
|
||||
f.Description = desc
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithDataType(dataType FieldType) *Field {
|
||||
f.DataType = dataType
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithIsPrimaryKey(isPrimaryKey bool) *Field {
|
||||
f.PrimaryKey = isPrimaryKey
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithIsAutoID(isAutoID bool) *Field {
|
||||
f.AutoID = isAutoID
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithIsDynamic(isDynamic bool) *Field {
|
||||
f.IsDynamic = isDynamic
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithIsPartitionKey(isPartitionKey bool) *Field {
|
||||
f.IsPartitionKey = isPartitionKey
|
||||
return f
|
||||
}
|
||||
|
||||
/*
|
||||
func (f *Field) WithDefaultValueBool(defaultValue bool) *Field {
|
||||
f.DefaultValue = &schemapb.ValueField{
|
||||
Data: &schemapb.ValueField_BoolData{
|
||||
BoolData: defaultValue,
|
||||
},
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithDefaultValueInt(defaultValue int32) *Field {
|
||||
f.DefaultValue = &schemapb.ValueField{
|
||||
Data: &schemapb.ValueField_IntData{
|
||||
IntData: defaultValue,
|
||||
},
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithDefaultValueLong(defaultValue int64) *Field {
|
||||
f.DefaultValue = &schemapb.ValueField{
|
||||
Data: &schemapb.ValueField_LongData{
|
||||
LongData: defaultValue,
|
||||
},
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithDefaultValueFloat(defaultValue float32) *Field {
|
||||
f.DefaultValue = &schemapb.ValueField{
|
||||
Data: &schemapb.ValueField_FloatData{
|
||||
FloatData: defaultValue,
|
||||
},
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithDefaultValueDouble(defaultValue float64) *Field {
|
||||
f.DefaultValue = &schemapb.ValueField{
|
||||
Data: &schemapb.ValueField_DoubleData{
|
||||
DoubleData: defaultValue,
|
||||
},
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithDefaultValueString(defaultValue string) *Field {
|
||||
f.DefaultValue = &schemapb.ValueField{
|
||||
Data: &schemapb.ValueField_StringData{
|
||||
StringData: defaultValue,
|
||||
},
|
||||
}
|
||||
return f
|
||||
}*/
|
||||
|
||||
func (f *Field) WithTypeParams(key string, value string) *Field {
|
||||
if f.TypeParams == nil {
|
||||
f.TypeParams = make(map[string]string)
|
||||
}
|
||||
f.TypeParams[key] = value
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithDim(dim int64) *Field {
|
||||
if f.TypeParams == nil {
|
||||
f.TypeParams = make(map[string]string)
|
||||
}
|
||||
f.TypeParams[TypeParamDim] = strconv.FormatInt(dim, 10)
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithMaxLength(maxLen int64) *Field {
|
||||
if f.TypeParams == nil {
|
||||
f.TypeParams = make(map[string]string)
|
||||
}
|
||||
f.TypeParams[TypeParamMaxLength] = strconv.FormatInt(maxLen, 10)
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithElementType(eleType FieldType) *Field {
|
||||
f.ElementType = eleType
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *Field) WithMaxCapacity(maxCap int64) *Field {
|
||||
if f.TypeParams == nil {
|
||||
f.TypeParams = make(map[string]string)
|
||||
}
|
||||
f.TypeParams[TypeParamMaxCapacity] = strconv.FormatInt(maxCap, 10)
|
||||
return f
|
||||
}
|
||||
|
||||
// ReadProto parses FieldSchema
|
||||
func (f *Field) ReadProto(p *schemapb.FieldSchema) *Field {
|
||||
f.ID = p.GetFieldID()
|
||||
f.Name = p.GetName()
|
||||
f.PrimaryKey = p.GetIsPrimaryKey()
|
||||
f.AutoID = p.GetAutoID()
|
||||
f.Description = p.GetDescription()
|
||||
f.DataType = FieldType(p.GetDataType())
|
||||
f.TypeParams = KvPairsMap(p.GetTypeParams())
|
||||
f.IndexParams = KvPairsMap(p.GetIndexParams())
|
||||
f.IsDynamic = p.GetIsDynamic()
|
||||
f.IsPartitionKey = p.GetIsPartitionKey()
|
||||
f.ElementType = FieldType(p.GetElementType())
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
// MapKvPairs converts map into commonpb.KeyValuePair slice
|
||||
func MapKvPairs(m map[string]string) []*commonpb.KeyValuePair {
|
||||
pairs := make([]*commonpb.KeyValuePair, 0, len(m))
|
||||
for k, v := range m {
|
||||
pairs = append(pairs, &commonpb.KeyValuePair{
|
||||
Key: k,
|
||||
Value: v,
|
||||
})
|
||||
}
|
||||
return pairs
|
||||
}
|
||||
|
||||
// KvPairsMap converts commonpb.KeyValuePair slices into map
|
||||
func KvPairsMap(kvps []*commonpb.KeyValuePair) map[string]string {
|
||||
m := make(map[string]string)
|
||||
for _, kvp := range kvps {
|
||||
m[kvp.Key] = kvp.Value
|
||||
}
|
||||
return m
|
||||
}
|
||||
138
client/entity/schema_test.go
Normal file
138
client/entity/schema_test.go
Normal file
@ -0,0 +1,138 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package entity
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
)
|
||||
|
||||
func TestCL_CommonCL(t *testing.T) {
|
||||
cls := []ConsistencyLevel{
|
||||
ClStrong,
|
||||
ClBounded,
|
||||
ClSession,
|
||||
ClEventually,
|
||||
}
|
||||
for _, cl := range cls {
|
||||
assert.EqualValues(t, commonpb.ConsistencyLevel(cl), cl.CommonConsistencyLevel())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldSchema(t *testing.T) {
|
||||
fields := []*Field{
|
||||
NewField().WithName("int_field").WithDataType(FieldTypeInt64).WithIsAutoID(true).WithIsPrimaryKey(true).WithDescription("int_field desc"),
|
||||
NewField().WithName("string_field").WithDataType(FieldTypeString).WithIsAutoID(false).WithIsPrimaryKey(true).WithIsDynamic(false).WithTypeParams("max_len", "32").WithDescription("string_field desc"),
|
||||
NewField().WithName("partition_key").WithDataType(FieldTypeInt32).WithIsPartitionKey(true),
|
||||
NewField().WithName("array_field").WithDataType(FieldTypeArray).WithElementType(FieldTypeBool).WithMaxCapacity(128),
|
||||
/*
|
||||
NewField().WithName("default_value_bool").WithDataType(FieldTypeBool).WithDefaultValueBool(true),
|
||||
NewField().WithName("default_value_int").WithDataType(FieldTypeInt32).WithDefaultValueInt(1),
|
||||
NewField().WithName("default_value_long").WithDataType(FieldTypeInt64).WithDefaultValueLong(1),
|
||||
NewField().WithName("default_value_float").WithDataType(FieldTypeFloat).WithDefaultValueFloat(1),
|
||||
NewField().WithName("default_value_double").WithDataType(FieldTypeDouble).WithDefaultValueDouble(1),
|
||||
NewField().WithName("default_value_string").WithDataType(FieldTypeString).WithDefaultValueString("a"),*/
|
||||
}
|
||||
|
||||
for _, field := range fields {
|
||||
fieldSchema := field.ProtoMessage()
|
||||
assert.Equal(t, field.ID, fieldSchema.GetFieldID())
|
||||
assert.Equal(t, field.Name, fieldSchema.GetName())
|
||||
assert.EqualValues(t, field.DataType, fieldSchema.GetDataType())
|
||||
assert.Equal(t, field.AutoID, fieldSchema.GetAutoID())
|
||||
assert.Equal(t, field.PrimaryKey, fieldSchema.GetIsPrimaryKey())
|
||||
assert.Equal(t, field.IsPartitionKey, fieldSchema.GetIsPartitionKey())
|
||||
assert.Equal(t, field.IsDynamic, fieldSchema.GetIsDynamic())
|
||||
assert.Equal(t, field.Description, fieldSchema.GetDescription())
|
||||
assert.Equal(t, field.TypeParams, KvPairsMap(fieldSchema.GetTypeParams()))
|
||||
assert.EqualValues(t, field.ElementType, fieldSchema.GetElementType())
|
||||
// marshal & unmarshal, still equals
|
||||
nf := &Field{}
|
||||
nf = nf.ReadProto(fieldSchema)
|
||||
assert.Equal(t, field.ID, nf.ID)
|
||||
assert.Equal(t, field.Name, nf.Name)
|
||||
assert.EqualValues(t, field.DataType, nf.DataType)
|
||||
assert.Equal(t, field.AutoID, nf.AutoID)
|
||||
assert.Equal(t, field.PrimaryKey, nf.PrimaryKey)
|
||||
assert.Equal(t, field.Description, nf.Description)
|
||||
assert.Equal(t, field.IsDynamic, nf.IsDynamic)
|
||||
assert.Equal(t, field.IsPartitionKey, nf.IsPartitionKey)
|
||||
assert.EqualValues(t, field.TypeParams, nf.TypeParams)
|
||||
assert.EqualValues(t, field.ElementType, nf.ElementType)
|
||||
}
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
(&Field{}).WithTypeParams("a", "b")
|
||||
})
|
||||
}
|
||||
|
||||
type SchemaSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func (s *SchemaSuite) TestBasic() {
|
||||
cases := []struct {
|
||||
tag string
|
||||
input *Schema
|
||||
pkName string
|
||||
}{
|
||||
{
|
||||
"test_collection",
|
||||
NewSchema().WithName("test_collection_1").WithDescription("test_collection_1 desc").WithAutoID(false).
|
||||
WithField(NewField().WithName("ID").WithDataType(FieldTypeInt64).WithIsPrimaryKey(true)).
|
||||
WithField(NewField().WithName("vector").WithDataType(FieldTypeFloatVector).WithDim(128)),
|
||||
"ID",
|
||||
},
|
||||
{
|
||||
"dynamic_schema",
|
||||
NewSchema().WithName("dynamic_schema").WithDescription("dynamic_schema desc").WithAutoID(true).WithDynamicFieldEnabled(true).
|
||||
WithField(NewField().WithName("ID").WithDataType(FieldTypeVarChar).WithMaxLength(256)).
|
||||
WithField(NewField().WithName("$meta").WithIsDynamic(true)),
|
||||
"",
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
s.Run(c.tag, func() {
|
||||
sch := c.input
|
||||
p := sch.ProtoMessage()
|
||||
s.Equal(sch.CollectionName, p.GetName())
|
||||
s.Equal(sch.AutoID, p.GetAutoID())
|
||||
s.Equal(sch.Description, p.GetDescription())
|
||||
s.Equal(sch.EnableDynamicField, p.GetEnableDynamicField())
|
||||
s.Equal(len(sch.Fields), len(p.GetFields()))
|
||||
|
||||
nsch := &Schema{}
|
||||
nsch = nsch.ReadProto(p)
|
||||
|
||||
s.Equal(sch.CollectionName, nsch.CollectionName)
|
||||
s.Equal(sch.Description, nsch.Description)
|
||||
s.Equal(sch.EnableDynamicField, nsch.EnableDynamicField)
|
||||
s.Equal(len(sch.Fields), len(nsch.Fields))
|
||||
s.Equal(c.pkName, sch.PKFieldName())
|
||||
s.Equal(c.pkName, nsch.PKFieldName())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchema(t *testing.T) {
|
||||
suite.Run(t, new(SchemaSuite))
|
||||
}
|
||||
124
client/entity/sparse.go
Normal file
124
client/entity/sparse.go
Normal file
@ -0,0 +1,124 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package entity
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math"
|
||||
"sort"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
)
|
||||
|
||||
type SparseEmbedding interface {
|
||||
Dim() int // the dimension
|
||||
Len() int // the actual items in this vector
|
||||
Get(idx int) (pos uint32, value float32, ok bool)
|
||||
Serialize() []byte
|
||||
}
|
||||
|
||||
var (
|
||||
_ SparseEmbedding = sliceSparseEmbedding{}
|
||||
_ Vector = sliceSparseEmbedding{}
|
||||
)
|
||||
|
||||
type sliceSparseEmbedding struct {
|
||||
positions []uint32
|
||||
values []float32
|
||||
dim int
|
||||
len int
|
||||
}
|
||||
|
||||
func (e sliceSparseEmbedding) Dim() int {
|
||||
return e.dim
|
||||
}
|
||||
|
||||
func (e sliceSparseEmbedding) Len() int {
|
||||
return e.len
|
||||
}
|
||||
|
||||
func (e sliceSparseEmbedding) FieldType() FieldType {
|
||||
return FieldTypeSparseVector
|
||||
}
|
||||
|
||||
func (e sliceSparseEmbedding) Get(idx int) (uint32, float32, bool) {
|
||||
if idx < 0 || idx >= int(e.len) {
|
||||
return 0, 0, false
|
||||
}
|
||||
return e.positions[idx], e.values[idx], true
|
||||
}
|
||||
|
||||
func (e sliceSparseEmbedding) Serialize() []byte {
|
||||
row := make([]byte, 8*e.Len())
|
||||
for idx := 0; idx < e.Len(); idx++ {
|
||||
pos, value, _ := e.Get(idx)
|
||||
binary.LittleEndian.PutUint32(row[idx*8:], pos)
|
||||
binary.LittleEndian.PutUint32(row[pos*8+4:], math.Float32bits(value))
|
||||
}
|
||||
return row
|
||||
}
|
||||
|
||||
// Less implements sort.Interce
|
||||
func (e sliceSparseEmbedding) Less(i, j int) bool {
|
||||
return e.positions[i] < e.positions[j]
|
||||
}
|
||||
|
||||
func (e sliceSparseEmbedding) Swap(i, j int) {
|
||||
e.positions[i], e.positions[j] = e.positions[j], e.positions[i]
|
||||
e.values[i], e.values[j] = e.values[j], e.values[i]
|
||||
}
|
||||
|
||||
func deserializeSliceSparceEmbedding(bs []byte) (sliceSparseEmbedding, error) {
|
||||
length := len(bs)
|
||||
if length%8 != 0 {
|
||||
return sliceSparseEmbedding{}, errors.New("not valid sparse embedding bytes")
|
||||
}
|
||||
|
||||
length = length / 8
|
||||
|
||||
result := sliceSparseEmbedding{
|
||||
positions: make([]uint32, length),
|
||||
values: make([]float32, length),
|
||||
len: length,
|
||||
}
|
||||
|
||||
for i := 0; i < length; i++ {
|
||||
result.positions[i] = binary.LittleEndian.Uint32(bs[i*8 : i*8+4])
|
||||
result.values[i] = math.Float32frombits(binary.LittleEndian.Uint32(bs[i*8+4 : i*8+8]))
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func NewSliceSparseEmbedding(positions []uint32, values []float32) (SparseEmbedding, error) {
|
||||
if len(positions) != len(values) {
|
||||
return nil, errors.New("invalid sparse embedding input, positions shall have same number of values")
|
||||
}
|
||||
|
||||
se := sliceSparseEmbedding{
|
||||
positions: positions,
|
||||
values: values,
|
||||
len: len(positions),
|
||||
}
|
||||
|
||||
sort.Sort(se)
|
||||
|
||||
if se.len > 0 {
|
||||
se.dim = int(se.positions[se.len-1]) + 1
|
||||
}
|
||||
|
||||
return se, nil
|
||||
}
|
||||
68
client/entity/sparse_test.go
Normal file
68
client/entity/sparse_test.go
Normal file
@ -0,0 +1,68 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package entity
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSliceSparseEmbedding(t *testing.T) {
|
||||
t.Run("normal_case", func(t *testing.T) {
|
||||
length := 1 + rand.Intn(5)
|
||||
positions := make([]uint32, length)
|
||||
values := make([]float32, length)
|
||||
for i := 0; i < length; i++ {
|
||||
positions[i] = uint32(i)
|
||||
values[i] = rand.Float32()
|
||||
}
|
||||
se, err := NewSliceSparseEmbedding(positions, values)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.EqualValues(t, length, se.Dim())
|
||||
assert.EqualValues(t, length, se.Len())
|
||||
|
||||
bs := se.Serialize()
|
||||
nv, err := deserializeSliceSparceEmbedding(bs)
|
||||
require.NoError(t, err)
|
||||
|
||||
for i := 0; i < length; i++ {
|
||||
pos, val, ok := se.Get(i)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, positions[i], pos)
|
||||
assert.Equal(t, values[i], val)
|
||||
|
||||
npos, nval, ok := nv.Get(i)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, positions[i], npos)
|
||||
assert.Equal(t, values[i], nval)
|
||||
}
|
||||
|
||||
_, _, ok := se.Get(-1)
|
||||
assert.False(t, ok)
|
||||
_, _, ok = se.Get(length)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("position values not match", func(t *testing.T) {
|
||||
_, err := NewSliceSparseEmbedding([]uint32{1}, []float32{})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
106
client/entity/vectors.go
Normal file
106
client/entity/vectors.go
Normal file
@ -0,0 +1,106 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package entity
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math"
|
||||
)
|
||||
|
||||
// Vector interface vector used int search
|
||||
type Vector interface {
|
||||
Dim() int
|
||||
Serialize() []byte
|
||||
FieldType() FieldType
|
||||
}
|
||||
|
||||
// FloatVector float32 vector wrapper.
|
||||
type FloatVector []float32
|
||||
|
||||
// Dim returns vector dimension.
|
||||
func (fv FloatVector) Dim() int {
|
||||
return len(fv)
|
||||
}
|
||||
|
||||
// entity.FieldType returns coresponding field type.
|
||||
func (fv FloatVector) FieldType() FieldType {
|
||||
return FieldTypeFloatVector
|
||||
}
|
||||
|
||||
// Serialize serializes vector into byte slice, used in search placeholder
|
||||
// LittleEndian is used for convention
|
||||
func (fv FloatVector) Serialize() []byte {
|
||||
data := make([]byte, 0, 4*len(fv)) // float32 occupies 4 bytes
|
||||
buf := make([]byte, 4)
|
||||
for _, f := range fv {
|
||||
binary.LittleEndian.PutUint32(buf, math.Float32bits(f))
|
||||
data = append(data, buf...)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// FloatVector float32 vector wrapper.
|
||||
type Float16Vector []byte
|
||||
|
||||
// Dim returns vector dimension.
|
||||
func (fv Float16Vector) Dim() int {
|
||||
return len(fv) / 2
|
||||
}
|
||||
|
||||
// entity.FieldType returns coresponding field type.
|
||||
func (fv Float16Vector) FieldType() FieldType {
|
||||
return FieldTypeFloat16Vector
|
||||
}
|
||||
|
||||
func (fv Float16Vector) Serialize() []byte {
|
||||
return fv
|
||||
}
|
||||
|
||||
// FloatVector float32 vector wrapper.
|
||||
type BFloat16Vector []byte
|
||||
|
||||
// Dim returns vector dimension.
|
||||
func (fv BFloat16Vector) Dim() int {
|
||||
return len(fv) / 2
|
||||
}
|
||||
|
||||
// entity.FieldType returns coresponding field type.
|
||||
func (fv BFloat16Vector) FieldType() FieldType {
|
||||
return FieldTypeBFloat16Vector
|
||||
}
|
||||
|
||||
func (fv BFloat16Vector) Serialize() []byte {
|
||||
return fv
|
||||
}
|
||||
|
||||
// BinaryVector []byte vector wrapper
|
||||
type BinaryVector []byte
|
||||
|
||||
// Dim return vector dimension, note that binary vector is bits count
|
||||
func (bv BinaryVector) Dim() int {
|
||||
return 8 * len(bv)
|
||||
}
|
||||
|
||||
// Serialize just return bytes
|
||||
func (bv BinaryVector) Serialize() []byte {
|
||||
return bv
|
||||
}
|
||||
|
||||
// entity.FieldType returns coresponding field type.
|
||||
func (bv BinaryVector) FieldType() FieldType {
|
||||
return FieldTypeBinaryVector
|
||||
}
|
||||
51
client/entity/vectors_test.go
Normal file
51
client/entity/vectors_test.go
Normal file
@ -0,0 +1,51 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package entity
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestVectors(t *testing.T) {
|
||||
dim := rand.Intn(127) + 1
|
||||
|
||||
t.Run("test float vector", func(t *testing.T) {
|
||||
raw := make([]float32, dim)
|
||||
for i := 0; i < dim; i++ {
|
||||
raw[i] = rand.Float32()
|
||||
}
|
||||
|
||||
fv := FloatVector(raw)
|
||||
|
||||
assert.Equal(t, dim, fv.Dim())
|
||||
assert.Equal(t, dim*4, len(fv.Serialize()))
|
||||
})
|
||||
|
||||
t.Run("test binary vector", func(t *testing.T) {
|
||||
raw := make([]byte, dim)
|
||||
_, err := rand.Read(raw)
|
||||
assert.Nil(t, err)
|
||||
|
||||
bv := BinaryVector(raw)
|
||||
|
||||
assert.Equal(t, dim*8, bv.Dim())
|
||||
assert.ElementsMatch(t, raw, bv.Serialize())
|
||||
})
|
||||
}
|
||||
125
client/go.mod
Normal file
125
client/go.mod
Normal file
@ -0,0 +1,125 @@
|
||||
module github.com/milvus-io/milvus/client/v2
|
||||
|
||||
go 1.21.8
|
||||
|
||||
require (
|
||||
github.com/blang/semver/v4 v4.0.0
|
||||
github.com/cockroachdb/errors v1.9.1
|
||||
github.com/gogo/status v1.1.0
|
||||
github.com/golang/protobuf v1.5.3
|
||||
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240228061649-a922b16f2a46
|
||||
github.com/milvus-io/milvus/pkg v0.0.2-0.20240317152703-17b4938985f3
|
||||
github.com/samber/lo v1.27.0
|
||||
github.com/stretchr/testify v1.8.4
|
||||
github.com/tidwall/gjson v1.17.1
|
||||
go.uber.org/atomic v1.10.0
|
||||
google.golang.org/grpc v1.54.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/benbjohnson/clock v1.1.0 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.2.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||
github.com/cilium/ebpf v0.11.0 // indirect
|
||||
github.com/cockroachdb/logtags v0.0.0-20211118104740-dabe8e521a4f // indirect
|
||||
github.com/cockroachdb/redact v1.1.3 // indirect
|
||||
github.com/containerd/cgroups/v3 v3.0.3 // indirect
|
||||
github.com/coreos/go-semver v0.3.0 // indirect
|
||||
github.com/coreos/go-systemd/v22 v22.3.2 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/docker/go-units v0.4.0 // indirect
|
||||
github.com/dustin/go-humanize v1.0.0 // indirect
|
||||
github.com/form3tech-oss/jwt-go v3.2.3+incompatible // indirect
|
||||
github.com/fsnotify/fsnotify v1.4.9 // indirect
|
||||
github.com/getsentry/sentry-go v0.12.0 // indirect
|
||||
github.com/go-logr/logr v1.3.0 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/godbus/dbus/v5 v5.0.4 // indirect
|
||||
github.com/gogo/googleapis v1.4.1 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/btree v1.1.2 // indirect
|
||||
github.com/gorilla/websocket v1.4.2 // indirect
|
||||
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/jonboulle/clockwork v0.2.2 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/kr/pretty v0.3.1 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||
github.com/magiconair/properties v1.8.5 // indirect
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
|
||||
github.com/mitchellh/mapstructure v1.4.1 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/opencontainers/runtime-spec v1.0.2 // indirect
|
||||
github.com/panjf2000/ants/v2 v2.7.2 // indirect
|
||||
github.com/pelletier/go-toml v1.9.3 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||
github.com/prometheus/client_golang v1.14.0 // indirect
|
||||
github.com/prometheus/client_model v0.3.0 // indirect
|
||||
github.com/prometheus/common v0.42.0 // indirect
|
||||
github.com/prometheus/procfs v0.9.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.10.0 // indirect
|
||||
github.com/shirou/gopsutil/v3 v3.22.9 // indirect
|
||||
github.com/sirupsen/logrus v1.9.0 // indirect
|
||||
github.com/soheilhy/cmux v0.1.5 // indirect
|
||||
github.com/spaolacci/murmur3 v1.1.0 // indirect
|
||||
github.com/spf13/afero v1.6.0 // indirect
|
||||
github.com/spf13/cast v1.3.1 // indirect
|
||||
github.com/spf13/jwalterweatherman v1.1.0 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/spf13/viper v1.8.1 // indirect
|
||||
github.com/stretchr/objx v0.5.0 // indirect
|
||||
github.com/subosito/gotenv v1.2.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.10 // indirect
|
||||
github.com/tklauser/numcpus v0.4.0 // indirect
|
||||
github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802 // indirect
|
||||
github.com/uber/jaeger-client-go v2.30.0+incompatible // indirect
|
||||
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.2 // indirect
|
||||
go.etcd.io/bbolt v1.3.6 // indirect
|
||||
go.etcd.io/etcd/api/v3 v3.5.5 // indirect
|
||||
go.etcd.io/etcd/client/pkg/v3 v3.5.5 // indirect
|
||||
go.etcd.io/etcd/client/v2 v2.305.5 // indirect
|
||||
go.etcd.io/etcd/client/v3 v3.5.5 // indirect
|
||||
go.etcd.io/etcd/pkg/v3 v3.5.5 // indirect
|
||||
go.etcd.io/etcd/raft/v3 v3.5.5 // indirect
|
||||
go.etcd.io/etcd/server/v3 v3.5.5 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.38.0 // indirect
|
||||
go.opentelemetry.io/otel v1.13.0 // indirect
|
||||
go.opentelemetry.io/otel/exporters/otlp/internal/retry v1.13.0 // indirect
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.13.0 // indirect
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.13.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v0.35.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.13.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.13.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v0.19.0 // indirect
|
||||
go.uber.org/automaxprocs v1.5.2 // indirect
|
||||
go.uber.org/multierr v1.7.0 // indirect
|
||||
go.uber.org/zap v1.20.0 // indirect
|
||||
golang.org/x/crypto v0.16.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2 // indirect
|
||||
golang.org/x/net v0.19.0 // indirect
|
||||
golang.org/x/sync v0.1.0 // indirect
|
||||
golang.org/x/sys v0.15.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
golang.org/x/time v0.3.0 // indirect
|
||||
google.golang.org/genproto v0.0.0-20230331144136-dcfb400f0633 // indirect
|
||||
google.golang.org/protobuf v1.31.0 // indirect
|
||||
gopkg.in/inf.v0 v0.9.1 // indirect
|
||||
gopkg.in/ini.v1 v1.62.0 // indirect
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
k8s.io/apimachinery v0.28.6 // indirect
|
||||
sigs.k8s.io/yaml v1.3.0 // indirect
|
||||
)
|
||||
1121
client/go.sum
Normal file
1121
client/go.sum
Normal file
File diff suppressed because it is too large
Load Diff
159
client/index.go
Normal file
159
client/index.go
Normal file
@ -0,0 +1,159 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
"github.com/milvus-io/milvus/client/v2/index"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
type CreateIndexTask struct {
|
||||
client *Client
|
||||
collectionName string
|
||||
fieldName string
|
||||
indexName string
|
||||
interval time.Duration
|
||||
}
|
||||
|
||||
func (t *CreateIndexTask) Await(ctx context.Context) error {
|
||||
ticker := time.NewTicker(t.interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
finished := false
|
||||
err := t.client.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.DescribeIndex(ctx, &milvuspb.DescribeIndexRequest{
|
||||
CollectionName: t.collectionName,
|
||||
FieldName: t.fieldName,
|
||||
IndexName: t.indexName,
|
||||
})
|
||||
err = merr.CheckRPCCall(resp, err)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, info := range resp.GetIndexDescriptions() {
|
||||
if (t.indexName == "" && info.GetFieldName() == t.fieldName) || t.indexName == info.GetIndexName() {
|
||||
switch info.GetState() {
|
||||
case commonpb.IndexState_Finished:
|
||||
finished = true
|
||||
return nil
|
||||
case commonpb.IndexState_Failed:
|
||||
return fmt.Errorf("create index failed, reason: %s", info.GetIndexStateFailReason())
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if finished {
|
||||
return nil
|
||||
}
|
||||
ticker.Reset(t.interval)
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) CreateIndex(ctx context.Context, option CreateIndexOption, callOptions ...grpc.CallOption) (*CreateIndexTask, error) {
|
||||
req := option.Request()
|
||||
var task *CreateIndexTask
|
||||
|
||||
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.CreateIndex(ctx, req, callOptions...)
|
||||
if err = merr.CheckRPCCall(resp, err); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
task = &CreateIndexTask{
|
||||
client: c,
|
||||
collectionName: req.GetCollectionName(),
|
||||
fieldName: req.GetFieldName(),
|
||||
indexName: req.GetIndexName(),
|
||||
interval: time.Millisecond * 100,
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return task, err
|
||||
}
|
||||
|
||||
func (c *Client) ListIndexes(ctx context.Context, opt ListIndexOption, callOptions ...grpc.CallOption) ([]string, error) {
|
||||
req := opt.Request()
|
||||
|
||||
var indexes []string
|
||||
|
||||
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.DescribeIndex(ctx, req, callOptions...)
|
||||
if err = merr.CheckRPCCall(resp, err); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, idxDef := range resp.GetIndexDescriptions() {
|
||||
if opt.Matches(idxDef) {
|
||||
indexes = append(indexes, idxDef.GetIndexName())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return indexes, err
|
||||
}
|
||||
|
||||
func (c *Client) DescribeIndex(ctx context.Context, opt DescribeIndexOption, callOptions ...grpc.CallOption) (index.Index, error) {
|
||||
req := opt.Request()
|
||||
var idx index.Index
|
||||
|
||||
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.DescribeIndex(ctx, req, callOptions...)
|
||||
if err = merr.CheckRPCCall(resp, err); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(resp.GetIndexDescriptions()) == 0 {
|
||||
return merr.WrapErrIndexNotFound(req.GetIndexName())
|
||||
}
|
||||
for _, idxDef := range resp.GetIndexDescriptions() {
|
||||
if idxDef.GetIndexName() == req.GetIndexName() {
|
||||
idx = index.NewGenericIndex(idxDef.GetIndexName(), entity.KvPairsMap(idxDef.GetParams()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return idx, err
|
||||
}
|
||||
|
||||
func (c *Client) DropIndex(ctx context.Context, opt DropIndexOption, callOptions ...grpc.CallOption) error {
|
||||
req := opt.Request()
|
||||
|
||||
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.DropIndex(ctx, req, callOptions...)
|
||||
return merr.CheckRPCCall(resp, err)
|
||||
})
|
||||
}
|
||||
61
client/index/common.go
Normal file
61
client/index/common.go
Normal file
@ -0,0 +1,61 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package index
|
||||
|
||||
import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
)
|
||||
|
||||
// index param field tag
|
||||
const (
|
||||
IndexTypeKey = `index_type`
|
||||
MetricTypeKey = `metric_type`
|
||||
ParamsKey = `params`
|
||||
)
|
||||
|
||||
// IndexState export index state
|
||||
type IndexState commonpb.IndexState
|
||||
|
||||
// IndexType index type
|
||||
type IndexType string
|
||||
|
||||
// MetricType alias for `entity.MetricsType`.
|
||||
type MetricType = entity.MetricType
|
||||
|
||||
// Index Constants
|
||||
const (
|
||||
Flat IndexType = "FLAT" // faiss
|
||||
BinFlat IndexType = "BIN_FLAT"
|
||||
IvfFlat IndexType = "IVF_FLAT" // faiss
|
||||
BinIvfFlat IndexType = "BIN_IVF_FLAT"
|
||||
IvfPQ IndexType = "IVF_PQ" // faiss
|
||||
IvfSQ8 IndexType = "IVF_SQ8"
|
||||
HNSW IndexType = "HNSW"
|
||||
IvfHNSW IndexType = "IVF_HNSW"
|
||||
AUTOINDEX IndexType = "AUTOINDEX"
|
||||
DISKANN IndexType = "DISKANN"
|
||||
SCANN IndexType = "SCANN"
|
||||
|
||||
GPUIvfFlat IndexType = "GPU_IVF_FLAT"
|
||||
GPUIvfPQ IndexType = "GPU_IVF_PQ"
|
||||
|
||||
GPUCagra IndexType = "GPU_CAGRA"
|
||||
GPUBruteForce IndexType = "GPU_BRUTE_FORCE"
|
||||
|
||||
Scalar IndexType = "SCALAR"
|
||||
)
|
||||
38
client/index/disk_ann.go
Normal file
38
client/index/disk_ann.go
Normal file
@ -0,0 +1,38 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package index
|
||||
|
||||
var _ Index = diskANNIndex{}
|
||||
|
||||
type diskANNIndex struct {
|
||||
baseIndex
|
||||
}
|
||||
|
||||
func (idx diskANNIndex) Params() map[string]string {
|
||||
return map[string]string{
|
||||
MetricTypeKey: string(idx.metricType),
|
||||
IndexTypeKey: string(DISKANN),
|
||||
}
|
||||
}
|
||||
|
||||
func NewDiskANNIndex(metricType MetricType) Index {
|
||||
return &diskANNIndex{
|
||||
baseIndex: baseIndex{
|
||||
metricType: metricType,
|
||||
},
|
||||
}
|
||||
}
|
||||
38
client/index/flat.go
Normal file
38
client/index/flat.go
Normal file
@ -0,0 +1,38 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package index
|
||||
|
||||
var _ Index = flatIndex{}
|
||||
|
||||
type flatIndex struct {
|
||||
baseIndex
|
||||
}
|
||||
|
||||
func (idx flatIndex) Params() map[string]string {
|
||||
return map[string]string{
|
||||
MetricTypeKey: string(idx.metricType),
|
||||
IndexTypeKey: string(Flat),
|
||||
}
|
||||
}
|
||||
|
||||
func NewFlatIndex(metricType MetricType) Index {
|
||||
return flatIndex{
|
||||
baseIndex: baseIndex{
|
||||
metricType: metricType,
|
||||
},
|
||||
}
|
||||
}
|
||||
53
client/index/hnsw.go
Normal file
53
client/index/hnsw.go
Normal file
@ -0,0 +1,53 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package index
|
||||
|
||||
import "strconv"
|
||||
|
||||
const (
|
||||
hnswMKey = `M`
|
||||
hsnwEfConstruction = `efConstruction`
|
||||
)
|
||||
|
||||
var _ Index = hnswIndex{}
|
||||
|
||||
type hnswIndex struct {
|
||||
baseIndex
|
||||
|
||||
m int
|
||||
efConstruction int // exploratory factor when building index
|
||||
}
|
||||
|
||||
func (idx hnswIndex) Params() map[string]string {
|
||||
return map[string]string{
|
||||
MetricTypeKey: string(idx.metricType),
|
||||
IndexTypeKey: string(HNSW),
|
||||
hnswMKey: strconv.Itoa(idx.m),
|
||||
hsnwEfConstruction: strconv.Itoa(idx.efConstruction),
|
||||
}
|
||||
}
|
||||
|
||||
func NewHNSWIndex(metricType MetricType, M int, efConstruction int) Index {
|
||||
return hnswIndex{
|
||||
baseIndex: baseIndex{
|
||||
metricType: metricType,
|
||||
indexType: HNSW,
|
||||
},
|
||||
m: M,
|
||||
efConstruction: efConstruction,
|
||||
}
|
||||
}
|
||||
79
client/index/index.go
Normal file
79
client/index/index.go
Normal file
@ -0,0 +1,79 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package index
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// Index represent index definition in milvus.
|
||||
type Index interface {
|
||||
Name() string
|
||||
IndexType() IndexType
|
||||
Params() map[string]string
|
||||
}
|
||||
|
||||
type baseIndex struct {
|
||||
name string
|
||||
metricType MetricType
|
||||
indexType IndexType
|
||||
params map[string]string
|
||||
}
|
||||
|
||||
func (idx baseIndex) Name() string {
|
||||
return idx.name
|
||||
}
|
||||
|
||||
func (idx baseIndex) IndexType() IndexType {
|
||||
return idx.indexType
|
||||
}
|
||||
|
||||
func (idx baseIndex) Params() map[string]string {
|
||||
return idx.params
|
||||
}
|
||||
|
||||
func (idx baseIndex) getExtraParams(params map[string]any) string {
|
||||
bs, _ := json.Marshal(params)
|
||||
return string(bs)
|
||||
}
|
||||
|
||||
var _ Index = GenericIndex{}
|
||||
|
||||
type GenericIndex struct {
|
||||
baseIndex
|
||||
params map[string]string
|
||||
}
|
||||
|
||||
// Params implements Index
|
||||
func (gi GenericIndex) Params() map[string]string {
|
||||
m := make(map[string]string)
|
||||
if gi.baseIndex.indexType != "" {
|
||||
m[IndexTypeKey] = string(gi.IndexType())
|
||||
}
|
||||
for k, v := range gi.params {
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// NewGenericIndex create generic index instance
|
||||
func NewGenericIndex(name string, params map[string]string) Index {
|
||||
return GenericIndex{
|
||||
baseIndex: baseIndex{
|
||||
name: name,
|
||||
},
|
||||
params: params,
|
||||
}
|
||||
}
|
||||
17
client/index/index_test.go
Normal file
17
client/index/index_test.go
Normal file
@ -0,0 +1,17 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package index
|
||||
108
client/index/ivf.go
Normal file
108
client/index/ivf.go
Normal file
@ -0,0 +1,108 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package index
|
||||
|
||||
import "strconv"
|
||||
|
||||
const (
|
||||
ivfNlistKey = `nlist`
|
||||
ivfPQMKey = `m`
|
||||
ivfPQNbits = `nbits`
|
||||
)
|
||||
|
||||
var _ Index = ivfFlatIndex{}
|
||||
|
||||
type ivfFlatIndex struct {
|
||||
baseIndex
|
||||
|
||||
nlist int
|
||||
}
|
||||
|
||||
func (idx ivfFlatIndex) Params() map[string]string {
|
||||
return map[string]string{
|
||||
MetricTypeKey: string(idx.metricType),
|
||||
IndexTypeKey: string(IvfFlat),
|
||||
ivfNlistKey: strconv.Itoa(idx.nlist),
|
||||
}
|
||||
}
|
||||
|
||||
func NewIvfFlatIndex(metricType MetricType, nlist int) Index {
|
||||
return ivfFlatIndex{
|
||||
baseIndex: baseIndex{
|
||||
metricType: metricType,
|
||||
indexType: IvfFlat,
|
||||
},
|
||||
|
||||
nlist: nlist,
|
||||
}
|
||||
}
|
||||
|
||||
type ivfPQIndex struct {
|
||||
baseIndex
|
||||
|
||||
nlist int
|
||||
m int
|
||||
nbits int
|
||||
}
|
||||
|
||||
func (idx ivfPQIndex) Params() map[string]string {
|
||||
return map[string]string{
|
||||
MetricTypeKey: string(idx.metricType),
|
||||
IndexTypeKey: string(IvfPQ),
|
||||
ivfNlistKey: strconv.Itoa(idx.nlist),
|
||||
ivfPQMKey: strconv.Itoa(idx.m),
|
||||
ivfPQNbits: strconv.Itoa(idx.nbits),
|
||||
}
|
||||
}
|
||||
|
||||
func NewIvfPQIndex(metricType MetricType, nlist int, m int, nbits int) Index {
|
||||
return ivfPQIndex{
|
||||
baseIndex: baseIndex{
|
||||
metricType: metricType,
|
||||
indexType: IvfPQ,
|
||||
},
|
||||
|
||||
nlist: nlist,
|
||||
m: m,
|
||||
nbits: nbits,
|
||||
}
|
||||
}
|
||||
|
||||
type ivfSQ8Index struct {
|
||||
baseIndex
|
||||
|
||||
nlist int
|
||||
}
|
||||
|
||||
func (idx ivfSQ8Index) Params() map[string]string {
|
||||
return map[string]string{
|
||||
MetricTypeKey: string(idx.metricType),
|
||||
IndexTypeKey: string(IvfSQ8),
|
||||
ivfNlistKey: strconv.Itoa(idx.nlist),
|
||||
}
|
||||
}
|
||||
|
||||
func NewIvfSQ8Index(metricType MetricType, nlist int) Index {
|
||||
return ivfPQIndex{
|
||||
baseIndex: baseIndex{
|
||||
metricType: metricType,
|
||||
indexType: IvfSQ8,
|
||||
},
|
||||
|
||||
nlist: nlist,
|
||||
}
|
||||
}
|
||||
50
client/index/scann.go
Normal file
50
client/index/scann.go
Normal file
@ -0,0 +1,50 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package index
|
||||
|
||||
import "strconv"
|
||||
|
||||
const (
|
||||
scannNlistKey = `nlist`
|
||||
scannWithRawDataKey = `with_raw_data`
|
||||
)
|
||||
|
||||
type scannIndex struct {
|
||||
baseIndex
|
||||
|
||||
nlist int
|
||||
withRawData bool
|
||||
}
|
||||
|
||||
func (idx scannIndex) Params() map[string]string {
|
||||
return map[string]string{
|
||||
MetricTypeKey: string(idx.metricType),
|
||||
IndexTypeKey: string(IvfFlat),
|
||||
ivfNlistKey: strconv.Itoa(idx.nlist),
|
||||
}
|
||||
}
|
||||
|
||||
func NewSCANNIndex(metricType MetricType, nlist int) Index {
|
||||
return ivfFlatIndex{
|
||||
baseIndex: baseIndex{
|
||||
metricType: metricType,
|
||||
indexType: IvfFlat,
|
||||
},
|
||||
|
||||
nlist: nlist,
|
||||
}
|
||||
}
|
||||
137
client/index_options.go
Normal file
137
client/index_options.go
Normal file
@ -0,0 +1,137 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
"github.com/milvus-io/milvus/client/v2/index"
|
||||
)
|
||||
|
||||
type CreateIndexOption interface {
|
||||
Request() *milvuspb.CreateIndexRequest
|
||||
}
|
||||
|
||||
type createIndexOption struct {
|
||||
collectionName string
|
||||
fieldName string
|
||||
indexName string
|
||||
indexDef index.Index
|
||||
}
|
||||
|
||||
func (opt *createIndexOption) Request() *milvuspb.CreateIndexRequest {
|
||||
return &milvuspb.CreateIndexRequest{
|
||||
CollectionName: opt.collectionName,
|
||||
FieldName: opt.fieldName,
|
||||
IndexName: opt.indexName,
|
||||
ExtraParams: entity.MapKvPairs(opt.indexDef.Params()),
|
||||
}
|
||||
}
|
||||
|
||||
func (opt *createIndexOption) WithIndexName(indexName string) *createIndexOption {
|
||||
opt.indexName = indexName
|
||||
return opt
|
||||
}
|
||||
|
||||
func NewCreateIndexOption(collectionName string, fieldName string, index index.Index) *createIndexOption {
|
||||
return &createIndexOption{
|
||||
collectionName: collectionName,
|
||||
fieldName: fieldName,
|
||||
indexDef: index,
|
||||
}
|
||||
}
|
||||
|
||||
type ListIndexOption interface {
|
||||
Request() *milvuspb.DescribeIndexRequest
|
||||
Matches(*milvuspb.IndexDescription) bool
|
||||
}
|
||||
|
||||
var _ ListIndexOption = (*listIndexOption)(nil)
|
||||
|
||||
type listIndexOption struct {
|
||||
collectionName string
|
||||
fieldName string
|
||||
}
|
||||
|
||||
func (opt *listIndexOption) WithFieldName(fieldName string) *listIndexOption {
|
||||
opt.fieldName = fieldName
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *listIndexOption) Matches(idxDef *milvuspb.IndexDescription) bool {
|
||||
return opt.fieldName == "" || idxDef.GetFieldName() == opt.fieldName
|
||||
}
|
||||
|
||||
func (opt *listIndexOption) Request() *milvuspb.DescribeIndexRequest {
|
||||
return &milvuspb.DescribeIndexRequest{
|
||||
CollectionName: opt.collectionName,
|
||||
FieldName: opt.fieldName,
|
||||
}
|
||||
}
|
||||
|
||||
func NewListIndexOption(collectionName string) *listIndexOption {
|
||||
return &listIndexOption{
|
||||
collectionName: collectionName,
|
||||
}
|
||||
}
|
||||
|
||||
type DescribeIndexOption interface {
|
||||
Request() *milvuspb.DescribeIndexRequest
|
||||
}
|
||||
|
||||
type describeIndexOption struct {
|
||||
collectionName string
|
||||
fieldName string
|
||||
indexName string
|
||||
}
|
||||
|
||||
func (opt *describeIndexOption) Request() *milvuspb.DescribeIndexRequest {
|
||||
return &milvuspb.DescribeIndexRequest{
|
||||
CollectionName: opt.collectionName,
|
||||
IndexName: opt.indexName,
|
||||
}
|
||||
}
|
||||
|
||||
func NewDescribeIndexOption(collectionName string, indexName string) *describeIndexOption {
|
||||
return &describeIndexOption{
|
||||
collectionName: collectionName,
|
||||
indexName: indexName,
|
||||
}
|
||||
}
|
||||
|
||||
type DropIndexOption interface {
|
||||
Request() *milvuspb.DropIndexRequest
|
||||
}
|
||||
|
||||
type dropIndexOption struct {
|
||||
collectionName string
|
||||
indexName string
|
||||
}
|
||||
|
||||
func (opt *dropIndexOption) Request() *milvuspb.DropIndexRequest {
|
||||
return &milvuspb.DropIndexRequest{
|
||||
CollectionName: opt.collectionName,
|
||||
IndexName: opt.indexName,
|
||||
}
|
||||
}
|
||||
|
||||
func NewDropIndexOption(collectionName string, indexName string) *dropIndexOption {
|
||||
return &dropIndexOption{
|
||||
collectionName: collectionName,
|
||||
indexName: indexName,
|
||||
}
|
||||
}
|
||||
221
client/index_test.go
Normal file
221
client/index_test.go
Normal file
@ -0,0 +1,221 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
"github.com/milvus-io/milvus/client/v2/index"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
type IndexSuite struct {
|
||||
MockSuiteBase
|
||||
}
|
||||
|
||||
func (s *IndexSuite) TestCreateIndex() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s.Run("success", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
fieldName := fmt.Sprintf("field_%s", s.randString(4))
|
||||
indexName := fmt.Sprintf("idx_%s", s.randString(6))
|
||||
|
||||
done := atomic.NewBool(false)
|
||||
|
||||
s.mock.EXPECT().CreateIndex(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cir *milvuspb.CreateIndexRequest) (*commonpb.Status, error) {
|
||||
s.Equal(collectionName, cir.GetCollectionName())
|
||||
s.Equal(fieldName, cir.GetFieldName())
|
||||
s.Equal(indexName, cir.GetIndexName())
|
||||
return merr.Success(), nil
|
||||
}).Once()
|
||||
s.mock.EXPECT().DescribeIndex(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, dir *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) {
|
||||
state := commonpb.IndexState_InProgress
|
||||
if done.Load() {
|
||||
state = commonpb.IndexState_Finished
|
||||
}
|
||||
return &milvuspb.DescribeIndexResponse{
|
||||
Status: merr.Success(),
|
||||
IndexDescriptions: []*milvuspb.IndexDescription{
|
||||
{
|
||||
FieldName: fieldName,
|
||||
IndexName: indexName,
|
||||
State: state,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
})
|
||||
defer s.mock.EXPECT().DescribeIndex(mock.Anything, mock.Anything).Unset()
|
||||
|
||||
task, err := s.client.CreateIndex(ctx, NewCreateIndexOption(collectionName, fieldName, index.NewHNSWIndex(entity.L2, 32, 128)).WithIndexName(indexName))
|
||||
s.NoError(err)
|
||||
|
||||
ch := make(chan struct{})
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := task.Await(ctx)
|
||||
s.NoError(err)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
s.FailNow("task done before index state set to finish")
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
|
||||
done.Store(true)
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(time.Second):
|
||||
s.FailNow("task not done after index set finished")
|
||||
}
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
fieldName := fmt.Sprintf("field_%s", s.randString(4))
|
||||
indexName := fmt.Sprintf("idx_%s", s.randString(6))
|
||||
|
||||
s.mock.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
_, err := s.client.CreateIndex(ctx, NewCreateIndexOption(collectionName, fieldName, index.NewHNSWIndex(entity.L2, 32, 128)).WithIndexName(indexName))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *IndexSuite) TestListIndexes() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
s.mock.EXPECT().DescribeIndex(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, dir *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) {
|
||||
s.Equal(collectionName, dir.GetCollectionName())
|
||||
return &milvuspb.DescribeIndexResponse{
|
||||
Status: merr.Success(),
|
||||
IndexDescriptions: []*milvuspb.IndexDescription{
|
||||
{IndexName: "test_idx"},
|
||||
},
|
||||
}, nil
|
||||
}).Once()
|
||||
|
||||
names, err := s.client.ListIndexes(ctx, NewListIndexOption(collectionName))
|
||||
s.NoError(err)
|
||||
s.ElementsMatch([]string{"test_idx"}, names)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
s.mock.EXPECT().DescribeIndex(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
_, err := s.client.ListIndexes(ctx, NewListIndexOption(collectionName))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *IndexSuite) TestDescribeIndex() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s.Run("success", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
indexName := fmt.Sprintf("idx_%s", s.randString(6))
|
||||
s.mock.EXPECT().DescribeIndex(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, dir *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) {
|
||||
s.Equal(collectionName, dir.GetCollectionName())
|
||||
s.Equal(indexName, dir.GetIndexName())
|
||||
return &milvuspb.DescribeIndexResponse{
|
||||
Status: merr.Success(),
|
||||
IndexDescriptions: []*milvuspb.IndexDescription{
|
||||
{IndexName: indexName, Params: []*commonpb.KeyValuePair{
|
||||
{Key: index.IndexTypeKey, Value: string(index.HNSW)},
|
||||
}},
|
||||
},
|
||||
}, nil
|
||||
}).Once()
|
||||
|
||||
index, err := s.client.DescribeIndex(ctx, NewDescribeIndexOption(collectionName, indexName))
|
||||
s.NoError(err)
|
||||
s.Equal(indexName, index.Name())
|
||||
})
|
||||
|
||||
s.Run("no_index_found", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
indexName := fmt.Sprintf("idx_%s", s.randString(6))
|
||||
s.mock.EXPECT().DescribeIndex(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, dir *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) {
|
||||
s.Equal(collectionName, dir.GetCollectionName())
|
||||
s.Equal(indexName, dir.GetIndexName())
|
||||
return &milvuspb.DescribeIndexResponse{
|
||||
Status: merr.Success(),
|
||||
IndexDescriptions: []*milvuspb.IndexDescription{},
|
||||
}, nil
|
||||
}).Once()
|
||||
|
||||
_, err := s.client.DescribeIndex(ctx, NewDescribeIndexOption(collectionName, indexName))
|
||||
s.Error(err)
|
||||
s.ErrorIs(err, merr.ErrIndexNotFound)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
indexName := fmt.Sprintf("idx_%s", s.randString(6))
|
||||
s.mock.EXPECT().DescribeIndex(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
_, err := s.client.DescribeIndex(ctx, NewDescribeIndexOption(collectionName, indexName))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *IndexSuite) TestDropIndexOption() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
indexName := fmt.Sprintf("idx_%s", s.randString(6))
|
||||
opt := NewDropIndexOption(collectionName, indexName)
|
||||
req := opt.Request()
|
||||
|
||||
s.Equal(collectionName, req.GetCollectionName())
|
||||
s.Equal(indexName, req.GetIndexName())
|
||||
}
|
||||
|
||||
func (s *IndexSuite) TestDropIndex() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s.Run("success", func() {
|
||||
s.mock.EXPECT().DropIndex(mock.Anything, mock.Anything).Return(merr.Success(), nil).Once()
|
||||
|
||||
err := s.client.DropIndex(ctx, NewDropIndexOption("testCollection", "testIndex"))
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
s.mock.EXPECT().DropIndex(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
err := s.client.DropIndex(ctx, NewDropIndexOption("testCollection", "testIndex"))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIndex(t *testing.T) {
|
||||
suite.Run(t, new(IndexSuite))
|
||||
}
|
||||
171
client/maintenance.go
Normal file
171
client/maintenance.go
Normal file
@ -0,0 +1,171 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
type LoadTask struct {
|
||||
client *Client
|
||||
collectionName string
|
||||
partitionNames []string
|
||||
interval time.Duration
|
||||
}
|
||||
|
||||
func (t *LoadTask) Await(ctx context.Context) error {
|
||||
ticker := time.NewTicker(t.interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
loaded := false
|
||||
t.client.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{
|
||||
CollectionName: t.collectionName,
|
||||
PartitionNames: t.partitionNames,
|
||||
})
|
||||
if err = merr.CheckRPCCall(resp, err); err != nil {
|
||||
return err
|
||||
}
|
||||
loaded = resp.GetProgress() == 100
|
||||
return nil
|
||||
})
|
||||
if loaded {
|
||||
return nil
|
||||
}
|
||||
ticker.Reset(t.interval)
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) LoadCollection(ctx context.Context, option LoadCollectionOption, callOptions ...grpc.CallOption) (LoadTask, error) {
|
||||
req := option.Request()
|
||||
|
||||
var task LoadTask
|
||||
|
||||
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.LoadCollection(ctx, req, callOptions...)
|
||||
if err = merr.CheckRPCCall(resp, err); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
task = LoadTask{
|
||||
client: c,
|
||||
collectionName: req.GetCollectionName(),
|
||||
interval: option.CheckInterval(),
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
return task, err
|
||||
}
|
||||
|
||||
func (c *Client) LoadPartitions(ctx context.Context, option LoadPartitionsOption, callOptions ...grpc.CallOption) (LoadTask, error) {
|
||||
req := option.Request()
|
||||
|
||||
var task LoadTask
|
||||
|
||||
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.LoadPartitions(ctx, req, callOptions...)
|
||||
if err = merr.CheckRPCCall(resp, err); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
task = LoadTask{
|
||||
client: c,
|
||||
collectionName: req.GetCollectionName(),
|
||||
partitionNames: req.GetPartitionNames(),
|
||||
interval: option.CheckInterval(),
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
return task, err
|
||||
}
|
||||
|
||||
type FlushTask struct {
|
||||
client *Client
|
||||
collectionName string
|
||||
segmentIDs []int64
|
||||
flushTs uint64
|
||||
interval time.Duration
|
||||
}
|
||||
|
||||
func (t *FlushTask) Await(ctx context.Context) error {
|
||||
ticker := time.NewTicker(t.interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
flushed := false
|
||||
t.client.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{
|
||||
CollectionName: t.collectionName,
|
||||
SegmentIDs: t.segmentIDs,
|
||||
FlushTs: t.flushTs,
|
||||
})
|
||||
err = merr.CheckRPCCall(resp, err)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
flushed = resp.GetFlushed()
|
||||
|
||||
return nil
|
||||
})
|
||||
if flushed {
|
||||
return nil
|
||||
}
|
||||
ticker.Reset(t.interval)
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Flush(ctx context.Context, option FlushOption, callOptions ...grpc.CallOption) (*FlushTask, error) {
|
||||
req := option.Request()
|
||||
collectionName := option.CollectionName()
|
||||
var task *FlushTask
|
||||
|
||||
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.Flush(ctx, req, callOptions...)
|
||||
err = merr.CheckRPCCall(resp, err)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
task = &FlushTask{
|
||||
client: c,
|
||||
collectionName: collectionName,
|
||||
segmentIDs: resp.GetCollSegIDs()[collectionName].GetData(),
|
||||
flushTs: resp.GetCollFlushTs()[collectionName],
|
||||
interval: option.CheckInterval(),
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
return task, err
|
||||
}
|
||||
125
client/maintenance_options.go
Normal file
125
client/maintenance_options.go
Normal file
@ -0,0 +1,125 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
)
|
||||
|
||||
type LoadCollectionOption interface {
|
||||
Request() *milvuspb.LoadCollectionRequest
|
||||
CheckInterval() time.Duration
|
||||
}
|
||||
|
||||
type loadCollectionOption struct {
|
||||
collectionName string
|
||||
interval time.Duration
|
||||
replicaNum int
|
||||
}
|
||||
|
||||
func (opt *loadCollectionOption) Request() *milvuspb.LoadCollectionRequest {
|
||||
return &milvuspb.LoadCollectionRequest{
|
||||
CollectionName: opt.collectionName,
|
||||
ReplicaNumber: int32(opt.replicaNum),
|
||||
}
|
||||
}
|
||||
|
||||
func (opt *loadCollectionOption) CheckInterval() time.Duration {
|
||||
return opt.interval
|
||||
}
|
||||
|
||||
func (opt *loadCollectionOption) WithReplica(num int) *loadCollectionOption {
|
||||
opt.replicaNum = num
|
||||
return opt
|
||||
}
|
||||
|
||||
func NewLoadCollectionOption(collectionName string) *loadCollectionOption {
|
||||
return &loadCollectionOption{
|
||||
collectionName: collectionName,
|
||||
replicaNum: 1,
|
||||
interval: time.Millisecond * 200,
|
||||
}
|
||||
}
|
||||
|
||||
type LoadPartitionsOption interface {
|
||||
Request() *milvuspb.LoadPartitionsRequest
|
||||
CheckInterval() time.Duration
|
||||
}
|
||||
|
||||
var _ LoadPartitionsOption = (*loadPartitionsOption)(nil)
|
||||
|
||||
type loadPartitionsOption struct {
|
||||
collectionName string
|
||||
partitionNames []string
|
||||
interval time.Duration
|
||||
replicaNum int
|
||||
}
|
||||
|
||||
func (opt *loadPartitionsOption) Request() *milvuspb.LoadPartitionsRequest {
|
||||
return &milvuspb.LoadPartitionsRequest{
|
||||
CollectionName: opt.collectionName,
|
||||
PartitionNames: opt.partitionNames,
|
||||
ReplicaNumber: int32(opt.replicaNum),
|
||||
}
|
||||
}
|
||||
|
||||
func (opt *loadPartitionsOption) CheckInterval() time.Duration {
|
||||
return opt.interval
|
||||
}
|
||||
|
||||
func NewLoadPartitionsOption(collectionName string, partitionsNames []string) *loadPartitionsOption {
|
||||
return &loadPartitionsOption{
|
||||
collectionName: collectionName,
|
||||
partitionNames: partitionsNames,
|
||||
replicaNum: 1,
|
||||
interval: time.Millisecond * 200,
|
||||
}
|
||||
}
|
||||
|
||||
type FlushOption interface {
|
||||
Request() *milvuspb.FlushRequest
|
||||
CollectionName() string
|
||||
CheckInterval() time.Duration
|
||||
}
|
||||
|
||||
type flushOption struct {
|
||||
collectionName string
|
||||
interval time.Duration
|
||||
}
|
||||
|
||||
func (opt *flushOption) Request() *milvuspb.FlushRequest {
|
||||
return &milvuspb.FlushRequest{
|
||||
CollectionNames: []string{opt.collectionName},
|
||||
}
|
||||
}
|
||||
|
||||
func (opt *flushOption) CollectionName() string {
|
||||
return opt.collectionName
|
||||
}
|
||||
|
||||
func (opt *flushOption) CheckInterval() time.Duration {
|
||||
return opt.interval
|
||||
}
|
||||
|
||||
func NewFlushOption(collName string) *flushOption {
|
||||
return &flushOption{
|
||||
collectionName: collName,
|
||||
interval: time.Millisecond * 200,
|
||||
}
|
||||
}
|
||||
229
client/maintenance_test.go
Normal file
229
client/maintenance_test.go
Normal file
@ -0,0 +1,229 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
type MaintenanceSuite struct {
|
||||
MockSuiteBase
|
||||
}
|
||||
|
||||
func (s *MaintenanceSuite) TestLoadCollection() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s.Run("success", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
|
||||
done := atomic.NewBool(false)
|
||||
s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, lcr *milvuspb.LoadCollectionRequest) (*commonpb.Status, error) {
|
||||
s.Equal(collectionName, lcr.GetCollectionName())
|
||||
return merr.Success(), nil
|
||||
}).Once()
|
||||
s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, glpr *milvuspb.GetLoadingProgressRequest) (*milvuspb.GetLoadingProgressResponse, error) {
|
||||
s.Equal(collectionName, glpr.GetCollectionName())
|
||||
|
||||
progress := int64(50)
|
||||
if done.Load() {
|
||||
progress = 100
|
||||
}
|
||||
|
||||
return &milvuspb.GetLoadingProgressResponse{
|
||||
Status: merr.Success(),
|
||||
Progress: progress,
|
||||
}, nil
|
||||
})
|
||||
defer s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).Unset()
|
||||
|
||||
task, err := s.client.LoadCollection(ctx, NewLoadCollectionOption(collectionName))
|
||||
s.NoError(err)
|
||||
|
||||
ch := make(chan struct{})
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := task.Await(ctx)
|
||||
s.NoError(err)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
s.FailNow("task done before index state set to finish")
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
|
||||
done.Store(true)
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(time.Second):
|
||||
s.FailNow("task not done after index set finished")
|
||||
}
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
|
||||
s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
_, err := s.client.LoadCollection(ctx, NewLoadCollectionOption(collectionName))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *MaintenanceSuite) TestLoadPartitions() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s.Run("success", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
partitionName := fmt.Sprintf("part_%s", s.randString(6))
|
||||
|
||||
done := atomic.NewBool(false)
|
||||
s.mock.EXPECT().LoadPartitions(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, lpr *milvuspb.LoadPartitionsRequest) (*commonpb.Status, error) {
|
||||
s.Equal(collectionName, lpr.GetCollectionName())
|
||||
s.ElementsMatch([]string{partitionName}, lpr.GetPartitionNames())
|
||||
return merr.Success(), nil
|
||||
}).Once()
|
||||
s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, glpr *milvuspb.GetLoadingProgressRequest) (*milvuspb.GetLoadingProgressResponse, error) {
|
||||
s.Equal(collectionName, glpr.GetCollectionName())
|
||||
s.ElementsMatch([]string{partitionName}, glpr.GetPartitionNames())
|
||||
|
||||
progress := int64(50)
|
||||
if done.Load() {
|
||||
progress = 100
|
||||
}
|
||||
|
||||
return &milvuspb.GetLoadingProgressResponse{
|
||||
Status: merr.Success(),
|
||||
Progress: progress,
|
||||
}, nil
|
||||
})
|
||||
defer s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).Unset()
|
||||
|
||||
task, err := s.client.LoadPartitions(ctx, NewLoadPartitionsOption(collectionName, []string{partitionName}))
|
||||
s.NoError(err)
|
||||
|
||||
ch := make(chan struct{})
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := task.Await(ctx)
|
||||
s.NoError(err)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
s.FailNow("task done before index state set to finish")
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
|
||||
done.Store(true)
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(time.Second):
|
||||
s.FailNow("task not done after index set finished")
|
||||
}
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
partitionName := fmt.Sprintf("part_%s", s.randString(6))
|
||||
|
||||
s.mock.EXPECT().LoadPartitions(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
_, err := s.client.LoadPartitions(ctx, NewLoadPartitionsOption(collectionName, []string{partitionName}))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *MaintenanceSuite) TestFlush() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s.Run("success", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
|
||||
done := atomic.NewBool(false)
|
||||
s.mock.EXPECT().Flush(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, fr *milvuspb.FlushRequest) (*milvuspb.FlushResponse, error) {
|
||||
s.ElementsMatch([]string{collectionName}, fr.GetCollectionNames())
|
||||
return &milvuspb.FlushResponse{
|
||||
Status: merr.Success(),
|
||||
CollSegIDs: map[string]*schemapb.LongArray{
|
||||
collectionName: {Data: []int64{1, 2, 3}},
|
||||
},
|
||||
CollFlushTs: map[string]uint64{collectionName: 321},
|
||||
}, nil
|
||||
}).Once()
|
||||
s.mock.EXPECT().GetFlushState(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, gfsr *milvuspb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error) {
|
||||
s.Equal(collectionName, gfsr.GetCollectionName())
|
||||
s.ElementsMatch([]int64{1, 2, 3}, gfsr.GetSegmentIDs())
|
||||
s.EqualValues(321, gfsr.GetFlushTs())
|
||||
return &milvuspb.GetFlushStateResponse{
|
||||
Status: merr.Success(),
|
||||
Flushed: done.Load(),
|
||||
}, nil
|
||||
})
|
||||
defer s.mock.EXPECT().GetFlushState(mock.Anything, mock.Anything).Unset()
|
||||
|
||||
task, err := s.client.Flush(ctx, NewFlushOption(collectionName))
|
||||
s.NoError(err)
|
||||
|
||||
ch := make(chan struct{})
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := task.Await(ctx)
|
||||
s.NoError(err)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
s.FailNow("task done before index state set to finish")
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
|
||||
done.Store(true)
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(time.Second):
|
||||
s.FailNow("task not done after index set finished")
|
||||
}
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
|
||||
s.mock.EXPECT().Flush(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
_, err := s.client.Flush(ctx, NewFlushOption(collectionName))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMaintenance(t *testing.T) {
|
||||
suite.Run(t, new(MaintenanceSuite))
|
||||
}
|
||||
4717
client/mock_milvus_server_test.go
Normal file
4717
client/mock_milvus_server_test.go
Normal file
File diff suppressed because it is too large
Load Diff
77
client/partition.go
Normal file
77
client/partition.go
Normal file
@ -0,0 +1,77 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// CreatePartition is the API for creating a partition for a collection.
|
||||
func (c *Client) CreatePartition(ctx context.Context, opt CreatePartitionOption, callOptions ...grpc.CallOption) error {
|
||||
req := opt.Request()
|
||||
|
||||
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.CreatePartition(ctx, req, callOptions...)
|
||||
return merr.CheckRPCCall(resp, err)
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) DropPartition(ctx context.Context, opt DropPartitionOption, callOptions ...grpc.CallOption) error {
|
||||
req := opt.Request()
|
||||
|
||||
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.DropPartition(ctx, req, callOptions...)
|
||||
return merr.CheckRPCCall(resp, err)
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) HasPartition(ctx context.Context, opt HasPartitionOption, callOptions ...grpc.CallOption) (has bool, err error) {
|
||||
req := opt.Request()
|
||||
|
||||
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.HasPartition(ctx, req, callOptions...)
|
||||
err = merr.CheckRPCCall(resp, err)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
has = resp.GetValue()
|
||||
return nil
|
||||
})
|
||||
return has, err
|
||||
}
|
||||
|
||||
func (c *Client) ListPartitions(ctx context.Context, opt ListPartitionsOption, callOptions ...grpc.CallOption) (partitionNames []string, err error) {
|
||||
req := opt.Request()
|
||||
|
||||
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.ShowPartitions(ctx, req, callOptions...)
|
||||
err = merr.CheckRPCCall(resp, err)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
partitionNames = resp.GetPartitionNames()
|
||||
return nil
|
||||
})
|
||||
return partitionNames, err
|
||||
}
|
||||
119
client/partition_options.go
Normal file
119
client/partition_options.go
Normal file
@ -0,0 +1,119 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
|
||||
// CreatePartitionOption is the interface builds Create Partition request.
|
||||
type CreatePartitionOption interface {
|
||||
// Request is the method returns the composed request.
|
||||
Request() *milvuspb.CreatePartitionRequest
|
||||
}
|
||||
|
||||
type createPartitionOpt struct {
|
||||
collectionName string
|
||||
partitionName string
|
||||
}
|
||||
|
||||
func (opt *createPartitionOpt) Request() *milvuspb.CreatePartitionRequest {
|
||||
return &milvuspb.CreatePartitionRequest{
|
||||
CollectionName: opt.collectionName,
|
||||
PartitionName: opt.partitionName,
|
||||
}
|
||||
}
|
||||
|
||||
func NewCreatePartitionOption(collectionName string, partitionName string) *createPartitionOpt {
|
||||
return &createPartitionOpt{
|
||||
collectionName: collectionName,
|
||||
partitionName: partitionName,
|
||||
}
|
||||
}
|
||||
|
||||
// DropPartitionOption is the interface that builds Drop Partition request.
|
||||
type DropPartitionOption interface {
|
||||
// Request is the method returns the composed request.
|
||||
Request() *milvuspb.DropPartitionRequest
|
||||
}
|
||||
|
||||
type dropPartitionOpt struct {
|
||||
collectionName string
|
||||
partitionName string
|
||||
}
|
||||
|
||||
func (opt *dropPartitionOpt) Request() *milvuspb.DropPartitionRequest {
|
||||
return &milvuspb.DropPartitionRequest{
|
||||
CollectionName: opt.collectionName,
|
||||
PartitionName: opt.partitionName,
|
||||
}
|
||||
}
|
||||
|
||||
func NewDropPartitionOption(collectionName string, partitionName string) *dropPartitionOpt {
|
||||
return &dropPartitionOpt{
|
||||
collectionName: collectionName,
|
||||
partitionName: partitionName,
|
||||
}
|
||||
}
|
||||
|
||||
// HasPartitionOption is the interface builds HasPartition request.
|
||||
type HasPartitionOption interface {
|
||||
// Request is the method returns the composed request.
|
||||
Request() *milvuspb.HasPartitionRequest
|
||||
}
|
||||
|
||||
var _ HasPartitionOption = (*hasPartitionOpt)(nil)
|
||||
|
||||
type hasPartitionOpt struct {
|
||||
collectionName string
|
||||
partitionName string
|
||||
}
|
||||
|
||||
func (opt *hasPartitionOpt) Request() *milvuspb.HasPartitionRequest {
|
||||
return &milvuspb.HasPartitionRequest{
|
||||
CollectionName: opt.collectionName,
|
||||
PartitionName: opt.partitionName,
|
||||
}
|
||||
}
|
||||
|
||||
func NewHasPartitionOption(collectionName string, partitionName string) *hasPartitionOpt {
|
||||
return &hasPartitionOpt{
|
||||
collectionName: collectionName,
|
||||
partitionName: partitionName,
|
||||
}
|
||||
}
|
||||
|
||||
// ListPartitionsOption is the interface builds List Partition request.
|
||||
type ListPartitionsOption interface {
|
||||
// Request is the method returns the composed request.
|
||||
Request() *milvuspb.ShowPartitionsRequest
|
||||
}
|
||||
|
||||
type listPartitionsOpt struct {
|
||||
collectionName string
|
||||
}
|
||||
|
||||
func (opt *listPartitionsOpt) Request() *milvuspb.ShowPartitionsRequest {
|
||||
return &milvuspb.ShowPartitionsRequest{
|
||||
CollectionName: opt.collectionName,
|
||||
Type: milvuspb.ShowType_All,
|
||||
}
|
||||
}
|
||||
|
||||
func NewListPartitionOption(collectionName string) *listPartitionsOpt {
|
||||
return &listPartitionsOpt{
|
||||
collectionName: collectionName,
|
||||
}
|
||||
}
|
||||
166
client/partition_test.go
Normal file
166
client/partition_test.go
Normal file
@ -0,0 +1,166 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type PartitionSuite struct {
|
||||
MockSuiteBase
|
||||
}
|
||||
|
||||
func (s *PartitionSuite) TestListPartitions() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
|
||||
s.mock.EXPECT().ShowPartitions(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, spr *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) {
|
||||
s.Equal(collectionName, spr.GetCollectionName())
|
||||
return &milvuspb.ShowPartitionsResponse{
|
||||
Status: merr.Success(),
|
||||
PartitionNames: []string{"_default", "part_1"},
|
||||
PartitionIDs: []int64{100, 101},
|
||||
}, nil
|
||||
}).Once()
|
||||
|
||||
names, err := s.client.ListPartitions(ctx, NewListPartitionOption(collectionName))
|
||||
s.NoError(err)
|
||||
s.ElementsMatch([]string{"_default", "part_1"}, names)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
|
||||
s.mock.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
_, err := s.client.ListPartitions(ctx, NewListPartitionOption(collectionName))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *PartitionSuite) TestCreatePartition() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
partitionName := fmt.Sprintf("part_%s", s.randString(6))
|
||||
|
||||
s.mock.EXPECT().CreatePartition(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cpr *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) {
|
||||
s.Equal(collectionName, cpr.GetCollectionName())
|
||||
s.Equal(partitionName, cpr.GetPartitionName())
|
||||
return merr.Success(), nil
|
||||
}).Once()
|
||||
|
||||
err := s.client.CreatePartition(ctx, NewCreatePartitionOption(collectionName, partitionName))
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
s.Run("success", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
partitionName := fmt.Sprintf("part_%s", s.randString(6))
|
||||
|
||||
s.mock.EXPECT().CreatePartition(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
err := s.client.CreatePartition(ctx, NewCreatePartitionOption(collectionName, partitionName))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *PartitionSuite) TestHasPartition() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
partitionName := fmt.Sprintf("part_%s", s.randString(6))
|
||||
|
||||
s.mock.EXPECT().HasPartition(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, hpr *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) {
|
||||
s.Equal(collectionName, hpr.GetCollectionName())
|
||||
s.Equal(partitionName, hpr.GetPartitionName())
|
||||
return &milvuspb.BoolResponse{Status: merr.Success()}, nil
|
||||
}).Once()
|
||||
|
||||
has, err := s.client.HasPartition(ctx, NewHasPartitionOption(collectionName, partitionName))
|
||||
s.NoError(err)
|
||||
s.False(has)
|
||||
|
||||
s.mock.EXPECT().HasPartition(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, hpr *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) {
|
||||
s.Equal(collectionName, hpr.GetCollectionName())
|
||||
s.Equal(partitionName, hpr.GetPartitionName())
|
||||
return &milvuspb.BoolResponse{
|
||||
Status: merr.Success(),
|
||||
Value: true,
|
||||
}, nil
|
||||
}).Once()
|
||||
|
||||
has, err = s.client.HasPartition(ctx, NewHasPartitionOption(collectionName, partitionName))
|
||||
s.NoError(err)
|
||||
s.True(has)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
partitionName := fmt.Sprintf("part_%s", s.randString(6))
|
||||
s.mock.EXPECT().HasPartition(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
_, err := s.client.HasPartition(ctx, NewHasPartitionOption(collectionName, partitionName))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *PartitionSuite) TestDropPartition() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
partitionName := fmt.Sprintf("part_%s", s.randString(6))
|
||||
s.mock.EXPECT().DropPartition(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, dpr *milvuspb.DropPartitionRequest) (*commonpb.Status, error) {
|
||||
s.Equal(collectionName, dpr.GetCollectionName())
|
||||
s.Equal(partitionName, dpr.GetPartitionName())
|
||||
return merr.Success(), nil
|
||||
}).Once()
|
||||
|
||||
err := s.client.DropPartition(ctx, NewDropPartitionOption(collectionName, partitionName))
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
partitionName := fmt.Sprintf("part_%s", s.randString(6))
|
||||
s.mock.EXPECT().DropPartition(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
err := s.client.DropPartition(ctx, NewDropPartitionOption(collectionName, partitionName))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPartition(t *testing.T) {
|
||||
suite.Run(t, new(PartitionSuite))
|
||||
}
|
||||
220
client/read.go
Normal file
220
client/read.go
Normal file
@ -0,0 +1,220 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/client/v2/column"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
type ResultSets struct{}
|
||||
|
||||
type ResultSet struct {
|
||||
ResultCount int // the returning entry count
|
||||
GroupByValue any
|
||||
IDs column.Column // auto generated id, can be mapped to the columns from `Insert` API
|
||||
Fields DataSet // output field data
|
||||
Scores []float32 // distance to the target vector
|
||||
Err error // search error if any
|
||||
}
|
||||
|
||||
// DataSet is an alias type for column slice.
|
||||
type DataSet []column.Column
|
||||
|
||||
func (c *Client) Search(ctx context.Context, option SearchOption, callOptions ...grpc.CallOption) ([]ResultSet, error) {
|
||||
req := option.Request()
|
||||
collection, err := c.getCollection(ctx, req.GetCollectionName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var resultSets []ResultSet
|
||||
|
||||
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.Search(ctx, req, callOptions...)
|
||||
err = merr.CheckRPCCall(resp, err)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resultSets, err = c.handleSearchResult(collection.Schema, req.GetOutputFields(), int(req.GetNq()), resp)
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
return resultSets, err
|
||||
}
|
||||
|
||||
func (c *Client) handleSearchResult(schema *entity.Schema, outputFields []string, nq int, resp *milvuspb.SearchResults) ([]ResultSet, error) {
|
||||
var err error
|
||||
sr := make([]ResultSet, 0, nq)
|
||||
results := resp.GetResults()
|
||||
offset := 0
|
||||
fieldDataList := results.GetFieldsData()
|
||||
gb := results.GetGroupByFieldValue()
|
||||
var gbc column.Column
|
||||
if gb != nil {
|
||||
gbc, err = column.FieldDataColumn(gb, 0, -1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
for i := 0; i < int(results.GetNumQueries()); i++ {
|
||||
rc := int(results.GetTopks()[i]) // result entry count for current query
|
||||
entry := ResultSet{
|
||||
ResultCount: rc,
|
||||
Scores: results.GetScores()[offset : offset+rc],
|
||||
}
|
||||
if gbc != nil {
|
||||
entry.GroupByValue, _ = gbc.Get(i)
|
||||
}
|
||||
// parse result set if current nq is not empty
|
||||
if rc > 0 {
|
||||
entry.IDs, entry.Err = column.IDColumns(results.GetIds(), offset, offset+rc)
|
||||
if entry.Err != nil {
|
||||
offset += rc
|
||||
continue
|
||||
}
|
||||
entry.Fields, entry.Err = c.parseSearchResult(schema, outputFields, fieldDataList, i, offset, offset+rc)
|
||||
sr = append(sr, entry)
|
||||
}
|
||||
|
||||
offset += rc
|
||||
}
|
||||
return sr, nil
|
||||
}
|
||||
|
||||
func (c *Client) parseSearchResult(sch *entity.Schema, outputFields []string, fieldDataList []*schemapb.FieldData, _, from, to int) ([]column.Column, error) {
|
||||
var wildcard bool
|
||||
outputFields, wildcard = expandWildcard(sch, outputFields)
|
||||
// duplicated name will have only one column now
|
||||
outputSet := make(map[string]struct{})
|
||||
for _, output := range outputFields {
|
||||
outputSet[output] = struct{}{}
|
||||
}
|
||||
// fields := make(map[string]*schemapb.FieldData)
|
||||
columns := make([]column.Column, 0, len(outputFields))
|
||||
var dynamicColumn *column.ColumnJSONBytes
|
||||
for _, fieldData := range fieldDataList {
|
||||
col, err := column.FieldDataColumn(fieldData, from, to)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if fieldData.GetIsDynamic() {
|
||||
var ok bool
|
||||
dynamicColumn, ok = col.(*column.ColumnJSONBytes)
|
||||
if !ok {
|
||||
return nil, errors.New("dynamic field not json")
|
||||
}
|
||||
|
||||
// return json column only explicitly specified in output fields and not in wildcard mode
|
||||
if _, ok := outputSet[fieldData.GetFieldName()]; !ok && !wildcard {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// remove processed field
|
||||
delete(outputSet, fieldData.GetFieldName())
|
||||
|
||||
columns = append(columns, col)
|
||||
}
|
||||
|
||||
if len(outputSet) > 0 && dynamicColumn == nil {
|
||||
var extraFields []string
|
||||
for output := range outputSet {
|
||||
extraFields = append(extraFields, output)
|
||||
}
|
||||
return nil, errors.Newf("extra output fields %v found and result does not dynamic field", extraFields)
|
||||
}
|
||||
// add dynamic column for extra fields
|
||||
for outputField := range outputSet {
|
||||
column := column.NewColumnDynamic(dynamicColumn, outputField)
|
||||
columns = append(columns, column)
|
||||
}
|
||||
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func (c *Client) Query(ctx context.Context, option QueryOption, callOptions ...grpc.CallOption) (ResultSet, error) {
|
||||
req := option.Request()
|
||||
var resultSet ResultSet
|
||||
|
||||
collection, err := c.getCollection(ctx, req.GetCollectionName())
|
||||
if err != nil {
|
||||
return resultSet, err
|
||||
}
|
||||
|
||||
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.Query(ctx, req, callOptions...)
|
||||
err = merr.CheckRPCCall(resp, err)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
columns, err := c.parseSearchResult(collection.Schema, resp.GetOutputFields(), resp.GetFieldsData(), 0, 0, -1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resultSet = ResultSet{
|
||||
Fields: columns,
|
||||
}
|
||||
if len(columns) > 0 {
|
||||
resultSet.ResultCount = columns[0].Len()
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
return resultSet, err
|
||||
}
|
||||
|
||||
func expandWildcard(schema *entity.Schema, outputFields []string) ([]string, bool) {
|
||||
wildcard := false
|
||||
for _, outputField := range outputFields {
|
||||
if outputField == "*" {
|
||||
wildcard = true
|
||||
}
|
||||
}
|
||||
if !wildcard {
|
||||
return outputFields, false
|
||||
}
|
||||
|
||||
set := make(map[string]struct{})
|
||||
result := make([]string, 0, len(schema.Fields))
|
||||
for _, field := range schema.Fields {
|
||||
result = append(result, field.Name)
|
||||
set[field.Name] = struct{}{}
|
||||
}
|
||||
|
||||
// add dynamic fields output
|
||||
for _, output := range outputFields {
|
||||
if output == "*" {
|
||||
continue
|
||||
}
|
||||
_, ok := set[output]
|
||||
if !ok {
|
||||
result = append(result, output)
|
||||
}
|
||||
}
|
||||
return result, true
|
||||
}
|
||||
250
client/read_options.go
Normal file
250
client/read_options.go
Normal file
@ -0,0 +1,250 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
)
|
||||
|
||||
const (
|
||||
spAnnsField = `anns_field`
|
||||
spTopK = `topk`
|
||||
spOffset = `offset`
|
||||
spParams = `params`
|
||||
spMetricsType = `metric_type`
|
||||
spRoundDecimal = `round_decimal`
|
||||
spIgnoreGrowing = `ignore_growing`
|
||||
spGroupBy = `group_by_field`
|
||||
)
|
||||
|
||||
type SearchOption interface {
|
||||
Request() *milvuspb.SearchRequest
|
||||
}
|
||||
|
||||
var _ SearchOption = (*searchOption)(nil)
|
||||
|
||||
type searchOption struct {
|
||||
collectionName string
|
||||
partitionNames []string
|
||||
topK int
|
||||
offset int
|
||||
outputFields []string
|
||||
consistencyLevel entity.ConsistencyLevel
|
||||
useDefaultConsistencyLevel bool
|
||||
ignoreGrowing bool
|
||||
expr string
|
||||
|
||||
// normal search request
|
||||
request *annRequest
|
||||
// TODO add sub request when support hybrid search
|
||||
}
|
||||
|
||||
type annRequest struct {
|
||||
vectors []entity.Vector
|
||||
|
||||
annField string
|
||||
metricsType entity.MetricType
|
||||
searchParam map[string]string
|
||||
groupByField string
|
||||
}
|
||||
|
||||
func (opt *searchOption) Request() *milvuspb.SearchRequest {
|
||||
// TODO check whether search is hybrid after logic merged
|
||||
return opt.prepareSearchRequest(opt.request)
|
||||
}
|
||||
|
||||
func (opt *searchOption) prepareSearchRequest(annRequest *annRequest) *milvuspb.SearchRequest {
|
||||
request := &milvuspb.SearchRequest{
|
||||
CollectionName: opt.collectionName,
|
||||
PartitionNames: opt.partitionNames,
|
||||
Dsl: opt.expr,
|
||||
DslType: commonpb.DslType_BoolExprV1,
|
||||
ConsistencyLevel: commonpb.ConsistencyLevel(opt.consistencyLevel),
|
||||
}
|
||||
if annRequest != nil {
|
||||
// nq
|
||||
request.Nq = int64(len(annRequest.vectors))
|
||||
|
||||
// search param
|
||||
bs, _ := json.Marshal(annRequest.searchParam)
|
||||
request.SearchParams = entity.MapKvPairs(map[string]string{
|
||||
spAnnsField: annRequest.annField,
|
||||
spTopK: strconv.Itoa(opt.topK),
|
||||
spOffset: strconv.Itoa(opt.offset),
|
||||
spParams: string(bs),
|
||||
spMetricsType: string(annRequest.metricsType),
|
||||
spRoundDecimal: "-1",
|
||||
spIgnoreGrowing: strconv.FormatBool(opt.ignoreGrowing),
|
||||
spGroupBy: annRequest.groupByField,
|
||||
})
|
||||
|
||||
// placeholder group
|
||||
request.PlaceholderGroup = vector2PlaceholderGroupBytes(annRequest.vectors)
|
||||
}
|
||||
|
||||
return request
|
||||
}
|
||||
|
||||
func (opt *searchOption) WithFilter(expr string) *searchOption {
|
||||
opt.expr = expr
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *searchOption) WithOffset(offset int) *searchOption {
|
||||
opt.offset = offset
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *searchOption) WithOutputFields(fieldNames []string) *searchOption {
|
||||
opt.outputFields = fieldNames
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *searchOption) WithConsistencyLevel(consistencyLevel entity.ConsistencyLevel) *searchOption {
|
||||
opt.consistencyLevel = consistencyLevel
|
||||
opt.useDefaultConsistencyLevel = false
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *searchOption) WithANNSField(annsField string) *searchOption {
|
||||
opt.request.annField = annsField
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *searchOption) WithPartitions(partitionNames []string) *searchOption {
|
||||
opt.partitionNames = partitionNames
|
||||
return opt
|
||||
}
|
||||
|
||||
func NewSearchOption(collectionName string, limit int, vectors []entity.Vector) *searchOption {
|
||||
return &searchOption{
|
||||
collectionName: collectionName,
|
||||
topK: limit,
|
||||
request: &annRequest{
|
||||
vectors: vectors,
|
||||
},
|
||||
useDefaultConsistencyLevel: true,
|
||||
consistencyLevel: entity.ClBounded,
|
||||
}
|
||||
}
|
||||
|
||||
func vector2PlaceholderGroupBytes(vectors []entity.Vector) []byte {
|
||||
phg := &commonpb.PlaceholderGroup{
|
||||
Placeholders: []*commonpb.PlaceholderValue{
|
||||
vector2Placeholder(vectors),
|
||||
},
|
||||
}
|
||||
|
||||
bs, _ := proto.Marshal(phg)
|
||||
return bs
|
||||
}
|
||||
|
||||
func vector2Placeholder(vectors []entity.Vector) *commonpb.PlaceholderValue {
|
||||
var placeHolderType commonpb.PlaceholderType
|
||||
ph := &commonpb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
Values: make([][]byte, 0, len(vectors)),
|
||||
}
|
||||
if len(vectors) == 0 {
|
||||
return ph
|
||||
}
|
||||
switch vectors[0].(type) {
|
||||
case entity.FloatVector:
|
||||
placeHolderType = commonpb.PlaceholderType_FloatVector
|
||||
case entity.BinaryVector:
|
||||
placeHolderType = commonpb.PlaceholderType_BinaryVector
|
||||
case entity.BFloat16Vector:
|
||||
placeHolderType = commonpb.PlaceholderType_BFloat16Vector
|
||||
case entity.Float16Vector:
|
||||
placeHolderType = commonpb.PlaceholderType_Float16Vector
|
||||
case entity.SparseEmbedding:
|
||||
placeHolderType = commonpb.PlaceholderType_SparseFloatVector
|
||||
}
|
||||
ph.Type = placeHolderType
|
||||
for _, vector := range vectors {
|
||||
ph.Values = append(ph.Values, vector.Serialize())
|
||||
}
|
||||
return ph
|
||||
}
|
||||
|
||||
type QueryOption interface {
|
||||
Request() *milvuspb.QueryRequest
|
||||
}
|
||||
|
||||
type queryOption struct {
|
||||
collectionName string
|
||||
partitionNames []string
|
||||
|
||||
limit int
|
||||
offset int
|
||||
outputFields []string
|
||||
consistencyLevel entity.ConsistencyLevel
|
||||
useDefaultConsistencyLevel bool
|
||||
ignoreGrowing bool
|
||||
expr string
|
||||
}
|
||||
|
||||
func (opt *queryOption) Request() *milvuspb.QueryRequest {
|
||||
return &milvuspb.QueryRequest{
|
||||
CollectionName: opt.collectionName,
|
||||
PartitionNames: opt.partitionNames,
|
||||
OutputFields: opt.outputFields,
|
||||
|
||||
Expr: opt.expr,
|
||||
ConsistencyLevel: opt.consistencyLevel.CommonConsistencyLevel(),
|
||||
}
|
||||
}
|
||||
|
||||
func (opt *queryOption) WithFilter(expr string) *queryOption {
|
||||
opt.expr = expr
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *queryOption) WithOffset(offset int) *queryOption {
|
||||
opt.offset = offset
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *queryOption) WithOutputFields(fieldNames []string) *queryOption {
|
||||
opt.outputFields = fieldNames
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *queryOption) WithConsistencyLevel(consistencyLevel entity.ConsistencyLevel) *queryOption {
|
||||
opt.consistencyLevel = consistencyLevel
|
||||
opt.useDefaultConsistencyLevel = false
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *queryOption) WithPartitions(partitionNames []string) *queryOption {
|
||||
opt.partitionNames = partitionNames
|
||||
return opt
|
||||
}
|
||||
|
||||
func NewQueryOption(collectionName string) *queryOption {
|
||||
return &queryOption{
|
||||
collectionName: collectionName,
|
||||
useDefaultConsistencyLevel: true,
|
||||
consistencyLevel: entity.ClBounded,
|
||||
}
|
||||
}
|
||||
154
client/read_test.go
Normal file
154
client/read_test.go
Normal file
@ -0,0 +1,154 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/samber/lo"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ReadSuite struct {
|
||||
MockSuiteBase
|
||||
|
||||
schema *entity.Schema
|
||||
schemaDyn *entity.Schema
|
||||
}
|
||||
|
||||
func (s *ReadSuite) SetupSuite() {
|
||||
s.MockSuiteBase.SetupSuite()
|
||||
s.schema = entity.NewSchema().
|
||||
WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
|
||||
WithField(entity.NewField().WithName("Vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128))
|
||||
|
||||
s.schemaDyn = entity.NewSchema().WithDynamicFieldEnabled(true).
|
||||
WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
|
||||
WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128))
|
||||
}
|
||||
|
||||
func (s *ReadSuite) TestSearch() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
partitionName := fmt.Sprintf("part_%s", s.randString(6))
|
||||
s.setupCache(collectionName, s.schema)
|
||||
s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
|
||||
s.Equal(collectionName, sr.GetCollectionName())
|
||||
s.ElementsMatch([]string{partitionName}, sr.GetPartitionNames())
|
||||
|
||||
return &milvuspb.SearchResults{
|
||||
Status: merr.Success(),
|
||||
Results: &schemapb.SearchResultData{
|
||||
NumQueries: 1,
|
||||
TopK: 10,
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
s.getInt64FieldData("ID", []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
|
||||
},
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: make([]float32, 10),
|
||||
Topks: []int64{10},
|
||||
},
|
||||
}, nil
|
||||
}).Once()
|
||||
|
||||
_, err := s.client.Search(ctx, NewSearchOption(collectionName, 10, []entity.Vector{
|
||||
entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
|
||||
return rand.Float32()
|
||||
})),
|
||||
}).WithPartitions([]string{partitionName}))
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
s.Run("dynamic_schema", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
partitionName := fmt.Sprintf("part_%s", s.randString(6))
|
||||
s.setupCache(collectionName, s.schemaDyn)
|
||||
s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
|
||||
return &milvuspb.SearchResults{
|
||||
Status: merr.Success(),
|
||||
Results: &schemapb.SearchResultData{
|
||||
NumQueries: 1,
|
||||
TopK: 2,
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
s.getInt64FieldData("ID", []int64{1, 2}),
|
||||
s.getJSONBytesFieldData("$meta", [][]byte{
|
||||
[]byte(`{"A": 123, "B": "456"}`),
|
||||
[]byte(`{"B": "abc", "A": 456}`),
|
||||
}, true),
|
||||
},
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1, 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: make([]float32, 2),
|
||||
Topks: []int64{2},
|
||||
},
|
||||
}, nil
|
||||
}).Once()
|
||||
|
||||
_, err := s.client.Search(ctx, NewSearchOption(collectionName, 10, []entity.Vector{
|
||||
entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
|
||||
return rand.Float32()
|
||||
})),
|
||||
}).WithPartitions([]string{partitionName}))
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
s.setupCache(collectionName, s.schemaDyn)
|
||||
|
||||
s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) {
|
||||
return nil, merr.WrapErrServiceInternal("mocked")
|
||||
}).Once()
|
||||
|
||||
_, err := s.client.Search(ctx, NewSearchOption(collectionName, 10, []entity.Vector{
|
||||
entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
|
||||
return rand.Float32()
|
||||
})),
|
||||
}))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ReadSuite) TestQuery() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
partitionName := fmt.Sprintf("part_%s", s.randString(6))
|
||||
s.setupCache(collectionName, s.schema)
|
||||
|
||||
s.mock.EXPECT().Query(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, qr *milvuspb.QueryRequest) (*milvuspb.QueryResults, error) {
|
||||
s.Equal(collectionName, qr.GetCollectionName())
|
||||
|
||||
return &milvuspb.QueryResults{}, nil
|
||||
}).Once()
|
||||
|
||||
_, err := s.client.Query(ctx, NewQueryOption(collectionName).WithPartitions([]string{partitionName}))
|
||||
s.NoError(err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRead(t *testing.T) {
|
||||
suite.Run(t, new(ReadSuite))
|
||||
}
|
||||
74
client/write.go
Normal file
74
client/write.go
Normal file
@ -0,0 +1,74 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
func (c *Client) Insert(ctx context.Context, option InsertOption, callOptions ...grpc.CallOption) error {
|
||||
collection, err := c.getCollection(ctx, option.CollectionName())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req, err := option.InsertRequest(collection)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.Insert(ctx, req, callOptions...)
|
||||
|
||||
return merr.CheckRPCCall(resp, err)
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) Delete(ctx context.Context, option DeleteOption, callOptions ...grpc.CallOption) error {
|
||||
req := option.Request()
|
||||
|
||||
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.Delete(ctx, req, callOptions...)
|
||||
if err = merr.CheckRPCCall(resp, err); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Client) Upsert(ctx context.Context, option UpsertOption, callOptions ...grpc.CallOption) error {
|
||||
collection, err := c.getCollection(ctx, option.CollectionName())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req, err := option.UpsertRequest(collection)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
resp, err := milvusService.Upsert(ctx, req, callOptions...)
|
||||
if err = merr.CheckRPCCall(resp, err); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
39
client/write_option_test.go
Normal file
39
client/write_option_test.go
Normal file
@ -0,0 +1,39 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type DeleteOptionSuite struct {
|
||||
MockSuiteBase
|
||||
}
|
||||
|
||||
func (s *DeleteOptionSuite) TestBasic() {
|
||||
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
opt := NewDeleteOption(collectionName)
|
||||
|
||||
s.Equal(collectionName, opt.Request().GetCollectionName())
|
||||
}
|
||||
|
||||
func TestDeleteOption(t *testing.T) {
|
||||
suite.Run(t, new(DeleteOptionSuite))
|
||||
}
|
||||
290
client/write_options.go
Normal file
290
client/write_options.go
Normal file
@ -0,0 +1,290 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/client/v2/column"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
)
|
||||
|
||||
type InsertOption interface {
|
||||
InsertRequest(coll *entity.Collection) (*milvuspb.InsertRequest, error)
|
||||
CollectionName() string
|
||||
}
|
||||
|
||||
type UpsertOption interface {
|
||||
UpsertRequest(coll *entity.Collection) (*milvuspb.UpsertRequest, error)
|
||||
CollectionName() string
|
||||
}
|
||||
|
||||
var (
|
||||
_ UpsertOption = (*columnBasedDataOption)(nil)
|
||||
_ InsertOption = (*columnBasedDataOption)(nil)
|
||||
)
|
||||
|
||||
type columnBasedDataOption struct {
|
||||
collName string
|
||||
partitionName string
|
||||
columns []column.Column
|
||||
}
|
||||
|
||||
func (opt *columnBasedDataOption) processInsertColumns(colSchema *entity.Schema, columns ...column.Column) ([]*schemapb.FieldData, int, error) {
|
||||
// setup dynamic related var
|
||||
isDynamic := colSchema.EnableDynamicField
|
||||
|
||||
// check columns and field matches
|
||||
var rowSize int
|
||||
mNameField := make(map[string]*entity.Field)
|
||||
for _, field := range colSchema.Fields {
|
||||
mNameField[field.Name] = field
|
||||
}
|
||||
mNameColumn := make(map[string]column.Column)
|
||||
var dynamicColumns []column.Column
|
||||
for _, col := range columns {
|
||||
_, dup := mNameColumn[col.Name()]
|
||||
if dup {
|
||||
return nil, 0, fmt.Errorf("duplicated column %s found", col.Name())
|
||||
}
|
||||
l := col.Len()
|
||||
if rowSize == 0 {
|
||||
rowSize = l
|
||||
} else {
|
||||
if rowSize != l {
|
||||
return nil, 0, errors.New("column size not match")
|
||||
}
|
||||
}
|
||||
field, has := mNameField[col.Name()]
|
||||
if !has {
|
||||
if !isDynamic {
|
||||
return nil, 0, fmt.Errorf("field %s does not exist in collection %s", col.Name(), colSchema.CollectionName)
|
||||
}
|
||||
// add to dynamic column list for further processing
|
||||
dynamicColumns = append(dynamicColumns, col)
|
||||
continue
|
||||
}
|
||||
|
||||
mNameColumn[col.Name()] = col
|
||||
if col.Type() != field.DataType {
|
||||
return nil, 0, fmt.Errorf("param column %s has type %v but collection field definition is %v", col.Name(), col.FieldData(), field.DataType)
|
||||
}
|
||||
if field.DataType == entity.FieldTypeFloatVector || field.DataType == entity.FieldTypeBinaryVector {
|
||||
dim := 0
|
||||
switch column := col.(type) {
|
||||
case *column.ColumnFloatVector:
|
||||
dim = column.Dim()
|
||||
case *column.ColumnBinaryVector:
|
||||
dim = column.Dim()
|
||||
}
|
||||
if fmt.Sprintf("%d", dim) != field.TypeParams[entity.TypeParamDim] {
|
||||
return nil, 0, fmt.Errorf("params column %s vector dim %d not match collection definition, which has dim of %s", field.Name, dim, field.TypeParams[entity.TypeParamDim])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check all fixed field pass value
|
||||
for _, field := range colSchema.Fields {
|
||||
_, has := mNameColumn[field.Name]
|
||||
if !has &&
|
||||
!field.AutoID && !field.IsDynamic {
|
||||
return nil, 0, fmt.Errorf("field %s not passed", field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
fieldsData := make([]*schemapb.FieldData, 0, len(mNameColumn)+1)
|
||||
for _, fixedColumn := range mNameColumn {
|
||||
fieldsData = append(fieldsData, fixedColumn.FieldData())
|
||||
}
|
||||
if len(dynamicColumns) > 0 {
|
||||
// use empty column name here
|
||||
col, err := opt.mergeDynamicColumns("", rowSize, dynamicColumns)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
fieldsData = append(fieldsData, col)
|
||||
}
|
||||
|
||||
return fieldsData, rowSize, nil
|
||||
}
|
||||
|
||||
func (opt *columnBasedDataOption) mergeDynamicColumns(dynamicName string, rowSize int, columns []column.Column) (*schemapb.FieldData, error) {
|
||||
values := make([][]byte, 0, rowSize)
|
||||
for i := 0; i < rowSize; i++ {
|
||||
m := make(map[string]interface{})
|
||||
for _, column := range columns {
|
||||
// range guaranteed
|
||||
m[column.Name()], _ = column.Get(i)
|
||||
}
|
||||
bs, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
values = append(values, bs)
|
||||
}
|
||||
return &schemapb.FieldData{
|
||||
Type: schemapb.DataType_JSON,
|
||||
FieldName: dynamicName,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_JsonData{
|
||||
JsonData: &schemapb.JSONArray{
|
||||
Data: values,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
IsDynamic: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (opt *columnBasedDataOption) WithColumns(columns ...column.Column) *columnBasedDataOption {
|
||||
opt.columns = append(opt.columns, columns...)
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *columnBasedDataOption) WithBoolColumn(colName string, data []bool) *columnBasedDataOption {
|
||||
column := column.NewColumnBool(colName, data)
|
||||
return opt.WithColumns(column)
|
||||
}
|
||||
|
||||
func (opt *columnBasedDataOption) WithInt8Column(colName string, data []int8) *columnBasedDataOption {
|
||||
column := column.NewColumnInt8(colName, data)
|
||||
return opt.WithColumns(column)
|
||||
}
|
||||
|
||||
func (opt *columnBasedDataOption) WithInt16Column(colName string, data []int16) *columnBasedDataOption {
|
||||
column := column.NewColumnInt16(colName, data)
|
||||
return opt.WithColumns(column)
|
||||
}
|
||||
|
||||
func (opt *columnBasedDataOption) WithInt32Column(colName string, data []int32) *columnBasedDataOption {
|
||||
column := column.NewColumnInt32(colName, data)
|
||||
return opt.WithColumns(column)
|
||||
}
|
||||
|
||||
func (opt *columnBasedDataOption) WithInt64Column(colName string, data []int64) *columnBasedDataOption {
|
||||
column := column.NewColumnInt64(colName, data)
|
||||
return opt.WithColumns(column)
|
||||
}
|
||||
|
||||
func (opt *columnBasedDataOption) WithVarcharColumn(colName string, data []string) *columnBasedDataOption {
|
||||
column := column.NewColumnVarChar(colName, data)
|
||||
return opt.WithColumns(column)
|
||||
}
|
||||
|
||||
func (opt *columnBasedDataOption) WithFloatVectorColumn(colName string, dim int, data [][]float32) *columnBasedDataOption {
|
||||
column := column.NewColumnFloatVector(colName, dim, data)
|
||||
return opt.WithColumns(column)
|
||||
}
|
||||
|
||||
func (opt *columnBasedDataOption) WithBinaryVectorColumn(colName string, dim int, data [][]byte) *columnBasedDataOption {
|
||||
column := column.NewColumnBinaryVector(colName, dim, data)
|
||||
return opt.WithColumns(column)
|
||||
}
|
||||
|
||||
func (opt *columnBasedDataOption) WithPartition(partitionName string) *columnBasedDataOption {
|
||||
opt.partitionName = partitionName
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *columnBasedDataOption) CollectionName() string {
|
||||
return opt.collName
|
||||
}
|
||||
|
||||
func (opt *columnBasedDataOption) InsertRequest(coll *entity.Collection) (*milvuspb.InsertRequest, error) {
|
||||
fieldsData, rowNum, err := opt.processInsertColumns(coll.Schema, opt.columns...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &milvuspb.InsertRequest{
|
||||
CollectionName: opt.collName,
|
||||
PartitionName: opt.partitionName,
|
||||
FieldsData: fieldsData,
|
||||
NumRows: uint32(rowNum),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (opt *columnBasedDataOption) UpsertRequest(coll *entity.Collection) (*milvuspb.UpsertRequest, error) {
|
||||
fieldsData, rowNum, err := opt.processInsertColumns(coll.Schema, opt.columns...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &milvuspb.UpsertRequest{
|
||||
CollectionName: opt.collName,
|
||||
PartitionName: opt.partitionName,
|
||||
FieldsData: fieldsData,
|
||||
NumRows: uint32(rowNum),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewColumnBasedInsertOption(collName string, columns ...column.Column) *columnBasedDataOption {
|
||||
return &columnBasedDataOption{
|
||||
columns: columns,
|
||||
collName: collName,
|
||||
// leave partition name empty, using default partition
|
||||
}
|
||||
}
|
||||
|
||||
type DeleteOption interface {
|
||||
Request() *milvuspb.DeleteRequest
|
||||
}
|
||||
|
||||
type deleteOption struct {
|
||||
collectionName string
|
||||
partitionName string
|
||||
expr string
|
||||
}
|
||||
|
||||
func (opt *deleteOption) Request() *milvuspb.DeleteRequest {
|
||||
return &milvuspb.DeleteRequest{
|
||||
CollectionName: opt.collectionName,
|
||||
PartitionName: opt.partitionName,
|
||||
Expr: opt.expr,
|
||||
}
|
||||
}
|
||||
|
||||
func (opt *deleteOption) WithExpr(expr string) *deleteOption {
|
||||
opt.expr = expr
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *deleteOption) WithInt64IDs(fieldName string, ids []int64) *deleteOption {
|
||||
opt.expr = fmt.Sprintf("%s in %s", fieldName, strings.Join(strings.Fields(fmt.Sprint(ids)), ","))
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *deleteOption) WithStringIDs(fieldName string, ids []string) *deleteOption {
|
||||
opt.expr = fmt.Sprintf("%s in [%s]", fieldName, strings.Join(lo.Map(ids, func(id string, _ int) string { return fmt.Sprintf("\"%s\"", id) }), ","))
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *deleteOption) WithPartition(partitionName string) *deleteOption {
|
||||
opt.partitionName = partitionName
|
||||
return opt
|
||||
}
|
||||
|
||||
func NewDeleteOption(collectionName string) *deleteOption {
|
||||
return &deleteOption{collectionName: collectionName}
|
||||
}
|
||||
330
client/write_test.go
Normal file
330
client/write_test.go
Normal file
@ -0,0 +1,330 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/samber/lo"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type WriteSuite struct {
|
||||
MockSuiteBase
|
||||
|
||||
schema *entity.Schema
|
||||
schemaDyn *entity.Schema
|
||||
}
|
||||
|
||||
func (s *WriteSuite) SetupSuite() {
|
||||
s.MockSuiteBase.SetupSuite()
|
||||
s.schema = entity.NewSchema().
|
||||
WithField(entity.NewField().WithName("id").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
|
||||
WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128))
|
||||
|
||||
s.schemaDyn = entity.NewSchema().WithDynamicFieldEnabled(true).
|
||||
WithField(entity.NewField().WithName("id").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
|
||||
WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128))
|
||||
}
|
||||
|
||||
func (s *WriteSuite) TestInsert() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
collName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
partName := fmt.Sprintf("part_%s", s.randString(6))
|
||||
s.setupCache(collName, s.schema)
|
||||
|
||||
s.mock.EXPECT().Insert(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ir *milvuspb.InsertRequest) (*milvuspb.MutationResult, error) {
|
||||
s.Equal(collName, ir.GetCollectionName())
|
||||
s.Equal(partName, ir.GetPartitionName())
|
||||
s.Require().Len(ir.GetFieldsData(), 2)
|
||||
s.EqualValues(3, ir.GetNumRows())
|
||||
return &milvuspb.MutationResult{
|
||||
Status: merr.Success(),
|
||||
}, nil
|
||||
}).Once()
|
||||
|
||||
err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
|
||||
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
||||
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
||||
})).
|
||||
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
s.Run("dynamic_schema", func() {
|
||||
collName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
partName := fmt.Sprintf("part_%s", s.randString(6))
|
||||
s.setupCache(collName, s.schemaDyn)
|
||||
|
||||
s.mock.EXPECT().Insert(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ir *milvuspb.InsertRequest) (*milvuspb.MutationResult, error) {
|
||||
s.Equal(collName, ir.GetCollectionName())
|
||||
s.Equal(partName, ir.GetPartitionName())
|
||||
s.Require().Len(ir.GetFieldsData(), 3)
|
||||
s.EqualValues(3, ir.GetNumRows())
|
||||
return &milvuspb.MutationResult{
|
||||
Status: merr.Success(),
|
||||
}, nil
|
||||
}).Once()
|
||||
|
||||
err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
|
||||
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
||||
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
||||
})).
|
||||
WithVarcharColumn("extra", []string{"a", "b", "c"}).
|
||||
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
s.Run("bad_input", func() {
|
||||
collName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
s.setupCache(collName, s.schema)
|
||||
|
||||
type badCase struct {
|
||||
tag string
|
||||
input InsertOption
|
||||
}
|
||||
|
||||
cases := []badCase{
|
||||
{
|
||||
tag: "missing_column",
|
||||
input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}),
|
||||
},
|
||||
{
|
||||
tag: "row_count_not_match",
|
||||
input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}).
|
||||
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
||||
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
||||
})),
|
||||
},
|
||||
{
|
||||
tag: "duplicated_columns",
|
||||
input: NewColumnBasedInsertOption(collName).
|
||||
WithInt64Column("id", []int64{1}).
|
||||
WithInt64Column("id", []int64{2}).
|
||||
WithFloatVectorColumn("vector", 128, lo.RepeatBy(1, func(i int) []float32 {
|
||||
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
||||
})),
|
||||
},
|
||||
{
|
||||
tag: "different_data_type",
|
||||
input: NewColumnBasedInsertOption(collName).
|
||||
WithVarcharColumn("id", []string{"1"}).
|
||||
WithFloatVectorColumn("vector", 128, lo.RepeatBy(1, func(i int) []float32 {
|
||||
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
||||
})),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
s.Run(tc.tag, func() {
|
||||
err := s.client.Insert(ctx, tc.input)
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
collName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
s.setupCache(collName, s.schema)
|
||||
|
||||
s.mock.EXPECT().Insert(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
err := s.client.Insert(ctx, NewColumnBasedInsertOption(collName).
|
||||
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
||||
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
||||
})).
|
||||
WithInt64Column("id", []int64{1, 2, 3}))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *WriteSuite) TestUpsert() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
collName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
partName := fmt.Sprintf("part_%s", s.randString(6))
|
||||
s.setupCache(collName, s.schema)
|
||||
|
||||
s.mock.EXPECT().Upsert(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ur *milvuspb.UpsertRequest) (*milvuspb.MutationResult, error) {
|
||||
s.Equal(collName, ur.GetCollectionName())
|
||||
s.Equal(partName, ur.GetPartitionName())
|
||||
s.Require().Len(ur.GetFieldsData(), 2)
|
||||
s.EqualValues(3, ur.GetNumRows())
|
||||
return &milvuspb.MutationResult{
|
||||
Status: merr.Success(),
|
||||
}, nil
|
||||
}).Once()
|
||||
|
||||
err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
|
||||
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
||||
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
||||
})).
|
||||
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
s.Run("dynamic_schema", func() {
|
||||
collName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
partName := fmt.Sprintf("part_%s", s.randString(6))
|
||||
s.setupCache(collName, s.schemaDyn)
|
||||
|
||||
s.mock.EXPECT().Upsert(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ur *milvuspb.UpsertRequest) (*milvuspb.MutationResult, error) {
|
||||
s.Equal(collName, ur.GetCollectionName())
|
||||
s.Equal(partName, ur.GetPartitionName())
|
||||
s.Require().Len(ur.GetFieldsData(), 3)
|
||||
s.EqualValues(3, ur.GetNumRows())
|
||||
return &milvuspb.MutationResult{
|
||||
Status: merr.Success(),
|
||||
}, nil
|
||||
}).Once()
|
||||
|
||||
err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
|
||||
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
||||
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
||||
})).
|
||||
WithVarcharColumn("extra", []string{"a", "b", "c"}).
|
||||
WithInt64Column("id", []int64{1, 2, 3}).WithPartition(partName))
|
||||
s.NoError(err)
|
||||
})
|
||||
|
||||
s.Run("bad_input", func() {
|
||||
collName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
s.setupCache(collName, s.schema)
|
||||
|
||||
type badCase struct {
|
||||
tag string
|
||||
input UpsertOption
|
||||
}
|
||||
|
||||
cases := []badCase{
|
||||
{
|
||||
tag: "missing_column",
|
||||
input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}),
|
||||
},
|
||||
{
|
||||
tag: "row_count_not_match",
|
||||
input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}).
|
||||
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
||||
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
||||
})),
|
||||
},
|
||||
{
|
||||
tag: "duplicated_columns",
|
||||
input: NewColumnBasedInsertOption(collName).
|
||||
WithInt64Column("id", []int64{1}).
|
||||
WithInt64Column("id", []int64{2}).
|
||||
WithFloatVectorColumn("vector", 128, lo.RepeatBy(1, func(i int) []float32 {
|
||||
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
||||
})),
|
||||
},
|
||||
{
|
||||
tag: "different_data_type",
|
||||
input: NewColumnBasedInsertOption(collName).
|
||||
WithVarcharColumn("id", []string{"1"}).
|
||||
WithFloatVectorColumn("vector", 128, lo.RepeatBy(1, func(i int) []float32 {
|
||||
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
||||
})),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
s.Run(tc.tag, func() {
|
||||
err := s.client.Upsert(ctx, tc.input)
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
s.Run("failure", func() {
|
||||
collName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
s.setupCache(collName, s.schema)
|
||||
|
||||
s.mock.EXPECT().Upsert(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
|
||||
|
||||
err := s.client.Upsert(ctx, NewColumnBasedInsertOption(collName).
|
||||
WithFloatVectorColumn("vector", 128, lo.RepeatBy(3, func(i int) []float32 {
|
||||
return lo.RepeatBy(128, func(i int) float32 { return rand.Float32() })
|
||||
})).
|
||||
WithInt64Column("id", []int64{1, 2, 3}))
|
||||
s.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *WriteSuite) TestDelete() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.Run("success", func() {
|
||||
collName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
partName := fmt.Sprintf("part_%s", s.randString(6))
|
||||
|
||||
type testCase struct {
|
||||
tag string
|
||||
input DeleteOption
|
||||
expectExpr string
|
||||
}
|
||||
|
||||
cases := []testCase{
|
||||
{
|
||||
tag: "raw_expr",
|
||||
input: NewDeleteOption(collName).WithPartition(partName).WithExpr("id > 100"),
|
||||
expectExpr: "id > 100",
|
||||
},
|
||||
{
|
||||
tag: "int_ids",
|
||||
input: NewDeleteOption(collName).WithPartition(partName).WithInt64IDs("id", []int64{1, 2, 3}),
|
||||
expectExpr: "id in [1,2,3]",
|
||||
},
|
||||
{
|
||||
tag: "str_ids",
|
||||
input: NewDeleteOption(collName).WithPartition(partName).WithStringIDs("id", []string{"a", "b", "c"}),
|
||||
expectExpr: `id in ["a","b","c"]`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
s.Run(tc.tag, func() {
|
||||
s.mock.EXPECT().Delete(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, dr *milvuspb.DeleteRequest) (*milvuspb.MutationResult, error) {
|
||||
s.Equal(collName, dr.GetCollectionName())
|
||||
s.Equal(partName, dr.GetPartitionName())
|
||||
s.Equal(tc.expectExpr, dr.GetExpr())
|
||||
return &milvuspb.MutationResult{
|
||||
Status: merr.Success(),
|
||||
}, nil
|
||||
}).Once()
|
||||
err := s.client.Delete(ctx, tc.input)
|
||||
s.NoError(err)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWrite(t *testing.T) {
|
||||
suite.Run(t, new(WriteSuite))
|
||||
}
|
||||
@ -58,6 +58,16 @@ for d in $(go list ./... | grep -v -e vendor -e kafka -e planparserv2/generated
|
||||
fi
|
||||
done
|
||||
popd
|
||||
# milvusclient
|
||||
pushd client
|
||||
for d in $(go list ./... | grep -v -e vendor -e kafka -e planparserv2/generated -e mocks); do
|
||||
$TEST_CMD -race -tags dynamic -v -coverpkg=./... -coverprofile=profile.out -covermode=atomic "$d"
|
||||
if [ -f profile.out ]; then
|
||||
grep -v kafka profile.out | grep -v planparserv2/generated | grep -v mocks | sed '1d' >> ../${FILE_COVERAGE_INFO}
|
||||
rm profile.out
|
||||
fi
|
||||
done
|
||||
popd
|
||||
endTime=`date +%s`
|
||||
|
||||
echo "Total time for go unittest:" $(($endTime-$beginTime)) "s"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user