diff --git a/internal/querynodev2/services_test.go b/internal/querynodev2/services_test.go index 9043f1c788..1d2a322459 100644 --- a/internal/querynodev2/services_test.go +++ b/internal/querynodev2/services_test.go @@ -39,8 +39,10 @@ import ( "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -891,6 +893,39 @@ func (suite *ServiceSuite) TestSearch_Normal() { suite.Equal(commonpb.ErrorCode_Success, rsp.GetStatus().GetErrorCode()) } +func (suite *ServiceSuite) TestSearch_Concurrent() { + ctx := context.Background() + // pre + suite.TestWatchDmChannelsInt64() + suite.TestLoadSegments_Int64() + + // data + schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64) + + concurrency := 8 + futures := make([]*conc.Future[*internalpb.SearchResults], 0, concurrency) + for i := 0; i < concurrency; i++ { + future := conc.Go(func() (*internalpb.SearchResults, error) { + creq, err := suite.genCSearchRequest(1, IndexFaissIDMap, schema) + req := &querypb.SearchRequest{ + Req: creq, + FromShardLeader: false, + DmlChannels: []string{suite.vchannel}, + } + suite.NoError(err) + return suite.node.Search(ctx, req) + }) + futures = append(futures, future) + } + + err := conc.AwaitAll(futures...) + suite.NoError(err) + + for i := range futures { + suite.True(merr.Ok(futures[i].Value().GetStatus())) + } +} + func (suite *ServiceSuite) TestSearch_Failed() { ctx := context.Background() diff --git a/internal/querynodev2/tasks/scheduler.go b/internal/querynodev2/tasks/scheduler.go index 55e4a80799..e9fbe21f47 100644 --- a/internal/querynodev2/tasks/scheduler.go +++ b/internal/querynodev2/tasks/scheduler.go @@ -3,15 +3,13 @@ package tasks import ( "context" "fmt" - "runtime" - ants "github.com/panjf2000/ants/v2" "go.uber.org/atomic" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/milvus-io/milvus/pkg/util/typeutil" + "github.com/panjf2000/ants/v2" ) const ( @@ -19,9 +17,10 @@ const ( ) type Scheduler struct { - searchProcessNum *atomic.Int32 - searchWaitQueue chan *SearchTask - mergedSearchTasks typeutil.Set[*SearchTask] + searchProcessNum *atomic.Int32 + searchWaitQueue chan *SearchTask + mergingSearchTasks []*SearchTask + mergedSearchTasks chan *SearchTask queryProcessQueue chan *QueryTask queryWaitQueue chan *QueryTask @@ -31,14 +30,15 @@ type Scheduler struct { func NewScheduler() *Scheduler { maxWaitTaskNum := paramtable.Get().QueryNodeCfg.MaxReceiveChanSize.GetAsInt() - pool := conc.NewPool(runtime.GOMAXPROCS(0)*2, ants.WithPreAlloc(true)) + maxReadConcurrency := paramtable.Get().QueryNodeCfg.MaxReadConcurrency.GetAsInt() return &Scheduler{ - searchProcessNum: atomic.NewInt32(0), - searchWaitQueue: make(chan *SearchTask, maxWaitTaskNum), - mergedSearchTasks: typeutil.NewSet[*SearchTask](), + searchProcessNum: atomic.NewInt32(0), + searchWaitQueue: make(chan *SearchTask, maxWaitTaskNum), + mergingSearchTasks: make([]*SearchTask, 0), + mergedSearchTasks: make(chan *SearchTask, maxReadConcurrency), // queryProcessQueue: make(chan), - pool: pool, + pool: conc.NewPool(maxReadConcurrency, ants.WithPreAlloc(true)), } } @@ -59,25 +59,11 @@ func (s *Scheduler) Add(task Task) bool { // schedule all tasks in the order: // try execute merged tasks -// try execute waitting tasks +// try execute waiting tasks func (s *Scheduler) Schedule(ctx context.Context) { + go s.processAll(ctx) + for { - if len(s.mergedSearchTasks) > 0 { - for task := range s.mergedSearchTasks { - if !s.tryPromote(task) { - break - } - - inQueueDuration := task.tr.RecordSpan() - metrics.QueryNodeSQLatencyInQueue.WithLabelValues( - fmt.Sprint(paramtable.GetNodeID()), - metrics.SearchLabel). - Observe(float64(inQueueDuration.Milliseconds())) - s.process(task) - s.mergedSearchTasks.Remove(task) - } - } - select { case <-ctx.Done(): return @@ -88,56 +74,74 @@ func (s *Scheduler) Schedule(ctx context.Context) { continue } - // Now we have no enough resource to execute this task, - // just wait and try to merge it with another tasks - if !s.tryPromote(t) { + mergeCount := 0 + mergeLimit := paramtable.Get().QueryNodeCfg.MaxGroupNQ.GetAsInt() + outer: + for i := 0; i < mergeLimit; i++ { s.mergeTasks(t) - } else { - s.process(t) + mergeCount++ + metrics.QueryNodeReadTaskUnsolveLen.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec() + + select { + case t = <-s.searchWaitQueue: + // Continue the loop to merge task + default: + break outer + } } - metrics.QueryNodeReadTaskUnsolveLen.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec() + for i := range s.mergingSearchTasks { + s.mergedSearchTasks <- s.mergingSearchTasks[i] + } + s.mergingSearchTasks = s.mergingSearchTasks[:0] + metrics.QueryNodeReadTaskReadyLen.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Set(float64(mergeCount)) } - - metrics.QueryNodeReadTaskReadyLen.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Set(float64(s.mergedSearchTasks.Len())) } } -func (s *Scheduler) tryPromote(t Task) bool { - current := s.searchProcessNum.Load() - if current >= MaxProcessTaskNum || - !s.searchProcessNum.CAS(current, current+1) { - return false - } +func (s *Scheduler) processAll(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return - return true + case task := <-s.mergedSearchTasks: + inQueueDuration := task.tr.RecordSpan() + metrics.QueryNodeSQLatencyInQueue.WithLabelValues( + fmt.Sprint(paramtable.GetNodeID()), + metrics.SearchLabel). + Observe(float64(inQueueDuration.Milliseconds())) + + s.process(task) + } + } } func (s *Scheduler) process(t Task) { - s.pool.Submit(func() (interface{}, error) { + s.pool.Submit(func() (any, error) { metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc() err := t.Execute() t.Done(err) - s.searchProcessNum.Dec() metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec() return nil, err }) } +// mergeTasks merge the given task with one of merged tasks, func (s *Scheduler) mergeTasks(t Task) { switch t := t.(type) { case *SearchTask: merged := false - for task := range s.mergedSearchTasks { + for _, task := range s.mergingSearchTasks { if task.Merge(t) { merged = true break } } if !merged { - s.mergedSearchTasks.Insert(t) + s.mergingSearchTasks = append(s.mergingSearchTasks, t) } } } diff --git a/internal/querynodev2/tasks/task.go b/internal/querynodev2/tasks/task.go index eecfb647c1..1849f83f7c 100644 --- a/internal/querynodev2/tasks/task.go +++ b/internal/querynodev2/tasks/task.go @@ -5,6 +5,7 @@ import ( "context" "fmt" + "github.com/golang/protobuf/proto" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/commonpb" @@ -27,15 +28,18 @@ type Task interface { } type SearchTask struct { - ctx context.Context - collection *segments.Collection - segmentManager *segments.Manager - req *querypb.SearchRequest - result *internalpb.SearchResults - originTopks []int64 - originNqs []int64 - others []*SearchTask - notifier chan error + ctx context.Context + collection *segments.Collection + segmentManager *segments.Manager + req *querypb.SearchRequest + result *internalpb.SearchResults + topk int64 + nq int64 + placeholderGroup []byte + originTopks []int64 + originNqs []int64 + others []*SearchTask + notifier chan error tr *timerecord.TimeRecorder } @@ -46,13 +50,16 @@ func NewSearchTask(ctx context.Context, req *querypb.SearchRequest, ) *SearchTask { return &SearchTask{ - ctx: ctx, - collection: collection, - segmentManager: manager, - req: req, - originTopks: []int64{req.GetReq().GetTopk()}, - originNqs: []int64{req.GetReq().GetNq()}, - notifier: make(chan error, 1), + ctx: ctx, + collection: collection, + segmentManager: manager, + req: req, + topk: req.GetReq().GetTopk(), + nq: req.GetReq().GetNq(), + placeholderGroup: req.GetReq().GetPlaceholderGroup(), + originTopks: []int64{req.GetReq().GetTopk()}, + originNqs: []int64{req.GetReq().GetNq()}, + notifier: make(chan error, 1), tr: timerecord.NewTimeRecorderWithTrace(ctx, "searchTask"), } @@ -63,8 +70,10 @@ func (t *SearchTask) Execute() error { zap.Int64("collectionID", t.collection.ID()), zap.String("shard", t.req.GetDmlChannels()[0]), ) + req := t.req - searchReq, err := segments.NewSearchRequest(t.collection, req, req.GetReq().GetPlaceholderGroup()) + t.combinePlaceHolderGroups() + searchReq, err := segments.NewSearchRequest(t.collection, req, t.placeholderGroup) if err != nil { return err } @@ -96,14 +105,22 @@ func (t *SearchTask) Execute() error { defer segments.DeleteSearchResults(results) if len(results) == 0 { - t.result = &internalpb.SearchResults{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, - MetricType: req.GetReq().GetMetricType(), - NumQueries: req.GetReq().GetNq(), - TopK: req.GetReq().GetTopk(), - SlicedBlob: nil, - SlicedOffset: 1, - SlicedNumCount: 1, + for i := range t.originNqs { + var task *SearchTask + if i == 0 { + task = t + } else { + task = t.others[i-1] + } + + task.result = &internalpb.SearchResults{ + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + MetricType: req.GetReq().GetMetricType(), + NumQueries: t.originNqs[i], + TopK: t.originTopks[i], + SlicedOffset: 1, + SlicedNumCount: 1, + } } return nil } @@ -113,8 +130,8 @@ func (t *SearchTask) Execute() error { searchReq.Plan(), results, int64(len(results)), - []int64{req.GetReq().GetNq()}, - []int64{req.GetReq().GetTopk()}, + t.originNqs, + t.originTopks, ) if err != nil { log.Warn("failed to reduce search results", zap.Error(err)) @@ -122,36 +139,45 @@ func (t *SearchTask) Execute() error { } defer segments.DeleteSearchResultDataBlobs(blobs) - blob, err := segments.GetSearchResultDataBlob(blobs, 0) - if err != nil { - return err - } + for i := range t.originNqs { + blob, err := segments.GetSearchResultDataBlob(blobs, i) + if err != nil { + return err + } - // Note: blob is unsafe because get from C - bs := make([]byte, len(blob)) - copy(bs, blob) + var task *SearchTask + if i == 0 { + task = t + } else { + task = t.others[i-1] + } - metrics.QueryNodeReduceLatency.WithLabelValues( - fmt.Sprint(paramtable.GetNodeID()), - metrics.SearchLabel). - Observe(float64(tr.ElapseSpan().Milliseconds())) + // Note: blob is unsafe because get from C + bs := make([]byte, len(blob)) + copy(bs, blob) - t.result = &internalpb.SearchResults{ - Status: util.WrapStatus(commonpb.ErrorCode_Success, ""), - MetricType: req.GetReq().GetMetricType(), - NumQueries: req.GetReq().GetNq(), - TopK: req.GetReq().GetTopk(), - SlicedBlob: bs, - SlicedOffset: 1, - SlicedNumCount: 1, + metrics.QueryNodeReduceLatency.WithLabelValues( + fmt.Sprint(paramtable.GetNodeID()), + metrics.SearchLabel). + Observe(float64(tr.ElapseSpan().Milliseconds())) + + task.result = &internalpb.SearchResults{ + Status: util.WrapStatus(commonpb.ErrorCode_Success, ""), + MetricType: req.GetReq().GetMetricType(), + NumQueries: t.originNqs[i], + TopK: t.originTopks[i], + SlicedBlob: bs, + SlicedOffset: 1, + SlicedNumCount: 1, + } } return nil } func (t *SearchTask) Merge(other *SearchTask) bool { var ( - nq = t.req.GetReq().GetNq() - topk = t.req.GetReq().GetTopk() + nq = t.nq + topk = t.topk otherNq = other.req.GetReq().GetNq() otherTopk = other.req.GetReq().GetTopk() ) @@ -176,8 +202,8 @@ func (t *SearchTask) Merge(other *SearchTask) bool { } // Merge - t.req.GetReq().Topk = maxTopk - t.req.GetReq().Nq += otherNq + t.topk = maxTopk + t.nq += otherNq t.originTopks = append(t.originTopks, other.originTopks...) t.originNqs = append(t.originNqs, other.originNqs...) t.others = append(t.others, other) @@ -210,5 +236,19 @@ func (t *SearchTask) Result() *internalpb.SearchResults { return t.result } +// combinePlaceHolderGroups combine all the placeholder groups. +func (t *SearchTask) combinePlaceHolderGroups() { + if len(t.others) > 0 { + ret := &commonpb.PlaceholderGroup{} + _ = proto.Unmarshal(t.placeholderGroup, ret) + for _, t := range t.others { + x := &commonpb.PlaceholderGroup{} + _ = proto.Unmarshal(t.placeholderGroup, x) + ret.Placeholders[0].Values = append(ret.Placeholders[0].Values, x.Placeholders[0].Values...) + } + t.placeholderGroup, _ = proto.Marshal(ret) + } +} + type QueryTask struct { }