diff --git a/internal/distributed/rootcoord/service.go b/internal/distributed/rootcoord/service.go index 7fbf18e1cf..983986cb1c 100644 --- a/internal/distributed/rootcoord/service.go +++ b/internal/distributed/rootcoord/service.go @@ -347,6 +347,12 @@ func (s *Server) Stop() (err error) { defer s.tikvCli.Close() } + if s.rootCoord != nil { + log.Info("graceful stop rootCoord") + s.rootCoord.GracefulStop() + log.Info("graceful stop rootCoord done") + } + if s.grpcServer != nil { utils.GracefulStopGRPCServer(s.grpcServer) } diff --git a/internal/distributed/rootcoord/service_test.go b/internal/distributed/rootcoord/service_test.go index 9e7b5115b7..917fa9d836 100644 --- a/internal/distributed/rootcoord/service_test.go +++ b/internal/distributed/rootcoord/service_test.go @@ -118,6 +118,9 @@ func (m *mockCore) Stop() error { return fmt.Errorf("stop error") } +func (m *mockCore) GracefulStop() { +} + func TestRun(t *testing.T) { paramtable.Init() parameters := []string{"tikv", "etcd"} diff --git a/internal/metastore/kv/streamingnode/kv_catalog.go b/internal/metastore/kv/streamingnode/kv_catalog.go index 9ba73784df..1ac42dd39f 100644 --- a/internal/metastore/kv/streamingnode/kv_catalog.go +++ b/internal/metastore/kv/streamingnode/kv_catalog.go @@ -110,7 +110,7 @@ func (c *catalog) GetConsumeCheckpoint(ctx context.Context, pchannelName string) return nil, err } val := &streamingpb.WALCheckpoint{} - if err = proto.Unmarshal([]byte(value), &streamingpb.WALCheckpoint{}); err != nil { + if err = proto.Unmarshal([]byte(value), val); err != nil { return nil, err } return val, nil diff --git a/internal/metastore/kv/streamingnode/kv_catalog_test.go b/internal/metastore/kv/streamingnode/kv_catalog_test.go index 177726585f..2c6117b44f 100644 --- a/internal/metastore/kv/streamingnode/kv_catalog_test.go +++ b/internal/metastore/kv/streamingnode/kv_catalog_test.go @@ -4,15 +4,52 @@ import ( "context" "testing" + "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus/internal/kv/mocks" "github.com/milvus-io/milvus/pkg/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/util/merr" ) -func TestCatalog(t *testing.T) { +func TestCatalogConsumeCheckpoint(t *testing.T) { + kv := mocks.NewMetaKv(t) + v := streamingpb.WALCheckpoint{} + vs, err := proto.Marshal(&v) + assert.NoError(t, err) + + kv.EXPECT().Load(mock.Anything, mock.Anything).Return(string(vs), nil) + catalog := NewCataLog(kv) + ctx := context.Background() + checkpoint, err := catalog.GetConsumeCheckpoint(ctx, "p1") + assert.NotNil(t, checkpoint) + assert.NoError(t, err) + + kv.EXPECT().Load(mock.Anything, mock.Anything).Unset() + kv.EXPECT().Load(mock.Anything, mock.Anything).Return("", errors.New("err")) + checkpoint, err = catalog.GetConsumeCheckpoint(ctx, "p1") + assert.Nil(t, checkpoint) + assert.Error(t, err) + + kv.EXPECT().Load(mock.Anything, mock.Anything).Unset() + kv.EXPECT().Load(mock.Anything, mock.Anything).Return("", merr.ErrIoKeyNotFound) + checkpoint, err = catalog.GetConsumeCheckpoint(ctx, "p1") + assert.Nil(t, checkpoint) + assert.Nil(t, err) + + kv.EXPECT().Save(mock.Anything, mock.Anything, mock.Anything).Return(nil) + err = catalog.SaveConsumeCheckpoint(ctx, "p1", &streamingpb.WALCheckpoint{}) + assert.NoError(t, err) + + kv.EXPECT().Save(mock.Anything, mock.Anything, mock.Anything).Unset() + kv.EXPECT().Save(mock.Anything, mock.Anything, mock.Anything).Return(errors.New("err")) + err = catalog.SaveConsumeCheckpoint(ctx, "p1", &streamingpb.WALCheckpoint{}) + assert.Error(t, err) +} + +func TestCatalogSegmentAssignments(t *testing.T) { kv := mocks.NewMetaKv(t) k := "p1" v := streamingpb.SegmentAssignmentMeta{} diff --git a/internal/rootcoord/root_coord.go b/internal/rootcoord/root_coord.go index 75e61cd46c..a7af848cd5 100644 --- a/internal/rootcoord/root_coord.go +++ b/internal/rootcoord/root_coord.go @@ -826,15 +826,18 @@ func (c *Core) revokeSession() { } } +func (c *Core) GracefulStop() { + if c.streamingCoord != nil { + c.streamingCoord.Stop() + } +} + // Stop stops rootCoord. func (c *Core) Stop() error { c.UpdateStateCode(commonpb.StateCode_Abnormal) c.stopExecutor() c.stopScheduler() - if c.streamingCoord != nil { - c.streamingCoord.Stop() - } if c.proxyWatcher != nil { c.proxyWatcher.Stop() } diff --git a/internal/streamingcoord/client/broadcast/watcher_resuming.go b/internal/streamingcoord/client/broadcast/watcher_resuming.go index 2f99c238d1..077f3786c5 100644 --- a/internal/streamingcoord/client/broadcast/watcher_resuming.go +++ b/internal/streamingcoord/client/broadcast/watcher_resuming.go @@ -67,7 +67,8 @@ func (r *resumingWatcher) Close() { func (r *resumingWatcher) execute(backoffConfig *typeutil.BackoffTimerConfig) { backoff := typeutil.NewBackoffTimer(backoffConfig) - nextTimer := time.After(0) + var nextTimer <-chan time.Time + var initialized bool var watcher Watcher defer func() { if watcher != nil { @@ -92,6 +93,12 @@ func (r *resumingWatcher) execute(backoffConfig *typeutil.BackoffTimerConfig) { watcher = nil } } + if !initialized { + // try to initialize watcher in next loop. + // avoid to make a grpc stream channel if the watch operation is not used. + nextTimer = time.After(0) + initialized = true + } case ev, ok := <-eventChan: if !ok { watcher.Close() @@ -101,15 +108,15 @@ func (r *resumingWatcher) execute(backoffConfig *typeutil.BackoffTimerConfig) { r.evs.Notify(ev) case <-nextTimer: var err error + nextTimer = nil if watcher, err = r.createNewWatcher(); err != nil { r.Logger().Warn("create new watcher failed", zap.Error(err)) break } r.Logger().Info("create new watcher successful") backoff.DisableBackoff() - nextTimer = nil } - if watcher == nil { + if watcher == nil && nextTimer == nil { backoff.EnableBackoff() var interval time.Duration nextTimer, interval = backoff.NextTimer() diff --git a/internal/streamingcoord/server/balancer/balancer_impl.go b/internal/streamingcoord/server/balancer/balancer_impl.go index 90f54e9f41..7d827121cc 100644 --- a/internal/streamingcoord/server/balancer/balancer_impl.go +++ b/internal/streamingcoord/server/balancer/balancer_impl.go @@ -65,7 +65,8 @@ func (b *balancerImpl) WatchChannelAssignments(ctx context.Context, cb func(vers } defer b.lifetime.Done() - ctx, _ = contextutil.MergeContext(ctx, b.ctx) + ctx, cancel := contextutil.MergeContext(ctx, b.ctx) + defer cancel() return b.channelMetaManager.WatchAssignmentResult(ctx, cb) } @@ -75,6 +76,8 @@ func (b *balancerImpl) MarkAsUnavailable(ctx context.Context, pChannels []types. } defer b.lifetime.Done() + ctx, cancel := contextutil.MergeContext(ctx, b.ctx) + defer cancel() return b.sendRequestAndWaitFinish(ctx, newOpMarkAsUnavailable(ctx, pChannels)) } @@ -85,6 +88,8 @@ func (b *balancerImpl) Trigger(ctx context.Context) error { } defer b.lifetime.Done() + ctx, cancel := contextutil.MergeContext(ctx, b.ctx) + defer cancel() return b.sendRequestAndWaitFinish(ctx, newOpTrigger(ctx)) } diff --git a/internal/streamingnode/server/flusher/flusherimpl/pchannel_checkpoint_test.go b/internal/streamingnode/server/flusher/flusherimpl/pchannel_checkpoint_test.go index e75a73bb8f..d18c742bda 100644 --- a/internal/streamingnode/server/flusher/flusherimpl/pchannel_checkpoint_test.go +++ b/internal/streamingnode/server/flusher/flusherimpl/pchannel_checkpoint_test.go @@ -43,7 +43,9 @@ func TestPChannelCheckpointManager(t *testing.T) { p.AddVChannel("vchannel-999", rmq.NewRmqID(1000000)) p.DropVChannel("vchannel-1000") - p.Update(vchannel, rmq.NewRmqID(1000001)) + for _, vchannel := range vchannel { + p.Update(vchannel, rmq.NewRmqID(1000001)) + } assert.Eventually(t, func() bool { newMinimum := minimumOne.Load() diff --git a/internal/streamingnode/server/flusher/flusherimpl/vchannel_checkpoint_test.go b/internal/streamingnode/server/flusher/flusherimpl/vchannel_checkpoint_test.go index 3ce2a66dda..f48c10b2d5 100644 --- a/internal/streamingnode/server/flusher/flusherimpl/vchannel_checkpoint_test.go +++ b/internal/streamingnode/server/flusher/flusherimpl/vchannel_checkpoint_test.go @@ -12,7 +12,7 @@ import ( ) func TestVChannelCheckpointManager(t *testing.T) { - exists, vchannel, minimumX := generateRandomExistsMessageID() + exists, vchannels, minimumX := generateRandomExistsMessageID() m := newVChannelCheckpointManager(exists) assert.True(t, m.MinimumCheckpoint().EQ(minimumX)) @@ -32,17 +32,31 @@ func TestVChannelCheckpointManager(t *testing.T) { assert.NoError(t, err) assert.True(t, m.MinimumCheckpoint().EQ(minimumX)) - err = m.Update(vchannel, rmq.NewRmqID(1000001)) - assert.NoError(t, err) + for _, vchannel := range vchannels { + err = m.Update(vchannel, rmq.NewRmqID(1000001)) + assert.NoError(t, err) + } assert.False(t, m.MinimumCheckpoint().EQ(minimumX)) - err = m.Update(vchannel, minimumX) + err = m.Update(vchannels[0], minimumX) assert.Error(t, err) err = m.Drop("vchannel-501") assert.NoError(t, err) + lastMinimum := m.MinimumCheckpoint() + for i := 0; i < 1001; i++ { + m.Update(fmt.Sprintf("vchannel-%d", i), rmq.NewRmqID(rand.Int63n(9999999)+2)) + newMinimum := m.MinimumCheckpoint() + assert.True(t, lastMinimum.LTE(newMinimum)) + lastMinimum = newMinimum + } for i := 0; i < 1001; i++ { m.Drop(fmt.Sprintf("vchannel-%d", i)) + newMinimum := m.MinimumCheckpoint() + if newMinimum != nil { + assert.True(t, lastMinimum.LTE(newMinimum)) + lastMinimum = newMinimum + } } assert.Len(t, m.index, 0) assert.Len(t, m.checkpointHeap, 0) @@ -50,17 +64,21 @@ func TestVChannelCheckpointManager(t *testing.T) { assert.Nil(t, m.MinimumCheckpoint()) } -func generateRandomExistsMessageID() (map[string]message.MessageID, string, message.MessageID) { +func generateRandomExistsMessageID() (map[string]message.MessageID, []string, message.MessageID) { minimumX := int64(10000000) - var vchannel string + var vchannel []string exists := make(map[string]message.MessageID) for i := 0; i < 1000; i++ { x := rand.Int63n(999999) + 2 exists[fmt.Sprintf("vchannel-%d", i)] = rmq.NewRmqID(x) if x < minimumX { minimumX = x - vchannel = fmt.Sprintf("vchannel-%d", i) + vchannel = []string{fmt.Sprintf("vchannel-%d", i)} + } else if x == minimumX { + vchannel = append(vchannel, fmt.Sprintf("vchannel-%d", i)) } } + vchannel = append(vchannel, "vchannel-1") + exists["vchannel-1"] = rmq.NewRmqID(minimumX) return exists, vchannel, rmq.NewRmqID(minimumX) } diff --git a/internal/types/types.go b/internal/types/types.go index 2d52a8e6dc..d694bae3c9 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -213,6 +213,8 @@ type RootCoordComponent interface { GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) RegisterStreamingCoordGRPCService(server *grpc.Server) + + GracefulStop() } // ProxyClient is the client interface for proxy server diff --git a/tests/_helm/values/e2e/distributed b/tests/_helm/values/e2e/distributed index 2175becb18..2b8208009c 100644 --- a/tests/_helm/values/e2e/distributed +++ b/tests/_helm/values/e2e/distributed @@ -20,7 +20,7 @@ dataCoordinator: dataNode: resources: limits: - cpu: "2" + cpu: "1" requests: cpu: "0.5" memory: 500Mi @@ -249,7 +249,21 @@ queryNode: cpu: "2" requests: cpu: "0.5" - memory: 500Mi + memory: 512Mi +streamingNode: + resources: + limits: + cpu: "2" + requests: + cpu: "0.5" + memory: 512Mi +mixCoordinator: + resources: + limits: + cpu: "1" + requests: + cpu: "0.2" + memory: 256Mi rootCoordinator: resources: limits: diff --git a/tests/go_client/testcases/helper/helper.go b/tests/go_client/testcases/helper/helper.go index 6881e9c812..173e863693 100644 --- a/tests/go_client/testcases/helper/helper.go +++ b/tests/go_client/testcases/helper/helper.go @@ -153,6 +153,12 @@ func (chainTask *CollectionPrepare) CreateCollection(ctx context.Context, t *tes common.CheckErr(t, err, true) t.Cleanup(func() { + // The collection will be cleanup after the test + // But some ctx is setted with timeout for only a part of unittest, + // which will cause the drop collection failed with timeout. + ctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), time.Second*10) + defer cancel() + err := mc.DropCollection(ctx, clientv2.NewDropCollectionOption(schema.CollectionName)) common.CheckErr(t, err, true) })