package proxy import ( "context" "errors" "fmt" "strconv" "sync" "testing" "time" "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/mq/msgstream" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/util/distance" "github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/timerecord" "github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/internal/util/uniquegenerator" ) func TestSearchTask(t *testing.T) { ctx := context.Background() ctxCancel, cancel := context.WithCancel(ctx) qt := &searchTask{ ctx: ctxCancel, Condition: NewTaskCondition(context.TODO()), SearchRequest: &internalpb.SearchRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_Search, SourceID: Params.ProxyCfg.ProxyID, }, ResultChannelID: strconv.FormatInt(Params.ProxyCfg.ProxyID, 10), }, resultBuf: make(chan []*internalpb.SearchResults), query: nil, chMgr: nil, qc: nil, tr: timerecord.NewTimeRecorder("search"), } // no result go func() { qt.resultBuf <- []*internalpb.SearchResults{} }() err := qt.PostExecute(context.TODO()) assert.NotNil(t, err) // test trace context done cancel() err = qt.PostExecute(context.TODO()) assert.NotNil(t, err) // error result ctx = context.Background() qt = &searchTask{ ctx: ctx, Condition: NewTaskCondition(context.TODO()), SearchRequest: &internalpb.SearchRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_Search, SourceID: Params.ProxyCfg.ProxyID, }, ResultChannelID: strconv.FormatInt(Params.ProxyCfg.ProxyID, 10), }, resultBuf: make(chan []*internalpb.SearchResults), query: nil, chMgr: nil, qc: nil, tr: timerecord.NewTimeRecorder("search"), } // no result go func() { result := internalpb.SearchResults{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: "test", }, } results := make([]*internalpb.SearchResults, 1) results[0] = &result qt.resultBuf <- results }() err = qt.PostExecute(context.TODO()) assert.NotNil(t, err) log.Debug("PostExecute failed" + err.Error()) // check result SlicedBlob ctx = context.Background() qt = &searchTask{ ctx: ctx, Condition: NewTaskCondition(context.TODO()), SearchRequest: &internalpb.SearchRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_Search, SourceID: Params.ProxyCfg.ProxyID, }, ResultChannelID: strconv.FormatInt(Params.ProxyCfg.ProxyID, 10), }, resultBuf: make(chan []*internalpb.SearchResults), query: nil, chMgr: nil, qc: nil, tr: timerecord.NewTimeRecorder("search"), } // no result go func() { result := internalpb.SearchResults{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, Reason: "test", }, SlicedBlob: nil, } results := make([]*internalpb.SearchResults, 1) results[0] = &result qt.resultBuf <- results }() err = qt.PostExecute(context.TODO()) assert.Nil(t, err) assert.Equal(t, qt.result.Status.ErrorCode, commonpb.ErrorCode_Success) // TODO, add decode result, reduce result test } func TestSearchTask_Channels(t *testing.T) { var err error Params.Init() rc := NewRootCoordMock() rc.Start() defer rc.Stop() ctx := context.Background() err = InitMetaCache(rc) assert.NoError(t, err) dmlChannelsFunc := getDmlChannelsFunc(ctx, rc) query := newMockGetChannelsService() factory := newSimpleMockMsgStreamFactory() chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory) defer chMgr.removeAllDMLStream() defer chMgr.removeAllDQLStream() prefix := "TestSearchTask_Channels" collectionName := prefix + funcutil.GenRandomStr() shardsNum := int32(2) dbName := "" int64Field := "int64" floatVecField := "fvec" dim := 128 task := &searchTask{ ctx: ctx, query: &milvuspb.SearchRequest{ CollectionName: collectionName, }, chMgr: chMgr, tr: timerecord.NewTimeRecorder("search"), } // collection not exist _, err = task.getVChannels() assert.Error(t, err) _, err = task.getVChannels() assert.Error(t, err) schema := constructCollectionSchema(int64Field, floatVecField, dim, collectionName) marshaledSchema, err := proto.Marshal(schema) assert.NoError(t, err) createColT := &createCollectionTask{ Condition: NewTaskCondition(ctx), CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ Base: nil, DbName: dbName, CollectionName: collectionName, Schema: marshaledSchema, ShardsNum: shardsNum, }, ctx: ctx, rootCoord: rc, result: nil, schema: nil, } assert.NoError(t, createColT.OnEnqueue()) assert.NoError(t, createColT.PreExecute(ctx)) assert.NoError(t, createColT.Execute(ctx)) assert.NoError(t, createColT.PostExecute(ctx)) _, err = task.getChannels() assert.NoError(t, err) _, err = task.getVChannels() assert.NoError(t, err) _ = chMgr.removeAllDMLStream() chMgr.dmlChannelsMgr.getChannelsFunc = func(collectionID UniqueID) (map[vChan]pChan, error) { return nil, errors.New("mock") } _, err = task.getChannels() assert.Error(t, err) _, err = task.getVChannels() assert.Error(t, err) } func TestSearchTask_PreExecute(t *testing.T) { var err error Params.Init() Params.ProxyCfg.SearchResultChannelNames = []string{funcutil.GenRandomStr()} rc := NewRootCoordMock() rc.Start() defer rc.Stop() qc := NewQueryCoordMock() qc.Start() defer qc.Stop() ctx := context.Background() err = InitMetaCache(rc) assert.NoError(t, err) dmlChannelsFunc := getDmlChannelsFunc(ctx, rc) query := newMockGetChannelsService() factory := newSimpleMockMsgStreamFactory() chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory) defer chMgr.removeAllDMLStream() defer chMgr.removeAllDQLStream() prefix := "TestSearchTask_PreExecute" collectionName := prefix + funcutil.GenRandomStr() shardsNum := int32(2) dbName := "" int64Field := "int64" floatVecField := "fvec" dim := 128 task := &searchTask{ ctx: ctx, SearchRequest: &internalpb.SearchRequest{}, query: &milvuspb.SearchRequest{ CollectionName: collectionName, }, chMgr: chMgr, qc: qc, tr: timerecord.NewTimeRecorder("search"), } assert.NoError(t, task.OnEnqueue()) // collection not exist assert.Error(t, task.PreExecute(ctx)) schema := constructCollectionSchema(int64Field, floatVecField, dim, collectionName) marshaledSchema, err := proto.Marshal(schema) assert.NoError(t, err) createColT := &createCollectionTask{ Condition: NewTaskCondition(ctx), CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ Base: nil, DbName: dbName, CollectionName: collectionName, Schema: marshaledSchema, ShardsNum: shardsNum, }, ctx: ctx, rootCoord: rc, result: nil, schema: nil, } assert.NoError(t, createColT.OnEnqueue()) assert.NoError(t, createColT.PreExecute(ctx)) assert.NoError(t, createColT.Execute(ctx)) assert.NoError(t, createColT.PostExecute(ctx)) collectionID, _ := globalMetaCache.GetCollectionID(ctx, collectionName) // validateCollectionName task.query.CollectionName = "$" assert.Error(t, task.PreExecute(ctx)) task.query.CollectionName = collectionName // Validate Partition task.query.PartitionNames = []string{"$"} assert.Error(t, task.PreExecute(ctx)) task.query.PartitionNames = nil // mock show collections of QueryCoord qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { return nil, errors.New("mock") }) assert.Error(t, task.PreExecute(ctx)) qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { return &querypb.ShowCollectionsResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: "mock", }, }, nil }) assert.Error(t, task.PreExecute(ctx)) qc.ResetShowCollectionsFunc() // collection not loaded assert.Error(t, task.PreExecute(ctx)) _, _ = qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_LoadCollection, MsgID: 0, Timestamp: 0, SourceID: 0, }, DbID: 0, CollectionID: collectionID, Schema: nil, }) // no anns field task.query.DslType = commonpb.DslType_BoolExprV1 assert.Error(t, task.PreExecute(ctx)) task.query.SearchParams = []*commonpb.KeyValuePair{ { Key: AnnsFieldKey, Value: floatVecField, }, } // no topk assert.Error(t, task.PreExecute(ctx)) task.query.SearchParams = []*commonpb.KeyValuePair{ { Key: AnnsFieldKey, Value: floatVecField, }, { Key: TopKKey, Value: "invalid", }, } // invalid topk assert.Error(t, task.PreExecute(ctx)) task.query.SearchParams = []*commonpb.KeyValuePair{ { Key: AnnsFieldKey, Value: floatVecField, }, { Key: TopKKey, Value: "10", }, } // no metric type assert.Error(t, task.PreExecute(ctx)) task.query.SearchParams = []*commonpb.KeyValuePair{ { Key: AnnsFieldKey, Value: floatVecField, }, { Key: TopKKey, Value: "10", }, { Key: MetricTypeKey, Value: distance.L2, }, } // no search params assert.Error(t, task.PreExecute(ctx)) task.query.SearchParams = []*commonpb.KeyValuePair{ { Key: AnnsFieldKey, Value: int64Field, }, { Key: TopKKey, Value: "10", }, { Key: MetricTypeKey, Value: distance.L2, }, { Key: SearchParamsKey, Value: `{"nprobe": 10}`, }, } // invalid round_decimal assert.Error(t, task.PreExecute(ctx)) task.query.SearchParams = []*commonpb.KeyValuePair{ { Key: AnnsFieldKey, Value: int64Field, }, { Key: TopKKey, Value: "10", }, { Key: MetricTypeKey, Value: distance.L2, }, { Key: SearchParamsKey, Value: `{"nprobe": 10}`, }, { Key: RoundDecimalKey, Value: "invalid", }, } // invalid round_decimal assert.Error(t, task.PreExecute(ctx)) task.query.SearchParams = []*commonpb.KeyValuePair{ { Key: AnnsFieldKey, Value: floatVecField, }, { Key: TopKKey, Value: "10", }, { Key: MetricTypeKey, Value: distance.L2, }, { Key: RoundDecimalKey, Value: "-1", }, } // failed to create query plan assert.Error(t, task.PreExecute(ctx)) task.query.SearchParams = []*commonpb.KeyValuePair{ { Key: AnnsFieldKey, Value: floatVecField, }, { Key: TopKKey, Value: "10", }, { Key: MetricTypeKey, Value: distance.L2, }, { Key: SearchParamsKey, Value: `{"nprobe": 10}`, }, { Key: RoundDecimalKey, Value: "-1", }, } // search task with timeout ctx1, cancel := context.WithTimeout(ctx, time.Second) defer cancel() // before preExecute assert.Equal(t, typeutil.ZeroTimestamp, task.TimeoutTimestamp) task.ctx = ctx1 assert.NoError(t, task.PreExecute(ctx)) // after preExecute assert.Greater(t, task.TimeoutTimestamp, typeutil.ZeroTimestamp) // field not exist task.query.OutputFields = []string{int64Field + funcutil.GenRandomStr()} assert.Error(t, task.PreExecute(ctx)) // contain vector field task.query.OutputFields = []string{floatVecField} assert.Error(t, task.PreExecute(ctx)) task.query.OutputFields = []string{int64Field} // partition rc.showPartitionsFunc = func(ctx context.Context, request *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { return nil, errors.New("mock") } assert.Error(t, task.PreExecute(ctx)) rc.showPartitionsFunc = nil // TODO(dragondriver): test partition-related error } func TestSearchTask_Ts(t *testing.T) { Params.Init() task := &searchTask{ SearchRequest: &internalpb.SearchRequest{ Base: nil, }, tr: timerecord.NewTimeRecorder("search"), } assert.NoError(t, task.OnEnqueue()) ts := Timestamp(time.Now().Nanosecond()) task.SetTs(ts) assert.Equal(t, ts, task.BeginTs()) assert.Equal(t, ts, task.EndTs()) } func TestSearchTask_Execute(t *testing.T) { var err error Params.Init() Params.ProxyCfg.SearchResultChannelNames = []string{funcutil.GenRandomStr()} rc := NewRootCoordMock() rc.Start() defer rc.Stop() qc := NewQueryCoordMock() qc.Start() defer qc.Stop() ctx := context.Background() err = InitMetaCache(rc) assert.NoError(t, err) dmlChannelsFunc := getDmlChannelsFunc(ctx, rc) query := newMockGetChannelsService() factory := newSimpleMockMsgStreamFactory() chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory) defer chMgr.removeAllDMLStream() defer chMgr.removeAllDQLStream() prefix := "TestSearchTask_Execute" collectionName := prefix + funcutil.GenRandomStr() shardsNum := int32(2) dbName := "" int64Field := "int64" floatVecField := "fvec" dim := 128 task := &searchTask{ ctx: ctx, SearchRequest: &internalpb.SearchRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_Search, MsgID: 0, Timestamp: uint64(time.Now().UnixNano()), SourceID: 0, }, }, query: &milvuspb.SearchRequest{ CollectionName: collectionName, }, result: &milvuspb.SearchResults{ Status: &commonpb.Status{}, Results: nil, }, chMgr: chMgr, qc: qc, tr: timerecord.NewTimeRecorder("search"), } assert.NoError(t, task.OnEnqueue()) // collection not exist assert.Error(t, task.PreExecute(ctx)) schema := constructCollectionSchema(int64Field, floatVecField, dim, collectionName) marshaledSchema, err := proto.Marshal(schema) assert.NoError(t, err) createColT := &createCollectionTask{ Condition: NewTaskCondition(ctx), CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ Base: nil, DbName: dbName, CollectionName: collectionName, Schema: marshaledSchema, ShardsNum: shardsNum, }, ctx: ctx, rootCoord: rc, result: nil, schema: nil, } assert.NoError(t, createColT.OnEnqueue()) assert.NoError(t, createColT.PreExecute(ctx)) assert.NoError(t, createColT.Execute(ctx)) assert.NoError(t, createColT.PostExecute(ctx)) assert.NoError(t, task.Execute(ctx)) _ = chMgr.removeAllDQLStream() query.f = func(collectionID UniqueID) (map[vChan]pChan, error) { return nil, errors.New("mock") } assert.Error(t, task.Execute(ctx)) // TODO(dragondriver): cover getDQLStream } func genSearchResultData(nq int64, topk int64, ids []int64, scores []float32) *schemapb.SearchResultData { return &schemapb.SearchResultData{ NumQueries: nq, TopK: topk, FieldsData: nil, Scores: scores, Ids: &schemapb.IDs{ IdField: &schemapb.IDs_IntId{ IntId: &schemapb.LongArray{ Data: ids, }, }, }, Topks: make([]int64, nq), } } func TestSearchTask_Reduce(t *testing.T) { const ( nq = 1 topk = 4 metricType = "L2" ) t.Run("case1", func(t *testing.T) { ids := []int64{1, 2, 3, 4} scores := []float32{-1.0, -2.0, -3.0, -4.0} data1 := genSearchResultData(nq, topk, ids, scores) data2 := genSearchResultData(nq, topk, ids, scores) dataArray := make([]*schemapb.SearchResultData, 0) dataArray = append(dataArray, data1) dataArray = append(dataArray, data2) res, err := reduceSearchResultData(dataArray, nq, topk, metricType) assert.Nil(t, err) assert.Equal(t, ids, res.Results.Ids.GetIntId().Data) assert.Equal(t, []float32{1.0, 2.0, 3.0, 4.0}, res.Results.Scores) }) t.Run("case2", func(t *testing.T) { ids1 := []int64{1, 2, 3, 4} scores1 := []float32{-1.0, -2.0, -3.0, -4.0} ids2 := []int64{5, 1, 3, 4} scores2 := []float32{-1.0, -1.0, -3.0, -4.0} data1 := genSearchResultData(nq, topk, ids1, scores1) data2 := genSearchResultData(nq, topk, ids2, scores2) dataArray := make([]*schemapb.SearchResultData, 0) dataArray = append(dataArray, data1) dataArray = append(dataArray, data2) res, err := reduceSearchResultData(dataArray, nq, topk, metricType) assert.Nil(t, err) assert.ElementsMatch(t, []int64{1, 5, 2, 3}, res.Results.Ids.GetIntId().Data) }) } func TestSearchTaskWithInvalidRoundDecimal(t *testing.T) { var err error Params.Init() Params.ProxyCfg.SearchResultChannelNames = []string{funcutil.GenRandomStr()} rc := NewRootCoordMock() rc.Start() defer rc.Stop() ctx := context.Background() err = InitMetaCache(rc) assert.NoError(t, err) shardsNum := int32(2) prefix := "TestSearchTask_all" dbName := "" collectionName := prefix + funcutil.GenRandomStr() dim := 128 expr := fmt.Sprintf("%s > 0", testInt64Field) nq := 10 topk := 10 roundDecimal := 7 nprobe := 10 fieldName2Types := map[string]schemapb.DataType{ testBoolField: schemapb.DataType_Bool, testInt32Field: schemapb.DataType_Int32, testInt64Field: schemapb.DataType_Int64, testFloatField: schemapb.DataType_Float, testDoubleField: schemapb.DataType_Double, testFloatVecField: schemapb.DataType_FloatVector, } if enableMultipleVectorFields { fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector } schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testInt64Field, false) marshaledSchema, err := proto.Marshal(schema) assert.NoError(t, err) createColT := &createCollectionTask{ Condition: NewTaskCondition(ctx), CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ Base: nil, DbName: dbName, CollectionName: collectionName, Schema: marshaledSchema, ShardsNum: shardsNum, }, ctx: ctx, rootCoord: rc, result: nil, schema: nil, } assert.NoError(t, createColT.OnEnqueue()) assert.NoError(t, createColT.PreExecute(ctx)) assert.NoError(t, createColT.Execute(ctx)) assert.NoError(t, createColT.PostExecute(ctx)) dmlChannelsFunc := getDmlChannelsFunc(ctx, rc) query := newMockGetChannelsService() factory := newSimpleMockMsgStreamFactory() chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory) defer chMgr.removeAllDMLStream() defer chMgr.removeAllDQLStream() collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName) assert.NoError(t, err) qc := NewQueryCoordMock() qc.Start() defer qc.Stop() status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_LoadCollection, MsgID: 0, Timestamp: 0, SourceID: Params.ProxyCfg.ProxyID, }, DbID: 0, CollectionID: collectionID, Schema: nil, }) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode) req := constructSearchRequest(dbName, collectionName, expr, testFloatVecField, nq, dim, nprobe, topk, roundDecimal) task := &searchTask{ Condition: NewTaskCondition(ctx), SearchRequest: &internalpb.SearchRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_Search, MsgID: 0, Timestamp: 0, SourceID: Params.ProxyCfg.ProxyID, }, ResultChannelID: strconv.FormatInt(Params.ProxyCfg.ProxyID, 10), DbID: 0, CollectionID: 0, PartitionIDs: nil, Dsl: "", PlaceholderGroup: nil, DslType: 0, SerializedExprPlan: nil, OutputFieldsId: nil, TravelTimestamp: 0, GuaranteeTimestamp: 0, }, ctx: ctx, resultBuf: make(chan []*internalpb.SearchResults), result: nil, query: req, chMgr: chMgr, qc: qc, tr: timerecord.NewTimeRecorder("search"), } // simple mock for query node // TODO(dragondriver): should we replace this mock using RocksMq or MemMsgStream? err = chMgr.createDQLStream(collectionID) assert.NoError(t, err) stream, err := chMgr.getDQLStream(collectionID) assert.NoError(t, err) var wg sync.WaitGroup wg.Add(1) consumeCtx, cancel := context.WithCancel(ctx) go func() { defer wg.Done() for { select { case <-consumeCtx.Done(): return case pack, ok := <-stream.Chan(): assert.True(t, ok) if pack == nil { continue } for _, msg := range pack.Msgs { _, ok := msg.(*msgstream.SearchMsg) assert.True(t, ok) // TODO(dragondriver): construct result according to the request constructSearchResulstData := func() *schemapb.SearchResultData { resultData := &schemapb.SearchResultData{ NumQueries: int64(nq), TopK: int64(topk), Scores: make([]float32, nq*topk), Ids: &schemapb.IDs{ IdField: &schemapb.IDs_IntId{ IntId: &schemapb.LongArray{ Data: make([]int64, nq*topk), }, }, }, Topks: make([]int64, nq), } fieldID := common.StartOfUserFieldID for fieldName, dataType := range fieldName2Types { resultData.FieldsData = append(resultData.FieldsData, generateFieldData(dataType, fieldName, int64(fieldID), nq*topk)) fieldID++ } for i := 0; i < nq; i++ { for j := 0; j < topk; j++ { offset := i*topk + j score := float32(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()) // increasingly id := int64(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()) resultData.Scores[offset] = score resultData.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data[offset] = id } resultData.Topks[i] = int64(topk) } return resultData } result1 := &internalpb.SearchResults{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_SearchResult, MsgID: 0, Timestamp: 0, SourceID: 0, }, Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, Reason: "", }, ResultChannelID: "", MetricType: distance.L2, NumQueries: int64(nq), TopK: int64(topk), SealedSegmentIDsSearched: nil, ChannelIDsSearched: nil, GlobalSealedSegmentIDs: nil, SlicedBlob: nil, SlicedNumCount: 1, SlicedOffset: 0, } resultData := constructSearchResulstData() sliceBlob, err := proto.Marshal(resultData) assert.NoError(t, err) result1.SlicedBlob = sliceBlob // result2.SliceBlob = nil, will be skipped in decode stage result2 := &internalpb.SearchResults{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_SearchResult, MsgID: 0, Timestamp: 0, SourceID: 0, }, Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, Reason: "", }, ResultChannelID: "", MetricType: distance.L2, NumQueries: int64(nq), TopK: int64(topk), SealedSegmentIDsSearched: nil, ChannelIDsSearched: nil, GlobalSealedSegmentIDs: nil, SlicedBlob: nil, SlicedNumCount: 1, SlicedOffset: 0, } // send search result task.resultBuf <- []*internalpb.SearchResults{result1, result2} } } } }() assert.NoError(t, task.OnEnqueue()) assert.Error(t, task.PreExecute(ctx)) cancel() wg.Wait() } func TestSearchTask_all(t *testing.T) { var err error Params.Init() Params.ProxyCfg.SearchResultChannelNames = []string{funcutil.GenRandomStr()} rc := NewRootCoordMock() rc.Start() defer rc.Stop() ctx := context.Background() err = InitMetaCache(rc) assert.NoError(t, err) shardsNum := int32(2) prefix := "TestSearchTask_all" dbName := "" collectionName := prefix + funcutil.GenRandomStr() dim := 128 expr := fmt.Sprintf("%s > 0", testInt64Field) nq := 10 topk := 10 roundDecimal := 3 nprobe := 10 fieldName2Types := map[string]schemapb.DataType{ testBoolField: schemapb.DataType_Bool, testInt32Field: schemapb.DataType_Int32, testInt64Field: schemapb.DataType_Int64, testFloatField: schemapb.DataType_Float, testDoubleField: schemapb.DataType_Double, testFloatVecField: schemapb.DataType_FloatVector, } if enableMultipleVectorFields { fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector } schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testInt64Field, false) marshaledSchema, err := proto.Marshal(schema) assert.NoError(t, err) createColT := &createCollectionTask{ Condition: NewTaskCondition(ctx), CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ Base: nil, DbName: dbName, CollectionName: collectionName, Schema: marshaledSchema, ShardsNum: shardsNum, }, ctx: ctx, rootCoord: rc, result: nil, schema: nil, } assert.NoError(t, createColT.OnEnqueue()) assert.NoError(t, createColT.PreExecute(ctx)) assert.NoError(t, createColT.Execute(ctx)) assert.NoError(t, createColT.PostExecute(ctx)) dmlChannelsFunc := getDmlChannelsFunc(ctx, rc) query := newMockGetChannelsService() factory := newSimpleMockMsgStreamFactory() chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory) defer chMgr.removeAllDMLStream() defer chMgr.removeAllDQLStream() collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName) assert.NoError(t, err) qc := NewQueryCoordMock() qc.Start() defer qc.Stop() status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_LoadCollection, MsgID: 0, Timestamp: 0, SourceID: Params.ProxyCfg.ProxyID, }, DbID: 0, CollectionID: collectionID, Schema: nil, }) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode) req := constructSearchRequest(dbName, collectionName, expr, testFloatVecField, nq, dim, nprobe, topk, roundDecimal) task := &searchTask{ Condition: NewTaskCondition(ctx), SearchRequest: &internalpb.SearchRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_Search, MsgID: 0, Timestamp: 0, SourceID: Params.ProxyCfg.ProxyID, }, ResultChannelID: strconv.FormatInt(Params.ProxyCfg.ProxyID, 10), DbID: 0, CollectionID: 0, PartitionIDs: nil, Dsl: "", PlaceholderGroup: nil, DslType: 0, SerializedExprPlan: nil, OutputFieldsId: nil, TravelTimestamp: 0, GuaranteeTimestamp: 0, }, ctx: ctx, resultBuf: make(chan []*internalpb.SearchResults), result: nil, query: req, chMgr: chMgr, qc: qc, tr: timerecord.NewTimeRecorder("search"), } // simple mock for query node // TODO(dragondriver): should we replace this mock using RocksMq or MemMsgStream? err = chMgr.createDQLStream(collectionID) assert.NoError(t, err) stream, err := chMgr.getDQLStream(collectionID) assert.NoError(t, err) var wg sync.WaitGroup wg.Add(1) consumeCtx, cancel := context.WithCancel(ctx) go func() { defer wg.Done() for { select { case <-consumeCtx.Done(): return case pack, ok := <-stream.Chan(): assert.True(t, ok) if pack == nil { continue } for _, msg := range pack.Msgs { _, ok := msg.(*msgstream.SearchMsg) assert.True(t, ok) // TODO(dragondriver): construct result according to the request constructSearchResulstData := func() *schemapb.SearchResultData { resultData := &schemapb.SearchResultData{ NumQueries: int64(nq), TopK: int64(topk), Scores: make([]float32, nq*topk), Ids: &schemapb.IDs{ IdField: &schemapb.IDs_IntId{ IntId: &schemapb.LongArray{ Data: make([]int64, nq*topk), }, }, }, Topks: make([]int64, nq), } fieldID := common.StartOfUserFieldID for fieldName, dataType := range fieldName2Types { resultData.FieldsData = append(resultData.FieldsData, generateFieldData(dataType, fieldName, int64(fieldID), nq*topk)) fieldID++ } for i := 0; i < nq; i++ { for j := 0; j < topk; j++ { offset := i*topk + j score := float32(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()) // increasingly id := int64(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()) resultData.Scores[offset] = score resultData.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data[offset] = id } resultData.Topks[i] = int64(topk) } return resultData } result1 := &internalpb.SearchResults{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_SearchResult, MsgID: 0, Timestamp: 0, SourceID: 0, }, Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, Reason: "", }, ResultChannelID: "", MetricType: distance.L2, NumQueries: int64(nq), TopK: int64(topk), SealedSegmentIDsSearched: nil, ChannelIDsSearched: nil, GlobalSealedSegmentIDs: nil, SlicedBlob: nil, SlicedNumCount: 1, SlicedOffset: 0, } resultData := constructSearchResulstData() sliceBlob, err := proto.Marshal(resultData) assert.NoError(t, err) result1.SlicedBlob = sliceBlob // result2.SliceBlob = nil, will be skipped in decode stage result2 := &internalpb.SearchResults{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_SearchResult, MsgID: 0, Timestamp: 0, SourceID: 0, }, Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, Reason: "", }, ResultChannelID: "", MetricType: distance.L2, NumQueries: int64(nq), TopK: int64(topk), SealedSegmentIDsSearched: nil, ChannelIDsSearched: nil, GlobalSealedSegmentIDs: nil, SlicedBlob: nil, SlicedNumCount: 1, SlicedOffset: 0, } // send search result task.resultBuf <- []*internalpb.SearchResults{result1, result2} } } } }() assert.NoError(t, task.OnEnqueue()) assert.NoError(t, task.PreExecute(ctx)) assert.NoError(t, task.Execute(ctx)) assert.NoError(t, task.PostExecute(ctx)) cancel() wg.Wait() } func TestSearchTask_7803_reduce(t *testing.T) { var err error Params.Init() Params.ProxyCfg.SearchResultChannelNames = []string{funcutil.GenRandomStr()} rc := NewRootCoordMock() rc.Start() defer rc.Stop() ctx := context.Background() err = InitMetaCache(rc) assert.NoError(t, err) shardsNum := int32(2) prefix := "TestSearchTask_7803_reduce" dbName := "" collectionName := prefix + funcutil.GenRandomStr() int64Field := "int64" floatVecField := "fvec" dim := 128 expr := fmt.Sprintf("%s > 0", int64Field) nq := 10 topk := 10 roundDecimal := 3 nprobe := 10 schema := constructCollectionSchema( int64Field, floatVecField, dim, collectionName) marshaledSchema, err := proto.Marshal(schema) assert.NoError(t, err) createColT := &createCollectionTask{ Condition: NewTaskCondition(ctx), CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ Base: nil, DbName: dbName, CollectionName: collectionName, Schema: marshaledSchema, ShardsNum: shardsNum, }, ctx: ctx, rootCoord: rc, result: nil, schema: nil, } assert.NoError(t, createColT.OnEnqueue()) assert.NoError(t, createColT.PreExecute(ctx)) assert.NoError(t, createColT.Execute(ctx)) assert.NoError(t, createColT.PostExecute(ctx)) dmlChannelsFunc := getDmlChannelsFunc(ctx, rc) query := newMockGetChannelsService() factory := newSimpleMockMsgStreamFactory() chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory) defer chMgr.removeAllDMLStream() defer chMgr.removeAllDQLStream() collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName) assert.NoError(t, err) qc := NewQueryCoordMock() qc.Start() defer qc.Stop() status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_LoadCollection, MsgID: 0, Timestamp: 0, SourceID: Params.ProxyCfg.ProxyID, }, DbID: 0, CollectionID: collectionID, Schema: nil, }) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode) req := constructSearchRequest(dbName, collectionName, expr, floatVecField, nq, dim, nprobe, topk, roundDecimal) task := &searchTask{ Condition: NewTaskCondition(ctx), SearchRequest: &internalpb.SearchRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_Search, MsgID: 0, Timestamp: 0, SourceID: Params.ProxyCfg.ProxyID, }, ResultChannelID: strconv.FormatInt(Params.ProxyCfg.ProxyID, 10), DbID: 0, CollectionID: 0, PartitionIDs: nil, Dsl: "", PlaceholderGroup: nil, DslType: 0, SerializedExprPlan: nil, OutputFieldsId: nil, TravelTimestamp: 0, GuaranteeTimestamp: 0, }, ctx: ctx, resultBuf: make(chan []*internalpb.SearchResults), result: nil, query: req, chMgr: chMgr, qc: qc, tr: timerecord.NewTimeRecorder("search"), } // simple mock for query node // TODO(dragondriver): should we replace this mock using RocksMq or MemMsgStream? err = chMgr.createDQLStream(collectionID) assert.NoError(t, err) stream, err := chMgr.getDQLStream(collectionID) assert.NoError(t, err) var wg sync.WaitGroup wg.Add(1) consumeCtx, cancel := context.WithCancel(ctx) go func() { defer wg.Done() for { select { case <-consumeCtx.Done(): return case pack, ok := <-stream.Chan(): assert.True(t, ok) if pack == nil { continue } for _, msg := range pack.Msgs { _, ok := msg.(*msgstream.SearchMsg) assert.True(t, ok) // TODO(dragondriver): construct result according to the request constructSearchResulstData := func(invalidNum int) *schemapb.SearchResultData { resultData := &schemapb.SearchResultData{ NumQueries: int64(nq), TopK: int64(topk), FieldsData: nil, Scores: make([]float32, nq*topk), Ids: &schemapb.IDs{ IdField: &schemapb.IDs_IntId{ IntId: &schemapb.LongArray{ Data: make([]int64, nq*topk), }, }, }, Topks: make([]int64, nq), } for i := 0; i < nq; i++ { for j := 0; j < topk; j++ { offset := i*topk + j if j >= invalidNum { resultData.Scores[offset] = minFloat32 resultData.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data[offset] = -1 } else { score := float32(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()) // increasingly id := int64(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()) resultData.Scores[offset] = score resultData.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data[offset] = id } } resultData.Topks[i] = int64(topk) } return resultData } result1 := &internalpb.SearchResults{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_SearchResult, MsgID: 0, Timestamp: 0, SourceID: 0, }, Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, Reason: "", }, ResultChannelID: "", MetricType: distance.L2, NumQueries: int64(nq), TopK: int64(topk), SealedSegmentIDsSearched: nil, ChannelIDsSearched: nil, GlobalSealedSegmentIDs: nil, SlicedBlob: nil, SlicedNumCount: 1, SlicedOffset: 0, } resultData := constructSearchResulstData(topk / 2) sliceBlob, err := proto.Marshal(resultData) assert.NoError(t, err) result1.SlicedBlob = sliceBlob result2 := &internalpb.SearchResults{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_SearchResult, MsgID: 0, Timestamp: 0, SourceID: 0, }, Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, Reason: "", }, ResultChannelID: "", MetricType: distance.L2, NumQueries: int64(nq), TopK: int64(topk), SealedSegmentIDsSearched: nil, ChannelIDsSearched: nil, GlobalSealedSegmentIDs: nil, SlicedBlob: nil, SlicedNumCount: 1, SlicedOffset: 0, } resultData2 := constructSearchResulstData(topk - topk/2) sliceBlob2, err := proto.Marshal(resultData2) assert.NoError(t, err) result2.SlicedBlob = sliceBlob2 // send search result task.resultBuf <- []*internalpb.SearchResults{result1, result2} } } } }() assert.NoError(t, task.OnEnqueue()) assert.NoError(t, task.PreExecute(ctx)) assert.NoError(t, task.Execute(ctx)) assert.NoError(t, task.PostExecute(ctx)) cancel() wg.Wait() }