fix: Record the nodeID before assigning tasks (#36371)

issue: #33744

---------

Signed-off-by: Cai Zhang <cai.zhang@zilliz.com>
This commit is contained in:
cai.zhang 2024-09-28 17:21:15 +08:00 committed by GitHub
parent 2adca8b754
commit 7bf40694fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 96 additions and 62 deletions

View File

@ -111,7 +111,7 @@ func (m *analyzeMeta) DropAnalyzeTask(taskID int64) error {
return nil return nil
} }
func (m *analyzeMeta) UpdateVersion(taskID int64) error { func (m *analyzeMeta) UpdateVersion(taskID int64, nodeID int64) error {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
@ -122,11 +122,13 @@ func (m *analyzeMeta) UpdateVersion(taskID int64) error {
cloneT := proto.Clone(t).(*indexpb.AnalyzeTask) cloneT := proto.Clone(t).(*indexpb.AnalyzeTask)
cloneT.Version++ cloneT.Version++
log.Info("update task version", zap.Int64("taskID", taskID), zap.Int64("newVersion", cloneT.Version)) cloneT.NodeID = nodeID
log.Info("update task version", zap.Int64("taskID", taskID), zap.Int64("newVersion", cloneT.Version),
zap.Int64("nodeID", nodeID))
return m.saveTask(cloneT) return m.saveTask(cloneT)
} }
func (m *analyzeMeta) BuildingTask(taskID, nodeID int64) error { func (m *analyzeMeta) BuildingTask(taskID int64) error {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
@ -136,9 +138,8 @@ func (m *analyzeMeta) BuildingTask(taskID, nodeID int64) error {
} }
cloneT := proto.Clone(t).(*indexpb.AnalyzeTask) cloneT := proto.Clone(t).(*indexpb.AnalyzeTask)
cloneT.NodeID = nodeID
cloneT.State = indexpb.JobState_JobStateInProgress cloneT.State = indexpb.JobState_JobStateInProgress
log.Info("task will be building", zap.Int64("taskID", taskID), zap.Int64("nodeID", nodeID)) log.Info("task will be building", zap.Int64("taskID", taskID))
return m.saveTask(cloneT) return m.saveTask(cloneT)
} }

View File

