diff --git a/internal/datanode/compaction/clustering_compactor.go b/internal/datanode/compaction/clustering_compactor.go index 93bc26ec8a..90deb75846 100644 --- a/internal/datanode/compaction/clustering_compactor.go +++ b/internal/datanode/compaction/clustering_compactor.go @@ -450,9 +450,7 @@ func (t *clusteringCompactionTask) mapping(ctx context.Context, func (t *clusteringCompactionTask) getBufferTotalUsedMemorySize() int64 { var totalBufferSize int64 = 0 for _, buffer := range t.clusterBuffers { - t.clusterBufferLocks.RLock(buffer.id) totalBufferSize = totalBufferSize + int64(buffer.writer.WrittenMemorySize()) + buffer.bufferMemorySize.Load() - t.clusterBufferLocks.RUnlock(buffer.id) } return totalBufferSize } @@ -585,30 +583,35 @@ func (t *clusteringCompactionTask) mappingSegment( remained++ if (remained+1)%100 == 0 { - t.clusterBufferLocks.RLock(clusterBuffer.id) - currentBufferWriterFull := clusterBuffer.writer.IsFull() - t.clusterBufferLocks.RUnlock(clusterBuffer.id) - currentBufferTotalMemorySize := t.getBufferTotalUsedMemorySize() - - currentSegmentNumRows := clusterBuffer.currentSegmentRowNum.Load() - if currentSegmentNumRows > t.plan.GetMaxSegmentRows() || currentBufferWriterFull { + if clusterBuffer.currentSegmentRowNum.Load() > t.plan.GetMaxSegmentRows() || clusterBuffer.writer.IsFull() { // reach segment/binlog max size - t.clusterBufferLocks.Lock(clusterBuffer.id) - writer := clusterBuffer.writer - pack, _ := t.refreshBufferWriterWithPack(clusterBuffer) - log.Debug("buffer need to flush", zap.Int("bufferID", clusterBuffer.id), - zap.Bool("pack", pack), - zap.Int64("current segment", writer.GetSegmentID()), - zap.Int64("current segment num rows", currentSegmentNumRows), - zap.Int64("writer num", writer.GetRowNum())) - t.clusterBufferLocks.Unlock(clusterBuffer.id) + flushWriterFunc := func() { + t.clusterBufferLocks.Lock(clusterBuffer.id) + currentSegmentNumRows := clusterBuffer.currentSegmentRowNum.Load() + // double-check the condition is still met + if currentSegmentNumRows > t.plan.GetMaxSegmentRows() || clusterBuffer.writer.IsFull() { + writer := clusterBuffer.writer + pack, _ := t.refreshBufferWriterWithPack(clusterBuffer) + log.Debug("buffer need to flush", zap.Int("bufferID", clusterBuffer.id), + zap.Bool("pack", pack), + zap.Int64("current segment", writer.GetSegmentID()), + zap.Int64("current segment num rows", currentSegmentNumRows), + zap.Int64("writer num", writer.GetRowNum())) - t.flushChan <- FlushSignal{ - writer: writer, - pack: pack, - id: clusterBuffer.id, + t.clusterBufferLocks.Unlock(clusterBuffer.id) + // release the lock before sending the signal, avoid long wait caused by a full channel. + t.flushChan <- FlushSignal{ + writer: writer, + pack: pack, + id: clusterBuffer.id, + } + return + } + // release the lock even if the conditions are no longer met. + t.clusterBufferLocks.Unlock(clusterBuffer.id) } + flushWriterFunc() } else if currentBufferTotalMemorySize > t.getMemoryBufferHighWatermark() && !t.hasSignal.Load() { // reach flushBinlog trigger threshold log.Debug("largest buffer need to flush", @@ -618,7 +621,7 @@ func (t *clusteringCompactionTask) mappingSegment( } // if the total buffer size is too large, block here, wait for memory release by flushBinlog - if currentBufferTotalMemorySize > t.getMemoryBufferBlockFlushThreshold() { + if t.getBufferTotalUsedMemorySize() > t.getMemoryBufferBlockFlushThreshold() { log.Debug("memory is already above the block watermark, pause writing", zap.Int64("currentBufferTotalMemorySize", currentBufferTotalMemorySize)) loop: diff --git a/internal/storage/serde.go b/internal/storage/serde.go index 499e32314b..d7e8ca2808 100644 --- a/internal/storage/serde.go +++ b/internal/storage/serde.go @@ -34,6 +34,7 @@ import ( "github.com/apache/arrow/go/v12/parquet/pqarrow" "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" + "go.uber.org/atomic" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/common" @@ -797,7 +798,7 @@ type SerializeWriter[T any] struct { buffer []T pos int - writtenMemorySize uint64 + writtenMemorySize atomic.Uint64 } func (sw *SerializeWriter[T]) Flush() error { @@ -816,7 +817,7 @@ func (sw *SerializeWriter[T]) Flush() error { return err } sw.pos = 0 - sw.writtenMemorySize += size + sw.writtenMemorySize.Add(size) return nil } @@ -835,7 +836,7 @@ func (sw *SerializeWriter[T]) Write(value T) error { } func (sw *SerializeWriter[T]) WrittenMemorySize() uint64 { - return sw.writtenMemorySize + return sw.writtenMemorySize.Load() } func (sw *SerializeWriter[T]) Close() error {