diff --git a/internal/common/common.go b/internal/common/common.go index 025588144e..1ebf4796b6 100644 --- a/internal/common/common.go +++ b/internal/common/common.go @@ -52,6 +52,9 @@ const ( // NotRegisteredID means node is not registered into etcd. NotRegisteredID = int64(-1) + + // InvalidNodeID indicates that node is not valid in querycoord replica or shard cluster. + InvalidNodeID = int64(-1) ) // Endian is type alias of binary.LittleEndian. diff --git a/internal/querynode/impl_test.go b/internal/querynode/impl_test.go index cf6ce28b79..9c2add7f26 100644 --- a/internal/querynode/impl_test.go +++ b/internal/querynode/impl_test.go @@ -27,6 +27,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/milvuspb" @@ -749,7 +750,7 @@ func TestImpl_SyncReplicaSegments(t *testing.T) { require.True(t, ok) segment, ok := cs.getSegment(1) require.True(t, ok) - assert.Equal(t, int64(1), segment.nodeID) + assert.Equal(t, common.InvalidNodeID, segment.nodeID) assert.Equal(t, defaultPartitionID, segment.partitionID) assert.Equal(t, segmentStateLoaded, segment.state) diff --git a/internal/querynode/shard_cluster.go b/internal/querynode/shard_cluster.go index d8d7ce545d..c0ce1de857 100644 --- a/internal/querynode/shard_cluster.go +++ b/internal/querynode/shard_cluster.go @@ -26,6 +26,7 @@ import ( "go.uber.org/atomic" "go.uber.org/zap" + "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/internalpb" @@ -292,15 +293,24 @@ func (sc *ShardCluster) updateSegment(evt shardSegmentInfo) { // SyncSegments synchronize segment distribution in batch func (sc *ShardCluster) SyncSegments(distribution []*querypb.ReplicaSegmentsInfo, state segmentState) { + log := log.With(zap.Int64("collectionID", sc.collectionID), zap.String("vchannel", sc.vchannelName), zap.Int64("replicaID", sc.replicaID)) log.Info("ShardCluster sync segments", zap.Any("replica segments", distribution), zap.Int32("state", int32(state))) sc.mut.Lock() for _, line := range distribution { for _, segmentID := range line.GetSegmentIds() { + nodeID := line.GetNodeId() + // if node id not in replica node list, this line shall be placeholder for segment offline + _, ok := sc.nodes[nodeID] + if !ok { + log.Warn("Sync segment with invalid nodeID", zap.Int64("segmentID", segmentID), zap.Int64("nodeID", line.NodeId)) + nodeID = common.InvalidNodeID + } + old, ok := sc.segments[segmentID] if !ok { // newly add sc.segments[segmentID] = shardSegmentInfo{ - nodeID: line.GetNodeId(), + nodeID: nodeID, partitionID: line.GetPartitionId(), segmentID: segmentID, state: state, @@ -309,7 +319,7 @@ func (sc *ShardCluster) SyncSegments(distribution []*querypb.ReplicaSegmentsInfo } sc.transferSegment(old, shardSegmentInfo{ - nodeID: line.GetNodeId(), + nodeID: nodeID, partitionID: line.GetPartitionId(), segmentID: segmentID, state: state, @@ -388,6 +398,7 @@ func (sc *ShardCluster) removeSegment(evt shardSegmentInfo) { } delete(sc.segments, evt.segmentID) + sc.healthCheck() } // init list all nodes and semgent states ant start watching @@ -455,7 +466,8 @@ func (sc *ShardCluster) updateShardClusterState(state shardClusterState) { // healthCheck iterate all segments to to check cluster could provide service. func (sc *ShardCluster) healthCheck() { for _, segment := range sc.segments { - if segment.state != segmentStateLoaded { // TODO check hand-off or load balance + if segment.state != segmentStateLoaded || + segment.nodeID == common.InvalidNodeID { // segment in offline nodes sc.updateShardClusterState(unavailable) return } @@ -600,11 +612,15 @@ func (sc *ShardCluster) HandoffSegments(info *querypb.SegmentChangeInfo) error { } nodeID, has := sc.selectNodeInReplica(seg.NodeIds) if !has { - continue + // remove segment placeholder + nodeID = common.InvalidNodeID } sc.removeSegment(shardSegmentInfo{segmentID: seg.GetSegmentID(), nodeID: nodeID}) - removes[nodeID] = append(removes[nodeID], seg.SegmentID) + // only add remove operations when node is valid + if nodeID != common.InvalidNodeID { + removes[nodeID] = append(removes[nodeID], seg.SegmentID) + } } var errs errorutil.ErrorList diff --git a/internal/querynode/shard_cluster_service_test.go b/internal/querynode/shard_cluster_service_test.go index 8c5a73ced5..f665367d9a 100644 --- a/internal/querynode/shard_cluster_service_test.go +++ b/internal/querynode/shard_cluster_service_test.go @@ -5,6 +5,7 @@ import ( "errors" "testing" + "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/stretchr/testify/assert" @@ -84,7 +85,7 @@ func TestShardClusterService_SyncReplicaSegments(t *testing.T) { require.True(t, ok) segment, ok := cs.getSegment(1) assert.True(t, ok) - assert.Equal(t, int64(1), segment.nodeID) + assert.Equal(t, common.InvalidNodeID, segment.nodeID) assert.Equal(t, defaultPartitionID, segment.partitionID) assert.Equal(t, segmentStateLoaded, segment.state) }) diff --git a/internal/querynode/shard_cluster_test.go b/internal/querynode/shard_cluster_test.go index 31486a658b..960e176b11 100644 --- a/internal/querynode/shard_cluster_test.go +++ b/internal/querynode/shard_cluster_test.go @@ -22,6 +22,7 @@ import ( "testing" "time" + "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" @@ -982,6 +983,46 @@ func TestShardCluster_SyncSegments(t *testing.T) { }, time.Second, time.Millisecond) }) + t.Run("sync segments with offline nodes", func(t *testing.T) { + nodeEvents := []nodeEvent{} + + segmentEvents := []segmentEvent{} + + evtCh := make(chan segmentEvent, 10) + sc := NewShardCluster(collectionID, replicaID, vchannelName, + &mockNodeDetector{initNodes: nodeEvents}, &mockSegmentDetector{ + initSegments: segmentEvents, + evtCh: evtCh, + }, buildMockQueryNode) + defer sc.Close() + + sc.SyncSegments([]*querypb.ReplicaSegmentsInfo{ + { + NodeId: 1, + SegmentIds: []int64{1}, + }, + { + NodeId: 2, + SegmentIds: []int64{2}, + }, + { + NodeId: 3, + SegmentIds: []int64{3}, + }, + }, segmentStateLoaded) + assert.Eventually(t, func() bool { + seg, has := sc.getSegment(1) + return has && seg.nodeID == common.InvalidNodeID && seg.state == segmentStateLoaded + }, time.Second, time.Millisecond) + assert.Eventually(t, func() bool { + seg, has := sc.getSegment(2) + return has && seg.nodeID == common.InvalidNodeID && seg.state == segmentStateLoaded + }, time.Second, time.Millisecond) + assert.Eventually(t, func() bool { + seg, has := sc.getSegment(3) + return has && seg.nodeID == common.InvalidNodeID && seg.state == segmentStateLoaded + }, time.Second, time.Millisecond) + }) } var streamingDoNothing = func(context.Context) error { return nil }