@ -142,13 +142,13 @@ func (s *AnalyzeMetaSuite) Test_AnalyzeMeta() {
}) })
s.Run("UpdateVersion", func() { s.Run("UpdateVersion", func() {
err := am.UpdateVersion(1) err := am.UpdateVersion(1, 1)
s.NoError(err) s.NoError(err)
s.Equal(int64(1), am.GetTask(1).Version) s.Equal(int64(1), am.GetTask(1).Version)
}) })
s.Run("BuildingTask", func() { s.Run("BuildingTask", func() {
err := am.BuildingTask(1, 1) err := am.BuildingTask(1)
s.NoError(err) s.NoError(err)
s.Equal(indexpb.JobState_JobStateInProgress, am.GetTask(1).State) s.Equal(indexpb.JobState_JobStateInProgress, am.GetTask(1).State)
}) })
@ -218,19 +218,19 @@ func (s *AnalyzeMetaSuite) Test_failCase() {
}) })
s.Run("UpdateVersion", func() { s.Run("UpdateVersion", func() {
err := am.UpdateVersion(777) err := am.UpdateVersion(777, 1)
s.Error(err) s.Error(err)
err = am.UpdateVersion(1) err = am.UpdateVersion(1, 1)
s.Error(err) s.Error(err)
s.Equal(int64(0), am.GetTask(1).Version) s.Equal(int64(0), am.GetTask(1).Version)
}) })
s.Run("BuildingTask", func() { s.Run("BuildingTask", func() {
err := am.BuildingTask(777, 1) err := am.BuildingTask(777)
s.Error(err) s.Error(err)
err = am.BuildingTask(1, 1) err = am.BuildingTask(1)
s.Error(err) s.Error(err)
s.Equal(int64(0), am.GetTask(1).NodeID) s.Equal(int64(0), am.GetTask(1).NodeID)
s.Equal(indexpb.JobState_JobStateInit, am.GetTask(1).State) s.Equal(indexpb.JobState_JobStateInit, am.GetTask(1).State)

View File

@ -697,11 +697,11 @@ func (m *indexMeta) IsIndexExist(collID, indexID UniqueID) bool {
} }
// UpdateVersion updates the version and nodeID of the index meta, whenever the task is built once, the version will be updated once. // UpdateVersion updates the version and nodeID of the index meta, whenever the task is built once, the version will be updated once.
func (m *indexMeta) UpdateVersion(buildID UniqueID) error { func (m *indexMeta) UpdateVersion(buildID, nodeID UniqueID) error {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
log.Debug("IndexCoord metaTable UpdateVersion receive", zap.Int64("buildID", buildID)) log.Info("IndexCoord metaTable UpdateVersion receive", zap.Int64("buildID", buildID), zap.Int64("nodeID", nodeID))
segIdx, ok := m.buildID2SegmentIndex[buildID] segIdx, ok := m.buildID2SegmentIndex[buildID]
if !ok { if !ok {
return fmt.Errorf("there is no index with buildID: %d", buildID) return fmt.Errorf("there is no index with buildID: %d", buildID)
@ -709,6 +709,7 @@ func (m *indexMeta) UpdateVersion(buildID UniqueID) error {
updateFunc := func(segIdx *model.SegmentIndex) error { updateFunc := func(segIdx *model.SegmentIndex) error {
segIdx.IndexVersion++ segIdx.IndexVersion++
segIdx.NodeID = nodeID
return m.alterSegmentIndexes([]*model.SegmentIndex{segIdx}) return m.alterSegmentIndexes([]*model.SegmentIndex{segIdx})
} }
@ -771,7 +772,7 @@ func (m *indexMeta) DeleteTask(buildID int64) error {
} }
// BuildIndex set the index state to be InProgress. It means IndexNode is building the index. // BuildIndex set the index state to be InProgress. It means IndexNode is building the index.
func (m *indexMeta) BuildIndex(buildID, nodeID UniqueID) error { func (m *indexMeta) BuildIndex(buildID UniqueID) error {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
@ -781,7 +782,6 @@ func (m *indexMeta) BuildIndex(buildID, nodeID UniqueID) error {
} }
updateFunc := func(segIdx *model.SegmentIndex) error { updateFunc := func(segIdx *model.SegmentIndex) error {
segIdx.NodeID = nodeID
segIdx.IndexState = commonpb.IndexState_InProgress segIdx.IndexState = commonpb.IndexState_InProgress
err := m.alterSegmentIndexes([]*model.SegmentIndex{segIdx}) err := m.alterSegmentIndexes([]*model.SegmentIndex{segIdx})

View File

@ -1247,18 +1247,18 @@ func TestMeta_UpdateVersion(t *testing.T) {
).Return(errors.New("fail")) ).Return(errors.New("fail"))
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
err := m.UpdateVersion(buildID) err := m.UpdateVersion(buildID, nodeID)
assert.NoError(t, err) assert.NoError(t, err)
}) })
t.Run("fail", func(t *testing.T) { t.Run("fail", func(t *testing.T) {
m.catalog = ec m.catalog = ec
err := m.UpdateVersion(buildID) err := m.UpdateVersion(buildID, nodeID)
assert.Error(t, err) assert.Error(t, err)
}) })
t.Run("not exist", func(t *testing.T) { t.Run("not exist", func(t *testing.T) {
err := m.UpdateVersion(buildID + 1) err := m.UpdateVersion(buildID+1, nodeID)
assert.Error(t, err) assert.Error(t, err)
}) })
} }
@ -1315,18 +1315,18 @@ func TestMeta_BuildIndex(t *testing.T) {
).Return(errors.New("fail")) ).Return(errors.New("fail"))
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
err := m.BuildIndex(buildID, nodeID) err := m.BuildIndex(buildID)
assert.NoError(t, err) assert.NoError(t, err)
}) })
t.Run("fail", func(t *testing.T) { t.Run("fail", func(t *testing.T) {
m.catalog = ec m.catalog = ec
err := m.BuildIndex(buildID, nodeID) err := m.BuildIndex(buildID)
assert.Error(t, err) assert.Error(t, err)
}) })
t.Run("not exist", func(t *testing.T) { t.Run("not exist", func(t *testing.T) {
err := m.BuildIndex(buildID+1, nodeID) err := m.BuildIndex(buildID + 1)
assert.Error(t, err) assert.Error(t, err)
}) })
} }

