diff --git a/internal/proxy/channels_time_ticker.go b/internal/proxy/channels_time_ticker.go index 0ebc39803b..9744f9e1fe 100644 --- a/internal/proxy/channels_time_ticker.go +++ b/internal/proxy/channels_time_ticker.go @@ -24,6 +24,7 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus/internal/log" + "github.com/milvus-io/milvus/internal/util/typeutil" ) // ticker can update ts only when the minTs are greater than the ts of ticker, we can use maxTs to update current later @@ -132,6 +133,11 @@ func (ticker *channelsTimeTickerImpl) tick() error { } for pchan, value := range stats { + if value.minTs == typeutil.ZeroTimestamp { + log.Warn("channelsTimeTickerImpl.tick, stats contains physical channel which min ts is zero ", + zap.String("pchan", pchan)) + continue + } _, ok := ticker.currents[pchan] if !ok { ticker.minTsStatistics[pchan] = value.minTs - 1 diff --git a/internal/proxy/mock_test.go b/internal/proxy/mock_test.go index 2258607b4f..bf85bedf45 100644 --- a/internal/proxy/mock_test.go +++ b/internal/proxy/mock_test.go @@ -65,10 +65,16 @@ func newMockTimestampAllocatorInterface() timestampAllocatorInterface { } type mockTsoAllocator struct { + mu sync.Mutex + logicPart uint32 } func (tso *mockTsoAllocator) AllocOne(ctx context.Context) (Timestamp, error) { - return Timestamp(time.Now().UnixNano()), nil + tso.mu.Lock() + defer tso.mu.Unlock() + tso.logicPart++ + physical := uint64(time.Now().UnixMilli()) + return (physical << 18) + uint64(tso.logicPart), nil } func newMockTsoAllocator() tsoAllocator { diff --git a/internal/proxy/task_scheduler.go b/internal/proxy/task_scheduler.go index ba69215d8b..6ffe376e4b 100644 --- a/internal/proxy/task_scheduler.go +++ b/internal/proxy/task_scheduler.go @@ -219,13 +219,15 @@ type pChanStatInfo struct { type dmTaskQueue struct { *baseTaskQueue - lock sync.Mutex statsLock sync.RWMutex pChanStatisticsInfos map[pChan]*pChanStatInfo } func (queue *dmTaskQueue) Enqueue(t task) error { + // This statsLock has two functions: + // 1) Protect member pChanStatisticsInfos + // 2) Serialize the timestamp allocation for dml tasks queue.statsLock.Lock() defer queue.statsLock.Unlock() //1. preAdd will check whether provided task is valid or addable @@ -301,12 +303,12 @@ func (queue *dmTaskQueue) commitPChanStats(dmt dmlTask, pChannels []pChan) { queue.pChanStatisticsInfos[cName] = currentStat } else { if currentStat.minTs > newStat.minTs { - queue.pChanStatisticsInfos[cName].minTs = newStat.minTs + currentStat.minTs = newStat.minTs } if currentStat.maxTs < newStat.maxTs { - queue.pChanStatisticsInfos[cName].maxTs = newStat.maxTs + currentStat.maxTs = newStat.maxTs } - queue.pChanStatisticsInfos[cName].tsSet[currentStat.minTs] = struct{}{} + currentStat.tsSet[newStat.minTs] = struct{}{} } } } @@ -317,20 +319,21 @@ func (queue *dmTaskQueue) popPChanStats(t task) error { if err != nil { return err } + taskTs := t.BeginTs() for _, cName := range channels { info, ok := queue.pChanStatisticsInfos[cName] if ok { - delete(queue.pChanStatisticsInfos[cName].tsSet, info.minTs) - if len(queue.pChanStatisticsInfos[cName].tsSet) <= 0 { + delete(info.tsSet, taskTs) + if len(info.tsSet) <= 0 { delete(queue.pChanStatisticsInfos, cName) - } else if queue.pChanStatisticsInfos[cName].minTs == info.minTs { - minTs := info.maxTs - for ts := range queue.pChanStatisticsInfos[cName].tsSet { - if ts < minTs { - minTs = ts + } else { + newMinTs := info.maxTs + for ts := range info.tsSet { + if newMinTs > ts { + newMinTs = ts } } - queue.pChanStatisticsInfos[cName].minTs = minTs + info.minTs = newMinTs } } } diff --git a/internal/proxy/task_scheduler_test.go b/internal/proxy/task_scheduler_test.go index 5b7d3a5b15..83c7b57f5b 100644 --- a/internal/proxy/task_scheduler_test.go +++ b/internal/proxy/task_scheduler_test.go @@ -18,9 +18,13 @@ package proxy import ( "context" + "fmt" "math/rand" "sync" "testing" + "time" + + "github.com/milvus-io/milvus/internal/util/funcutil" "github.com/stretchr/testify/assert" ) @@ -300,6 +304,115 @@ func TestDmTaskQueue_TimestampStatistics(t *testing.T) { assert.Zero(t, len(stats)) } +// test the timestamp statistics +func TestDmTaskQueue_TimestampStatistics2(t *testing.T) { + tsoAllocatorIns := newMockTsoAllocator() + queue := newDmTaskQueue(tsoAllocatorIns) + assert.NotNil(t, queue) + + prefix := funcutil.GenRandomStr() + insertNum := 100 + + var processWg sync.WaitGroup + processWg.Add(1) + processCtx, processCancel := context.WithCancel(context.TODO()) + processCount := insertNum + var processCountMut sync.RWMutex + go func() { + defer processWg.Done() + var workerWg sync.WaitGroup + workerWg.Add(insertNum) + for processCtx.Err() == nil { + if queue.utEmpty() { + continue + } + utTask := queue.PopUnissuedTask() + go func(ut task) { + defer workerWg.Done() + assert.NotNil(t, ut) + queue.AddActiveTask(ut) + dur := time.Duration(50+rand.Int()%10) * time.Millisecond + time.Sleep(dur) + queue.PopActiveTask(ut.ID()) + processCountMut.Lock() + defer processCountMut.Unlock() + processCount-- + }(utTask) + } + workerWg.Wait() + }() + + var currPChanStats map[pChan]*pChanStatistics + var wgSchedule sync.WaitGroup + scheduleCtx, scheduleCancel := context.WithCancel(context.TODO()) + schedule := func() { + defer wgSchedule.Done() + ticker := time.NewTicker(time.Millisecond * 10) + defer ticker.Stop() + for { + select { + case <-scheduleCtx.Done(): + return + case <-ticker.C: + stats, err := queue.getPChanStatsInfo() + assert.Nil(t, err) + if currPChanStats == nil { + currPChanStats = stats + } else { + // assure minTs and maxTs will not go back + for p, stat := range stats { + curInfo, ok := currPChanStats[p] + if ok { + fmt.Println("stat.minTs", stat.minTs, " ", "curInfo.minTs:", curInfo.minTs) + fmt.Println("stat.maxTs", stat.maxTs, " ", "curInfo.minTs:", curInfo.maxTs) + assert.True(t, stat.minTs >= curInfo.minTs) + curInfo.minTs = stat.minTs + assert.True(t, stat.maxTs >= curInfo.maxTs) + curInfo.maxTs = stat.maxTs + } + } + } + } + } + } + wgSchedule.Add(1) + go schedule() + + var wg sync.WaitGroup + wg.Add(insertNum) + for i := 0; i < insertNum; i++ { + go func() { + defer wg.Done() + time.Sleep(time.Millisecond) + st := newDefaultMockDmlTask() + vChannels := make([]string, 2) + vChannels[0] = prefix + "_1" + vChannels[1] = prefix + "_2" + st.vchans = vChannels + st.pchans = vChannels + err := queue.Enqueue(st) + assert.NoError(t, err) + }() + } + wg.Wait() + //time.Sleep(time.Millisecond*100) + needLoop := true + for needLoop { + processCountMut.RLock() + needLoop = processCount != 0 + processCountMut.RUnlock() + } + processCancel() + processWg.Wait() + + scheduleCancel() + wgSchedule.Wait() + + stats, err := queue.getPChanStatsInfo() + assert.Nil(t, err) + assert.Zero(t, len(stats)) +} + func TestDqTaskQueue(t *testing.T) { var err error