mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
Add unittest for distributed/querycoord/indexnode (#8680)
Signed-off-by: yhmo <yihua.mo@zilliz.com>
This commit is contained in:
parent
c05d4fa362
commit
693994c6db
@ -42,9 +42,15 @@ type Client struct {
|
||||
grpcClientMtx sync.RWMutex
|
||||
|
||||
addr string
|
||||
|
||||
getGrpcClient func() (indexpb.IndexNodeClient, error)
|
||||
}
|
||||
|
||||
func (c *Client) getGrpcClient() (indexpb.IndexNodeClient, error) {
|
||||
func (c *Client) setGetGrpcClientFunc() {
|
||||
c.getGrpcClient = c.getGrpcClientFunc
|
||||
}
|
||||
|
||||
func (c *Client) getGrpcClientFunc() (indexpb.IndexNodeClient, error) {
|
||||
c.grpcClientMtx.RLock()
|
||||
if c.grpcClient != nil {
|
||||
defer c.grpcClientMtx.RUnlock()
|
||||
@ -86,16 +92,19 @@ func NewClient(ctx context.Context, addr string) (*Client, error) {
|
||||
}
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
return &Client{
|
||||
client := &Client{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
addr: addr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
client.setGetGrpcClientFunc()
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *Client) Init() error {
|
||||
Params.Init()
|
||||
return c.connect(retry.Attempts(20))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) connect(retryOptions ...retry.Option) error {
|
||||
|
||||
@ -13,6 +13,7 @@ package grpcindexnodeclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
grpcindexnode "github.com/milvus-io/milvus/internal/distributed/indexnode"
|
||||
@ -21,9 +22,102 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/proto/indexpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proxy"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
type MockIndexNodeClient struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *MockIndexNodeClient) GetComponentStates(ctx context.Context, in *internalpb.GetComponentStatesRequest, opts ...grpc.CallOption) (*internalpb.ComponentStates, error) {
|
||||
return &internalpb.ComponentStates{}, m.err
|
||||
}
|
||||
|
||||
func (m *MockIndexNodeClient) GetTimeTickChannel(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) {
|
||||
return &milvuspb.StringResponse{}, m.err
|
||||
}
|
||||
|
||||
func (m *MockIndexNodeClient) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) {
|
||||
return &milvuspb.StringResponse{}, m.err
|
||||
}
|
||||
|
||||
func (m *MockIndexNodeClient) CreateIndex(ctx context.Context, in *indexpb.CreateIndexRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
return &commonpb.Status{}, m.err
|
||||
}
|
||||
|
||||
func (m *MockIndexNodeClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) {
|
||||
return &milvuspb.GetMetricsResponse{}, m.err
|
||||
}
|
||||
|
||||
func Test_NewClient(t *testing.T) {
|
||||
proxy.Params.InitOnce()
|
||||
|
||||
ctx := context.Background()
|
||||
client, err := NewClient(ctx, "")
|
||||
assert.Nil(t, client)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
client, err = NewClient(ctx, "test")
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, client)
|
||||
|
||||
err = client.Init()
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = client.Start()
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = client.Register()
|
||||
assert.Nil(t, err)
|
||||
|
||||
checkFunc := func(retNotNil bool) {
|
||||
retCheck := func(notNil bool, ret interface{}, err error) {
|
||||
if notNil {
|
||||
assert.NotNil(t, ret)
|
||||
assert.Nil(t, err)
|
||||
} else {
|
||||
assert.Nil(t, ret)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
r1, err := client.GetComponentStates(ctx)
|
||||
retCheck(retNotNil, r1, err)
|
||||
|
||||
r2, err := client.GetTimeTickChannel(ctx)
|
||||
retCheck(retNotNil, r2, err)
|
||||
|
||||
r3, err := client.GetStatisticsChannel(ctx)
|
||||
retCheck(retNotNil, r3, err)
|
||||
|
||||
r4, err := client.CreateIndex(ctx, nil)
|
||||
retCheck(retNotNil, r4, err)
|
||||
|
||||
r5, err := client.GetMetrics(ctx, nil)
|
||||
retCheck(retNotNil, r5, err)
|
||||
}
|
||||
|
||||
client.getGrpcClient = func() (indexpb.IndexNodeClient, error) {
|
||||
return &MockIndexNodeClient{err: nil}, errors.New("dummy")
|
||||
}
|
||||
checkFunc(false)
|
||||
|
||||
client.getGrpcClient = func() (indexpb.IndexNodeClient, error) {
|
||||
return &MockIndexNodeClient{err: errors.New("dummy")}, nil
|
||||
}
|
||||
checkFunc(false)
|
||||
|
||||
client.getGrpcClient = func() (indexpb.IndexNodeClient, error) {
|
||||
return &MockIndexNodeClient{err: nil}, nil
|
||||
}
|
||||
checkFunc(true)
|
||||
|
||||
err = client.Stop()
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestIndexNodeClient(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
|
||||
@ -1,356 +0,0 @@
|
||||
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
package grpcquerycoordclient
|
||||
|
||||
//import (
|
||||
// "context"
|
||||
// "encoding/binary"
|
||||
// "math"
|
||||
// "path"
|
||||
// "strconv"
|
||||
// "testing"
|
||||
// "time"
|
||||
//
|
||||
// "github.com/stretchr/testify/assert"
|
||||
//
|
||||
// "github.com/milvus-io/milvus/internal/indexnode"
|
||||
// minioKV "github.com/milvus-io/milvus/internal/kv/minio"
|
||||
// "github.com/milvus-io/milvus/internal/msgstream"
|
||||
// "github.com/milvus-io/milvus/internal/msgstream/pulsarms"
|
||||
// "github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
// "github.com/milvus-io/milvus/internal/proto/etcdpb"
|
||||
// "github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
// "github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
// "github.com/milvus-io/milvus/internal/storage"
|
||||
// "github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
//)
|
||||
//
|
||||
////generate insert data
|
||||
//const msgLength = 100
|
||||
//const receiveBufSize = 1024
|
||||
//const pulsarBufSize = 1024
|
||||
//const DIM = 16
|
||||
//
|
||||
//type UniqueID = typeutil.UniqueID
|
||||
//
|
||||
//func genInsert(collectionID int64,
|
||||
// partitionID int64,
|
||||
// timeStart int,
|
||||
// numDmChannels int,
|
||||
// binlog bool) (*msgstream.MsgPack, *msgstream.MsgPack) {
|
||||
// msgs := make([]msgstream.TsMsg, 0)
|
||||
// for n := timeStart; n < timeStart+msgLength; n++ {
|
||||
// rowData := make([]byte, 0)
|
||||
// if binlog {
|
||||
// id := make([]byte, 8)
|
||||
// binary.BigEndian.PutUint64(id, uint64(n))
|
||||
// rowData = append(rowData, id...)
|
||||
// time := make([]byte, 8)
|
||||
// binary.BigEndian.PutUint64(time, uint64(n))
|
||||
// rowData = append(rowData, time...)
|
||||
// }
|
||||
// for i := 0; i < DIM; i++ {
|
||||
// vec := make([]byte, 4)
|
||||
// binary.BigEndian.PutUint32(vec, math.Float32bits(float32(n*i)))
|
||||
// rowData = append(rowData, vec...)
|
||||
// }
|
||||
// age := make([]byte, 4)
|
||||
// binary.BigEndian.PutUint32(age, 1)
|
||||
// rowData = append(rowData, age...)
|
||||
// blob := &commonpb.Blob{
|
||||
// Value: rowData,
|
||||
// }
|
||||
//
|
||||
// var insertMsg msgstream.TsMsg = &msgstream.InsertMsg{
|
||||
// BaseMsg: msgstream.BaseMsg{
|
||||
// HashValues: []uint32{uint32((n - 1) % numDmChannels)},
|
||||
// },
|
||||
// InsertRequest: internalpb.InsertRequest{
|
||||
// Base: &commonpb.MsgBase{
|
||||
// MsgType: commonpb.MsgType_kInsert,
|
||||
// MsgID: 0,
|
||||
// Timestamp: uint64(n),
|
||||
// SourceID: 0,
|
||||
// },
|
||||
// CollectionID: collectionID,
|
||||
// PartitionID: partitionID,
|
||||
// SegmentID: UniqueID(((n - 1) % numDmChannels) + ((n-1)/(numDmChannels*msgLength))*numDmChannels),
|
||||
// ChannelID: "0",
|
||||
// Timestamps: []uint64{uint64(n)},
|
||||
// RowIDs: []int64{int64(n)},
|
||||
// RowData: []*commonpb.Blob{blob},
|
||||
// },
|
||||
// }
|
||||
// //fmt.Println("hash value = ", insertMsg.(*msgstream.InsertMsg).HashValues, "segmentID = ", insertMsg.(*msgstream.InsertMsg).SegmentID)
|
||||
// msgs = append(msgs, insertMsg)
|
||||
// }
|
||||
//
|
||||
// insertMsgPack := &msgstream.MsgPack{
|
||||
// BeginTs: uint64(timeStart),
|
||||
// EndTs: uint64(timeStart + msgLength),
|
||||
// Msgs: msgs,
|
||||
// }
|
||||
//
|
||||
// // generate timeTick
|
||||
// timeTickMsg := &msgstream.TimeTickMsg{
|
||||
// BaseMsg: msgstream.BaseMsg{
|
||||
// BeginTimestamp: 0,
|
||||
// EndTimestamp: 0,
|
||||
// HashValues: []uint32{0},
|
||||
// },
|
||||
// TimeTickMsg: internalpb.TimeTickMsg{
|
||||
// Base: &commonpb.MsgBase{
|
||||
// MsgType: commonpb.MsgType_kTimeTick,
|
||||
// MsgID: 0,
|
||||
// Timestamp: uint64(timeStart + msgLength),
|
||||
// SourceID: 0,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
// timeTickMsgPack := &msgstream.MsgPack{
|
||||
// Msgs: []msgstream.TsMsg{timeTickMsg},
|
||||
// }
|
||||
// return insertMsgPack, timeTickMsgPack
|
||||
//}
|
||||
//
|
||||
//func genSchema(collectionID int64) *schemapb.CollectionSchema {
|
||||
// fieldID := schemapb.FieldSchema{
|
||||
// FieldID: UniqueID(0),
|
||||
// Name: "RowID",
|
||||
// IsPrimaryKey: false,
|
||||
// DataType: schemapb.DataType_INT64,
|
||||
// }
|
||||
//
|
||||
// fieldTime := schemapb.FieldSchema{
|
||||
// FieldID: UniqueID(1),
|
||||
// Name: "Timestamp",
|
||||
// IsPrimaryKey: false,
|
||||
// DataType: schemapb.DataType_INT64,
|
||||
// }
|
||||
//
|
||||
// fieldVec := schemapb.FieldSchema{
|
||||
// FieldID: UniqueID(100),
|
||||
// Name: "vec",
|
||||
// IsPrimaryKey: false,
|
||||
// DataType: schemapb.DataType_VECTOR_FLOAT,
|
||||
// TypeParams: []*commonpb.KeyValuePair{
|
||||
// {
|
||||
// Key: "dim",
|
||||
// Value: "16",
|
||||
// },
|
||||
// },
|
||||
// IndexParams: []*commonpb.KeyValuePair{
|
||||
// {
|
||||
// Key: "metric_type",
|
||||
// Value: "L2",
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
//
|
||||
// fieldInt := schemapb.FieldSchema{
|
||||
// FieldID: UniqueID(101),
|
||||
// Name: "age",
|
||||
// IsPrimaryKey: false,
|
||||
// DataType: schemapb.DataType_INT32,
|
||||
// }
|
||||
//
|
||||
// return &schemapb.CollectionSchema{
|
||||
// Name: "collection-" + strconv.FormatInt(collectionID, 10),
|
||||
// AutoID: true,
|
||||
// Fields: []*schemapb.FieldSchema{
|
||||
// &fieldID, &fieldTime, &fieldVec, &fieldInt,
|
||||
// },
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//func getMinioKV(ctx context.Context) (*minioKV.MinIOKV, error) {
|
||||
// minioAddress := "localhost:9000"
|
||||
// accessKeyID := "minioadmin"
|
||||
// secretAccessKey := "minioadmin"
|
||||
// useSSL := false
|
||||
// bucketName := "a-bucket"
|
||||
//
|
||||
// option := &minioKV.Option{
|
||||
// Address: minioAddress,
|
||||
// AccessKeyID: accessKeyID,
|
||||
// SecretAccessKeyID: secretAccessKey,
|
||||
// UseSSL: useSSL,
|
||||
// BucketName: bucketName,
|
||||
// CreateBucket: true,
|
||||
// }
|
||||
//
|
||||
// return minioKV.NewMinIOKV(ctx, option)
|
||||
//}
|
||||
//
|
||||
//func TestWriteBinLog(t *testing.T) {
|
||||
// const (
|
||||
// debug = true
|
||||
// consumeSubName = "test-load-collection-sub-name"
|
||||
// )
|
||||
// var ctx context.Context
|
||||
// if debug {
|
||||
// ctx = context.Background()
|
||||
// } else {
|
||||
// var cancel context.CancelFunc
|
||||
// ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
|
||||
// defer cancel()
|
||||
// }
|
||||
//
|
||||
// // produce msg
|
||||
// insertChannels := []string{"insert-0", "insert-1", "insert-2", "insert-3"}
|
||||
// pulsarAddress := "pulsar://127.0.0.1:6650"
|
||||
//
|
||||
// factory := pulsarms.NewFactory(pulsarAddress, receiveBufSize, pulsarBufSize)
|
||||
//
|
||||
// insertStream, _ := factory.NewTtMsgStream(ctx)
|
||||
// insertStream.AsProducer(insertChannels)
|
||||
// insertStream.AsConsumer(insertChannels, consumeSubName)
|
||||
// insertStream.Start()
|
||||
//
|
||||
// for i := 0; i < 12; i++ {
|
||||
// insertMsgPack, timeTickMsgPack := genInsert(1, 1, i*msgLength+1, 4, true)
|
||||
// err := insertStream.Produce(insertMsgPack)
|
||||
// assert.NoError(t, err)
|
||||
// err = insertStream.Broadcast(timeTickMsgPack)
|
||||
// assert.NoError(t, err)
|
||||
// }
|
||||
//
|
||||
// //consume msg
|
||||
// segmentData := make([]*storage.InsertData, 12)
|
||||
// idData := make([][]int64, 12)
|
||||
// timestamps := make([][]int64, 12)
|
||||
// fieldAgeData := make([][]int32, 12)
|
||||
// fieldVecData := make([][]float32, 12)
|
||||
// for i := 0; i < 12; i++ {
|
||||
// idData[i] = make([]int64, 0)
|
||||
// timestamps[i] = make([]int64, 0)
|
||||
// fieldAgeData[i] = make([]int32, 0)
|
||||
// fieldVecData[i] = make([]float32, 0)
|
||||
// }
|
||||
// for i := 0; i < 12; i++ {
|
||||
// msgPack := insertStream.Consume()
|
||||
//
|
||||
// for n := 0; n < msgLength; n++ {
|
||||
// segmentID := msgPack.Msgs[n].(*msgstream.InsertMsg).SegmentID
|
||||
// blob := msgPack.Msgs[n].(*msgstream.InsertMsg).RowData[0].Value
|
||||
// id := binary.BigEndian.Uint64(blob[0:8])
|
||||
// idData[segmentID] = append(idData[segmentID], int64(id))
|
||||
// t := binary.BigEndian.Uint64(blob[8:16])
|
||||
// timestamps[segmentID] = append(timestamps[segmentID], int64(t))
|
||||
// for i := 0; i < DIM; i++ {
|
||||
// bits := binary.BigEndian.Uint32(blob[16+4*i : 16+4*(i+1)])
|
||||
// floatVec := math.Float32frombits(bits)
|
||||
// fieldVecData[segmentID] = append(fieldVecData[segmentID], floatVec)
|
||||
// }
|
||||
// ageValue := binary.BigEndian.Uint32(blob[80:84])
|
||||
// fieldAgeData[segmentID] = append(fieldAgeData[segmentID], int32(ageValue))
|
||||
// }
|
||||
// }
|
||||
// for i := 0; i < 12; i++ {
|
||||
// insertData := &storage.InsertData{
|
||||
// Data: map[int64]storage.FieldData{
|
||||
// 0: &storage.Int64FieldData{
|
||||
// NumRows: msgLength,
|
||||
// Data: idData[i],
|
||||
// },
|
||||
// 1: &storage.Int64FieldData{
|
||||
// NumRows: msgLength,
|
||||
// Data: timestamps[i],
|
||||
// },
|
||||
// 100: &storage.FloatVectorFieldData{
|
||||
// NumRows: msgLength,
|
||||
// Data: fieldVecData[i],
|
||||
// Dim: DIM,
|
||||
// },
|
||||
// 101: &storage.Int32FieldData{
|
||||
// NumRows: msgLength,
|
||||
// Data: fieldAgeData[i],
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
// segmentData[i] = insertData
|
||||
// }
|
||||
//
|
||||
// //gen inCodec
|
||||
// collectionMeta := &etcdpb.CollectionMeta{
|
||||
// ID: 1,
|
||||
// Schema: genSchema(1),
|
||||
// CreateTime: 0,
|
||||
// PartitionIDs: []int64{1},
|
||||
// SegmentIDs: []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
|
||||
// }
|
||||
// inCodec := storage.NewInsertCodec(collectionMeta)
|
||||
// indexCodec := storage.NewIndexCodec()
|
||||
//
|
||||
// // get minio client
|
||||
// kv, err := getMinioKV(context.Background())
|
||||
// assert.Nil(t, err)
|
||||
//
|
||||
// // write binlog minio
|
||||
// collectionStr := strconv.FormatInt(1, 10)
|
||||
// for i := 0; i < 12; i++ {
|
||||
// binLogs, err := inCodec.Serialize(1, storage.UniqueID(i), segmentData[i])
|
||||
// assert.Nil(t, err)
|
||||
// assert.Equal(t, len(binLogs), 4)
|
||||
// keyPrefix := "distributed-query-test-binlog"
|
||||
// segmentStr := strconv.FormatInt(int64(i), 10)
|
||||
//
|
||||
// for _, blob := range binLogs {
|
||||
// key := path.Join(keyPrefix, collectionStr, segmentStr, blob.Key)
|
||||
// err = kv.Save(key, string(blob.Value[:]))
|
||||
// assert.Nil(t, err)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // gen index build's indexParams
|
||||
// indexParams := make(map[string]string)
|
||||
// indexParams["index_type"] = "IVF_PQ"
|
||||
// indexParams["index_mode"] = "cpu"
|
||||
// indexParams["dim"] = "16"
|
||||
// indexParams["k"] = "10"
|
||||
// indexParams["nlist"] = "100"
|
||||
// indexParams["nprobe"] = "10"
|
||||
// indexParams["m"] = "4"
|
||||
// indexParams["nbits"] = "8"
|
||||
// indexParams["metric_type"] = "L2"
|
||||
// indexParams["SLICE_SIZE"] = "400"
|
||||
//
|
||||
// var indexParamsKV []*commonpb.KeyValuePair
|
||||
// for key, value := range indexParams {
|
||||
// indexParamsKV = append(indexParamsKV, &commonpb.KeyValuePair{
|
||||
// Key: key,
|
||||
// Value: value,
|
||||
// })
|
||||
// }
|
||||
//
|
||||
// // generator index and write index to minio
|
||||
// for i := 0; i < 12; i++ {
|
||||
// typeParams := make(map[string]string)
|
||||
// typeParams["dim"] = "16"
|
||||
// index, err := indexnode.NewCIndex(typeParams, indexParams)
|
||||
// assert.Nil(t, err)
|
||||
// err = index.BuildFloatVecIndexWithoutIds(fieldVecData[i])
|
||||
// assert.Equal(t, err, nil)
|
||||
// binarySet, err := index.Serialize()
|
||||
// assert.Equal(t, len(binarySet), 1)
|
||||
// assert.Nil(t, err)
|
||||
// codecIndex, err := indexCodec.Serialize(binarySet, indexParams, "test_index", UniqueID(i))
|
||||
// assert.Equal(t, len(codecIndex), 2)
|
||||
// assert.Nil(t, err)
|
||||
// keyPrefix := "distributed-query-test-index"
|
||||
// segmentStr := strconv.FormatInt(int64(i), 10)
|
||||
// key1 := path.Join(keyPrefix, collectionStr, segmentStr, "IVF")
|
||||
// key2 := path.Join(keyPrefix, collectionStr, segmentStr, "indexParams")
|
||||
// kv.Save(key1, string(codecIndex[0].Value))
|
||||
// kv.Save(key2, string(codecIndex[1].Value))
|
||||
// }
|
||||
//}
|
||||
@ -45,9 +45,15 @@ type Client struct {
|
||||
|
||||
sess *sessionutil.Session
|
||||
addr string
|
||||
|
||||
getGrpcClient func() (querypb.QueryCoordClient, error)
|
||||
}
|
||||
|
||||
func (c *Client) getGrpcClient() (querypb.QueryCoordClient, error) {
|
||||
func (c *Client) setGetGrpcClientFunc() {
|
||||
c.getGrpcClient = c.getGrpcClientFunc
|
||||
}
|
||||
|
||||
func (c *Client) getGrpcClientFunc() (querypb.QueryCoordClient, error) {
|
||||
c.grpcClientMtx.RLock()
|
||||
if c.grpcClient != nil {
|
||||
defer c.grpcClientMtx.RUnlock()
|
||||
@ -107,16 +113,19 @@ func NewClient(ctx context.Context, metaRoot string, etcdEndpoints []string) (*C
|
||||
return nil, err
|
||||
}
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &Client{
|
||||
client := &Client{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
sess: sess,
|
||||
}, nil
|
||||
}
|
||||
|
||||
client.setGetGrpcClientFunc()
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *Client) Init() error {
|
||||
Params.Init()
|
||||
return c.connect(retry.Attempts(20))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) connect(retryOptions ...retry.Option) error {
|
||||
|
||||
@ -11,262 +11,165 @@
|
||||
|
||||
package grpcquerycoordclient
|
||||
|
||||
//import (
|
||||
// "context"
|
||||
// "encoding/binary"
|
||||
// "fmt"
|
||||
// "log"
|
||||
// "math"
|
||||
// "testing"
|
||||
// "time"
|
||||
//
|
||||
// "github.com/golang/protobuf/proto"
|
||||
// "github.com/stretchr/testify/assert"
|
||||
//
|
||||
// "github.com/milvus-io/milvus/internal/msgstream"
|
||||
// "github.com/milvus-io/milvus/internal/msgstream/pulsarms"
|
||||
// "github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
// "github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
// "github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
// "github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
// qs "github.com/milvus-io/milvus/internal/querycoord"
|
||||
//)
|
||||
//
|
||||
//const (
|
||||
// debug = false
|
||||
// pulsarAddress = "pulsar://127.0.0.1:6650"
|
||||
//)
|
||||
//
|
||||
//func TestClient_LoadCollection(t *testing.T) {
|
||||
// var ctx context.Context
|
||||
// if debug {
|
||||
// ctx = context.Background()
|
||||
// } else {
|
||||
// var cancel context.CancelFunc
|
||||
// ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
|
||||
// defer cancel()
|
||||
// }
|
||||
//
|
||||
// //create queryCoord client
|
||||
// qs.Params.Init()
|
||||
// log.Println("QueryCoord address:", qs.Params.Address)
|
||||
// log.Println("Init Query service client ...")
|
||||
// client, err := NewClient(qs.Params.Address, 20*time.Second)
|
||||
// assert.Nil(t, err)
|
||||
// err = client.Init()
|
||||
// assert.Nil(t, err)
|
||||
// err = client.Start()
|
||||
// assert.Nil(t, err)
|
||||
//
|
||||
// insertChannels := []string{"insert-0", "insert-1", "insert-2", "insert-3"}
|
||||
// ddChannels := []string{"data-definition"}
|
||||
//
|
||||
// factory := pulsarms.NewFactory(pulsarAddress, receiveBufSize, pulsarBufSize)
|
||||
// insertStream, _ := factory.NewTtMsgStream(ctx)
|
||||
// insertStream.AsProducer(insertChannels)
|
||||
// insertStream.Start()
|
||||
//
|
||||
// ddStream, err := factory.NewTtMsgStream(ctx)
|
||||
// assert.NoError(t, err)
|
||||
// ddStream.AsProducer(ddChannels)
|
||||
// ddStream.Start()
|
||||
//
|
||||
// // showCollection
|
||||
// showCollectionRequest := &querypb.ShowCollectionsRequest{
|
||||
// DbID: 0,
|
||||
// }
|
||||
// showCollectionRes, err := client.ShowCollections(showCollectionRequest)
|
||||
// fmt.Println("showCollectionRes: ", showCollectionRes)
|
||||
// assert.Nil(t, err)
|
||||
//
|
||||
// //load collection
|
||||
// loadCollectionRequest := &querypb.LoadCollectionRequest{
|
||||
// CollectionID: 1,
|
||||
// Schema: genSchema(1),
|
||||
// }
|
||||
// loadCollectionRes, err := client.LoadCollection(loadCollectionRequest)
|
||||
// fmt.Println("loadCollectionRes: ", loadCollectionRes)
|
||||
// assert.Nil(t, err)
|
||||
//
|
||||
// // showCollection
|
||||
// showCollectionRes, err = client.ShowCollections(showCollectionRequest)
|
||||
// fmt.Println("showCollectionRes: ", showCollectionRes)
|
||||
// assert.Nil(t, err)
|
||||
//
|
||||
// //showSegmentInfo
|
||||
// getSegmentInfoRequest := &querypb.SegmentInfoRequest{
|
||||
// SegmentIDs: []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
|
||||
// }
|
||||
// getSegmentInfoRes, err := client.GetSegmentInfo(getSegmentInfoRequest)
|
||||
// fmt.Println("segment info : ", getSegmentInfoRes)
|
||||
// assert.Nil(t, err)
|
||||
//
|
||||
// // insert msg
|
||||
// for i := 0; i < 12; i++ {
|
||||
// insertMsgPack, timeTickMsgPack := genInsert(1, 1, i*msgLength+1, 4, false)
|
||||
// err := insertStream.Produce(insertMsgPack)
|
||||
// assert.NoError(t, err)
|
||||
// err = insertStream.Broadcast(timeTickMsgPack)
|
||||
// assert.NoError(t, err)
|
||||
// err = ddStream.Broadcast(timeTickMsgPack)
|
||||
// assert.NoError(t, err)
|
||||
// }
|
||||
//
|
||||
// getSegmentInfoRes, err = client.GetSegmentInfo(getSegmentInfoRequest)
|
||||
// assert.Nil(t, err)
|
||||
// fmt.Println("segment info : ", getSegmentInfoRes)
|
||||
//
|
||||
//}
|
||||
//
|
||||
//func TestClient_GetSegmentInfo(t *testing.T) {
|
||||
// if !debug {
|
||||
// _, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
|
||||
// defer cancel()
|
||||
// }
|
||||
//
|
||||
// //create queryCoord client
|
||||
// qs.Params.Init()
|
||||
// log.Println("QueryCoord address:", qs.Params.Address)
|
||||
// log.Println("Init Query Coord client ...")
|
||||
// client, err := NewClient(qs.Params.Address, 20*time.Second)
|
||||
// assert.Nil(t, err)
|
||||
// err = client.Init()
|
||||
// assert.Nil(t, err)
|
||||
// err = client.Start()
|
||||
// assert.Nil(t, err)
|
||||
//
|
||||
// //showSegmentInfo
|
||||
// getSegmentInfoRequest := &querypb.SegmentInfoRequest{
|
||||
// SegmentIDs: []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
|
||||
// }
|
||||
// getSegmentInfoRes, err := client.GetSegmentInfo(getSegmentInfoRequest)
|
||||
// assert.Nil(t, err)
|
||||
// fmt.Println("segment info : ", getSegmentInfoRes)
|
||||
//}
|
||||
//
|
||||
//func TestClient_LoadPartitions(t *testing.T) {
|
||||
// if !debug {
|
||||
// _, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
|
||||
// defer cancel()
|
||||
// }
|
||||
//
|
||||
// //create queryCoord client
|
||||
// qs.Params.Init()
|
||||
// log.Println("QueryCoord address:", qs.Params.Address)
|
||||
// log.Println("Init Query service client ...")
|
||||
// client, err := NewClient(qs.Params.Address, 20*time.Second)
|
||||
// assert.Nil(t, err)
|
||||
// err = client.Init()
|
||||
// assert.Nil(t, err)
|
||||
// err = client.Start()
|
||||
// assert.Nil(t, err)
|
||||
//
|
||||
// loadPartitionRequest := &querypb.LoadPartitionsRequest{
|
||||
// CollectionID: 1,
|
||||
// Schema: genSchema(1),
|
||||
// }
|
||||
// loadPartitionRes, err := client.LoadPartitions(loadPartitionRequest)
|
||||
// fmt.Println("loadCollectionRes: ", loadPartitionRes)
|
||||
// assert.Nil(t, err)
|
||||
//}
|
||||
//
|
||||
//func TestClient_GetChannels(t *testing.T) {
|
||||
// if !debug {
|
||||
// _, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
|
||||
// defer cancel()
|
||||
// }
|
||||
//
|
||||
// //create queryCoord client
|
||||
// qs.Params.Init()
|
||||
// log.Println("QueryCoord address:", qs.Params.Address)
|
||||
// log.Println("Init Query service client ...")
|
||||
// client, err := NewClient(qs.Params.Address, 20*time.Second)
|
||||
// assert.Nil(t, err)
|
||||
// err = client.Init()
|
||||
// assert.Nil(t, err)
|
||||
// err = client.Start()
|
||||
// assert.Nil(t, err)
|
||||
//
|
||||
// getTimeTickChannelRes, err := client.GetTimeTickChannel()
|
||||
// fmt.Println("loadCollectionRes: ", getTimeTickChannelRes)
|
||||
// assert.Nil(t, err)
|
||||
//}
|
||||
//
|
||||
//func sendSearchRequest(ctx context.Context, searchChannels []string) {
|
||||
// // test data generate
|
||||
// const msgLength = 10
|
||||
// const receiveBufSize = 1024
|
||||
// const DIM = 16
|
||||
//
|
||||
// var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
// // start search service
|
||||
// dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }"
|
||||
// var searchRawData1 []byte
|
||||
// var searchRawData2 []byte
|
||||
// for i, ele := range vec {
|
||||
// buf := make([]byte, 4)
|
||||
// binary.LittleEndian.PutUint32(buf, math.Float32bits(ele+float32(i*2)))
|
||||
// searchRawData1 = append(searchRawData1, buf...)
|
||||
// }
|
||||
// for i, ele := range vec {
|
||||
// buf := make([]byte, 4)
|
||||
// binary.LittleEndian.PutUint32(buf, math.Float32bits(ele+float32(i*4)))
|
||||
// searchRawData2 = append(searchRawData2, buf...)
|
||||
// }
|
||||
// placeholderValue := milvuspb.PlaceholderValue{
|
||||
// Tag: "$0",
|
||||
// Type: milvuspb.PlaceholderType_VECTOR_FLOAT,
|
||||
// Values: [][]byte{searchRawData1, searchRawData2},
|
||||
// }
|
||||
//
|
||||
// placeholderGroup := milvuspb.PlaceholderGroup{
|
||||
// Placeholders: []*milvuspb.PlaceholderValue{&placeholderValue},
|
||||
// }
|
||||
//
|
||||
// placeGroupByte, err := proto.Marshal(&placeholderGroup)
|
||||
// if err != nil {
|
||||
// log.Print("marshal placeholderGroup failed")
|
||||
// }
|
||||
//
|
||||
// query := milvuspb.SearchRequest{
|
||||
// Dsl: dslString,
|
||||
// PlaceholderGroup: placeGroupByte,
|
||||
// }
|
||||
//
|
||||
// queryByte, err := proto.Marshal(&query)
|
||||
// if err != nil {
|
||||
// log.Print("marshal query failed")
|
||||
// }
|
||||
//
|
||||
// blob := commonpb.Blob{
|
||||
// Value: queryByte,
|
||||
// }
|
||||
//
|
||||
// searchMsg := &msgstream.SearchMsg{
|
||||
// BaseMsg: msgstream.BaseMsg{
|
||||
// HashValues: []uint32{0},
|
||||
// },
|
||||
// SearchRequest: internalpb.SearchRequest{
|
||||
// Base: &commonpb.MsgBase{
|
||||
// MsgType: commonpb.MsgType_kSearch,
|
||||
// MsgID: 1,
|
||||
// Timestamp: uint64(10 + 1000),
|
||||
// SourceID: 1,
|
||||
// },
|
||||
// ResultChannelID: "0",
|
||||
// Query: &blob,
|
||||
// },
|
||||
// }
|
||||
//
|
||||
// msgPackSearch := msgstream.MsgPack{}
|
||||
// msgPackSearch.Msgs = append(msgPackSearch.Msgs, searchMsg)
|
||||
//
|
||||
// factory := pulsarms.NewFactory(pulsarAddress, receiveBufSize, 1024)
|
||||
// searchStream, _ := factory.NewMsgStream(ctx)
|
||||
// searchStream.AsProducer(searchChannels)
|
||||
// searchStream.Start()
|
||||
// err = searchStream.Produce(&msgPackSearch)
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
//}
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/proxy"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
type MockQueryCoordClient struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *MockQueryCoordClient) GetComponentStates(ctx context.Context, in *internalpb.GetComponentStatesRequest, opts ...grpc.CallOption) (*internalpb.ComponentStates, error) {
|
||||
return &internalpb.ComponentStates{}, m.err
|
||||
}
|
||||
|
||||
func (m *MockQueryCoordClient) GetTimeTickChannel(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) {
|
||||
return &milvuspb.StringResponse{}, m.err
|
||||
}
|
||||
|
||||
func (m *MockQueryCoordClient) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) {
|
||||
return &milvuspb.StringResponse{}, m.err
|
||||
}
|
||||
|
||||
func (m *MockQueryCoordClient) ShowCollections(ctx context.Context, in *querypb.ShowCollectionsRequest, opts ...grpc.CallOption) (*querypb.ShowCollectionsResponse, error) {
|
||||
return &querypb.ShowCollectionsResponse{}, m.err
|
||||
}
|
||||
|
||||
func (m *MockQueryCoordClient) ShowPartitions(ctx context.Context, in *querypb.ShowPartitionsRequest, opts ...grpc.CallOption) (*querypb.ShowPartitionsResponse, error) {
|
||||
return &querypb.ShowPartitionsResponse{}, m.err
|
||||
}
|
||||
|
||||
func (m *MockQueryCoordClient) LoadPartitions(ctx context.Context, in *querypb.LoadPartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
return &commonpb.Status{}, m.err
|
||||
}
|
||||
|
||||
func (m *MockQueryCoordClient) ReleasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
return &commonpb.Status{}, m.err
|
||||
}
|
||||
|
||||
func (m *MockQueryCoordClient) LoadCollection(ctx context.Context, in *querypb.LoadCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
return &commonpb.Status{}, m.err
|
||||
}
|
||||
|
||||
func (m *MockQueryCoordClient) ReleaseCollection(ctx context.Context, in *querypb.ReleaseCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
|
||||
return &commonpb.Status{}, m.err
|
||||
}
|
||||
|
||||
func (m *MockQueryCoordClient) CreateQueryChannel(ctx context.Context, in *querypb.CreateQueryChannelRequest, opts ...grpc.CallOption) (*querypb.CreateQueryChannelResponse, error) {
|
||||
return &querypb.CreateQueryChannelResponse{}, m.err
|
||||
}
|
||||
|
||||
func (m *MockQueryCoordClient) GetPartitionStates(ctx context.Context, in *querypb.GetPartitionStatesRequest, opts ...grpc.CallOption) (*querypb.GetPartitionStatesResponse, error) {
|
||||
return &querypb.GetPartitionStatesResponse{}, m.err
|
||||
}
|
||||
|
||||
func (m *MockQueryCoordClient) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest, opts ...grpc.CallOption) (*querypb.GetSegmentInfoResponse, error) {
|
||||
return &querypb.GetSegmentInfoResponse{}, m.err
|
||||
}
|
||||
|
||||
func (m *MockQueryCoordClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) {
|
||||
return &milvuspb.GetMetricsResponse{}, m.err
|
||||
}
|
||||
|
||||
func Test_NewClient(t *testing.T) {
|
||||
proxy.Params.InitOnce()
|
||||
|
||||
ctx := context.Background()
|
||||
client, err := NewClient(ctx, proxy.Params.MetaRootPath, proxy.Params.EtcdEndpoints)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, client)
|
||||
|
||||
err = client.Init()
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = client.Start()
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = client.Register()
|
||||
assert.Nil(t, err)
|
||||
|
||||
checkFunc := func(retNotNil bool) {
|
||||
retCheck := func(notNil bool, ret interface{}, err error) {
|
||||
if notNil {
|
||||
assert.NotNil(t, ret)
|
||||
assert.Nil(t, err)
|
||||
} else {
|
||||
assert.Nil(t, ret)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
r1, err := client.GetComponentStates(ctx)
|
||||
retCheck(retNotNil, r1, err)
|
||||
|
||||
r2, err := client.GetTimeTickChannel(ctx)
|
||||
retCheck(retNotNil, r2, err)
|
||||
|
||||
r3, err := client.GetStatisticsChannel(ctx)
|
||||
retCheck(retNotNil, r3, err)
|
||||
|
||||
r4, err := client.ShowCollections(ctx, nil)
|
||||
retCheck(retNotNil, r4, err)
|
||||
|
||||
r5, err := client.ShowPartitions(ctx, nil)
|
||||
retCheck(retNotNil, r5, err)
|
||||
|
||||
r6, err := client.LoadPartitions(ctx, nil)
|
||||
retCheck(retNotNil, r6, err)
|
||||
|
||||
r7, err := client.ReleasePartitions(ctx, nil)
|
||||
retCheck(retNotNil, r7, err)
|
||||
|
||||
r8, err := client.ShowCollections(ctx, nil)
|
||||
retCheck(retNotNil, r8, err)
|
||||
|
||||
r9, err := client.LoadCollection(ctx, nil)
|
||||
retCheck(retNotNil, r9, err)
|
||||
|
||||
r10, err := client.ReleaseCollection(ctx, nil)
|
||||
retCheck(retNotNil, r10, err)
|
||||
|
||||
r11, err := client.CreateQueryChannel(ctx, nil)
|
||||
retCheck(retNotNil, r11, err)
|
||||
|
||||
r12, err := client.ShowPartitions(ctx, nil)
|
||||
retCheck(retNotNil, r12, err)
|
||||
|
||||
r13, err := client.GetPartitionStates(ctx, nil)
|
||||
retCheck(retNotNil, r13, err)
|
||||
|
||||
r14, err := client.GetSegmentInfo(ctx, nil)
|
||||
retCheck(retNotNil, r14, err)
|
||||
|
||||
r15, err := client.GetMetrics(ctx, nil)
|
||||
retCheck(retNotNil, r15, err)
|
||||
}
|
||||
|
||||
client.getGrpcClient = func() (querypb.QueryCoordClient, error) {
|
||||
return &MockQueryCoordClient{err: nil}, errors.New("dummy")
|
||||
}
|
||||
checkFunc(false)
|
||||
|
||||
client.getGrpcClient = func() (querypb.QueryCoordClient, error) {
|
||||
return &MockQueryCoordClient{err: errors.New("dummy")}, nil
|
||||
}
|
||||
checkFunc(false)
|
||||
|
||||
client.getGrpcClient = func() (querypb.QueryCoordClient, error) {
|
||||
return &MockQueryCoordClient{err: nil}, nil
|
||||
}
|
||||
checkFunc(true)
|
||||
|
||||
err = client.Stop()
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user