From d9d2f33a23b2a97c96f3abd62bd0c774345e86ec Mon Sep 17 00:00:00 2001 From: xige-16 Date: Thu, 7 Jan 2021 17:22:10 +0800 Subject: [PATCH] Add binary test for loadIndex and fix loadIndexService can't close correctly Signed-off-by: xige-16 --- internal/querynode/load_index_service.go | 8 + internal/querynode/load_index_service_test.go | 331 +++++++++++++++++- internal/querynode/query_node.go | 3 + internal/querynode/query_node_test.go | 10 +- 4 files changed, 340 insertions(+), 12 deletions(-) diff --git a/internal/querynode/load_index_service.go b/internal/querynode/load_index_service.go index 32b276bcf5..4a45e7fb8a 100644 --- a/internal/querynode/load_index_service.go +++ b/internal/querynode/load_index_service.go @@ -100,6 +100,7 @@ func (lis *loadIndexService) start() { continue } // 1. use msg's index paths to get index bytes + fmt.Println("start load index") var indexBuffer [][]byte var err error fn := func() error { @@ -138,6 +139,13 @@ func (lis *loadIndexService) start() { } } +func (lis *loadIndexService) close() { + if lis.loadIndexMsgStream != nil { + lis.loadIndexMsgStream.Close() + } + lis.cancel() +} + func (lis *loadIndexService) printIndexParams(index []*commonpb.KeyValuePair) { fmt.Println("=================================================") for i := 0; i < len(index); i++ { diff --git a/internal/querynode/load_index_service_test.go b/internal/querynode/load_index_service_test.go index b214b40824..852d976366 100644 --- a/internal/querynode/load_index_service_test.go +++ b/internal/querynode/load_index_service_test.go @@ -22,26 +22,29 @@ import ( "github.com/zilliztech/milvus-distributed/internal/querynode/client" ) -func TestLoadIndexService(t *testing.T) { +func TestLoadIndexService_FloatVector(t *testing.T) { node := newQueryNode() collectionID := rand.Int63n(1000000) segmentID := rand.Int63n(1000000) initTestMeta(t, node, "collection0", collectionID, segmentID) // loadIndexService and statsService + suffix := "-test-search" + strconv.FormatInt(rand.Int63n(1000000), 10) oldSearchChannelNames := Params.SearchChannelNames - var newSearchChannelNames []string - for _, channel := range oldSearchChannelNames { - newSearchChannelNames = append(newSearchChannelNames, channel+"new") - } + newSearchChannelNames := makeNewChannelNames(oldSearchChannelNames, suffix) Params.SearchChannelNames = newSearchChannelNames oldSearchResultChannelNames := Params.SearchChannelNames - var newSearchResultChannelNames []string - for _, channel := range oldSearchResultChannelNames { - newSearchResultChannelNames = append(newSearchResultChannelNames, channel+"new") - } + newSearchResultChannelNames := makeNewChannelNames(oldSearchResultChannelNames, suffix) Params.SearchResultChannelNames = newSearchResultChannelNames + + oldLoadIndexChannelNames := Params.LoadIndexChannelNames + newLoadIndexChannelNames := makeNewChannelNames(oldLoadIndexChannelNames, suffix) + Params.LoadIndexChannelNames = newLoadIndexChannelNames + + oldStatsChannelName := Params.StatsChannelName + newStatsChannelNames := makeNewChannelNames([]string{oldStatsChannelName}, suffix) + Params.StatsChannelName = newStatsChannelNames[0] go node.Start() //generate insert data @@ -328,9 +331,319 @@ func TestLoadIndexService(t *testing.T) { } Params.SearchChannelNames = oldSearchChannelNames Params.SearchResultChannelNames = oldSearchResultChannelNames + Params.LoadIndexChannelNames = oldLoadIndexChannelNames + Params.StatsChannelName = oldStatsChannelName fmt.Println("loadIndex floatVector test Done!") defer assert.Equal(t, findFiledStats, true) <-node.queryNodeLoopCtx.Done() node.Close() } + +func TestLoadIndexService_BinaryVector(t *testing.T) { + node := newQueryNode() + collectionID := rand.Int63n(1000000) + segmentID := rand.Int63n(1000000) + initTestMeta(t, node, "collection0", collectionID, segmentID, true) + + // loadIndexService and statsService + suffix := "-test-search-binary" + strconv.FormatInt(rand.Int63n(1000000), 10) + oldSearchChannelNames := Params.SearchChannelNames + newSearchChannelNames := makeNewChannelNames(oldSearchChannelNames, suffix) + Params.SearchChannelNames = newSearchChannelNames + + oldSearchResultChannelNames := Params.SearchChannelNames + newSearchResultChannelNames := makeNewChannelNames(oldSearchResultChannelNames, suffix) + Params.SearchResultChannelNames = newSearchResultChannelNames + + oldLoadIndexChannelNames := Params.LoadIndexChannelNames + newLoadIndexChannelNames := makeNewChannelNames(oldLoadIndexChannelNames, suffix) + Params.LoadIndexChannelNames = newLoadIndexChannelNames + + oldStatsChannelName := Params.StatsChannelName + newStatsChannelNames := makeNewChannelNames([]string{oldStatsChannelName}, suffix) + Params.StatsChannelName = newStatsChannelNames[0] + go node.Start() + + const msgLength = 1000 + const receiveBufSize = 1024 + const DIM = 128 + + // generator index data + var indexRowData []byte + for n := 0; n < msgLength; n++ { + for i := 0; i < DIM/8; i++ { + indexRowData = append(indexRowData, byte(rand.Intn(8))) + } + } + + //generator insert data + var insertRowBlob []*commonpb.Blob + var timestamps []uint64 + var rowIDs []int64 + var hashValues []uint32 + offset := 0 + for n := 0; n < msgLength; n++ { + rowData := make([]byte, 0) + rowData = append(rowData, indexRowData[offset:offset+(DIM/8)]...) + offset += DIM / 8 + age := make([]byte, 4) + binary.LittleEndian.PutUint32(age, 1) + rowData = append(rowData, age...) + blob := &commonpb.Blob{ + Value: rowData, + } + insertRowBlob = append(insertRowBlob, blob) + timestamps = append(timestamps, uint64(n)) + rowIDs = append(rowIDs, int64(n)) + hashValues = append(hashValues, uint32(n)) + } + + var insertMsg msgstream.TsMsg = &msgstream.InsertMsg{ + BaseMsg: msgstream.BaseMsg{ + HashValues: hashValues, + }, + InsertRequest: internalpb.InsertRequest{ + MsgType: internalpb.MsgType_kInsert, + ReqID: 0, + CollectionName: "collection0", + PartitionTag: "default", + SegmentID: segmentID, + ChannelID: int64(0), + ProxyID: int64(0), + Timestamps: timestamps, + RowIDs: rowIDs, + RowData: insertRowBlob, + }, + } + insertMsgPack := msgstream.MsgPack{ + BeginTs: 0, + EndTs: math.MaxUint64, + Msgs: []msgstream.TsMsg{insertMsg}, + } + + // generate timeTick + timeTickMsg := &msgstream.TimeTickMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: 0, + EndTimestamp: 0, + HashValues: []uint32{0}, + }, + TimeTickMsg: internalpb.TimeTickMsg{ + MsgType: internalpb.MsgType_kTimeTick, + PeerID: UniqueID(0), + Timestamp: math.MaxUint64, + }, + } + timeTickMsgPack := &msgstream.MsgPack{ + Msgs: []msgstream.TsMsg{timeTickMsg}, + } + + // pulsar produce + insertChannels := Params.InsertChannelNames + ddChannels := Params.DDChannelNames + + insertStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) + insertStream.SetPulsarClient(Params.PulsarAddress) + insertStream.CreatePulsarProducers(insertChannels) + ddStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) + ddStream.SetPulsarClient(Params.PulsarAddress) + ddStream.CreatePulsarProducers(ddChannels) + + var insertMsgStream msgstream.MsgStream = insertStream + insertMsgStream.Start() + var ddMsgStream msgstream.MsgStream = ddStream + ddMsgStream.Start() + + err := insertMsgStream.Produce(&insertMsgPack) + assert.NoError(t, err) + err = insertMsgStream.Broadcast(timeTickMsgPack) + assert.NoError(t, err) + err = ddMsgStream.Broadcast(timeTickMsgPack) + assert.NoError(t, err) + + //generate search data and send search msg + searchRowData := indexRowData[42*(DIM/8) : 43*(DIM/8)] + dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"JACCARD\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }" + placeholderValue := servicepb.PlaceholderValue{ + Tag: "$0", + Type: servicepb.PlaceholderType_VECTOR_BINARY, + Values: [][]byte{searchRowData}, + } + placeholderGroup := servicepb.PlaceholderGroup{ + Placeholders: []*servicepb.PlaceholderValue{&placeholderValue}, + } + placeGroupByte, err := proto.Marshal(&placeholderGroup) + if err != nil { + log.Print("marshal placeholderGroup failed") + } + query := servicepb.Query{ + CollectionName: "collection0", + PartitionTags: []string{"default"}, + Dsl: dslString, + PlaceholderGroup: placeGroupByte, + } + queryByte, err := proto.Marshal(&query) + if err != nil { + log.Print("marshal query failed") + } + blob := commonpb.Blob{ + Value: queryByte, + } + fn := func(n int64) *msgstream.MsgPack { + searchMsg := &msgstream.SearchMsg{ + BaseMsg: msgstream.BaseMsg{ + HashValues: []uint32{0}, + }, + SearchRequest: internalpb.SearchRequest{ + MsgType: internalpb.MsgType_kSearch, + ReqID: n, + ProxyID: int64(1), + Timestamp: uint64(msgLength), + ResultChannelID: int64(0), + Query: &blob, + }, + } + return &msgstream.MsgPack{ + Msgs: []msgstream.TsMsg{searchMsg}, + } + } + searchStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) + searchStream.SetPulsarClient(Params.PulsarAddress) + searchStream.CreatePulsarProducers(newSearchChannelNames) + searchStream.Start() + err = searchStream.Produce(fn(1)) + assert.NoError(t, err) + + //get search result + searchResultStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize) + searchResultStream.SetPulsarClient(Params.PulsarAddress) + unmarshalDispatcher := msgstream.NewUnmarshalDispatcher() + searchResultStream.CreatePulsarConsumers(newSearchResultChannelNames, "loadIndexTestSubSearchResult2", unmarshalDispatcher, receiveBufSize) + searchResultStream.Start() + searchResult := searchResultStream.Consume() + assert.NotNil(t, searchResult) + unMarshaledHit := servicepb.Hits{} + err = proto.Unmarshal(searchResult.Msgs[0].(*msgstream.SearchResultMsg).Hits[0], &unMarshaledHit) + assert.Nil(t, err) + + // gen load index message pack + indexParams := make(map[string]string) + indexParams["index_type"] = "BIN_IVF_FLAT" + indexParams["index_mode"] = "cpu" + indexParams["dim"] = "128" + indexParams["k"] = "10" + indexParams["nlist"] = "100" + indexParams["nprobe"] = "10" + indexParams["m"] = "4" + indexParams["nbits"] = "8" + indexParams["metric_type"] = "JACCARD" + indexParams["SLICE_SIZE"] = "4" + + var indexParamsKV []*commonpb.KeyValuePair + for key, value := range indexParams { + indexParamsKV = append(indexParamsKV, &commonpb.KeyValuePair{ + Key: key, + Value: value, + }) + } + + // generator index + typeParams := make(map[string]string) + typeParams["dim"] = "128" + index, err := indexbuilder.NewCIndex(typeParams, indexParams) + assert.Nil(t, err) + err = index.BuildBinaryVecIndexWithoutIds(indexRowData) + assert.Equal(t, err, nil) + + option := &minioKV.Option{ + Address: Params.MinioEndPoint, + AccessKeyID: Params.MinioAccessKeyID, + SecretAccessKeyID: Params.MinioSecretAccessKey, + UseSSL: Params.MinioUseSSLStr, + BucketName: Params.MinioBucketName, + CreateBucket: true, + } + + minioKV, err := minioKV.NewMinIOKV(node.queryNodeLoopCtx, option) + assert.Equal(t, err, nil) + //save index to minio + binarySet, err := index.Serialize() + assert.Equal(t, err, nil) + indexPaths := make([]string, 0) + for _, index := range binarySet { + path := strconv.Itoa(int(segmentID)) + "/" + index.Key + indexPaths = append(indexPaths, path) + minioKV.Save(path, string(index.Value)) + } + + //test index search result + indexResult, err := index.QueryOnBinaryVecIndexWithParam(searchRowData, indexParams) + assert.Equal(t, err, nil) + + // create loadIndexClient + fieldID := UniqueID(100) + loadIndexChannelNames := Params.LoadIndexChannelNames + client := client.NewLoadIndexClient(node.queryNodeLoopCtx, Params.PulsarAddress, loadIndexChannelNames) + client.LoadIndex(indexPaths, segmentID, fieldID, "vec", indexParams) + + // init message stream consumer and do checks + statsMs := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, Params.StatsReceiveBufSize) + statsMs.SetPulsarClient(Params.PulsarAddress) + statsMs.CreatePulsarConsumers([]string{Params.StatsChannelName}, Params.MsgChannelSubName, msgstream.NewUnmarshalDispatcher(), Params.StatsReceiveBufSize) + statsMs.Start() + + findFiledStats := false + for { + receiveMsg := msgstream.MsgStream(statsMs).Consume() + assert.NotNil(t, receiveMsg) + assert.NotEqual(t, len(receiveMsg.Msgs), 0) + + for _, msg := range receiveMsg.Msgs { + statsMsg, ok := msg.(*msgstream.QueryNodeStatsMsg) + if statsMsg.FieldStats == nil || len(statsMsg.FieldStats) == 0 { + continue + } + findFiledStats = true + assert.Equal(t, ok, true) + assert.Equal(t, len(statsMsg.FieldStats), 1) + fieldStats0 := statsMsg.FieldStats[0] + assert.Equal(t, fieldStats0.FieldID, fieldID) + assert.Equal(t, fieldStats0.CollectionID, collectionID) + assert.Equal(t, len(fieldStats0.IndexStats), 1) + indexStats0 := fieldStats0.IndexStats[0] + params := indexStats0.IndexParams + // sort index params by key + sort.Slice(indexParamsKV, func(i, j int) bool { return indexParamsKV[i].Key < indexParamsKV[j].Key }) + indexEqual := node.loadIndexService.indexParamsEqual(params, indexParamsKV) + assert.Equal(t, indexEqual, true) + } + + if findFiledStats { + break + } + } + + err = searchStream.Produce(fn(2)) + assert.NoError(t, err) + searchResult = searchResultStream.Consume() + assert.NotNil(t, searchResult) + err = proto.Unmarshal(searchResult.Msgs[0].(*msgstream.SearchResultMsg).Hits[0], &unMarshaledHit) + assert.Nil(t, err) + + idsIndex := indexResult.IDs() + idsSegment := unMarshaledHit.IDs + assert.Equal(t, len(idsIndex), len(idsSegment)) + for i := 0; i < len(idsIndex); i++ { + assert.Equal(t, idsIndex[i], idsSegment[i]) + } + Params.SearchChannelNames = oldSearchChannelNames + Params.SearchResultChannelNames = oldSearchResultChannelNames + Params.LoadIndexChannelNames = oldLoadIndexChannelNames + Params.StatsChannelName = oldStatsChannelName + fmt.Println("loadIndex binaryVector test Done!") + + defer assert.Equal(t, findFiledStats, true) + <-node.queryNodeLoopCtx.Done() + node.Close() +} diff --git a/internal/querynode/query_node.go b/internal/querynode/query_node.go index 819d2b8554..41f9391e8b 100644 --- a/internal/querynode/query_node.go +++ b/internal/querynode/query_node.go @@ -97,6 +97,9 @@ func (node *QueryNode) Close() { if node.searchService != nil { node.searchService.close() } + if node.loadIndexService != nil { + node.loadIndexService.close() + } if node.statsService != nil { node.statsService.close() } diff --git a/internal/querynode/query_node_test.go b/internal/querynode/query_node_test.go index 34ec092f52..1217fa3da3 100644 --- a/internal/querynode/query_node_test.go +++ b/internal/querynode/query_node_test.go @@ -35,7 +35,7 @@ func genTestCollectionMeta(collectionName string, collectionID UniqueID, isBinar TypeParams: []*commonpb.KeyValuePair{ { Key: "dim", - Value: "16", + Value: "128", }, }, IndexParams: []*commonpb.KeyValuePair{ @@ -92,8 +92,12 @@ func genTestCollectionMeta(collectionName string, collectionID UniqueID, isBinar return &collectionMeta } -func initTestMeta(t *testing.T, node *QueryNode, collectionName string, collectionID UniqueID, segmentID UniqueID) { - collectionMeta := genTestCollectionMeta(collectionName, collectionID, false) +func initTestMeta(t *testing.T, node *QueryNode, collectionName string, collectionID UniqueID, segmentID UniqueID, optional ...bool) { + isBinary := false + if len(optional) > 0 { + isBinary = optional[0] + } + collectionMeta := genTestCollectionMeta(collectionName, collectionID, isBinary) schemaBlob := proto.MarshalTextString(collectionMeta.Schema) assert.NotEqual(t, "", schemaBlob)