mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
Fix task merge doesn't work (#23405)
Signed-off-by: yah01 <yang.cen@zilliz.com>
This commit is contained in:
parent
288582b2d9
commit
7c4cafc83c
@ -39,8 +39,10 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/pkg/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/util/conc"
|
||||
"github.com/milvus-io/milvus/pkg/util/etcd"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
@ -891,6 +893,39 @@ func (suite *ServiceSuite) TestSearch_Normal() {
|
||||
suite.Equal(commonpb.ErrorCode_Success, rsp.GetStatus().GetErrorCode())
|
||||
}
|
||||
|
||||
func (suite *ServiceSuite) TestSearch_Concurrent() {
|
||||
ctx := context.Background()
|
||||
// pre
|
||||
suite.TestWatchDmChannelsInt64()
|
||||
suite.TestLoadSegments_Int64()
|
||||
|
||||
// data
|
||||
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
|
||||
|
||||
concurrency := 8
|
||||
futures := make([]*conc.Future[*internalpb.SearchResults], 0, concurrency)
|
||||
for i := 0; i < concurrency; i++ {
|
||||
future := conc.Go(func() (*internalpb.SearchResults, error) {
|
||||
creq, err := suite.genCSearchRequest(1, IndexFaissIDMap, schema)
|
||||
req := &querypb.SearchRequest{
|
||||
Req: creq,
|
||||
FromShardLeader: false,
|
||||
DmlChannels: []string{suite.vchannel},
|
||||
}
|
||||
suite.NoError(err)
|
||||
return suite.node.Search(ctx, req)
|
||||
})
|
||||
futures = append(futures, future)
|
||||
}
|
||||
|
||||
err := conc.AwaitAll(futures...)
|
||||
suite.NoError(err)
|
||||
|
||||
for i := range futures {
|
||||
suite.True(merr.Ok(futures[i].Value().GetStatus()))
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *ServiceSuite) TestSearch_Failed() {
|
||||
ctx := context.Background()
|
||||
|
||||
|
||||
@ -3,15 +3,13 @@ package tasks
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
ants "github.com/panjf2000/ants/v2"
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/util/conc"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
"github.com/panjf2000/ants/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -19,9 +17,10 @@ const (
|
||||
)
|
||||
|
||||
type Scheduler struct {
|
||||
searchProcessNum *atomic.Int32
|
||||
searchWaitQueue chan *SearchTask
|
||||
mergedSearchTasks typeutil.Set[*SearchTask]
|
||||
searchProcessNum *atomic.Int32
|
||||
searchWaitQueue chan *SearchTask
|
||||
mergingSearchTasks []*SearchTask
|
||||
mergedSearchTasks chan *SearchTask
|
||||
|
||||
queryProcessQueue chan *QueryTask
|
||||
queryWaitQueue chan *QueryTask
|
||||
@ -31,14 +30,15 @@ type Scheduler struct {
|
||||
|
||||
func NewScheduler() *Scheduler {
|
||||
maxWaitTaskNum := paramtable.Get().QueryNodeCfg.MaxReceiveChanSize.GetAsInt()
|
||||
pool := conc.NewPool(runtime.GOMAXPROCS(0)*2, ants.WithPreAlloc(true))
|
||||
maxReadConcurrency := paramtable.Get().QueryNodeCfg.MaxReadConcurrency.GetAsInt()
|
||||
return &Scheduler{
|
||||
searchProcessNum: atomic.NewInt32(0),
|
||||
searchWaitQueue: make(chan *SearchTask, maxWaitTaskNum),
|
||||
mergedSearchTasks: typeutil.NewSet[*SearchTask](),
|
||||
searchProcessNum: atomic.NewInt32(0),
|
||||
searchWaitQueue: make(chan *SearchTask, maxWaitTaskNum),
|
||||
mergingSearchTasks: make([]*SearchTask, 0),
|
||||
mergedSearchTasks: make(chan *SearchTask, maxReadConcurrency),
|
||||
// queryProcessQueue: make(chan),
|
||||
|
||||
pool: pool,
|
||||
pool: conc.NewPool(maxReadConcurrency, ants.WithPreAlloc(true)),
|
||||
}
|
||||
}
|
||||
|
||||
@ -59,25 +59,11 @@ func (s *Scheduler) Add(task Task) bool {
|
||||
|
||||
// schedule all tasks in the order:
|
||||
// try execute merged tasks
|
||||
// try execute waitting tasks
|
||||
// try execute waiting tasks
|
||||
func (s *Scheduler) Schedule(ctx context.Context) {
|
||||
go s.processAll(ctx)
|
||||
|
||||
for {
|
||||
if len(s.mergedSearchTasks) > 0 {
|
||||
for task := range s.mergedSearchTasks {
|
||||
if !s.tryPromote(task) {
|
||||
break
|
||||
}
|
||||
|
||||
inQueueDuration := task.tr.RecordSpan()
|
||||
metrics.QueryNodeSQLatencyInQueue.WithLabelValues(
|
||||
fmt.Sprint(paramtable.GetNodeID()),
|
||||
metrics.SearchLabel).
|
||||
Observe(float64(inQueueDuration.Milliseconds()))
|
||||
s.process(task)
|
||||
s.mergedSearchTasks.Remove(task)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
@ -88,56 +74,74 @@ func (s *Scheduler) Schedule(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Now we have no enough resource to execute this task,
|
||||
// just wait and try to merge it with another tasks
|
||||
if !s.tryPromote(t) {
|
||||
mergeCount := 0
|
||||
mergeLimit := paramtable.Get().QueryNodeCfg.MaxGroupNQ.GetAsInt()
|
||||
outer:
|
||||
for i := 0; i < mergeLimit; i++ {
|
||||
s.mergeTasks(t)
|
||||
} else {
|
||||
s.process(t)
|
||||
mergeCount++
|
||||
metrics.QueryNodeReadTaskUnsolveLen.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec()
|
||||
|
||||
select {
|
||||
case t = <-s.searchWaitQueue:
|
||||
// Continue the loop to merge task
|
||||
default:
|
||||
break outer
|
||||
}
|
||||
}
|
||||
|
||||
metrics.QueryNodeReadTaskUnsolveLen.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec()
|
||||
for i := range s.mergingSearchTasks {
|
||||
s.mergedSearchTasks <- s.mergingSearchTasks[i]
|
||||
}
|
||||
s.mergingSearchTasks = s.mergingSearchTasks[:0]
|
||||
metrics.QueryNodeReadTaskReadyLen.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Set(float64(mergeCount))
|
||||
}
|
||||
|
||||
metrics.QueryNodeReadTaskReadyLen.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Set(float64(s.mergedSearchTasks.Len()))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scheduler) tryPromote(t Task) bool {
|
||||
current := s.searchProcessNum.Load()
|
||||
if current >= MaxProcessTaskNum ||
|
||||
!s.searchProcessNum.CAS(current, current+1) {
|
||||
return false
|
||||
}
|
||||
func (s *Scheduler) processAll(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
||||
return true
|
||||
case task := <-s.mergedSearchTasks:
|
||||
inQueueDuration := task.tr.RecordSpan()
|
||||
metrics.QueryNodeSQLatencyInQueue.WithLabelValues(
|
||||
fmt.Sprint(paramtable.GetNodeID()),
|
||||
metrics.SearchLabel).
|
||||
Observe(float64(inQueueDuration.Milliseconds()))
|
||||
|
||||
s.process(task)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scheduler) process(t Task) {
|
||||
s.pool.Submit(func() (interface{}, error) {
|
||||
s.pool.Submit(func() (any, error) {
|
||||
metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc()
|
||||
|
||||
err := t.Execute()
|
||||
t.Done(err)
|
||||
s.searchProcessNum.Dec()
|
||||
|
||||
metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec()
|
||||
return nil, err
|
||||
})
|
||||
}
|
||||
|
||||
// mergeTasks merge the given task with one of merged tasks,
|
||||
func (s *Scheduler) mergeTasks(t Task) {
|
||||
switch t := t.(type) {
|
||||
case *SearchTask:
|
||||
merged := false
|
||||
for task := range s.mergedSearchTasks {
|
||||
for _, task := range s.mergingSearchTasks {
|
||||
if task.Merge(t) {
|
||||
merged = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !merged {
|
||||
s.mergedSearchTasks.Insert(t)
|
||||
s.mergingSearchTasks = append(s.mergingSearchTasks, t)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
@ -27,15 +28,18 @@ type Task interface {
|
||||
}
|
||||
|
||||
type SearchTask struct {
|
||||
ctx context.Context
|
||||
collection *segments.Collection
|
||||
segmentManager *segments.Manager
|
||||
req *querypb.SearchRequest
|
||||
result *internalpb.SearchResults
|
||||
originTopks []int64
|
||||
originNqs []int64
|
||||
others []*SearchTask
|
||||
notifier chan error
|
||||
ctx context.Context
|
||||
collection *segments.Collection
|
||||
segmentManager *segments.Manager
|
||||
req *querypb.SearchRequest
|
||||
result *internalpb.SearchResults
|
||||
topk int64
|
||||
nq int64
|
||||
placeholderGroup []byte
|
||||
originTopks []int64
|
||||
originNqs []int64
|
||||
others []*SearchTask
|
||||
notifier chan error
|
||||
|
||||
tr *timerecord.TimeRecorder
|
||||
}
|
||||
@ -46,13 +50,16 @@ func NewSearchTask(ctx context.Context,
|
||||
req *querypb.SearchRequest,
|
||||
) *SearchTask {
|
||||
return &SearchTask{
|
||||
ctx: ctx,
|
||||
collection: collection,
|
||||
segmentManager: manager,
|
||||
req: req,
|
||||
originTopks: []int64{req.GetReq().GetTopk()},
|
||||
originNqs: []int64{req.GetReq().GetNq()},
|
||||
notifier: make(chan error, 1),
|
||||
ctx: ctx,
|
||||
collection: collection,
|
||||
segmentManager: manager,
|
||||
req: req,
|
||||
topk: req.GetReq().GetTopk(),
|
||||
nq: req.GetReq().GetNq(),
|
||||
placeholderGroup: req.GetReq().GetPlaceholderGroup(),
|
||||
originTopks: []int64{req.GetReq().GetTopk()},
|
||||
originNqs: []int64{req.GetReq().GetNq()},
|
||||
notifier: make(chan error, 1),
|
||||
|
||||
tr: timerecord.NewTimeRecorderWithTrace(ctx, "searchTask"),
|
||||
}
|
||||
@ -63,8 +70,10 @@ func (t *SearchTask) Execute() error {
|
||||
zap.Int64("collectionID", t.collection.ID()),
|
||||
zap.String("shard", t.req.GetDmlChannels()[0]),
|
||||
)
|
||||
|
||||
req := t.req
|
||||
searchReq, err := segments.NewSearchRequest(t.collection, req, req.GetReq().GetPlaceholderGroup())
|
||||
t.combinePlaceHolderGroups()
|
||||
searchReq, err := segments.NewSearchRequest(t.collection, req, t.placeholderGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -96,14 +105,22 @@ func (t *SearchTask) Execute() error {
|
||||
defer segments.DeleteSearchResults(results)
|
||||
|
||||
if len(results) == 0 {
|
||||
t.result = &internalpb.SearchResults{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
|
||||
MetricType: req.GetReq().GetMetricType(),
|
||||
NumQueries: req.GetReq().GetNq(),
|
||||
TopK: req.GetReq().GetTopk(),
|
||||
SlicedBlob: nil,
|
||||
SlicedOffset: 1,
|
||||
SlicedNumCount: 1,
|
||||
for i := range t.originNqs {
|
||||
var task *SearchTask
|
||||
if i == 0 {
|
||||
task = t
|
||||
} else {
|
||||
task = t.others[i-1]
|
||||
}
|
||||
|
||||
task.result = &internalpb.SearchResults{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
|
||||
MetricType: req.GetReq().GetMetricType(),
|
||||
NumQueries: t.originNqs[i],
|
||||
TopK: t.originTopks[i],
|
||||
SlicedOffset: 1,
|
||||
SlicedNumCount: 1,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -113,8 +130,8 @@ func (t *SearchTask) Execute() error {
|
||||
searchReq.Plan(),
|
||||
results,
|
||||
int64(len(results)),
|
||||
[]int64{req.GetReq().GetNq()},
|
||||
[]int64{req.GetReq().GetTopk()},
|
||||
t.originNqs,
|
||||
t.originTopks,
|
||||
)
|
||||
if err != nil {
|
||||
log.Warn("failed to reduce search results", zap.Error(err))
|
||||
@ -122,36 +139,45 @@ func (t *SearchTask) Execute() error {
|
||||
}
|
||||
defer segments.DeleteSearchResultDataBlobs(blobs)
|
||||
|
||||
blob, err := segments.GetSearchResultDataBlob(blobs, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i := range t.originNqs {
|
||||
blob, err := segments.GetSearchResultDataBlob(blobs, i)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Note: blob is unsafe because get from C
|
||||
bs := make([]byte, len(blob))
|
||||
copy(bs, blob)
|
||||
var task *SearchTask
|
||||
if i == 0 {
|
||||
task = t
|
||||
} else {
|
||||
task = t.others[i-1]
|
||||
}
|
||||
|
||||
metrics.QueryNodeReduceLatency.WithLabelValues(
|
||||
fmt.Sprint(paramtable.GetNodeID()),
|
||||
metrics.SearchLabel).
|
||||
Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
// Note: blob is unsafe because get from C
|
||||
bs := make([]byte, len(blob))
|
||||
copy(bs, blob)
|
||||
|
||||
t.result = &internalpb.SearchResults{
|
||||
Status: util.WrapStatus(commonpb.ErrorCode_Success, ""),
|
||||
MetricType: req.GetReq().GetMetricType(),
|
||||
NumQueries: req.GetReq().GetNq(),
|
||||
TopK: req.GetReq().GetTopk(),
|
||||
SlicedBlob: bs,
|
||||
SlicedOffset: 1,
|
||||
SlicedNumCount: 1,
|
||||
metrics.QueryNodeReduceLatency.WithLabelValues(
|
||||
fmt.Sprint(paramtable.GetNodeID()),
|
||||
metrics.SearchLabel).
|
||||
Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
|
||||
task.result = &internalpb.SearchResults{
|
||||
Status: util.WrapStatus(commonpb.ErrorCode_Success, ""),
|
||||
MetricType: req.GetReq().GetMetricType(),
|
||||
NumQueries: t.originNqs[i],
|
||||
TopK: t.originTopks[i],
|
||||
SlicedBlob: bs,
|
||||
SlicedOffset: 1,
|
||||
SlicedNumCount: 1,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *SearchTask) Merge(other *SearchTask) bool {
|
||||
var (
|
||||
nq = t.req.GetReq().GetNq()
|
||||
topk = t.req.GetReq().GetTopk()
|
||||
nq = t.nq
|
||||
topk = t.topk
|
||||
otherNq = other.req.GetReq().GetNq()
|
||||
otherTopk = other.req.GetReq().GetTopk()
|
||||
)
|
||||
@ -176,8 +202,8 @@ func (t *SearchTask) Merge(other *SearchTask) bool {
|
||||
}
|
||||
|
||||
// Merge
|
||||
t.req.GetReq().Topk = maxTopk
|
||||
t.req.GetReq().Nq += otherNq
|
||||
t.topk = maxTopk
|
||||
t.nq += otherNq
|
||||
t.originTopks = append(t.originTopks, other.originTopks...)
|
||||
t.originNqs = append(t.originNqs, other.originNqs...)
|
||||
t.others = append(t.others, other)
|
||||
@ -210,5 +236,19 @@ func (t *SearchTask) Result() *internalpb.SearchResults {
|
||||
return t.result
|
||||
}
|
||||
|
||||
// combinePlaceHolderGroups combine all the placeholder groups.
|
||||
func (t *SearchTask) combinePlaceHolderGroups() {
|
||||
if len(t.others) > 0 {
|
||||
ret := &commonpb.PlaceholderGroup{}
|
||||
_ = proto.Unmarshal(t.placeholderGroup, ret)
|
||||
for _, t := range t.others {
|
||||
x := &commonpb.PlaceholderGroup{}
|
||||
_ = proto.Unmarshal(t.placeholderGroup, x)
|
||||
ret.Placeholders[0].Values = append(ret.Placeholders[0].Values, x.Placeholders[0].Values...)
|
||||
}
|
||||
t.placeholderGroup, _ = proto.Marshal(ret)
|
||||
}
|
||||
}
|
||||
|
||||
type QueryTask struct {
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user