From 693994c6dbfd8a67b8f9eb9d502753908fd939ae Mon Sep 17 00:00:00 2001 From: groot Date: Mon, 27 Sep 2021 18:41:58 +0800 Subject: [PATCH] Add unittest for distributed/querycoord/indexnode (#8680) Signed-off-by: yhmo --- .../distributed/indexnode/client/client.go | 17 +- .../indexnode/client/client_test.go | 94 ++++ .../querycoord/client/binlog_test.go | 356 --------------- .../distributed/querycoord/client/client.go | 17 +- .../querycoord/client/client_test.go | 421 +++++++----------- 5 files changed, 282 insertions(+), 623 deletions(-) delete mode 100644 internal/distributed/querycoord/client/binlog_test.go diff --git a/internal/distributed/indexnode/client/client.go b/internal/distributed/indexnode/client/client.go index 48ab0e8892..dd4072b3a1 100644 --- a/internal/distributed/indexnode/client/client.go +++ b/internal/distributed/indexnode/client/client.go @@ -42,9 +42,15 @@ type Client struct { grpcClientMtx sync.RWMutex addr string + + getGrpcClient func() (indexpb.IndexNodeClient, error) } -func (c *Client) getGrpcClient() (indexpb.IndexNodeClient, error) { +func (c *Client) setGetGrpcClientFunc() { + c.getGrpcClient = c.getGrpcClientFunc +} + +func (c *Client) getGrpcClientFunc() (indexpb.IndexNodeClient, error) { c.grpcClientMtx.RLock() if c.grpcClient != nil { defer c.grpcClientMtx.RUnlock() @@ -86,16 +92,19 @@ func NewClient(ctx context.Context, addr string) (*Client, error) { } ctx, cancel := context.WithCancel(ctx) - return &Client{ + client := &Client{ ctx: ctx, cancel: cancel, addr: addr, - }, nil + } + + client.setGetGrpcClientFunc() + return client, nil } func (c *Client) Init() error { Params.Init() - return c.connect(retry.Attempts(20)) + return nil } func (c *Client) connect(retryOptions ...retry.Option) error { diff --git a/internal/distributed/indexnode/client/client_test.go b/internal/distributed/indexnode/client/client_test.go index 28a1b088fd..c2f9b1b4a7 100644 --- a/internal/distributed/indexnode/client/client_test.go +++ b/internal/distributed/indexnode/client/client_test.go @@ -13,6 +13,7 @@ package grpcindexnodeclient import ( "context" + "errors" "testing" grpcindexnode "github.com/milvus-io/milvus/internal/distributed/indexnode" @@ -21,9 +22,102 @@ import ( "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/milvuspb" + "github.com/milvus-io/milvus/internal/proxy" "github.com/stretchr/testify/assert" + "google.golang.org/grpc" ) +type MockIndexNodeClient struct { + err error +} + +func (m *MockIndexNodeClient) GetComponentStates(ctx context.Context, in *internalpb.GetComponentStatesRequest, opts ...grpc.CallOption) (*internalpb.ComponentStates, error) { + return &internalpb.ComponentStates{}, m.err +} + +func (m *MockIndexNodeClient) GetTimeTickChannel(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return &milvuspb.StringResponse{}, m.err +} + +func (m *MockIndexNodeClient) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return &milvuspb.StringResponse{}, m.err +} + +func (m *MockIndexNodeClient) CreateIndex(ctx context.Context, in *indexpb.CreateIndexRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.err +} + +func (m *MockIndexNodeClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { + return &milvuspb.GetMetricsResponse{}, m.err +} + +func Test_NewClient(t *testing.T) { + proxy.Params.InitOnce() + + ctx := context.Background() + client, err := NewClient(ctx, "") + assert.Nil(t, client) + assert.NotNil(t, err) + + client, err = NewClient(ctx, "test") + assert.Nil(t, err) + assert.NotNil(t, client) + + err = client.Init() + assert.Nil(t, err) + + err = client.Start() + assert.Nil(t, err) + + err = client.Register() + assert.Nil(t, err) + + checkFunc := func(retNotNil bool) { + retCheck := func(notNil bool, ret interface{}, err error) { + if notNil { + assert.NotNil(t, ret) + assert.Nil(t, err) + } else { + assert.Nil(t, ret) + assert.NotNil(t, err) + } + } + + r1, err := client.GetComponentStates(ctx) + retCheck(retNotNil, r1, err) + + r2, err := client.GetTimeTickChannel(ctx) + retCheck(retNotNil, r2, err) + + r3, err := client.GetStatisticsChannel(ctx) + retCheck(retNotNil, r3, err) + + r4, err := client.CreateIndex(ctx, nil) + retCheck(retNotNil, r4, err) + + r5, err := client.GetMetrics(ctx, nil) + retCheck(retNotNil, r5, err) + } + + client.getGrpcClient = func() (indexpb.IndexNodeClient, error) { + return &MockIndexNodeClient{err: nil}, errors.New("dummy") + } + checkFunc(false) + + client.getGrpcClient = func() (indexpb.IndexNodeClient, error) { + return &MockIndexNodeClient{err: errors.New("dummy")}, nil + } + checkFunc(false) + + client.getGrpcClient = func() (indexpb.IndexNodeClient, error) { + return &MockIndexNodeClient{err: nil}, nil + } + checkFunc(true) + + err = client.Stop() + assert.Nil(t, err) +} + func TestIndexNodeClient(t *testing.T) { ctx := context.Background() diff --git a/internal/distributed/querycoord/client/binlog_test.go b/internal/distributed/querycoord/client/binlog_test.go deleted file mode 100644 index 0c4edb4480..0000000000 --- a/internal/distributed/querycoord/client/binlog_test.go +++ /dev/null @@ -1,356 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -package grpcquerycoordclient - -//import ( -// "context" -// "encoding/binary" -// "math" -// "path" -// "strconv" -// "testing" -// "time" -// -// "github.com/stretchr/testify/assert" -// -// "github.com/milvus-io/milvus/internal/indexnode" -// minioKV "github.com/milvus-io/milvus/internal/kv/minio" -// "github.com/milvus-io/milvus/internal/msgstream" -// "github.com/milvus-io/milvus/internal/msgstream/pulsarms" -// "github.com/milvus-io/milvus/internal/proto/commonpb" -// "github.com/milvus-io/milvus/internal/proto/etcdpb" -// "github.com/milvus-io/milvus/internal/proto/internalpb" -// "github.com/milvus-io/milvus/internal/proto/schemapb" -// "github.com/milvus-io/milvus/internal/storage" -// "github.com/milvus-io/milvus/internal/util/typeutil" -//) -// -////generate insert data -//const msgLength = 100 -//const receiveBufSize = 1024 -//const pulsarBufSize = 1024 -//const DIM = 16 -// -//type UniqueID = typeutil.UniqueID -// -//func genInsert(collectionID int64, -// partitionID int64, -// timeStart int, -// numDmChannels int, -// binlog bool) (*msgstream.MsgPack, *msgstream.MsgPack) { -// msgs := make([]msgstream.TsMsg, 0) -// for n := timeStart; n < timeStart+msgLength; n++ { -// rowData := make([]byte, 0) -// if binlog { -// id := make([]byte, 8) -// binary.BigEndian.PutUint64(id, uint64(n)) -// rowData = append(rowData, id...) -// time := make([]byte, 8) -// binary.BigEndian.PutUint64(time, uint64(n)) -// rowData = append(rowData, time...) -// } -// for i := 0; i < DIM; i++ { -// vec := make([]byte, 4) -// binary.BigEndian.PutUint32(vec, math.Float32bits(float32(n*i))) -// rowData = append(rowData, vec...) -// } -// age := make([]byte, 4) -// binary.BigEndian.PutUint32(age, 1) -// rowData = append(rowData, age...) -// blob := &commonpb.Blob{ -// Value: rowData, -// } -// -// var insertMsg msgstream.TsMsg = &msgstream.InsertMsg{ -// BaseMsg: msgstream.BaseMsg{ -// HashValues: []uint32{uint32((n - 1) % numDmChannels)}, -// }, -// InsertRequest: internalpb.InsertRequest{ -// Base: &commonpb.MsgBase{ -// MsgType: commonpb.MsgType_kInsert, -// MsgID: 0, -// Timestamp: uint64(n), -// SourceID: 0, -// }, -// CollectionID: collectionID, -// PartitionID: partitionID, -// SegmentID: UniqueID(((n - 1) % numDmChannels) + ((n-1)/(numDmChannels*msgLength))*numDmChannels), -// ChannelID: "0", -// Timestamps: []uint64{uint64(n)}, -// RowIDs: []int64{int64(n)}, -// RowData: []*commonpb.Blob{blob}, -// }, -// } -// //fmt.Println("hash value = ", insertMsg.(*msgstream.InsertMsg).HashValues, "segmentID = ", insertMsg.(*msgstream.InsertMsg).SegmentID) -// msgs = append(msgs, insertMsg) -// } -// -// insertMsgPack := &msgstream.MsgPack{ -// BeginTs: uint64(timeStart), -// EndTs: uint64(timeStart + msgLength), -// Msgs: msgs, -// } -// -// // generate timeTick -// timeTickMsg := &msgstream.TimeTickMsg{ -// BaseMsg: msgstream.BaseMsg{ -// BeginTimestamp: 0, -// EndTimestamp: 0, -// HashValues: []uint32{0}, -// }, -// TimeTickMsg: internalpb.TimeTickMsg{ -// Base: &commonpb.MsgBase{ -// MsgType: commonpb.MsgType_kTimeTick, -// MsgID: 0, -// Timestamp: uint64(timeStart + msgLength), -// SourceID: 0, -// }, -// }, -// } -// timeTickMsgPack := &msgstream.MsgPack{ -// Msgs: []msgstream.TsMsg{timeTickMsg}, -// } -// return insertMsgPack, timeTickMsgPack -//} -// -//func genSchema(collectionID int64) *schemapb.CollectionSchema { -// fieldID := schemapb.FieldSchema{ -// FieldID: UniqueID(0), -// Name: "RowID", -// IsPrimaryKey: false, -// DataType: schemapb.DataType_INT64, -// } -// -// fieldTime := schemapb.FieldSchema{ -// FieldID: UniqueID(1), -// Name: "Timestamp", -// IsPrimaryKey: false, -// DataType: schemapb.DataType_INT64, -// } -// -// fieldVec := schemapb.FieldSchema{ -// FieldID: UniqueID(100), -// Name: "vec", -// IsPrimaryKey: false, -// DataType: schemapb.DataType_VECTOR_FLOAT, -// TypeParams: []*commonpb.KeyValuePair{ -// { -// Key: "dim", -// Value: "16", -// }, -// }, -// IndexParams: []*commonpb.KeyValuePair{ -// { -// Key: "metric_type", -// Value: "L2", -// }, -// }, -// } -// -// fieldInt := schemapb.FieldSchema{ -// FieldID: UniqueID(101), -// Name: "age", -// IsPrimaryKey: false, -// DataType: schemapb.DataType_INT32, -// } -// -// return &schemapb.CollectionSchema{ -// Name: "collection-" + strconv.FormatInt(collectionID, 10), -// AutoID: true, -// Fields: []*schemapb.FieldSchema{ -// &fieldID, &fieldTime, &fieldVec, &fieldInt, -// }, -// } -//} -// -//func getMinioKV(ctx context.Context) (*minioKV.MinIOKV, error) { -// minioAddress := "localhost:9000" -// accessKeyID := "minioadmin" -// secretAccessKey := "minioadmin" -// useSSL := false -// bucketName := "a-bucket" -// -// option := &minioKV.Option{ -// Address: minioAddress, -// AccessKeyID: accessKeyID, -// SecretAccessKeyID: secretAccessKey, -// UseSSL: useSSL, -// BucketName: bucketName, -// CreateBucket: true, -// } -// -// return minioKV.NewMinIOKV(ctx, option) -//} -// -//func TestWriteBinLog(t *testing.T) { -// const ( -// debug = true -// consumeSubName = "test-load-collection-sub-name" -// ) -// var ctx context.Context -// if debug { -// ctx = context.Background() -// } else { -// var cancel context.CancelFunc -// ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) -// defer cancel() -// } -// -// // produce msg -// insertChannels := []string{"insert-0", "insert-1", "insert-2", "insert-3"} -// pulsarAddress := "pulsar://127.0.0.1:6650" -// -// factory := pulsarms.NewFactory(pulsarAddress, receiveBufSize, pulsarBufSize) -// -// insertStream, _ := factory.NewTtMsgStream(ctx) -// insertStream.AsProducer(insertChannels) -// insertStream.AsConsumer(insertChannels, consumeSubName) -// insertStream.Start() -// -// for i := 0; i < 12; i++ { -// insertMsgPack, timeTickMsgPack := genInsert(1, 1, i*msgLength+1, 4, true) -// err := insertStream.Produce(insertMsgPack) -// assert.NoError(t, err) -// err = insertStream.Broadcast(timeTickMsgPack) -// assert.NoError(t, err) -// } -// -// //consume msg -// segmentData := make([]*storage.InsertData, 12) -// idData := make([][]int64, 12) -// timestamps := make([][]int64, 12) -// fieldAgeData := make([][]int32, 12) -// fieldVecData := make([][]float32, 12) -// for i := 0; i < 12; i++ { -// idData[i] = make([]int64, 0) -// timestamps[i] = make([]int64, 0) -// fieldAgeData[i] = make([]int32, 0) -// fieldVecData[i] = make([]float32, 0) -// } -// for i := 0; i < 12; i++ { -// msgPack := insertStream.Consume() -// -// for n := 0; n < msgLength; n++ { -// segmentID := msgPack.Msgs[n].(*msgstream.InsertMsg).SegmentID -// blob := msgPack.Msgs[n].(*msgstream.InsertMsg).RowData[0].Value -// id := binary.BigEndian.Uint64(blob[0:8]) -// idData[segmentID] = append(idData[segmentID], int64(id)) -// t := binary.BigEndian.Uint64(blob[8:16]) -// timestamps[segmentID] = append(timestamps[segmentID], int64(t)) -// for i := 0; i < DIM; i++ { -// bits := binary.BigEndian.Uint32(blob[16+4*i : 16+4*(i+1)]) -// floatVec := math.Float32frombits(bits) -// fieldVecData[segmentID] = append(fieldVecData[segmentID], floatVec) -// } -// ageValue := binary.BigEndian.Uint32(blob[80:84]) -// fieldAgeData[segmentID] = append(fieldAgeData[segmentID], int32(ageValue)) -// } -// } -// for i := 0; i < 12; i++ { -// insertData := &storage.InsertData{ -// Data: map[int64]storage.FieldData{ -// 0: &storage.Int64FieldData{ -// NumRows: msgLength, -// Data: idData[i], -// }, -// 1: &storage.Int64FieldData{ -// NumRows: msgLength, -// Data: timestamps[i], -// }, -// 100: &storage.FloatVectorFieldData{ -// NumRows: msgLength, -// Data: fieldVecData[i], -// Dim: DIM, -// }, -// 101: &storage.Int32FieldData{ -// NumRows: msgLength, -// Data: fieldAgeData[i], -// }, -// }, -// } -// segmentData[i] = insertData -// } -// -// //gen inCodec -// collectionMeta := &etcdpb.CollectionMeta{ -// ID: 1, -// Schema: genSchema(1), -// CreateTime: 0, -// PartitionIDs: []int64{1}, -// SegmentIDs: []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, -// } -// inCodec := storage.NewInsertCodec(collectionMeta) -// indexCodec := storage.NewIndexCodec() -// -// // get minio client -// kv, err := getMinioKV(context.Background()) -// assert.Nil(t, err) -// -// // write binlog minio -// collectionStr := strconv.FormatInt(1, 10) -// for i := 0; i < 12; i++ { -// binLogs, err := inCodec.Serialize(1, storage.UniqueID(i), segmentData[i]) -// assert.Nil(t, err) -// assert.Equal(t, len(binLogs), 4) -// keyPrefix := "distributed-query-test-binlog" -// segmentStr := strconv.FormatInt(int64(i), 10) -// -// for _, blob := range binLogs { -// key := path.Join(keyPrefix, collectionStr, segmentStr, blob.Key) -// err = kv.Save(key, string(blob.Value[:])) -// assert.Nil(t, err) -// } -// } -// -// // gen index build's indexParams -// indexParams := make(map[string]string) -// indexParams["index_type"] = "IVF_PQ" -// indexParams["index_mode"] = "cpu" -// indexParams["dim"] = "16" -// indexParams["k"] = "10" -// indexParams["nlist"] = "100" -// indexParams["nprobe"] = "10" -// indexParams["m"] = "4" -// indexParams["nbits"] = "8" -// indexParams["metric_type"] = "L2" -// indexParams["SLICE_SIZE"] = "400" -// -// var indexParamsKV []*commonpb.KeyValuePair -// for key, value := range indexParams { -// indexParamsKV = append(indexParamsKV, &commonpb.KeyValuePair{ -// Key: key, -// Value: value, -// }) -// } -// -// // generator index and write index to minio -// for i := 0; i < 12; i++ { -// typeParams := make(map[string]string) -// typeParams["dim"] = "16" -// index, err := indexnode.NewCIndex(typeParams, indexParams) -// assert.Nil(t, err) -// err = index.BuildFloatVecIndexWithoutIds(fieldVecData[i]) -// assert.Equal(t, err, nil) -// binarySet, err := index.Serialize() -// assert.Equal(t, len(binarySet), 1) -// assert.Nil(t, err) -// codecIndex, err := indexCodec.Serialize(binarySet, indexParams, "test_index", UniqueID(i)) -// assert.Equal(t, len(codecIndex), 2) -// assert.Nil(t, err) -// keyPrefix := "distributed-query-test-index" -// segmentStr := strconv.FormatInt(int64(i), 10) -// key1 := path.Join(keyPrefix, collectionStr, segmentStr, "IVF") -// key2 := path.Join(keyPrefix, collectionStr, segmentStr, "indexParams") -// kv.Save(key1, string(codecIndex[0].Value)) -// kv.Save(key2, string(codecIndex[1].Value)) -// } -//} diff --git a/internal/distributed/querycoord/client/client.go b/internal/distributed/querycoord/client/client.go index dbc3b2f192..0756b0a1b7 100644 --- a/internal/distributed/querycoord/client/client.go +++ b/internal/distributed/querycoord/client/client.go @@ -45,9 +45,15 @@ type Client struct { sess *sessionutil.Session addr string + + getGrpcClient func() (querypb.QueryCoordClient, error) } -func (c *Client) getGrpcClient() (querypb.QueryCoordClient, error) { +func (c *Client) setGetGrpcClientFunc() { + c.getGrpcClient = c.getGrpcClientFunc +} + +func (c *Client) getGrpcClientFunc() (querypb.QueryCoordClient, error) { c.grpcClientMtx.RLock() if c.grpcClient != nil { defer c.grpcClientMtx.RUnlock() @@ -107,16 +113,19 @@ func NewClient(ctx context.Context, metaRoot string, etcdEndpoints []string) (*C return nil, err } ctx, cancel := context.WithCancel(ctx) - return &Client{ + client := &Client{ ctx: ctx, cancel: cancel, sess: sess, - }, nil + } + + client.setGetGrpcClientFunc() + return client, nil } func (c *Client) Init() error { Params.Init() - return c.connect(retry.Attempts(20)) + return nil } func (c *Client) connect(retryOptions ...retry.Option) error { diff --git a/internal/distributed/querycoord/client/client_test.go b/internal/distributed/querycoord/client/client_test.go index 8a5fb523fb..e94251d856 100644 --- a/internal/distributed/querycoord/client/client_test.go +++ b/internal/distributed/querycoord/client/client_test.go @@ -11,262 +11,165 @@ package grpcquerycoordclient -//import ( -// "context" -// "encoding/binary" -// "fmt" -// "log" -// "math" -// "testing" -// "time" -// -// "github.com/golang/protobuf/proto" -// "github.com/stretchr/testify/assert" -// -// "github.com/milvus-io/milvus/internal/msgstream" -// "github.com/milvus-io/milvus/internal/msgstream/pulsarms" -// "github.com/milvus-io/milvus/internal/proto/commonpb" -// "github.com/milvus-io/milvus/internal/proto/internalpb" -// "github.com/milvus-io/milvus/internal/proto/milvuspb" -// "github.com/milvus-io/milvus/internal/proto/querypb" -// qs "github.com/milvus-io/milvus/internal/querycoord" -//) -// -//const ( -// debug = false -// pulsarAddress = "pulsar://127.0.0.1:6650" -//) -// -//func TestClient_LoadCollection(t *testing.T) { -// var ctx context.Context -// if debug { -// ctx = context.Background() -// } else { -// var cancel context.CancelFunc -// ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) -// defer cancel() -// } -// -// //create queryCoord client -// qs.Params.Init() -// log.Println("QueryCoord address:", qs.Params.Address) -// log.Println("Init Query service client ...") -// client, err := NewClient(qs.Params.Address, 20*time.Second) -// assert.Nil(t, err) -// err = client.Init() -// assert.Nil(t, err) -// err = client.Start() -// assert.Nil(t, err) -// -// insertChannels := []string{"insert-0", "insert-1", "insert-2", "insert-3"} -// ddChannels := []string{"data-definition"} -// -// factory := pulsarms.NewFactory(pulsarAddress, receiveBufSize, pulsarBufSize) -// insertStream, _ := factory.NewTtMsgStream(ctx) -// insertStream.AsProducer(insertChannels) -// insertStream.Start() -// -// ddStream, err := factory.NewTtMsgStream(ctx) -// assert.NoError(t, err) -// ddStream.AsProducer(ddChannels) -// ddStream.Start() -// -// // showCollection -// showCollectionRequest := &querypb.ShowCollectionsRequest{ -// DbID: 0, -// } -// showCollectionRes, err := client.ShowCollections(showCollectionRequest) -// fmt.Println("showCollectionRes: ", showCollectionRes) -// assert.Nil(t, err) -// -// //load collection -// loadCollectionRequest := &querypb.LoadCollectionRequest{ -// CollectionID: 1, -// Schema: genSchema(1), -// } -// loadCollectionRes, err := client.LoadCollection(loadCollectionRequest) -// fmt.Println("loadCollectionRes: ", loadCollectionRes) -// assert.Nil(t, err) -// -// // showCollection -// showCollectionRes, err = client.ShowCollections(showCollectionRequest) -// fmt.Println("showCollectionRes: ", showCollectionRes) -// assert.Nil(t, err) -// -// //showSegmentInfo -// getSegmentInfoRequest := &querypb.SegmentInfoRequest{ -// SegmentIDs: []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, -// } -// getSegmentInfoRes, err := client.GetSegmentInfo(getSegmentInfoRequest) -// fmt.Println("segment info : ", getSegmentInfoRes) -// assert.Nil(t, err) -// -// // insert msg -// for i := 0; i < 12; i++ { -// insertMsgPack, timeTickMsgPack := genInsert(1, 1, i*msgLength+1, 4, false) -// err := insertStream.Produce(insertMsgPack) -// assert.NoError(t, err) -// err = insertStream.Broadcast(timeTickMsgPack) -// assert.NoError(t, err) -// err = ddStream.Broadcast(timeTickMsgPack) -// assert.NoError(t, err) -// } -// -// getSegmentInfoRes, err = client.GetSegmentInfo(getSegmentInfoRequest) -// assert.Nil(t, err) -// fmt.Println("segment info : ", getSegmentInfoRes) -// -//} -// -//func TestClient_GetSegmentInfo(t *testing.T) { -// if !debug { -// _, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) -// defer cancel() -// } -// -// //create queryCoord client -// qs.Params.Init() -// log.Println("QueryCoord address:", qs.Params.Address) -// log.Println("Init Query Coord client ...") -// client, err := NewClient(qs.Params.Address, 20*time.Second) -// assert.Nil(t, err) -// err = client.Init() -// assert.Nil(t, err) -// err = client.Start() -// assert.Nil(t, err) -// -// //showSegmentInfo -// getSegmentInfoRequest := &querypb.SegmentInfoRequest{ -// SegmentIDs: []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, -// } -// getSegmentInfoRes, err := client.GetSegmentInfo(getSegmentInfoRequest) -// assert.Nil(t, err) -// fmt.Println("segment info : ", getSegmentInfoRes) -//} -// -//func TestClient_LoadPartitions(t *testing.T) { -// if !debug { -// _, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) -// defer cancel() -// } -// -// //create queryCoord client -// qs.Params.Init() -// log.Println("QueryCoord address:", qs.Params.Address) -// log.Println("Init Query service client ...") -// client, err := NewClient(qs.Params.Address, 20*time.Second) -// assert.Nil(t, err) -// err = client.Init() -// assert.Nil(t, err) -// err = client.Start() -// assert.Nil(t, err) -// -// loadPartitionRequest := &querypb.LoadPartitionsRequest{ -// CollectionID: 1, -// Schema: genSchema(1), -// } -// loadPartitionRes, err := client.LoadPartitions(loadPartitionRequest) -// fmt.Println("loadCollectionRes: ", loadPartitionRes) -// assert.Nil(t, err) -//} -// -//func TestClient_GetChannels(t *testing.T) { -// if !debug { -// _, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) -// defer cancel() -// } -// -// //create queryCoord client -// qs.Params.Init() -// log.Println("QueryCoord address:", qs.Params.Address) -// log.Println("Init Query service client ...") -// client, err := NewClient(qs.Params.Address, 20*time.Second) -// assert.Nil(t, err) -// err = client.Init() -// assert.Nil(t, err) -// err = client.Start() -// assert.Nil(t, err) -// -// getTimeTickChannelRes, err := client.GetTimeTickChannel() -// fmt.Println("loadCollectionRes: ", getTimeTickChannelRes) -// assert.Nil(t, err) -//} -// -//func sendSearchRequest(ctx context.Context, searchChannels []string) { -// // test data generate -// const msgLength = 10 -// const receiveBufSize = 1024 -// const DIM = 16 -// -// var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} -// // start search service -// dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }" -// var searchRawData1 []byte -// var searchRawData2 []byte -// for i, ele := range vec { -// buf := make([]byte, 4) -// binary.LittleEndian.PutUint32(buf, math.Float32bits(ele+float32(i*2))) -// searchRawData1 = append(searchRawData1, buf...) -// } -// for i, ele := range vec { -// buf := make([]byte, 4) -// binary.LittleEndian.PutUint32(buf, math.Float32bits(ele+float32(i*4))) -// searchRawData2 = append(searchRawData2, buf...) -// } -// placeholderValue := milvuspb.PlaceholderValue{ -// Tag: "$0", -// Type: milvuspb.PlaceholderType_VECTOR_FLOAT, -// Values: [][]byte{searchRawData1, searchRawData2}, -// } -// -// placeholderGroup := milvuspb.PlaceholderGroup{ -// Placeholders: []*milvuspb.PlaceholderValue{&placeholderValue}, -// } -// -// placeGroupByte, err := proto.Marshal(&placeholderGroup) -// if err != nil { -// log.Print("marshal placeholderGroup failed") -// } -// -// query := milvuspb.SearchRequest{ -// Dsl: dslString, -// PlaceholderGroup: placeGroupByte, -// } -// -// queryByte, err := proto.Marshal(&query) -// if err != nil { -// log.Print("marshal query failed") -// } -// -// blob := commonpb.Blob{ -// Value: queryByte, -// } -// -// searchMsg := &msgstream.SearchMsg{ -// BaseMsg: msgstream.BaseMsg{ -// HashValues: []uint32{0}, -// }, -// SearchRequest: internalpb.SearchRequest{ -// Base: &commonpb.MsgBase{ -// MsgType: commonpb.MsgType_kSearch, -// MsgID: 1, -// Timestamp: uint64(10 + 1000), -// SourceID: 1, -// }, -// ResultChannelID: "0", -// Query: &blob, -// }, -// } -// -// msgPackSearch := msgstream.MsgPack{} -// msgPackSearch.Msgs = append(msgPackSearch.Msgs, searchMsg) -// -// factory := pulsarms.NewFactory(pulsarAddress, receiveBufSize, 1024) -// searchStream, _ := factory.NewMsgStream(ctx) -// searchStream.AsProducer(searchChannels) -// searchStream.Start() -// err = searchStream.Produce(&msgPackSearch) -// if err != nil { -// panic(err) -// } -//} +import ( + "context" + "errors" + "testing" + + "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/milvuspb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/proxy" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" +) + +type MockQueryCoordClient struct { + err error +} + +func (m *MockQueryCoordClient) GetComponentStates(ctx context.Context, in *internalpb.GetComponentStatesRequest, opts ...grpc.CallOption) (*internalpb.ComponentStates, error) { + return &internalpb.ComponentStates{}, m.err +} + +func (m *MockQueryCoordClient) GetTimeTickChannel(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return &milvuspb.StringResponse{}, m.err +} + +func (m *MockQueryCoordClient) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return &milvuspb.StringResponse{}, m.err +} + +func (m *MockQueryCoordClient) ShowCollections(ctx context.Context, in *querypb.ShowCollectionsRequest, opts ...grpc.CallOption) (*querypb.ShowCollectionsResponse, error) { + return &querypb.ShowCollectionsResponse{}, m.err +} + +func (m *MockQueryCoordClient) ShowPartitions(ctx context.Context, in *querypb.ShowPartitionsRequest, opts ...grpc.CallOption) (*querypb.ShowPartitionsResponse, error) { + return &querypb.ShowPartitionsResponse{}, m.err +} + +func (m *MockQueryCoordClient) LoadPartitions(ctx context.Context, in *querypb.LoadPartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.err +} + +func (m *MockQueryCoordClient) ReleasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.err +} + +func (m *MockQueryCoordClient) LoadCollection(ctx context.Context, in *querypb.LoadCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.err +} + +func (m *MockQueryCoordClient) ReleaseCollection(ctx context.Context, in *querypb.ReleaseCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.err +} + +func (m *MockQueryCoordClient) CreateQueryChannel(ctx context.Context, in *querypb.CreateQueryChannelRequest, opts ...grpc.CallOption) (*querypb.CreateQueryChannelResponse, error) { + return &querypb.CreateQueryChannelResponse{}, m.err +} + +func (m *MockQueryCoordClient) GetPartitionStates(ctx context.Context, in *querypb.GetPartitionStatesRequest, opts ...grpc.CallOption) (*querypb.GetPartitionStatesResponse, error) { + return &querypb.GetPartitionStatesResponse{}, m.err +} + +func (m *MockQueryCoordClient) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest, opts ...grpc.CallOption) (*querypb.GetSegmentInfoResponse, error) { + return &querypb.GetSegmentInfoResponse{}, m.err +} + +func (m *MockQueryCoordClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { + return &milvuspb.GetMetricsResponse{}, m.err +} + +func Test_NewClient(t *testing.T) { + proxy.Params.InitOnce() + + ctx := context.Background() + client, err := NewClient(ctx, proxy.Params.MetaRootPath, proxy.Params.EtcdEndpoints) + assert.Nil(t, err) + assert.NotNil(t, client) + + err = client.Init() + assert.Nil(t, err) + + err = client.Start() + assert.Nil(t, err) + + err = client.Register() + assert.Nil(t, err) + + checkFunc := func(retNotNil bool) { + retCheck := func(notNil bool, ret interface{}, err error) { + if notNil { + assert.NotNil(t, ret) + assert.Nil(t, err) + } else { + assert.Nil(t, ret) + assert.NotNil(t, err) + } + } + + r1, err := client.GetComponentStates(ctx) + retCheck(retNotNil, r1, err) + + r2, err := client.GetTimeTickChannel(ctx) + retCheck(retNotNil, r2, err) + + r3, err := client.GetStatisticsChannel(ctx) + retCheck(retNotNil, r3, err) + + r4, err := client.ShowCollections(ctx, nil) + retCheck(retNotNil, r4, err) + + r5, err := client.ShowPartitions(ctx, nil) + retCheck(retNotNil, r5, err) + + r6, err := client.LoadPartitions(ctx, nil) + retCheck(retNotNil, r6, err) + + r7, err := client.ReleasePartitions(ctx, nil) + retCheck(retNotNil, r7, err) + + r8, err := client.ShowCollections(ctx, nil) + retCheck(retNotNil, r8, err) + + r9, err := client.LoadCollection(ctx, nil) + retCheck(retNotNil, r9, err) + + r10, err := client.ReleaseCollection(ctx, nil) + retCheck(retNotNil, r10, err) + + r11, err := client.CreateQueryChannel(ctx, nil) + retCheck(retNotNil, r11, err) + + r12, err := client.ShowPartitions(ctx, nil) + retCheck(retNotNil, r12, err) + + r13, err := client.GetPartitionStates(ctx, nil) + retCheck(retNotNil, r13, err) + + r14, err := client.GetSegmentInfo(ctx, nil) + retCheck(retNotNil, r14, err) + + r15, err := client.GetMetrics(ctx, nil) + retCheck(retNotNil, r15, err) + } + + client.getGrpcClient = func() (querypb.QueryCoordClient, error) { + return &MockQueryCoordClient{err: nil}, errors.New("dummy") + } + checkFunc(false) + + client.getGrpcClient = func() (querypb.QueryCoordClient, error) { + return &MockQueryCoordClient{err: errors.New("dummy")}, nil + } + checkFunc(false) + + client.getGrpcClient = func() (querypb.QueryCoordClient, error) { + return &MockQueryCoordClient{err: nil}, nil + } + checkFunc(true) + + err = client.Stop() + assert.Nil(t, err) +}