From 8475d8b19384c8e52190428fd3bccbca33de8cf2 Mon Sep 17 00:00:00 2001 From: Bingyi Sun Date: Tue, 21 Jan 2025 15:25:05 +0800 Subject: [PATCH] fix: cherry pick warmup async (#39402) (#39474) related pr: https://github.com/milvus-io/milvus/pull/38690 issue: https://github.com/milvus-io/milvus/issues/38692 Signed-off-by: sunby --- .../querynodev2/delegator/delegator_data.go | 1 + internal/querynodev2/segments/manager_test.go | 1 + internal/querynodev2/segments/pool.go | 2 +- .../querynodev2/segments/retrieve_test.go | 2 + internal/querynodev2/segments/search_test.go | 2 + internal/querynodev2/segments/segment.go | 67 +++++++++++++++++-- .../querynodev2/segments/segment_loader.go | 7 +- internal/querynodev2/segments/segment_test.go | 24 +++++++ internal/querynodev2/server_test.go | 1 + 9 files changed, 100 insertions(+), 7 deletions(-) diff --git a/internal/querynodev2/delegator/delegator_data.go b/internal/querynodev2/delegator/delegator_data.go index fc2d7003a1..d82fd9f3a5 100644 --- a/internal/querynodev2/delegator/delegator_data.go +++ b/internal/querynodev2/delegator/delegator_data.go @@ -116,6 +116,7 @@ func (sd *shardDelegator) ProcessInsert(insertRecords map[int64]*InsertData) { DeltaPosition: insertData.StartPosition, Level: datapb.SegmentLevel_L1, }, + nil, ) if err != nil { log.Error("failed to create new segment", diff --git a/internal/querynodev2/segments/manager_test.go b/internal/querynodev2/segments/manager_test.go index 5904f10bd8..474a520ab2 100644 --- a/internal/querynodev2/segments/manager_test.go +++ b/internal/querynodev2/segments/manager_test.go @@ -68,6 +68,7 @@ func (s *ManagerSuite) SetupTest() { InsertChannel: s.channels[i], Level: s.levels[i], }, + nil, ) s.Require().NoError(err) s.segments = append(s.segments, segment) diff --git a/internal/querynodev2/segments/pool.go b/internal/querynodev2/segments/pool.go index 578a9c37ef..ca9971caa9 100644 --- a/internal/querynodev2/segments/pool.go +++ b/internal/querynodev2/segments/pool.go @@ -142,7 +142,7 @@ func initWarmupPool() { runtime.LockOSThread() C.SetThreadName(cgoTagWarmup) }), // lock os thread for cgo thread disposal - conc.WithNonBlocking(true), // make warming up non blocking + conc.WithNonBlocking(false), ) warmupPool.Store(pool) diff --git a/internal/querynodev2/segments/retrieve_test.go b/internal/querynodev2/segments/retrieve_test.go index f294fd2b1e..b377df3ee6 100644 --- a/internal/querynodev2/segments/retrieve_test.go +++ b/internal/querynodev2/segments/retrieve_test.go @@ -98,6 +98,7 @@ func (suite *RetrieveSuite) SetupTest() { InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID), Level: datapb.SegmentLevel_Legacy, }, + nil, ) suite.Require().NoError(err) @@ -126,6 +127,7 @@ func (suite *RetrieveSuite) SetupTest() { InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID), Level: datapb.SegmentLevel_Legacy, }, + nil, ) suite.Require().NoError(err) diff --git a/internal/querynodev2/segments/search_test.go b/internal/querynodev2/segments/search_test.go index 14771485ec..eac94af7ec 100644 --- a/internal/querynodev2/segments/search_test.go +++ b/internal/querynodev2/segments/search_test.go @@ -88,6 +88,7 @@ func (suite *SearchSuite) SetupTest() { InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID), Level: datapb.SegmentLevel_Legacy, }, + nil, ) suite.Require().NoError(err) @@ -116,6 +117,7 @@ func (suite *SearchSuite) SetupTest() { InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID), Level: datapb.SegmentLevel_Legacy, }, + nil, ) suite.Require().NoError(err) diff --git a/internal/querynodev2/segments/segment.go b/internal/querynodev2/segments/segment.go index fc996ec35a..1f647f9f40 100644 --- a/internal/querynodev2/segments/segment.go +++ b/internal/querynodev2/segments/segment.go @@ -30,6 +30,7 @@ import ( "context" "fmt" "strings" + "sync" "time" "unsafe" @@ -295,6 +296,7 @@ type LocalSegment struct { lastDeltaTimestamp *atomic.Uint64 fields *typeutil.ConcurrentMap[int64, *FieldInfo] fieldIndexes *typeutil.ConcurrentMap[int64, *IndexedFieldInfo] + warmupDispatcher *AsyncWarmupDispatcher } func NewSegment(ctx context.Context, @@ -302,6 +304,7 @@ func NewSegment(ctx context.Context, segmentType SegmentType, version int64, loadInfo *querypb.SegmentLoadInfo, + warmupDispatcher *AsyncWarmupDispatcher, ) (Segment, error) { log := log.Ctx(ctx) /* @@ -361,9 +364,10 @@ func NewSegment(ctx context.Context, fields: typeutil.NewConcurrentMap[int64, *FieldInfo](), fieldIndexes: typeutil.NewConcurrentMap[int64, *IndexedFieldInfo](), - memSize: atomic.NewInt64(-1), - rowNum: atomic.NewInt64(-1), - insertCount: atomic.NewInt64(0), + memSize: atomic.NewInt64(-1), + rowNum: atomic.NewInt64(-1), + insertCount: atomic.NewInt64(0), + warmupDispatcher: warmupDispatcher, } if err := segment.initializeSegment(); err != nil { @@ -1200,7 +1204,7 @@ func (s *LocalSegment) WarmupChunkCache(ctx context.Context, fieldID int64, mmap return nil, nil }).Await() case "async": - GetWarmupPool().Submit(func() (any, error) { + task := func() (any, error) { // failed to wait for state update, return directly if !s.ptrLock.BlockUntilDataLoadedOrReleased() { return nil, nil @@ -1220,7 +1224,8 @@ func (s *LocalSegment) WarmupChunkCache(ctx context.Context, fieldID int64, mmap } log.Info("warming up chunk cache asynchronously done") return nil, nil - }) + } + s.warmupDispatcher.AddTask(task) default: // no warming up } @@ -1386,3 +1391,55 @@ func (s *LocalSegment) indexNeedLoadRawData(schema *schemapb.CollectionSchema, i } return !typeutil.IsVectorType(fieldSchema.DataType) && s.HasRawData(indexInfo.IndexInfo.FieldID), nil } + +type ( + WarmupTask = func() (any, error) + AsyncWarmupDispatcher struct { + mu sync.RWMutex + tasks []WarmupTask + notify chan struct{} + } +) + +func NewWarmupDispatcher() *AsyncWarmupDispatcher { + return &AsyncWarmupDispatcher{ + notify: make(chan struct{}, 1), + } +} + +func (d *AsyncWarmupDispatcher) AddTask(task func() (any, error)) { + d.mu.Lock() + d.tasks = append(d.tasks, task) + d.mu.Unlock() + select { + case d.notify <- struct{}{}: + default: + } +} + +func (d *AsyncWarmupDispatcher) Run(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-d.notify: + d.mu.RLock() + tasks := make([]WarmupTask, len(d.tasks)) + copy(tasks, d.tasks) + d.mu.RUnlock() + + for _, task := range tasks { + select { + case <-ctx.Done(): + return + default: + GetWarmupPool().Submit(task) + } + } + + d.mu.Lock() + d.tasks = d.tasks[len(tasks):] + d.mu.Unlock() + } + } +} diff --git a/internal/querynodev2/segments/segment_loader.go b/internal/querynodev2/segments/segment_loader.go index 404d1e5587..d4932d65be 100644 --- a/internal/querynodev2/segments/segment_loader.go +++ b/internal/querynodev2/segments/segment_loader.go @@ -170,12 +170,15 @@ func NewLoader( duf := NewDiskUsageFetcher(ctx) go duf.Start() + warmupDispatcher := NewWarmupDispatcher() + go warmupDispatcher.Run(ctx) loader := &segmentLoader{ manager: manager, cm: cm, loadingSegments: typeutil.NewConcurrentMap[int64, *loadResult](), committedResourceNotifier: syncutil.NewVersionedNotifier(), duf: duf, + warmupDispatcher: warmupDispatcher, } return loader @@ -218,7 +221,8 @@ type segmentLoader struct { committedResource LoadResource committedResourceNotifier *syncutil.VersionedNotifier - duf *diskUsageFetcher + duf *diskUsageFetcher + warmupDispatcher *AsyncWarmupDispatcher } var _ Loader = (*segmentLoader)(nil) @@ -301,6 +305,7 @@ func (loader *segmentLoader) Load(ctx context.Context, segmentType, version, loadInfo, + loader.warmupDispatcher, ) if err != nil { log.Warn("load segment failed when create new segment", diff --git a/internal/querynodev2/segments/segment_test.go b/internal/querynodev2/segments/segment_test.go index 0662928a76..41442abf16 100644 --- a/internal/querynodev2/segments/segment_test.go +++ b/internal/querynodev2/segments/segment_test.go @@ -5,8 +5,11 @@ import ( "fmt" "path/filepath" "testing" + "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + "go.uber.org/atomic" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/mocks/util/mock_segcore" @@ -91,6 +94,7 @@ func (suite *SegmentSuite) SetupTest() { }, }, }, + nil, ) suite.Require().NoError(err) @@ -122,6 +126,7 @@ func (suite *SegmentSuite) SetupTest() { InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID), Level: datapb.SegmentLevel_Legacy, }, + nil, ) suite.Require().NoError(err) @@ -222,3 +227,22 @@ func (suite *SegmentSuite) TestSegmentReleased() { func TestSegment(t *testing.T) { suite.Run(t, new(SegmentSuite)) } + +func TestWarmupDispatcher(t *testing.T) { + d := NewWarmupDispatcher() + ctx := context.Background() + go d.Run(ctx) + + completed := atomic.NewInt64(0) + taskCnt := 10000 + for i := 0; i < taskCnt; i++ { + d.AddTask(func() (any, error) { + completed.Inc() + return nil, nil + }) + } + + assert.Eventually(t, func() bool { + return completed.Load() == int64(taskCnt) + }, 10*time.Second, time.Second) +} diff --git a/internal/querynodev2/server_test.go b/internal/querynodev2/server_test.go index 1448a337cd..3a48501cf8 100644 --- a/internal/querynodev2/server_test.go +++ b/internal/querynodev2/server_test.go @@ -237,6 +237,7 @@ func (suite *QueryNodeSuite) TestStop() { Level: datapb.SegmentLevel_Legacy, InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", 1), }, + nil, ) suite.NoError(err) suite.node.manager.Segment.Put(context.Background(), segments.SegmentTypeSealed, segment)