View File

@ -155,7 +155,7 @@ func (stm *statsTaskMeta) DropStatsTask(taskID int64) error {
return nil return nil
} }
func (stm *statsTaskMeta) UpdateVersion(taskID int64) error { func (stm *statsTaskMeta) UpdateVersion(taskID, nodeID int64) error {
stm.Lock() stm.Lock()
defer stm.Unlock() defer stm.Unlock()
@ -166,23 +166,25 @@ func (stm *statsTaskMeta) UpdateVersion(taskID int64) error {
cloneT := proto.Clone(t).(*indexpb.StatsTask) cloneT := proto.Clone(t).(*indexpb.StatsTask)
cloneT.Version++ cloneT.Version++
cloneT.NodeID = nodeID
if err := stm.catalog.SaveStatsTask(stm.ctx, cloneT); err != nil { if err := stm.catalog.SaveStatsTask(stm.ctx, cloneT); err != nil {
log.Warn("update stats task version failed", log.Warn("update stats task version failed",
zap.Int64("taskID", t.GetTaskID()), zap.Int64("taskID", t.GetTaskID()),
zap.Int64("segmentID", t.GetSegmentID()), zap.Int64("segmentID", t.GetSegmentID()),
zap.Int64("nodeID", nodeID),
zap.Error(err)) zap.Error(err))
return err return err
} }
stm.tasks[t.TaskID] = cloneT stm.tasks[t.TaskID] = cloneT
stm.updateMetrics() stm.updateMetrics()
log.Info("update stats task version success", zap.Int64("taskID", taskID), log.Info("update stats task version success", zap.Int64("taskID", taskID), zap.Int64("nodeID", nodeID),
zap.Int64("newVersion", cloneT.GetVersion())) zap.Int64("newVersion", cloneT.GetVersion()))
return nil return nil
} }
func (stm *statsTaskMeta) UpdateBuildingTask(taskID, nodeID int64) error { func (stm *statsTaskMeta) UpdateBuildingTask(taskID int64) error {
stm.Lock() stm.Lock()
defer stm.Unlock() defer stm.Unlock()
@ -192,7 +194,6 @@ func (stm *statsTaskMeta) UpdateBuildingTask(taskID, nodeID int64) error {
} }
cloneT := proto.Clone(t).(*indexpb.StatsTask) cloneT := proto.Clone(t).(*indexpb.StatsTask)
cloneT.NodeID = nodeID
cloneT.State = indexpb.JobState_JobStateInProgress cloneT.State = indexpb.JobState_JobStateInProgress
if err := stm.catalog.SaveStatsTask(stm.ctx, cloneT); err != nil { if err := stm.catalog.SaveStatsTask(stm.ctx, cloneT); err != nil {
@ -206,7 +207,7 @@ func (stm *statsTaskMeta) UpdateBuildingTask(taskID, nodeID int64) error {
stm.tasks[t.TaskID] = cloneT stm.tasks[t.TaskID] = cloneT
stm.updateMetrics() stm.updateMetrics()
log.Info("update building stats task success", zap.Int64("taskID", taskID), zap.Int64("nodeID", nodeID)) log.Info("update building stats task success", zap.Int64("taskID", taskID))
return nil return nil
} }

View File

@ -131,7 +131,7 @@ func (s *statsTaskMetaSuite) Test_Method() {
s.Run("normal case", func() { s.Run("normal case", func() {
catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(nil).Once() catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(nil).Once()
s.NoError(m.UpdateVersion(1)) s.NoError(m.UpdateVersion(1, 1180))
task, ok := m.tasks[1] task, ok := m.tasks[1]
s.True(ok) s.True(ok)
s.Equal(int64(1), task.GetVersion()) s.Equal(int64(1), task.GetVersion())
@ -141,13 +141,13 @@ func (s *statsTaskMetaSuite) Test_Method() {
_, ok := m.tasks[100] _, ok := m.tasks[100]
s.False(ok) s.False(ok)
s.Error(m.UpdateVersion(100)) s.Error(m.UpdateVersion(100, 1180))
}) })
s.Run("failed case", func() { s.Run("failed case", func() {
catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(fmt.Errorf("mock error")).Once() catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(fmt.Errorf("mock error")).Once()
s.Error(m.UpdateVersion(1)) s.Error(m.UpdateVersion(1, 1180))
task, ok := m.tasks[1] task, ok := m.tasks[1]
s.True(ok) s.True(ok)
// still 1 // still 1
@ -159,17 +159,17 @@ func (s *statsTaskMetaSuite) Test_Method() {
s.Run("failed case", func() { s.Run("failed case", func() {
catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(fmt.Errorf("mock error")).Once() catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(fmt.Errorf("mock error")).Once()
s.Error(m.UpdateBuildingTask(1, 1180)) s.Error(m.UpdateBuildingTask(1))
task, ok := m.tasks[1] task, ok := m.tasks[1]
s.True(ok) s.True(ok)
s.Equal(indexpb.JobState_JobStateInit, task.GetState()) s.Equal(indexpb.JobState_JobStateInit, task.GetState())
s.Equal(int64(0), task.GetNodeID()) s.Equal(int64(1180), task.GetNodeID())
}) })
s.Run("normal case", func() { s.Run("normal case", func() {
catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(nil).Once() catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(nil).Once()
s.NoError(m.UpdateBuildingTask(1, 1180)) s.NoError(m.UpdateBuildingTask(1))
task, ok := m.tasks[1] task, ok := m.tasks[1]
s.True(ok) s.True(ok)
s.Equal(indexpb.JobState_JobStateInProgress, task.GetState()) s.Equal(indexpb.JobState_JobStateInProgress, task.GetState())
@ -180,7 +180,7 @@ func (s *statsTaskMetaSuite) Test_Method() {
_, ok := m.tasks[100] _, ok := m.tasks[100]
s.False(ok) s.False(ok)
s.Error(m.UpdateBuildingTask(100, 1180)) s.Error(m.UpdateBuildingTask(100))
}) })
}) })

View File

@ -118,18 +118,21 @@ func (at *analyzeTask) GetFailReason() string {
return at.taskInfo.GetFailReason() return at.taskInfo.GetFailReason()
} }
func (at *analyzeTask) UpdateVersion(ctx context.Context, meta *meta) error { func (at *analyzeTask) UpdateVersion(ctx context.Context, nodeID int64, meta *meta) error {
return meta.analyzeMeta.UpdateVersion(at.GetTaskID()) if err := meta.analyzeMeta.UpdateVersion(at.GetTaskID(), nodeID); err != nil {
}
func (at *analyzeTask) UpdateMetaBuildingState(nodeID int64, meta *meta) error {
if err := meta.analyzeMeta.BuildingTask(at.GetTaskID(), nodeID); err != nil {
return err return err
} }
at.nodeID = nodeID at.nodeID = nodeID
return nil return nil
} }
func (at *analyzeTask) UpdateMetaBuildingState(meta *meta) error {
if err := meta.analyzeMeta.BuildingTask(at.GetTaskID()); err != nil {
return err
}
return nil
}
func (at *analyzeTask) PreCheck(ctx context.Context, dependency *taskScheduler) bool { func (at *analyzeTask) PreCheck(ctx context.Context, dependency *taskScheduler) bool {
t := dependency.meta.analyzeMeta.GetTask(at.GetTaskID()) t := dependency.meta.analyzeMeta.GetTask(at.GetTaskID())
if t == nil { if t == nil {

View File

@ -118,13 +118,16 @@ func (it *indexBuildTask) GetFailReason() string {
return it.taskInfo.FailReason return it.taskInfo.FailReason
} }
func (it *indexBuildTask) UpdateVersion(ctx context.Context, meta *meta) error { func (it *indexBuildTask) UpdateVersion(ctx context.Context, nodeID int64, meta *meta) error {
return meta.indexMeta.UpdateVersion(it.taskID) if err := meta.indexMeta.UpdateVersion(it.taskID, nodeID); err != nil {
return err
}
it.nodeID = nodeID
return nil
} }
func (it *indexBuildTask) UpdateMetaBuildingState(nodeID int64, meta *meta) error { func (it *indexBuildTask) UpdateMetaBuildingState(meta *meta) error {
it.nodeID = nodeID return meta.indexMeta.BuildIndex(it.taskID)
return meta.indexMeta.BuildIndex(it.taskID, nodeID)
} }
func (it *indexBuildTask) PreCheck(ctx context.Context, dependency *taskScheduler) bool { func (it *indexBuildTask) PreCheck(ctx context.Context, dependency *taskScheduler) bool {

View File

@ -384,7 +384,7 @@ func (s *taskScheduler) processInit(task Task) bool {
log.Ctx(s.ctx).Info("pick client success", zap.Int64("taskID", task.GetTaskID()), zap.Int64("nodeID", nodeID)) log.Ctx(s.ctx).Info("pick client success", zap.Int64("taskID", task.GetTaskID()), zap.Int64("nodeID", nodeID))
// 2. update version // 2. update version
if err := task.UpdateVersion(s.ctx, s.meta); err != nil { if err := task.UpdateVersion(s.ctx, nodeID, s.meta); err != nil {
log.Ctx(s.ctx).Warn("update task version failed", zap.Int64("taskID", task.GetTaskID()), zap.Error(err)) log.Ctx(s.ctx).Warn("update task version failed", zap.Int64("taskID", task.GetTaskID()), zap.Error(err))
return false return false
} }
@ -402,7 +402,7 @@ func (s *taskScheduler) processInit(task Task) bool {
log.Ctx(s.ctx).Info("assign task to client success", zap.Int64("taskID", task.GetTaskID()), zap.Int64("nodeID", nodeID)) log.Ctx(s.ctx).Info("assign task to client success", zap.Int64("taskID", task.GetTaskID()), zap.Int64("nodeID", nodeID))
// 4. update meta state // 4. update meta state
if err := task.UpdateMetaBuildingState(nodeID, s.meta); err != nil { if err := task.UpdateMetaBuildingState(s.meta); err != nil {
log.Ctx(s.ctx).Warn("update meta building state failed", zap.Int64("taskID", task.GetTaskID()), zap.Error(err)) log.Ctx(s.ctx).Warn("update meta building state failed", zap.Int64("taskID", task.GetTaskID()), zap.Error(err))
task.SetState(indexpb.JobState_JobStateRetry, "update meta building state failed") task.SetState(indexpb.JobState_JobStateRetry, "update meta building state failed")
return false return false

View File

@ -1256,6 +1256,7 @@ func (s *taskSchedulerSuite) Test_analyzeTaskFailCase() {
func (s *taskSchedulerSuite) Test_indexTaskFailCase() { func (s *taskSchedulerSuite) Test_indexTaskFailCase() {
s.Run("HNSW", func() { s.Run("HNSW", func() {
ctx := context.Background() ctx := context.Background()
indexNodeTasks := make(map[int64]int)
catalog := catalogmocks.NewDataCoordCatalog(s.T()) catalog := catalogmocks.NewDataCoordCatalog(s.T())
in := mocks.NewMockIndexNodeClient(s.T()) in := mocks.NewMockIndexNodeClient(s.T())
@ -1353,10 +1354,19 @@ func (s *taskSchedulerSuite) Test_indexTaskFailCase() {
// assign failed --> retry // assign failed --> retry
workerManager.EXPECT().PickClient().Return(s.nodeID, in).Once() workerManager.EXPECT().PickClient().Return(s.nodeID, in).Once()
catalog.EXPECT().AlterSegmentIndexes(mock.Anything, mock.Anything).Return(nil).Once() catalog.EXPECT().AlterSegmentIndexes(mock.Anything, mock.Anything).Return(nil).Once()
in.EXPECT().CreateJobV2(mock.Anything, mock.Anything).Return(nil, errors.New("mock error")).Once() in.EXPECT().CreateJobV2(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, request *workerpb.CreateJobV2Request, option ...grpc.CallOption) (*commonpb.Status, error) {
indexNodeTasks[request.GetTaskID()]++
return nil, errors.New("mock error")
}).Once()
// retry --> init // retry --> init
workerManager.EXPECT().GetClientByID(mock.Anything).Return(nil, false).Once() workerManager.EXPECT().GetClientByID(mock.Anything).Return(in, true).Once()
in.EXPECT().DropJobsV2(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, request *workerpb.DropJobsV2Request, option ...grpc.CallOption) (*commonpb.Status, error) {
for _, taskID := range request.GetTaskIDs() {
indexNodeTasks[taskID]--
}
return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil
}).Once()
// init --> inProgress // init --> inProgress
workerManager.EXPECT().PickClient().Return(s.nodeID, in).Once() workerManager.EXPECT().PickClient().Return(s.nodeID, in).Once()
@ -1370,7 +1380,10 @@ func (s *taskSchedulerSuite) Test_indexTaskFailCase() {
}, },
}, },
}, nil).Once() }, nil).Once()
in.EXPECT().CreateJobV2(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil).Once() in.EXPECT().CreateJobV2(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, request *workerpb.CreateJobV2Request, option ...grpc.CallOption) (*commonpb.Status, error) {
indexNodeTasks[request.GetTaskID()]++
return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil
}).Once()
// inProgress --> Finished // inProgress --> Finished
workerManager.EXPECT().GetClientByID(mock.Anything).Return(in, true).Once() workerManager.EXPECT().GetClientByID(mock.Anything).Return(in, true).Once()
@ -1393,7 +1406,13 @@ func (s *taskSchedulerSuite) Test_indexTaskFailCase() {
// finished --> done // finished --> done
catalog.EXPECT().AlterSegmentIndexes(mock.Anything, mock.Anything).Return(nil).Once() catalog.EXPECT().AlterSegmentIndexes(mock.Anything, mock.Anything).Return(nil).Once()
workerManager.EXPECT().GetClientByID(mock.Anything).Return(nil, false).Once() workerManager.EXPECT().GetClientByID(mock.Anything).Return(in, true).Once()
in.EXPECT().DropJobsV2(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, request *workerpb.DropJobsV2Request, option ...grpc.CallOption) (*commonpb.Status, error) {
for _, taskID := range request.GetTaskIDs() {
indexNodeTasks[taskID]--
}
return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil
}).Once()
for { for {
scheduler.RLock() scheduler.RLock()
@ -1411,6 +1430,10 @@ func (s *taskSchedulerSuite) Test_indexTaskFailCase() {
indexJob, exist := mt.indexMeta.GetIndexJob(buildID) indexJob, exist := mt.indexMeta.GetIndexJob(buildID)
s.True(exist) s.True(exist)
s.Equal(commonpb.IndexState_Finished, indexJob.IndexState) s.Equal(commonpb.IndexState_Finished, indexJob.IndexState)
for _, v := range indexNodeTasks {
s.Zero(v)
}
}) })
} }

View File

@ -127,7 +127,7 @@ func (st *statsTask) GetFailReason() string {
return st.taskInfo.GetFailReason() return st.taskInfo.GetFailReason()
} }
func (st *statsTask) UpdateVersion(ctx context.Context, meta *meta) error { func (st *statsTask) UpdateVersion(ctx context.Context, nodeID int64, meta *meta) error {
// mark compacting // mark compacting
if exist, canDo := meta.CheckAndSetSegmentsCompacting([]UniqueID{st.segmentID}); !exist || !canDo { if exist, canDo := meta.CheckAndSetSegmentsCompacting([]UniqueID{st.segmentID}); !exist || !canDo {
log.Warn("segment is not exist or is compacting, skip stats", log.Warn("segment is not exist or is compacting, skip stats",
@ -136,12 +136,15 @@ func (st *statsTask) UpdateVersion(ctx context.Context, meta *meta) error {
return fmt.Errorf("mark segment compacting failed, isCompacting: %v", !canDo) return fmt.Errorf("mark segment compacting failed, isCompacting: %v", !canDo)
} }
return meta.statsTaskMeta.UpdateVersion(st.taskID) if err := meta.statsTaskMeta.UpdateVersion(st.taskID, nodeID); err != nil {
return err
}
st.nodeID = nodeID
return nil
} }
func (st *statsTask) UpdateMetaBuildingState(nodeID int64, meta *meta) error { func (st *statsTask) UpdateMetaBuildingState(meta *meta) error {
st.nodeID = nodeID return meta.statsTaskMeta.UpdateBuildingTask(st.taskID)
return meta.statsTaskMeta.UpdateBuildingTask(st.taskID, nodeID)
} }
func (st *statsTask) PreCheck(ctx context.Context, dependency *taskScheduler) bool { func (st *statsTask) PreCheck(ctx context.Context, dependency *taskScheduler) bool {

View File

@ -163,21 +163,21 @@ func (s *statsTaskSuite) TestTaskStats_PreCheck() {
s.Run("segment is compacting", func() { s.Run("segment is compacting", func() {
s.mt.segments.segments[s.segID].isCompacting = true s.mt.segments.segments[s.segID].isCompacting = true
s.Error(st.UpdateVersion(context.Background(), s.mt)) s.Error(st.UpdateVersion(context.Background(), 1, s.mt))
}) })
s.Run("normal case", func() { s.Run("normal case", func() {
s.mt.segments.segments[s.segID].isCompacting = false s.mt.segments.segments[s.segID].isCompacting = false
catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(nil).Once() catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(nil).Once()
s.NoError(st.UpdateVersion(context.Background(), s.mt)) s.NoError(st.UpdateVersion(context.Background(), 1, s.mt))
}) })
s.Run("failed case", func() { s.Run("failed case", func() {
s.mt.segments.segments[s.segID].isCompacting = false s.mt.segments.segments[s.segID].isCompacting = false
catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(fmt.Errorf("error")).Once() catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(fmt.Errorf("error")).Once()
s.Error(st.UpdateVersion(context.Background(), s.mt)) s.Error(st.UpdateVersion(context.Background(), 1, s.mt))
}) })
}) })
@ -187,12 +187,12 @@ func (s *statsTaskSuite) TestTaskStats_PreCheck() {
s.Run("normal case", func() { s.Run("normal case", func() {
catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(nil).Once() catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(nil).Once()
s.NoError(st.UpdateMetaBuildingState(1, s.mt)) s.NoError(st.UpdateMetaBuildingState(s.mt))
}) })
s.Run("update error", func() { s.Run("update error", func() {
catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(fmt.Errorf("error")).Once() catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(fmt.Errorf("error")).Once()
s.Error(st.UpdateMetaBuildingState(1, s.mt)) s.Error(st.UpdateMetaBuildingState(s.mt))
}) })
}) })

View File

@ -33,8 +33,8 @@ type Task interface {
SetState(state indexpb.JobState, failReason string) SetState(state indexpb.JobState, failReason string)
GetState() indexpb.JobState GetState() indexpb.JobState
GetFailReason() string GetFailReason() string
UpdateVersion(ctx context.Context, meta *meta) error UpdateVersion(ctx context.Context, nodeID int64, meta *meta) error
UpdateMetaBuildingState(nodeID int64, meta *meta) error UpdateMetaBuildingState(meta *meta) error
AssignTask(ctx context.Context, client types.IndexNodeClient) bool AssignTask(ctx context.Context, client types.IndexNodeClient) bool
QueryResult(ctx context.Context, client types.IndexNodeClient) QueryResult(ctx context.Context, client types.IndexNodeClient)
DropTaskOnWorker(ctx context.Context, client types.IndexNodeClient) bool DropTaskOnWorker(ctx context.Context, client types.IndexNodeClient) bool