diff --git a/internal/querycoord/cluster.go b/internal/querycoord/cluster.go index cb4484f960..f6e7989b23 100644 --- a/internal/querycoord/cluster.go +++ b/internal/querycoord/cluster.go @@ -64,6 +64,7 @@ type Cluster interface { getSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest) ([]*querypb.SegmentInfo, error) getSegmentInfoByNode(ctx context.Context, nodeID int64, in *querypb.GetSegmentInfoRequest) ([]*querypb.SegmentInfo, error) getSegmentInfoByID(ctx context.Context, segmentID UniqueID) (*querypb.SegmentInfo, error) + syncReplicaSegments(ctx context.Context, leaderID UniqueID, in *querypb.SyncReplicaSegmentsRequest) error registerNode(ctx context.Context, session *sessionutil.Session, id UniqueID, state nodeState) error getNodeInfoByID(nodeID int64) (Node, error) @@ -494,6 +495,17 @@ func (c *queryNodeCluster) getSegmentInfoByNode(ctx context.Context, nodeID int6 return res.GetInfos(), nil } +func (c *queryNodeCluster) syncReplicaSegments(ctx context.Context, leaderID UniqueID, in *querypb.SyncReplicaSegmentsRequest) error { + c.RLock() + leader, ok := c.nodes[leaderID] + c.RUnlock() + + if !ok { + return fmt.Errorf("syncReplicaSegments: can't find leader query node, leaderID = %d", leaderID) + } + return leader.syncReplicaSegments(ctx, in) +} + type queryNodeGetMetricsResponse struct { resp *milvuspb.GetMetricsResponse err error diff --git a/internal/querycoord/mock_querynode_server_test.go b/internal/querycoord/mock_querynode_server_test.go index 774c4d572e..124c6e49ea 100644 --- a/internal/querycoord/mock_querynode_server_test.go +++ b/internal/querycoord/mock_querynode_server_test.go @@ -271,6 +271,12 @@ func (qs *queryNodeServerMock) GetSegmentInfo(ctx context.Context, req *querypb. return res, err } +func (qs *queryNodeServerMock) SyncReplicaSegments(ctx context.Context, req *querypb.SyncReplicaSegmentsRequest) (*commonpb.Status, error) { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, nil +} + func (qs *queryNodeServerMock) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { response, err := qs.getMetrics() if err != nil { diff --git a/internal/querycoord/querynode.go b/internal/querycoord/querynode.go index ebc7b56324..0ad8fc080c 100644 --- a/internal/querycoord/querynode.go +++ b/internal/querycoord/querynode.go @@ -47,6 +47,7 @@ type Node interface { watchDmChannels(ctx context.Context, in *querypb.WatchDmChannelsRequest) error watchDeltaChannels(ctx context.Context, in *querypb.WatchDeltaChannelsRequest) error + syncReplicaSegments(ctx context.Context, in *querypb.SyncReplicaSegmentsRequest) error //removeDmChannel(collectionID UniqueID, channels []string) error hasWatchedDeltaChannel(collectionID UniqueID) bool @@ -445,3 +446,19 @@ func (qn *queryNode) getNodeInfo() (Node, error) { cpuUsage: qn.cpuUsage, }, nil } + +func (qn *queryNode) syncReplicaSegments(ctx context.Context, in *querypb.SyncReplicaSegmentsRequest) error { + if !qn.isOnline() { + return errors.New("ReleaseSegments: queryNode is offline") + } + + status, err := qn.client.SyncReplicaSegments(ctx, in) + if err != nil { + return err + } + if status.ErrorCode != commonpb.ErrorCode_Success { + return errors.New(status.Reason) + } + + return nil +} diff --git a/internal/querycoord/task.go b/internal/querycoord/task.go index 1f473fd868..8752f9e328 100644 --- a/internal/querycoord/task.go +++ b/internal/querycoord/task.go @@ -340,11 +340,23 @@ func (lct *loadCollectionTask) updateTaskProcess() { } if allDone { - err := lct.meta.setLoadPercentage(collectionID, 0, 100, querypb.LoadType_LoadCollection) + err := syncReplicaSegments(lct.ctx, lct.cluster, childTasks) + if err != nil { + log.Error("loadCollectionTask: failed to sync replica segments to shard leader", + zap.Int64("taskID", lct.getTaskID()), + zap.Int64("collectionID", collectionID), + zap.Error(err)) + lct.setResultInfo(err) + return + } + + err = lct.meta.setLoadPercentage(collectionID, 0, 100, querypb.LoadType_LoadCollection) if err != nil { log.Error("loadCollectionTask: set load percentage to meta's collectionInfo", zap.Int64("collectionID", collectionID)) lct.setResultInfo(err) + return } + lct.once.Do(func() { metrics.QueryCoordLoadCount.WithLabelValues(metrics.SuccessLabel).Inc() metrics.QueryCoordLoadLatency.WithLabelValues().Observe(float64(lct.elapseSpan().Milliseconds())) @@ -781,11 +793,22 @@ func (lpt *loadPartitionTask) updateTaskProcess() { } } if allDone { + err := syncReplicaSegments(lpt.ctx, lpt.cluster, childTasks) + if err != nil { + log.Error("loadPartitionTask: failed to sync replica segments to shard leader", + zap.Int64("taskID", lpt.getTaskID()), + zap.Int64("collectionID", collectionID), + zap.Error(err)) + lpt.setResultInfo(err) + return + } + for _, id := range partitionIDs { err := lpt.meta.setLoadPercentage(collectionID, id, 100, querypb.LoadType_LoadPartition) if err != nil { log.Error("loadPartitionTask: set load percentage to meta's collectionInfo", zap.Int64("collectionID", collectionID), zap.Int64("partitionID", id)) lpt.setResultInfo(err) + return } } lpt.once.Do(func() { @@ -2198,6 +2221,13 @@ func (lbt *loadBalanceTask) getReplica(nodeID, collectionID int64) (*milvuspb.Re } func (lbt *loadBalanceTask) postExecute(context.Context) error { + err := syncReplicaSegments(lbt.ctx, lbt.cluster, lbt.getChildTask()) + if err != nil { + log.Error("loadBalanceTask: failed to sync replica segments to shard leaders", + zap.Int64("taskID", lbt.getTaskID()), + zap.Error(err)) + } + if lbt.getResultInfo().ErrorCode != commonpb.ErrorCode_Success { lbt.clearChildTasks() } diff --git a/internal/querycoord/util.go b/internal/querycoord/util.go index 0de55a0da1..aecedc375d 100644 --- a/internal/querycoord/util.go +++ b/internal/querycoord/util.go @@ -17,9 +17,12 @@ package querycoord import ( + "context" + "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/typeutil" ) func getCompareMapFromSlice(sliceData []int64) map[int64]struct{} { @@ -104,3 +107,95 @@ func getDstNodeIDByTask(t task) int64 { return nodeID } + +func syncReplicaSegments(ctx context.Context, cluster Cluster, childTasks []task) error { + type SegmentIndex struct { + NodeID UniqueID + PartitionID UniqueID + ReplicaID UniqueID + } + + type ShardLeader struct { + ReplicaID UniqueID + LeaderID UniqueID + } + + shardSegments := make(map[string]map[SegmentIndex]typeutil.UniqueSet) // DMC -> set[Segment] + shardLeaders := make(map[string][]*ShardLeader) // DMC -> leader + for _, childTask := range childTasks { + switch task := childTask.(type) { + case *loadSegmentTask: + nodeID := getDstNodeIDByTask(task) + for _, segment := range task.Infos { + segments, ok := shardSegments[segment.InsertChannel] + if !ok { + segments = make(map[SegmentIndex]typeutil.UniqueSet) + } + + index := SegmentIndex{ + NodeID: nodeID, + PartitionID: segment.PartitionID, + ReplicaID: task.ReplicaID, + } + + _, ok = segments[index] + if !ok { + segments[index] = make(typeutil.UniqueSet) + } + segments[index].Insert(segment.SegmentID) + + shardSegments[segment.InsertChannel] = segments + } + + case *watchDmChannelTask: + leaderID := getDstNodeIDByTask(task) + leader := &ShardLeader{ + ReplicaID: task.ReplicaID, + LeaderID: leaderID, + } + + for _, dmc := range task.Infos { + leaders, ok := shardLeaders[dmc.ChannelName] + if !ok { + leaders = make([]*ShardLeader, 0) + } + + leaders = append(leaders, leader) + + shardLeaders[dmc.ChannelName] = leaders + } + } + } + + for dmc, leaders := range shardLeaders { + for _, leader := range leaders { + segments, ok := shardSegments[dmc] + if !ok { + break + } + + req := querypb.SyncReplicaSegmentsRequest{ + VchannelName: dmc, + ReplicaSegments: make([]*querypb.ReplicaSegmentsInfo, 0, len(segments)), + } + + for index, segmentSet := range segments { + if index.ReplicaID == leader.ReplicaID { + req.ReplicaSegments = append(req.ReplicaSegments, + &querypb.ReplicaSegmentsInfo{ + NodeId: index.NodeID, + PartitionId: index.PartitionID, + SegmentIds: segmentSet.Collect(), + }) + } + } + + err := cluster.syncReplicaSegments(ctx, leader.LeaderID, &req) + if err != nil { + return err + } + } + } + + return nil +} diff --git a/internal/util/typeutil/set.go b/internal/util/typeutil/set.go new file mode 100644 index 0000000000..18974dca59 --- /dev/null +++ b/internal/util/typeutil/set.go @@ -0,0 +1,61 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package typeutil + +// UniqueSet is set type, which contains only UniqueIDs, +// the underlying type is map[UniqueID]struct{}. +// Create a UniqueSet instance with make(UniqueSet) like creating a map instance. +type UniqueSet map[UniqueID]struct{} + +// Insert elements into the set, +// do nothing if the id existed +func (set UniqueSet) Insert(ids ...UniqueID) { + for i := range ids { + set[ids[i]] = struct{}{} + } +} + +// Check whether the elements exist +func (set UniqueSet) Contain(ids ...UniqueID) bool { + for i := range ids { + _, ok := set[ids[i]] + if !ok { + return false + } + } + + return true +} + +// Remove elements from the set, +// do nothing if set is nil or id not exists +func (set UniqueSet) Remove(ids ...UniqueID) { + for i := range ids { + delete(set, ids[i]) + } +} + +// Get all elements in the set +func (set UniqueSet) Collect() []UniqueID { + ids := make([]UniqueID, 0, len(set)) + + for id := range set { + ids = append(ids, id) + } + + return ids +} diff --git a/internal/util/typeutil/set_test.go b/internal/util/typeutil/set_test.go new file mode 100644 index 0000000000..eb33342c19 --- /dev/null +++ b/internal/util/typeutil/set_test.go @@ -0,0 +1,38 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package typeutil + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestUniqueSet(t *testing.T) { + set := make(UniqueSet) + set.Insert(5, 7, 9) + assert.True(t, set.Contain(5)) + assert.True(t, set.Contain(7)) + assert.True(t, set.Contain(9)) + assert.True(t, set.Contain(5, 7, 9)) + + set.Remove(7) + assert.True(t, set.Contain(5)) + assert.False(t, set.Contain(7)) + assert.True(t, set.Contain(9)) + assert.False(t, set.Contain(5, 7, 9)) +}