diff --git a/internal/proxy/mock_test.go b/internal/proxy/mock_test.go index b8c9fb5074..2df773c26b 100644 --- a/internal/proxy/mock_test.go +++ b/internal/proxy/mock_test.go @@ -184,17 +184,6 @@ func (m *mockDmlTask) getChannels() ([]vChan, error) { return m.vchans, nil } -func (m *mockDmlTask) getPChanStats() (map[pChan]pChanStatistics, error) { - ret := make(map[pChan]pChanStatistics) - for _, pchan := range m.pchans { - ret[pchan] = pChanStatistics{ - minTs: m.ts, - maxTs: m.ts, - } - } - return ret, nil -} - func newMockDmlTask(ctx context.Context) *mockDmlTask { shardNum := 2 diff --git a/internal/proxy/task.go b/internal/proxy/task.go index ec49b57f8f..a03330fb53 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -103,7 +103,6 @@ type task interface { type dmlTask interface { task getChannels() ([]pChan, error) - getPChanStats() (map[pChan]pChanStatistics, error) } type BaseInsertTask = msgstream.InsertMsg diff --git a/internal/proxy/task_delete.go b/internal/proxy/task_delete.go index ca5dc92192..276a954926 100644 --- a/internal/proxy/task_delete.go +++ b/internal/proxy/task_delete.go @@ -77,26 +77,6 @@ func (dt *deleteTask) OnEnqueue() error { return nil } -func (dt *deleteTask) getPChanStats() (map[pChan]pChanStatistics, error) { - ret := make(map[pChan]pChanStatistics) - - channels, err := dt.getChannels() - if err != nil { - return ret, err - } - - beginTs := dt.BeginTs() - endTs := dt.EndTs() - - for _, channel := range channels { - ret[channel] = pChanStatistics{ - minTs: beginTs, - maxTs: endTs, - } - } - return ret, nil -} - func (dt *deleteTask) getChannels() ([]pChan, error) { collID, err := globalMetaCache.GetCollectionID(dt.ctx, dt.CollectionName) if err != nil { diff --git a/internal/proxy/task_insert.go b/internal/proxy/task_insert.go index fa41664f65..3b977ab2e3 100644 --- a/internal/proxy/task_insert.go +++ b/internal/proxy/task_insert.go @@ -71,26 +71,6 @@ func (it *insertTask) EndTs() Timestamp { return it.EndTimestamp } -func (it *insertTask) getPChanStats() (map[pChan]pChanStatistics, error) { - ret := make(map[pChan]pChanStatistics) - - channels, err := it.getChannels() - if err != nil { - return ret, err - } - - beginTs := it.BeginTs() - endTs := it.EndTs() - - for _, channel := range channels { - ret[channel] = pChanStatistics{ - minTs: beginTs, - maxTs: endTs, - } - } - return ret, nil -} - func (it *insertTask) getChannels() ([]pChan, error) { collID, err := globalMetaCache.GetCollectionID(it.ctx, it.CollectionName) if err != nil { diff --git a/internal/proxy/task_scheduler.go b/internal/proxy/task_scheduler.go index b599095654..03372a95f9 100644 --- a/internal/proxy/task_scheduler.go +++ b/internal/proxy/task_scheduler.go @@ -228,16 +228,22 @@ type dmTaskQueue struct { func (queue *dmTaskQueue) Enqueue(t task) error { queue.statsLock.Lock() defer queue.statsLock.Unlock() - err := queue.addPChanStats(t) + //1. preAdd will check whether provided task is valid or addable + //and get the current pChannels for this dmTask + pChannels, dmt, err := queue.preAddPChanStats(t) if err != nil { return err } + //2. enqueue dml task err = queue.baseTaskQueue.Enqueue(t) if err != nil { - queue.popPChanStats(t) return err } - + //3. if preAdd succeed, commit will use pChannels got previously when preAdding and will definitely succeed + queue.commitPChanStats(dmt, pChannels) + //there's indeed a possibility that the collection info cache was expired after preAddPChanStats + //but considering root coord knows everything about meta modification, invalid stats appended after the meta changed + //will be discarded by root coord and will not lead to inconsistent state return nil } @@ -258,38 +264,51 @@ func (queue *dmTaskQueue) PopActiveTask(taskID UniqueID) task { return t } -func (queue *dmTaskQueue) addPChanStats(t task) error { +func (queue *dmTaskQueue) preAddPChanStats(t task) ([]pChan, dmlTask, error) { if dmT, ok := t.(dmlTask); ok { - stats, err := dmT.getPChanStats() + channels, err := dmT.getChannels() if err != nil { - log.Warn("Proxy dmTaskQueue addPChanStats", zap.Any("tID", t.ID()), - zap.Any("stats", stats), zap.Error(err)) - return err + log.Warn("Proxy dmTaskQueue preAddPChanStats getChannels failed", zap.Any("tID", t.ID()), + zap.Error(err)) + return nil, nil, err + } + return channels, dmT, nil + } + return nil, nil, fmt.Errorf("proxy preAddPChanStats reflect to dmlTask failed, tID:%v", t.ID()) +} + +func (queue *dmTaskQueue) commitPChanStats(dmt dmlTask, pChannels []pChan) { + //1. prepare new stat for all pChannels + newStats := make(map[pChan]pChanStatistics) + beginTs := dmt.BeginTs() + endTs := dmt.EndTs() + for _, channel := range pChannels { + newStats[channel] = pChanStatistics{ + minTs: beginTs, + maxTs: endTs, + } + } + //2. update stats for all pChannels + for cName, newStat := range newStats { + currentStat, ok := queue.pChanStatisticsInfos[cName] + if !ok { + currentStat = &pChanStatInfo{ + pChanStatistics: newStat, + tsSet: map[Timestamp]struct{}{ + newStat.minTs: {}, + }, + } + queue.pChanStatisticsInfos[cName] = currentStat + } else { + if currentStat.minTs > newStat.minTs { + queue.pChanStatisticsInfos[cName].minTs = newStat.minTs + } + if currentStat.maxTs < newStat.maxTs { + queue.pChanStatisticsInfos[cName].maxTs = newStat.maxTs + } + queue.pChanStatisticsInfos[cName].tsSet[currentStat.minTs] = struct{}{} } - for cName, stat := range stats { - info, ok := queue.pChanStatisticsInfos[cName] - if !ok { - info = &pChanStatInfo{ - pChanStatistics: stat, - tsSet: map[Timestamp]struct{}{ - stat.minTs: {}, - }, - } - queue.pChanStatisticsInfos[cName] = info - } else { - if info.minTs > stat.minTs { - queue.pChanStatisticsInfos[cName].minTs = stat.minTs - } - if info.maxTs < stat.maxTs { - queue.pChanStatisticsInfos[cName].maxTs = stat.maxTs - } - queue.pChanStatisticsInfos[cName].tsSet[info.minTs] = struct{}{} - } - } - } else { - return fmt.Errorf("proxy addUnissuedTask reflect to dmlTask failed, tID:%v", t.ID()) } - return nil } func (queue *dmTaskQueue) popPChanStats(t task) error { diff --git a/internal/proxy/task_scheduler_test.go b/internal/proxy/task_scheduler_test.go index 66d17be78e..3255b8bcff 100644 --- a/internal/proxy/task_scheduler_test.go +++ b/internal/proxy/task_scheduler_test.go @@ -196,11 +196,15 @@ func TestDmTaskQueue_Basic(t *testing.T) { assert.True(t, queue.utEmpty()) assert.False(t, queue.utFull()) + //test wrong task type + dqlTask := newDefaultMockDqlTask() + err = queue.Enqueue(dqlTask) + assert.NotNil(t, err) + st := newDefaultMockDmlTask() stID := st.ID() // no task in queue - unissuedTask = queue.FrontUnissuedTask() assert.Nil(t, unissuedTask)