diff --git a/internal/querynode/benchmark_test.go b/internal/querynode/benchmark_test.go new file mode 100644 index 0000000000..c41d1290a0 --- /dev/null +++ b/internal/querynode/benchmark_test.go @@ -0,0 +1,210 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package querynode + +import ( + "context" + "os" + "runtime/pprof" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/zap/zapcore" + + "github.com/milvus-io/milvus/internal/log" + msgstream2 "github.com/milvus-io/milvus/internal/mq/msgstream" + "github.com/milvus-io/milvus/internal/util/typeutil" +) + +const ( + maxNQ = 100 + nb = 10000 +) + +func benchmarkQueryCollectionSearch(nq int, b *testing.B) { + log.SetLevel(zapcore.ErrorLevel) + defer log.SetLevel(zapcore.DebugLevel) + + tx, cancel := context.WithCancel(context.Background()) + + queryCollection, err := genSimpleQueryCollection(tx, cancel) + assert.NoError(b, err) + + // search only one segment + err = queryCollection.streaming.replica.removeSegment(defaultSegmentID) + assert.NoError(b, err) + err = queryCollection.historical.replica.removeSegment(defaultSegmentID) + assert.NoError(b, err) + + assert.Equal(b, 0, queryCollection.historical.replica.getSegmentNum()) + assert.Equal(b, 0, queryCollection.streaming.replica.getSegmentNum()) + + segment, err := genSealedSegmentWithMsgLength(nb) + assert.NoError(b, err) + err = queryCollection.historical.replica.setSegment(segment) + assert.NoError(b, err) + + sessionManager := NewSessionManager(withSessionCreator(mockProxyCreator())) + sessionManager.AddSession(&NodeInfo{ + NodeID: 0, + Address: "", + }) + queryCollection.sessionManager = sessionManager + + // segment check + assert.Equal(b, 1, queryCollection.historical.replica.getSegmentNum()) + assert.Equal(b, 0, queryCollection.streaming.replica.getSegmentNum()) + seg, err := queryCollection.historical.replica.getSegmentByID(defaultSegmentID) + assert.NoError(b, err) + assert.Equal(b, int64(nb), seg.getRowCount()) + sizePerRecord, err := typeutil.EstimateSizePerRecord(genSimpleSegCoreSchema()) + assert.NoError(b, err) + expectSize := sizePerRecord * nb + assert.Equal(b, seg.getMemSize(), int64(expectSize)) + + // warming up + msgTmp, err := genSearchMsg(10) + assert.NoError(b, err) + for j := 0; j < 10000; j++ { + err = queryCollection.search(msgTmp) + assert.NoError(b, err) + } + + msgs := make([]*msgstream2.SearchMsg, maxNQ/nq) + for i := 0; i < maxNQ/nq; i++ { + msg, err := genSearchMsg(nq) + assert.NoError(b, err) + msgs[i] = msg + } + + f, err := os.Create("nq_" + strconv.Itoa(nq) + ".perf") + if err != nil { + panic(err) + } + if err = pprof.StartCPUProfile(f); err != nil { + panic(err) + } + defer pprof.StopCPUProfile() + + // start benchmark + b.ResetTimer() + for i := 0; i < b.N; i++ { + for j := 0; j < maxNQ/nq; j++ { + err = queryCollection.search(msgs[j]) + assert.NoError(b, err) + } + } +} + +func benchmarkQueryCollectionSearchIndex(nq int, indexType string, b *testing.B) { + log.SetLevel(zapcore.ErrorLevel) + defer log.SetLevel(zapcore.DebugLevel) + + tx, cancel := context.WithCancel(context.Background()) + + queryCollection, err := genSimpleQueryCollection(tx, cancel) + assert.NoError(b, err) + + err = queryCollection.historical.replica.removeSegment(defaultSegmentID) + assert.NoError(b, err) + err = queryCollection.streaming.replica.removeSegment(defaultSegmentID) + assert.NoError(b, err) + + assert.Equal(b, 0, queryCollection.historical.replica.getSegmentNum()) + assert.Equal(b, 0, queryCollection.streaming.replica.getSegmentNum()) + + node, err := genSimpleQueryNode(tx) + assert.NoError(b, err) + node.loader.historicalReplica = queryCollection.historical.replica + + err = loadIndexForSegment(tx, node, defaultSegmentID, nb, indexType, L2) + assert.NoError(b, err) + + sessionManager := NewSessionManager(withSessionCreator(mockProxyCreator())) + sessionManager.AddSession(&NodeInfo{ + NodeID: 0, + Address: "", + }) + queryCollection.sessionManager = sessionManager + + // segment check + assert.Equal(b, 1, queryCollection.historical.replica.getSegmentNum()) + assert.Equal(b, 0, queryCollection.streaming.replica.getSegmentNum()) + seg, err := queryCollection.historical.replica.getSegmentByID(defaultSegmentID) + assert.NoError(b, err) + assert.Equal(b, int64(nb), seg.getRowCount()) + sizePerRecord, err := typeutil.EstimateSizePerRecord(genSimpleSegCoreSchema()) + assert.NoError(b, err) + expectSize := sizePerRecord * nb + assert.Equal(b, seg.getMemSize(), int64(expectSize)) + + // warming up + msgTmp, err := genSearchMsg(10) + assert.NoError(b, err) + for j := 0; j < 10000; j++ { + err = queryCollection.search(msgTmp) + assert.NoError(b, err) + } + + msgs := make([]*msgstream2.SearchMsg, maxNQ/nq) + for i := 0; i < maxNQ/nq; i++ { + msg, err := genSearchMsg(nq) + assert.NoError(b, err) + msgs[i] = msg + } + + f, err := os.Create(indexType + "_nq_" + strconv.Itoa(nq) + ".perf") + if err != nil { + panic(err) + } + if err = pprof.StartCPUProfile(f); err != nil { + panic(err) + } + defer pprof.StopCPUProfile() + + // start benchmark + b.ResetTimer() + for i := 0; i < b.N; i++ { + for j := 0; j < maxNQ/nq; j++ { + err = queryCollection.search(msgs[j]) + assert.NoError(b, err) + } + } +} + +func BenchmarkSearch_NQ1(b *testing.B) { benchmarkQueryCollectionSearch(1, b) } +func BenchmarkSearch_NQ10(b *testing.B) { benchmarkQueryCollectionSearch(10, b) } +func BenchmarkSearch_NQ100(b *testing.B) { benchmarkQueryCollectionSearch(100, b) } +func BenchmarkSearch_NQ1000(b *testing.B) { benchmarkQueryCollectionSearch(1000, b) } +func BenchmarkSearch_NQ10000(b *testing.B) { benchmarkQueryCollectionSearch(10000, b) } + +func BenchmarkSearch_IVFFLAT_NQ1(b *testing.B) { + benchmarkQueryCollectionSearchIndex(1, IndexFaissIVFFlat, b) +} +func BenchmarkSearch_IVFFLAT_NQ10(b *testing.B) { + benchmarkQueryCollectionSearchIndex(10, IndexFaissIVFFlat, b) +} +func BenchmarkSearch_IVFFLAT_NQ100(b *testing.B) { + benchmarkQueryCollectionSearchIndex(100, IndexFaissIVFFlat, b) +} +func BenchmarkSearch_IVFFLAT_NQ1000(b *testing.B) { + benchmarkQueryCollectionSearchIndex(1000, IndexFaissIVFFlat, b) +} +func BenchmarkSearch_IVFFLAT_NQ10000(b *testing.B) { + benchmarkQueryCollectionSearchIndex(10000, IndexFaissIVFFlat, b) +} diff --git a/internal/querynode/mock_test.go b/internal/querynode/mock_test.go index 364a16c713..c7d4cbcfc6 100644 --- a/internal/querynode/mock_test.go +++ b/internal/querynode/mock_test.go @@ -19,6 +19,7 @@ package querynode import ( "context" "errors" + "fmt" "math" "math/rand" "strconv" @@ -43,6 +44,7 @@ import ( "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util" "github.com/milvus-io/milvus/internal/util/etcd" + "github.com/milvus-io/milvus/internal/util/funcutil" ) // ---------- unittest util functions ---------- @@ -61,7 +63,8 @@ const ( defaultRoundDecimal = int64(6) defaultDim = 128 defaultNProb = 10 - defaultMetricType = "JACCARD" + defaultMetricType = L2 + defaultNQ = 10 defaultDMLChannel = "query-node-unittest-DML-0" defaultDeltaChannel = "query-node-unittest-delta-channel-0" @@ -89,6 +92,44 @@ const ( indexName = "query-node-index-0" ) +const ( + // index type + IndexFaissIDMap = "FLAT" + IndexFaissIVFFlat = "IVF_FLAT" + IndexFaissIVFPQ = "IVF_PQ" + IndexFaissIVFSQ8 = "IVF_SQ8" + IndexFaissIVFSQ8H = "IVF_SQ8_HYBRID" + IndexFaissBinIDMap = "BIN_FLAT" + IndexFaissBinIVFFlat = "BIN_IVF_FLAT" + IndexNsg = "NSG" + + IndexHNSW = "HNSW" + IndexRHNSWFlat = "RHNSW_FLAT" + IndexRHNSWPQ = "RHNSW_PQ" + IndexRHNSWSQ = "RHNSW_SQ" + IndexANNOY = "ANNOY" + IndexNGTPANNG = "NGT_PANNG" + IndexNGTONNG = "NGT_ONNG" + + // metric type + L2 = "L2" + IP = "IP" + hamming = "HAMMING" + Jaccard = "JACCARD" + tanimoto = "TANIMOTO" + + nlist = 100 + m = 4 + nbits = 8 + nprobe = 8 + sliceSize = 4 + efConstruction = 200 + ef = 200 + edgeSize = 10 + epsilon = 0.1 + maxSearchEdges = 50 +) + // ---------- unittest util functions ---------- // functions of init meta and generate meta type vecFieldParam struct { @@ -222,6 +263,69 @@ func genIndexBinarySet() ([][]byte, error) { return bytesSet, nil } +func loadIndexForSegment(ctx context.Context, node *QueryNode, segmentID UniqueID, msgLength int, indexType string, metricType string) error { + schema := genSimpleInsertDataSchema() + + // generate insert binlog + fieldBinlog, err := saveBinLog(ctx, defaultCollectionID, defaultPartitionID, defaultSegmentID, msgLength, schema) + if err != nil { + return err + } + + // generate index file for segment + indexPaths, err := generateAndSaveIndex(segmentID, msgLength, indexType, metricType) + if err != nil { + return err + } + _, indexParams := genIndexParams(indexType, metricType) + indexInfo := &querypb.VecFieldIndexInfo{ + FieldID: simpleVecField.id, + EnableIndex: true, + IndexName: indexName, + IndexID: indexID, + BuildID: buildID, + IndexParams: funcutil.Map2KeyValuePair(indexParams), + IndexFilePaths: indexPaths, + } + + loader := node.loader + req := &querypb.LoadSegmentsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_LoadSegments, + MsgID: rand.Int63(), + }, + DstNodeID: 0, + Schema: schema, + Infos: []*querypb.SegmentLoadInfo{ + { + SegmentID: segmentID, + PartitionID: defaultPartitionID, + CollectionID: defaultCollectionID, + BinlogPaths: fieldBinlog, + IndexInfos: []*querypb.VecFieldIndexInfo{indexInfo}, + }, + }, + } + + err = loader.loadSegment(req, segmentTypeSealed) + if err != nil { + return err + } + + segment, err := node.loader.historicalReplica.getSegmentByID(segmentID) + if err != nil { + return err + } + vecFieldInfo, err := segment.getVectorFieldInfo(simpleVecField.id) + if err != nil { + return err + } + if vecFieldInfo == nil { + return fmt.Errorf("nil vecFieldInfo, load index failed") + } + return nil +} + func generateIndex(segmentID UniqueID) ([]string, error) { indexParams := genSimpleIndexParams() @@ -303,6 +407,177 @@ func generateIndex(segmentID UniqueID) ([]string, error) { return indexPaths, nil } +func generateAndSaveIndex(segmentID UniqueID, msgLength int, indexType, metricType string) ([]string, error) { + typeParams, indexParams := genIndexParams(indexType, metricType) + + var indexParamsKV []*commonpb.KeyValuePair + for key, value := range indexParams { + indexParamsKV = append(indexParamsKV, &commonpb.KeyValuePair{ + Key: key, + Value: value, + }) + } + + var indexRowData []float32 + for n := 0; n < msgLength; n++ { + for i := 0; i < defaultDim; i++ { + indexRowData = append(indexRowData, rand.Float32()) + } + } + + index, err := indexnode.NewCIndex(typeParams, indexParams) + if err != nil { + return nil, err + } + + err = index.BuildFloatVecIndexWithoutIds(indexRowData) + if err != nil { + return nil, err + } + + option := &minioKV.Option{ + Address: Params.MinioCfg.Address, + AccessKeyID: Params.MinioCfg.AccessKeyID, + SecretAccessKeyID: Params.MinioCfg.SecretAccessKey, + UseSSL: Params.MinioCfg.UseSSL, + BucketName: Params.MinioCfg.BucketName, + CreateBucket: true, + } + + kv, err := minioKV.NewMinIOKV(context.Background(), option) + if err != nil { + return nil, err + } + + // save index to minio + binarySet, err := index.Serialize() + if err != nil { + return nil, err + } + + // serialize index params + indexCodec := storage.NewIndexFileBinlogCodec() + serializedIndexBlobs, err := indexCodec.Serialize( + buildID, + 0, + defaultCollectionID, + defaultPartitionID, + defaultSegmentID, + simpleVecField.id, + indexParams, + indexName, + indexID, + binarySet, + ) + if err != nil { + return nil, err + } + + indexPaths := make([]string, 0) + for _, index := range serializedIndexBlobs { + p := strconv.Itoa(int(segmentID)) + "/" + index.Key + indexPaths = append(indexPaths, p) + err := kv.Save(p, string(index.Value)) + if err != nil { + return nil, err + } + } + + return indexPaths, nil +} + +func genIndexParams(indexType, metricType string) (map[string]string, map[string]string) { + typeParams := make(map[string]string) + indexParams := make(map[string]string) + indexParams["index_type"] = indexType + indexParams["metric_type"] = metricType + if indexType == IndexFaissIDMap { // float vector + indexParams["dim"] = strconv.Itoa(defaultDim) + indexParams["SLICE_SIZE"] = strconv.Itoa(sliceSize) + } else if indexType == IndexFaissIVFFlat { + indexParams["dim"] = strconv.Itoa(defaultDim) + indexParams["nlist"] = strconv.Itoa(nlist) + } else if indexType == IndexFaissIVFPQ { + indexParams["dim"] = strconv.Itoa(defaultDim) + indexParams["nlist"] = strconv.Itoa(nlist) + indexParams["m"] = strconv.Itoa(m) + indexParams["nbits"] = strconv.Itoa(nbits) + indexParams["SLICE_SIZE"] = strconv.Itoa(sliceSize) + } else if indexType == IndexFaissIVFSQ8 { + indexParams["dim"] = strconv.Itoa(defaultDim) + indexParams["nlist"] = strconv.Itoa(nlist) + indexParams["nbits"] = strconv.Itoa(nbits) + indexParams["SLICE_SIZE"] = strconv.Itoa(sliceSize) + } else if indexType == IndexFaissIVFSQ8H { + // TODO: enable gpu + } else if indexType == IndexNsg { + indexParams["dim"] = strconv.Itoa(defaultDim) + indexParams["nlist"] = strconv.Itoa(163) + indexParams["nprobe"] = strconv.Itoa(nprobe) + indexParams["knng"] = strconv.Itoa(20) + indexParams["search_length"] = strconv.Itoa(40) + indexParams["out_degree"] = strconv.Itoa(30) + indexParams["candidate_pool_size"] = strconv.Itoa(100) + } else if indexType == IndexHNSW { + indexParams["dim"] = strconv.Itoa(defaultDim) + indexParams["M"] = strconv.Itoa(16) + indexParams["efConstruction"] = strconv.Itoa(efConstruction) + //indexParams["ef"] = strconv.Itoa(ef) + } else if indexType == IndexRHNSWFlat { + indexParams["dim"] = strconv.Itoa(defaultDim) + indexParams["m"] = strconv.Itoa(16) + indexParams["efConstruction"] = strconv.Itoa(efConstruction) + indexParams["ef"] = strconv.Itoa(ef) + indexParams["SLICE_SIZE"] = strconv.Itoa(sliceSize) + } else if indexType == IndexRHNSWPQ { + indexParams["dim"] = strconv.Itoa(defaultDim) + indexParams["m"] = strconv.Itoa(16) + indexParams["efConstruction"] = strconv.Itoa(efConstruction) + indexParams["ef"] = strconv.Itoa(ef) + indexParams["SLICE_SIZE"] = strconv.Itoa(sliceSize) + indexParams["PQM"] = strconv.Itoa(8) + } else if indexType == IndexRHNSWSQ { + indexParams["dim"] = strconv.Itoa(defaultDim) + indexParams["m"] = strconv.Itoa(16) + indexParams["efConstruction"] = strconv.Itoa(efConstruction) + indexParams["ef"] = strconv.Itoa(ef) + indexParams["SLICE_SIZE"] = strconv.Itoa(sliceSize) + } else if indexType == IndexANNOY { + indexParams["dim"] = strconv.Itoa(defaultDim) + indexParams["n_trees"] = strconv.Itoa(4) + indexParams["search_k"] = strconv.Itoa(100) + indexParams["SLICE_SIZE"] = strconv.Itoa(sliceSize) + } else if indexType == IndexNGTPANNG { + indexParams["dim"] = strconv.Itoa(defaultDim) + indexParams["edge_size"] = strconv.Itoa(edgeSize) + indexParams["epsilon"] = fmt.Sprint(epsilon) + indexParams["max_search_edges"] = strconv.Itoa(maxSearchEdges) + indexParams["forcedly_pruned_edge_size"] = strconv.Itoa(60) + indexParams["selectively_pruned_edge_size"] = strconv.Itoa(30) + indexParams["SLICE_SIZE"] = strconv.Itoa(sliceSize) + } else if indexType == IndexNGTONNG { + indexParams["dim"] = strconv.Itoa(defaultDim) + indexParams["edge_size"] = strconv.Itoa(edgeSize) + indexParams["epsilon"] = fmt.Sprint(epsilon) + indexParams["max_search_edges"] = strconv.Itoa(maxSearchEdges) + indexParams["outgoing_edge_size"] = strconv.Itoa(5) + indexParams["incoming_edge_size"] = strconv.Itoa(40) + indexParams["SLICE_SIZE"] = strconv.Itoa(sliceSize) + } else if indexType == IndexFaissBinIVFFlat { // binary vector + indexParams["dim"] = strconv.Itoa(defaultDim) + indexParams["nlist"] = strconv.Itoa(nlist) + indexParams["m"] = strconv.Itoa(m) + indexParams["nbits"] = strconv.Itoa(nbits) + indexParams["SLICE_SIZE"] = strconv.Itoa(sliceSize) + } else if indexType == IndexFaissBinIDMap { + indexParams["dim"] = strconv.Itoa(defaultDim) + } else { + panic("") + } + + return typeParams, indexParams +} + func genSimpleSegCoreSchema() *schemapb.CollectionSchema { fieldVec := genFloatVectorField(simpleVecField) fieldInt := genConstantField(simpleConstField) @@ -866,6 +1141,18 @@ func genSimpleSealedSegment() (*Segment, error) { defaultMsgLength) } +func genSealedSegmentWithMsgLength(msgLength int) (*Segment, error) { + schema := genSimpleSegCoreSchema() + schema2 := genSimpleInsertDataSchema() + return genSealedSegment(schema, + schema2, + defaultCollectionID, + defaultPartitionID, + defaultSegmentID, + defaultDMLChannel, + msgLength) +} + func genSimpleReplica() (ReplicaInterface, error) { kv, err := genEtcdKV() if err != nil { @@ -990,13 +1277,13 @@ func genSimpleDSL() (string, error) { return genDSL(schema, defaultNProb, defaultTopK, defaultRoundDecimal) } -func genSimplePlaceHolderGroup() ([]byte, error) { +func genPlaceHolderGroup(nq int) ([]byte, error) { placeholderValue := &milvuspb.PlaceholderValue{ Tag: "$0", Type: milvuspb.PlaceholderType_FloatVector, Values: make([][]byte, 0), } - for i := 0; i < int(defaultTopK); i++ { + for i := 0; i < nq; i++ { var vec = make([]float32, defaultDim) for j := 0; j < defaultDim; j++ { vec[j] = rand.Float32() @@ -1021,6 +1308,10 @@ func genSimplePlaceHolderGroup() ([]byte, error) { return placeGroupByte, nil } +func genSimplePlaceHolderGroup() ([]byte, error) { + return genPlaceHolderGroup(defaultNQ) +} + func genSimpleSearchPlanAndRequests() (*SearchPlan, []*searchRequest, error) { schema := genSimpleSegCoreSchema() collection := newCollection(defaultCollectionID, schema) @@ -1111,8 +1402,8 @@ func genSimpleRetrievePlan() (*RetrievePlan, error) { return plan, err } -func genSimpleSearchRequest() (*internalpb.SearchRequest, error) { - placeHolder, err := genSimplePlaceHolderGroup() +func genSearchRequest(nq int) (*internalpb.SearchRequest, error) { + placeHolder, err := genPlaceHolderGroup(nq) if err != nil { return nil, err } @@ -1130,6 +1421,10 @@ func genSimpleSearchRequest() (*internalpb.SearchRequest, error) { }, nil } +func genSimpleSearchRequest() (*internalpb.SearchRequest, error) { + return genSearchRequest(defaultNQ) +} + func genSimpleRetrieveRequest() (*internalpb.RetrieveRequest, error) { expr, err := genSimpleRetrievePlanExpr() if err != nil { @@ -1149,6 +1444,19 @@ func genSimpleRetrieveRequest() (*internalpb.RetrieveRequest, error) { }, nil } +func genSearchMsg(nq int) (*msgstream.SearchMsg, error) { + req, err := genSearchRequest(nq) + if err != nil { + return nil, err + } + msg := &msgstream.SearchMsg{ + BaseMsg: genMsgStreamBaseMsg(), + SearchRequest: *req, + } + msg.SetTimeRecorder() + return msg, nil +} + func genSimpleSearchMsg() (*msgstream.SearchMsg, error) { req, err := genSimpleSearchRequest() if err != nil {