diff --git a/internal/datacoord/server_test.go b/internal/datacoord/server_test.go index 4943ae9886..918a2e3a62 100644 --- a/internal/datacoord/server_test.go +++ b/internal/datacoord/server_test.go @@ -3312,43 +3312,96 @@ func TestDataCoord_UnsetIsImportingState(t *testing.T) { func TestDataCoordServer_UpdateChannelCheckpoint(t *testing.T) { mockVChannel := "fake-by-dev-rootcoord-dml-1-testchannelcp-v0" - mockPChannel := "fake-by-dev-rootcoord-dml-1" - t.Run("UpdateChannelCheckpoint", func(t *testing.T) { + t.Run("UpdateChannelCheckpoint_Success", func(t *testing.T) { svr := newTestServer(t, nil) defer closeTestServer(t, svr) + datanodeID := int64(1) + channelManager := NewMockChannelManager(t) + channelManager.EXPECT().Match(datanodeID, mockVChannel).Return(true) + + svr.channelManager = channelManager req := &datapb.UpdateChannelCheckpointRequest{ Base: &commonpb.MsgBase{ - SourceID: paramtable.GetNodeID(), + SourceID: datanodeID, }, VChannel: mockVChannel, Position: &msgpb.MsgPosition{ - ChannelName: mockPChannel, + ChannelName: mockVChannel, Timestamp: 1000, MsgID: []byte{0, 0, 0, 0, 0, 0, 0, 0}, }, } resp, err := svr.UpdateChannelCheckpoint(context.TODO(), req) - assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.ErrorCode) + assert.NoError(t, merr.CheckRPCCall(resp, err)) + + cp := svr.meta.GetChannelCheckpoint(mockVChannel) + assert.NotNil(t, cp) + svr.meta.DropChannelCheckpoint(mockVChannel) req = &datapb.UpdateChannelCheckpointRequest{ Base: &commonpb.MsgBase{ - SourceID: paramtable.GetNodeID(), + SourceID: datanodeID, }, VChannel: mockVChannel, ChannelCheckpoints: []*msgpb.MsgPosition{{ - ChannelName: mockPChannel, + ChannelName: mockVChannel, Timestamp: 1000, MsgID: []byte{0, 0, 0, 0, 0, 0, 0, 0}, }}, } resp, err = svr.UpdateChannelCheckpoint(context.TODO(), req) - assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.ErrorCode) + assert.NoError(t, merr.CheckRPCCall(resp, err)) + cp = svr.meta.GetChannelCheckpoint(mockVChannel) + assert.NotNil(t, cp) + }) + + t.Run("UpdateChannelCheckpoint_NodeNotMatch", func(t *testing.T) { + svr := newTestServer(t, nil) + defer closeTestServer(t, svr) + + datanodeID := int64(1) + channelManager := NewMockChannelManager(t) + channelManager.EXPECT().Match(datanodeID, mockVChannel).Return(false) + + svr.channelManager = channelManager + req := &datapb.UpdateChannelCheckpointRequest{ + Base: &commonpb.MsgBase{ + SourceID: datanodeID, + }, + VChannel: mockVChannel, + Position: &msgpb.MsgPosition{ + ChannelName: mockVChannel, + Timestamp: 1000, + MsgID: []byte{0, 0, 0, 0, 0, 0, 0, 0}, + }, + } + + resp, err := svr.UpdateChannelCheckpoint(context.TODO(), req) + assert.Error(t, merr.CheckRPCCall(resp, err)) + assert.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrChannelNotFound) + cp := svr.meta.GetChannelCheckpoint(mockVChannel) + assert.Nil(t, cp) + + req = &datapb.UpdateChannelCheckpointRequest{ + Base: &commonpb.MsgBase{ + SourceID: datanodeID, + }, + VChannel: mockVChannel, + ChannelCheckpoints: []*msgpb.MsgPosition{{ + ChannelName: mockVChannel, + Timestamp: 1000, + MsgID: []byte{0, 0, 0, 0, 0, 0, 0, 0}, + }}, + } + + resp, err = svr.UpdateChannelCheckpoint(context.TODO(), req) + assert.NoError(t, merr.CheckRPCCall(resp, err)) + cp = svr.meta.GetChannelCheckpoint(mockVChannel) + assert.Nil(t, cp) }) } diff --git a/internal/datacoord/services.go b/internal/datacoord/services.go index af413fca34..7ca58f1fd2 100644 --- a/internal/datacoord/services.go +++ b/internal/datacoord/services.go @@ -1465,8 +1465,14 @@ func (s *Server) UpdateChannelCheckpoint(ctx context.Context, req *datapb.Update return merr.Status(err), nil } + nodeID := req.GetBase().GetSourceID() // For compatibility with old client if req.GetVChannel() != "" && req.GetPosition() != nil { + channel := req.GetVChannel() + if !s.channelManager.Match(nodeID, channel) { + log.Warn("node is not matched with channel", zap.String("channel", channel), zap.Int64("nodeID", nodeID)) + return merr.Status(merr.WrapErrChannelNotFound(channel, fmt.Sprintf("from node %d", nodeID))), nil + } err := s.meta.UpdateChannelCheckpoint(req.GetVChannel(), req.GetPosition()) if err != nil { log.Warn("failed to UpdateChannelCheckpoint", zap.String("vChannel", req.GetVChannel()), zap.Error(err)) @@ -1475,7 +1481,16 @@ func (s *Server) UpdateChannelCheckpoint(ctx context.Context, req *datapb.Update return merr.Success(), nil } - err := s.meta.UpdateChannelCheckpoints(req.GetChannelCheckpoints()) + checkpoints := lo.Filter(req.GetChannelCheckpoints(), func(cp *msgpb.MsgPosition, _ int) bool { + channel := cp.GetChannelName() + matched := s.channelManager.Match(nodeID, channel) + if !matched { + log.Warn("node is not matched with channel", zap.String("channel", channel), zap.Int64("nodeID", nodeID)) + } + return matched + }) + + err := s.meta.UpdateChannelCheckpoints(checkpoints) if err != nil { log.Warn("failed to update channel checkpoint", zap.Error(err)) return merr.Status(err), nil