diff --git a/internal/datacoord/session_manager.go b/internal/datacoord/session_manager.go index e4449433a2..7d6485a9e7 100644 --- a/internal/datacoord/session_manager.go +++ b/internal/datacoord/session_manager.go @@ -31,6 +31,7 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/retry" "github.com/milvus-io/milvus/pkg/util/tsoutil" @@ -249,15 +250,12 @@ func (c *SessionManager) GetCompactionState() map[int64]*datapb.CompactionStateR commonpbutil.WithSourceID(paramtable.GetNodeID()), ), }) - if err != nil { + + if err := merr.CheckRPCCall(resp, err); err != nil { log.Info("Get State failed", zap.Error(err)) return } - if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - log.Info("Get State failed", zap.String("Reason", resp.GetStatus().GetReason())) - return - } for _, rst := range resp.GetResults() { plans.Insert(rst.PlanID, rst) } @@ -296,6 +294,46 @@ func (c *SessionManager) FlushChannels(ctx context.Context, nodeID int64, req *d return nil } +func (c *SessionManager) NotifyChannelOperation(ctx context.Context, nodeID int64, req *datapb.ChannelOperationsRequest) error { + log := log.Ctx(ctx).With(zap.Int64("nodeID", nodeID)) + cli, err := c.getClient(ctx, nodeID) + if err != nil { + log.Info("failed to get dataNode client", zap.Error(err)) + return err + } + ctx, cancel := context.WithTimeout(ctx, Params.DataCoordCfg.ChannelOperationRPCTimeout.GetAsDuration(time.Second)) + defer cancel() + resp, err := cli.NotifyChannelOperation(ctx, req) + if err := merr.CheckRPCCall(resp, err); err != nil { + log.Warn("Notify channel operations failed", zap.Error(err)) + return err + } + return nil +} + +func (c *SessionManager) CheckChannelOperationProgress(ctx context.Context, nodeID int64, info *datapb.ChannelWatchInfo) (*datapb.ChannelOperationProgressResponse, error) { + log := log.With( + zap.Int64("nodeID", nodeID), + zap.String("channel", info.GetVchan().GetChannelName()), + zap.String("operation", info.GetState().String()), + ) + cli, err := c.getClient(ctx, nodeID) + if err != nil { + log.Info("failed to get dataNode client", zap.Error(err)) + return nil, err + } + + ctx, cancel := context.WithTimeout(ctx, Params.DataCoordCfg.ChannelOperationRPCTimeout.GetAsDuration(time.Second)) + defer cancel() + resp, err := cli.CheckChannelOperationProgress(ctx, info) + if err := merr.CheckRPCCall(resp, err); err != nil { + log.Warn("Check channel operation failed", zap.Error(err)) + return nil, err + } + + return resp, nil +} + func (c *SessionManager) getClient(ctx context.Context, nodeID int64) (types.DataNodeClient, error) { c.sessions.RLock() session, ok := c.sessions.data[nodeID] diff --git a/internal/datacoord/session_manager_test.go b/internal/datacoord/session_manager_test.go new file mode 100644 index 0000000000..0229eec359 --- /dev/null +++ b/internal/datacoord/session_manager_test.go @@ -0,0 +1,117 @@ +package datacoord + +import ( + "context" + "testing" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +func TestSessionManagerSuite(t *testing.T) { + suite.Run(t, new(SessionManagerSuite)) +} + +type SessionManagerSuite struct { + suite.Suite + + dn *mocks.MockDataNodeClient + + m *SessionManager +} + +func (s *SessionManagerSuite) SetupTest() { + s.dn = mocks.NewMockDataNodeClient(s.T()) + + s.m = NewSessionManager(withSessionCreator(func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) { + return s.dn, nil + })) + + s.m.AddSession(&NodeInfo{1000, "addr-1"}) +} + +func (s *SessionManagerSuite) TestNotifyChannelOperation() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + info := &datapb.ChannelWatchInfo{ + Vchan: &datapb.VchannelInfo{}, + State: datapb.ChannelWatchState_ToWatch, + OpID: 1, + } + + req := &datapb.ChannelOperationsRequest{ + Infos: []*datapb.ChannelWatchInfo{info}, + } + s.Run("no node", func() { + err := s.m.NotifyChannelOperation(ctx, 100, req) + s.Error(err) + }) + + s.Run("fail", func() { + s.SetupTest() + s.dn.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything).Return(nil, errors.New("mock")) + + err := s.m.NotifyChannelOperation(ctx, 1000, req) + s.Error(err) + }) + + s.Run("normal", func() { + s.SetupTest() + s.dn.EXPECT().NotifyChannelOperation(mock.Anything, mock.Anything).Return(merr.Status(nil), nil) + + err := s.m.NotifyChannelOperation(ctx, 1000, req) + s.NoError(err) + }) +} + +func (s *SessionManagerSuite) TestCheckCHannelOperationProgress() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + info := &datapb.ChannelWatchInfo{ + Vchan: &datapb.VchannelInfo{}, + State: datapb.ChannelWatchState_ToWatch, + OpID: 1, + } + + s.Run("no node", func() { + resp, err := s.m.CheckChannelOperationProgress(ctx, 100, info) + s.Error(err) + s.Nil(resp) + }) + + s.Run("fail", func() { + s.SetupTest() + s.dn.EXPECT().CheckChannelOperationProgress(mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("mock")) + + resp, err := s.m.CheckChannelOperationProgress(ctx, 1000, info) + s.Error(err) + s.Nil(resp) + }) + + s.Run("normal", func() { + s.SetupTest() + s.dn.EXPECT().CheckChannelOperationProgress(mock.Anything, mock.Anything, mock.Anything). + Return( + &datapb.ChannelOperationProgressResponse{ + Status: merr.Status(nil), + OpID: info.OpID, + State: info.State, + Progress: 100, + }, + nil) + + resp, err := s.m.CheckChannelOperationProgress(ctx, 1000, info) + s.NoError(err) + s.Equal(resp.GetState(), info.State) + s.Equal(resp.OpID, info.OpID) + s.EqualValues(100, resp.Progress) + }) +}