diff --git a/internal/datanode/writebuffer/manager.go b/internal/datanode/writebuffer/manager.go index 6392a4bec9..b0191f7a9a 100644 --- a/internal/datanode/writebuffer/manager.go +++ b/internal/datanode/writebuffer/manager.go @@ -94,38 +94,40 @@ func (m *bufferManager) memoryCheck() { m.mut.Lock() defer m.mut.Unlock() + for { + var total int64 + var candidate WriteBuffer + var candiSize int64 + var candiChan string - var total int64 - var candidate WriteBuffer - var candiSize int64 - var candiChan string - for chanName, buf := range m.buffers { - size := buf.MemorySize() - total += size - if size > candiSize { - candiSize = size - candidate = buf - candiChan = chanName + toMB := func(mem float64) float64 { + return mem / 1024 / 1024 } - } - toMB := func(mem float64) float64 { - return mem / 1024 / 1024 - } + for chanName, buf := range m.buffers { + size := buf.MemorySize() + total += size + if size > candiSize { + candiSize = size + candidate = buf + candiChan = chanName + } + } - totalMemory := hardware.GetMemoryCount() - memoryWatermark := float64(totalMemory) * paramtable.Get().DataNodeCfg.MemoryForceSyncWatermark.GetAsFloat() - if float64(total) < memoryWatermark { - log.RatedDebug(20, "skip force sync because memory level is not high enough", - zap.Float64("current_total_memory_usage", toMB(float64(total))), - zap.Float64("current_memory_watermark", toMB(memoryWatermark))) - return - } + totalMemory := hardware.GetMemoryCount() + memoryWatermark := float64(totalMemory) * paramtable.Get().DataNodeCfg.MemoryForceSyncWatermark.GetAsFloat() + if float64(total) < memoryWatermark { + log.RatedDebug(20, "skip force sync because memory level is not high enough", + zap.Float64("current_total_memory_usage", toMB(float64(total))), + zap.Float64("current_memory_watermark", toMB(memoryWatermark))) + return + } - if candidate != nil { - candidate.EvictBuffer(GetOldestBufferPolicy(paramtable.Get().DataNodeCfg.MemoryForceSyncSegmentNum.GetAsInt())) - log.Info("notify writebuffer to sync", - zap.String("channel", candiChan), zap.Float64("bufferSize(MB)", toMB(float64(candiSize)))) + if candidate != nil { + candidate.EvictBuffer(GetOldestBufferPolicy(paramtable.Get().DataNodeCfg.MemoryForceSyncSegmentNum.GetAsInt())) + log.Info("notify writebuffer to sync", + zap.String("channel", candiChan), zap.Float64("bufferSize(MB)", toMB(float64(candiSize)))) + } } } diff --git a/internal/datanode/writebuffer/manager_test.go b/internal/datanode/writebuffer/manager_test.go index 910ab5f492..f8c2d7d1ff 100644 --- a/internal/datanode/writebuffer/manager_test.go +++ b/internal/datanode/writebuffer/manager_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" + "go.uber.org/atomic" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" @@ -209,14 +210,22 @@ func (s *ManagerSuite) TestMemoryCheck() { wb := NewMockWriteBuffer(s.T()) + flag := atomic.NewBool(false) memoryLimit := hardware.GetMemoryCount() signal := make(chan struct{}, 1) - wb.EXPECT().MemorySize().Return(int64(float64(memoryLimit) * 0.6)) + wb.EXPECT().MemorySize().RunAndReturn(func() int64 { + if flag.Load() { + return int64(float64(memoryLimit) * 0.4) + } + return int64(float64(memoryLimit) * 0.6) + }) + //.Return(int64(float64(memoryLimit) * 0.6)) wb.EXPECT().EvictBuffer(mock.Anything).Run(func(polices ...SyncPolicy) { select { case signal <- struct{}{}: default: } + flag.Store(true) }).Return() manager.mut.Lock() manager.buffers[s.channelName] = wb diff --git a/internal/datanode/writebuffer/write_buffer.go b/internal/datanode/writebuffer/write_buffer.go index 930c11320e..e1410b5ecc 100644 --- a/internal/datanode/writebuffer/write_buffer.go +++ b/internal/datanode/writebuffer/write_buffer.go @@ -196,7 +196,7 @@ func (wb *writeBufferBase) EvictBuffer(policies ...SyncPolicy) { segmentIDs := wb.getSegmentsToSync(ts, policies...) if len(segmentIDs) > 0 { log.Info("evict buffer find segments to sync", zap.Int64s("segmentIDs", segmentIDs)) - wb.syncSegments(context.Background(), segmentIDs) + conc.AwaitAll(wb.syncSegments(context.Background(), segmentIDs)...) } } @@ -266,6 +266,7 @@ func (wb *writeBufferBase) triggerSync() (segmentIDs []int64) { segmentsToSync := wb.getSegmentsToSync(wb.checkpoint.GetTimestamp(), wb.syncPolicies...) if len(segmentsToSync) > 0 { log.Info("write buffer get segments to sync", zap.Int64s("segmentIDs", segmentsToSync)) + // ignore future here, use callback to handle error wb.syncSegments(context.Background(), segmentsToSync) } @@ -296,8 +297,9 @@ func (wb *writeBufferBase) sealSegments(ctx context.Context, segmentIDs []int64) return nil } -func (wb *writeBufferBase) syncSegments(ctx context.Context, segmentIDs []int64) { +func (wb *writeBufferBase) syncSegments(ctx context.Context, segmentIDs []int64) []*conc.Future[error] { log := log.Ctx(ctx) + result := make([]*conc.Future[error], 0, len(segmentIDs)) for _, segmentID := range segmentIDs { syncTask, err := wb.getSyncTask(ctx, segmentID) if err != nil { @@ -309,9 +311,9 @@ func (wb *writeBufferBase) syncSegments(ctx context.Context, segmentIDs []int64) } } - // discard Future here, handle error in callback - _ = wb.syncMgr.SyncData(ctx, syncTask) + result = append(result, wb.syncMgr.SyncData(ctx, syncTask)) } + return result } // getSegmentsToSync applies all policies to get segments list to sync. diff --git a/internal/datanode/writebuffer/write_buffer_test.go b/internal/datanode/writebuffer/write_buffer_test.go index 2659ea4cfd..95842f2aac 100644 --- a/internal/datanode/writebuffer/write_buffer_test.go +++ b/internal/datanode/writebuffer/write_buffer_test.go @@ -15,6 +15,7 @@ import ( "github.com/milvus-io/milvus/internal/datanode/syncmgr" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -356,7 +357,9 @@ func (s *WriteBufferSuite) TestEvictBuffer() { s.metacache.EXPECT().GetSegmentByID(int64(2)).Return(segment, true) s.metacache.EXPECT().UpdateSegments(mock.Anything, mock.Anything).Return() serializer.EXPECT().EncodeBuffer(mock.Anything, mock.Anything).Return(syncmgr.NewSyncTask(), nil) - s.syncMgr.EXPECT().SyncData(mock.Anything, mock.Anything).Return(nil) + s.syncMgr.EXPECT().SyncData(mock.Anything, mock.Anything).Return(conc.Go[error](func() (error, error) { + return nil, nil + })) defer func() { s.wb.mut.Lock() defer s.wb.mut.Unlock()