Fix SyncSegment generate source (#19929)

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
congqixia 2022-10-20 16:35:28 +08:00 committed by GitHub
parent b15e97a61a
commit 3de74165ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 139 additions and 1089 deletions

View File

@ -330,11 +330,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC
}
sc, _ := node.ShardClusterService.getShardCluster(in.Infos[0].GetChannelName())
sc.mutVersion.Lock()
defer sc.mutVersion.Unlock()
version := NewShardClusterVersion(sc.nextVersionID.Inc(), make(SegmentsStatus), nil)
sc.versions.Store(version.versionID, version)
sc.currentVersion = version
sc.SetupFirstVersion()
log.Info("watchDmChannelsTask WaitToFinish done", zap.Int64("collectionID", in.CollectionID), zap.Int64("nodeID", Params.QueryNodeCfg.GetNodeID()))
return &commonpb.Status{

View File

@ -620,7 +620,7 @@ func TestImpl_Search(t *testing.T) {
// shard cluster sync segments
sc, ok := node.ShardClusterService.getShardCluster(defaultDMLChannel)
assert.True(t, ok)
sc.SyncSegments(nil, segmentStateLoaded)
sc.SetupFirstVersion()
_, err = node.Search(ctx, &queryPb.SearchRequest{
Req: req,
@ -645,7 +645,7 @@ func TestImpl_searchWithDmlChannel(t *testing.T) {
node.ShardClusterService.addShardCluster(defaultCollectionID, defaultReplicaID, defaultDMLChannel)
sc, ok := node.ShardClusterService.getShardCluster(defaultDMLChannel)
assert.True(t, ok)
sc.SyncSegments(nil, segmentStateLoaded)
sc.SetupFirstVersion()
_, err = node.searchWithDmlChannel(ctx, &queryPb.SearchRequest{
Req: req,
@ -730,7 +730,7 @@ func TestImpl_Query(t *testing.T) {
// sync cluster segments
sc, ok := node.ShardClusterService.getShardCluster(defaultDMLChannel)
assert.True(t, ok)
sc.SyncSegments(nil, segmentStateLoaded)
sc.SetupFirstVersion()
_, err = node.Query(ctx, &queryPb.QueryRequest{
Req: req,
@ -756,7 +756,7 @@ func TestImpl_queryWithDmlChannel(t *testing.T) {
node.ShardClusterService.addShardCluster(defaultCollectionID, defaultReplicaID, defaultDMLChannel)
sc, ok := node.ShardClusterService.getShardCluster(defaultDMLChannel)
assert.True(t, ok)
sc.SyncSegments(nil, segmentStateLoaded)
sc.SetupFirstVersion()
_, err = node.queryWithDmlChannel(ctx, &queryPb.QueryRequest{
Req: req,
@ -820,6 +820,9 @@ func TestImpl_SyncReplicaSegments(t *testing.T) {
assert.NoError(t, err)
node.ShardClusterService.addShardCluster(defaultCollectionID, defaultReplicaID, defaultDMLChannel)
cs, ok := node.ShardClusterService.getShardCluster(defaultDMLChannel)
require.True(t, ok)
cs.SetupFirstVersion()
resp, err := node.SyncReplicaSegments(ctx, &querypb.SyncReplicaSegmentsRequest{
VchannelName: defaultDMLChannel,
@ -837,8 +840,6 @@ func TestImpl_SyncReplicaSegments(t *testing.T) {
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
t.Log(resp.GetReason())
cs, ok := node.ShardClusterService.getShardCluster(defaultDMLChannel)
require.True(t, ok)
segment, ok := cs.getSegment(1)
require.True(t, ok)
assert.Equal(t, common.InvalidNodeID, segment.nodeID)
@ -921,6 +922,9 @@ func TestSyncDistribution(t *testing.T) {
assert.NoError(t, err)
node.ShardClusterService.addShardCluster(defaultCollectionID, defaultReplicaID, defaultDMLChannel)
cs, ok := node.ShardClusterService.getShardCluster(defaultDMLChannel)
require.True(t, ok)
cs.SetupFirstVersion()
resp, err := node.SyncDistribution(ctx, &querypb.SyncDistributionRequest{
Base: &commonpb.MsgBase{TargetID: node.session.ServerID},
@ -939,8 +943,6 @@ func TestSyncDistribution(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
cs, ok := node.ShardClusterService.getShardCluster(defaultDMLChannel)
require.True(t, ok)
segment, ok := cs.getSegment(defaultSegmentID)
require.True(t, ok)
assert.Equal(t, common.InvalidNodeID, segment.nodeID)

View File

@ -33,7 +33,6 @@ import (
"math"
"os"
"path"
"path/filepath"
"runtime"
"strconv"
"sync"
@ -44,19 +43,14 @@ import (
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/golang/protobuf/proto"
"github.com/panjf2000/ants/v2"
"go.etcd.io/etcd/api/v3/mvccpb"
v3rpc "go.etcd.io/etcd/api/v3/v3rpc/rpctypes"
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/zap"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util"
"github.com/milvus-io/milvus/internal/util/concurrency"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/initcore"
@ -308,9 +302,6 @@ func (node *QueryNode) Start() error {
// start task scheduler
go node.scheduler.Start()
// start services
go node.watchChangeInfo()
// create shardClusterService for shardLeader functions.
node.ShardClusterService = newShardClusterService(node.etcdCli, node.session, node)
// create shard-level query service
@ -366,104 +357,3 @@ func (node *QueryNode) UpdateStateCode(code commonpb.StateCode) {
func (node *QueryNode) SetEtcdClient(client *clientv3.Client) {
node.etcdCli = client
}
func (node *QueryNode) watchChangeInfo() {
log.Info("query node watchChangeInfo start")
watchChan := node.etcdKV.WatchWithPrefix(util.ChangeInfoMetaPrefix)
for {
select {
case <-node.queryNodeLoopCtx.Done():
log.Info("query node watchChangeInfo close")
return
case resp, ok := <-watchChan:
if !ok {
log.Warn("querynode failed to watch channel, return")
return
}
if err := resp.Err(); err != nil {
log.Warn("query watch channel canceled", zap.Error(resp.Err()))
// https://github.com/etcd-io/etcd/issues/8980
if resp.Err() == v3rpc.ErrCompacted {
go node.watchChangeInfo()
return
}
// if watch loop return due to event canceled, the datanode is not functional anymore
log.Panic("querynode is not functional for event canceled", zap.Error(err))
return
}
for _, event := range resp.Events {
switch event.Type {
case mvccpb.PUT:
infoID, err := strconv.ParseInt(filepath.Base(string(event.Kv.Key)), 10, 64)
if err != nil {
log.Warn("Parse SealedSegmentsChangeInfo id failed", zap.Any("error", err.Error()))
continue
}
log.Info("get SealedSegmentsChangeInfo from etcd",
zap.Any("infoID", infoID),
)
info := &querypb.SealedSegmentsChangeInfo{}
err = proto.Unmarshal(event.Kv.Value, info)
if err != nil {
log.Warn("Unmarshal SealedSegmentsChangeInfo failed", zap.Any("error", err.Error()))
continue
}
go node.handleSealedSegmentsChangeInfo(info)
default:
// do nothing
}
}
}
}
}
func (node *QueryNode) handleSealedSegmentsChangeInfo(info *querypb.SealedSegmentsChangeInfo) {
for _, line := range info.GetInfos() {
result := splitSegmentsChange(line)
for vchannel, changeInfo := range result {
err := node.ShardClusterService.HandoffVChannelSegments(vchannel, changeInfo)
if err != nil {
log.Warn("failed to handle vchannel segments", zap.String("vchannel", vchannel))
}
}
}
}
// splitSegmentsChange returns rearranged segmentChangeInfo in vchannel dimension
func splitSegmentsChange(changeInfo *querypb.SegmentChangeInfo) map[string]*querypb.SegmentChangeInfo {
result := make(map[string]*querypb.SegmentChangeInfo)
for _, segment := range changeInfo.GetOnlineSegments() {
dmlChannel := segment.GetDmChannel()
info, has := result[dmlChannel]
if !has {
info = &querypb.SegmentChangeInfo{
OnlineNodeID: changeInfo.OnlineNodeID,
OfflineNodeID: changeInfo.OfflineNodeID,
}
result[dmlChannel] = info
}
info.OnlineSegments = append(info.OnlineSegments, segment)
}
for _, segment := range changeInfo.GetOfflineSegments() {
dmlChannel := segment.GetDmChannel()
info, has := result[dmlChannel]
if !has {
info = &querypb.SegmentChangeInfo{
OnlineNodeID: changeInfo.OnlineNodeID,
OfflineNodeID: changeInfo.OfflineNodeID,
}
result[dmlChannel] = info
}
info.OfflineSegments = append(info.OfflineSegments, segment)
}
return result
}

View File

@ -26,16 +26,13 @@ import (
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.etcd.io/etcd/server/v3/embed"
"github.com/milvus-io/milvus/internal/util/concurrency"
"github.com/milvus-io/milvus/internal/util/dependency"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/etcd"
)
@ -253,259 +250,3 @@ func TestQueryNode_adjustByChangeInfo(t *testing.T) {
})
wg.Wait()
}
func TestQueryNode_watchChangeInfo(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var wg sync.WaitGroup
wg.Add(1)
t.Run("test watchChangeInfo", func(t *testing.T) {
defer wg.Done()
node, err := genSimpleQueryNodeToTestWatchChangeInfo(ctx)
assert.NoError(t, err)
go node.watchChangeInfo()
info := genSimpleSegmentInfo()
value, err := proto.Marshal(info)
assert.NoError(t, err)
err = saveChangeInfo("0", string(value))
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
})
wg.Add(1)
t.Run("test watchChangeInfo key error", func(t *testing.T) {
defer wg.Done()
node, err := genSimpleQueryNodeToTestWatchChangeInfo(ctx)
assert.NoError(t, err)
go node.watchChangeInfo()
err = saveChangeInfo("*$&#%^^", "%EUY%&#^$%&@")
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
})
wg.Add(1)
t.Run("test watchChangeInfo unmarshal error", func(t *testing.T) {
defer wg.Done()
node, err := genSimpleQueryNodeToTestWatchChangeInfo(ctx)
assert.NoError(t, err)
go node.watchChangeInfo()
err = saveChangeInfo("0", "$%^$*&%^#$&*")
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
})
wg.Add(1)
t.Run("test watchChangeInfo adjustByChangeInfo error", func(t *testing.T) {
defer wg.Done()
node, err := genSimpleQueryNodeToTestWatchChangeInfo(ctx)
assert.NoError(t, err)
node.metaReplica.removeSegment(defaultSegmentID, segmentTypeSealed)
segmentChangeInfos := genSimpleChangeInfo()
segmentChangeInfos.Infos[0].OnlineSegments = nil
segmentChangeInfos.Infos[0].OfflineNodeID = Params.QueryNodeCfg.GetNodeID()
/*
qc, err := node.queryService.getQueryCollection(defaultCollectionID)
assert.NoError(t, err)
qc.globalSegmentManager.removeGlobalSealedSegmentInfo(defaultSegmentID)*/
go node.watchChangeInfo()
value, err := proto.Marshal(segmentChangeInfos)
assert.NoError(t, err)
err = saveChangeInfo("0", string(value))
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
})
wg.Wait()
}
func TestQueryNode_splitChangeChannel(t *testing.T) {
type testCase struct {
name string
info *querypb.SegmentChangeInfo
expectedResult map[string]*querypb.SegmentChangeInfo
}
cases := []testCase{
{
name: "empty info",
info: &querypb.SegmentChangeInfo{},
expectedResult: map[string]*querypb.SegmentChangeInfo{},
},
{
name: "normal segment change info",
info: &querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
OfflineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
},
expectedResult: map[string]*querypb.SegmentChangeInfo{
defaultDMLChannel: {
OnlineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
OfflineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
},
},
},
{
name: "empty offline change info",
info: &querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
},
expectedResult: map[string]*querypb.SegmentChangeInfo{
defaultDMLChannel: {
OnlineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
},
},
},
{
name: "empty online change info",
info: &querypb.SegmentChangeInfo{
OfflineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
},
expectedResult: map[string]*querypb.SegmentChangeInfo{
defaultDMLChannel: {
OfflineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
},
},
},
{
name: "different channel in online",
info: &querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
{DmChannel: "other_channel"},
},
},
expectedResult: map[string]*querypb.SegmentChangeInfo{
defaultDMLChannel: {
OnlineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
},
"other_channel": {
OnlineSegments: []*querypb.SegmentInfo{
{DmChannel: "other_channel"},
},
},
},
},
{
name: "different channel in offline",
info: &querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
OfflineSegments: []*querypb.SegmentInfo{
{DmChannel: "other_channel"},
},
},
expectedResult: map[string]*querypb.SegmentChangeInfo{
defaultDMLChannel: {
OnlineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
},
"other_channel": {
OfflineSegments: []*querypb.SegmentInfo{
{DmChannel: "other_channel"},
},
},
},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
result := splitSegmentsChange(tc.info)
assert.Equal(t, len(tc.expectedResult), len(result))
for k, v := range tc.expectedResult {
r := assert.True(t, proto.Equal(v, result[k]))
if !r {
t.Log(v)
t.Log(result[k])
}
}
})
}
}
func TestQueryNode_handleSealedSegmentsChangeInfo(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
qn, err := genSimpleQueryNode(ctx)
require.NoError(t, err)
t.Run("empty info", func(t *testing.T) {
assert.NotPanics(t, func() {
qn.handleSealedSegmentsChangeInfo(&querypb.SealedSegmentsChangeInfo{})
})
assert.NotPanics(t, func() {
qn.handleSealedSegmentsChangeInfo(nil)
})
})
t.Run("normal segment change info", func(t *testing.T) {
assert.NotPanics(t, func() {
qn.handleSealedSegmentsChangeInfo(&querypb.SealedSegmentsChangeInfo{
Infos: []*querypb.SegmentChangeInfo{
{
OnlineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
OfflineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
},
},
})
})
})
t.Run("multple vchannel change info", func(t *testing.T) {
assert.NotPanics(t, func() {
qn.handleSealedSegmentsChangeInfo(&querypb.SealedSegmentsChangeInfo{
Infos: []*querypb.SegmentChangeInfo{
{
OnlineSegments: []*querypb.SegmentInfo{
{DmChannel: defaultDMLChannel},
},
OfflineSegments: []*querypb.SegmentInfo{
{DmChannel: "other_channel"},
},
},
},
})
})
})
}

View File

@ -32,7 +32,6 @@ import (
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/errorutil"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
@ -301,11 +300,29 @@ func (sc *ShardCluster) updateSegment(evt shardSegmentInfo) {
sc.transferSegment(old, evt)
}
// SetupFirstVersion initialized first version for shard cluster.
func (sc *ShardCluster) SetupFirstVersion() {
sc.mutVersion.Lock()
defer sc.mutVersion.Unlock()
version := NewShardClusterVersion(sc.nextVersionID.Inc(), make(SegmentsStatus), nil)
sc.versions.Store(version.versionID, version)
sc.currentVersion = version
}
// SyncSegments synchronize segment distribution in batch
func (sc *ShardCluster) SyncSegments(distribution []*querypb.ReplicaSegmentsInfo, state segmentState) {
log := sc.getLogger()
log.Info("ShardCluster sync segments", zap.Any("replica segments", distribution), zap.Int32("state", int32(state)))
var currentVersion *ShardClusterVersion
sc.mutVersion.RLock()
currentVersion = sc.currentVersion
sc.mutVersion.RUnlock()
if currentVersion == nil {
log.Warn("received SyncSegments call before version setup")
return
}
sc.mut.Lock()
for _, line := range distribution {
for i, segmentID := range line.GetSegmentIds() {
@ -340,7 +357,7 @@ func (sc *ShardCluster) SyncSegments(distribution []*querypb.ReplicaSegmentsInfo
}
}
allocations := sc.segments.Clone(filterNothing)
// allocations := sc.segments.Clone(filterNothing)
sc.mut.Unlock()
// notify handoff wait online if any
@ -350,6 +367,15 @@ func (sc *ShardCluster) SyncSegments(distribution []*querypb.ReplicaSegmentsInfo
sc.mutVersion.Lock()
defer sc.mutVersion.Unlock()
// update shardleader allocation view
allocations := sc.currentVersion.segments.Clone(filterNothing)
for _, line := range distribution {
for _, segmentID := range line.GetSegmentIds() {
allocations[segmentID] = shardSegmentInfo{nodeID: line.GetNodeId(), segmentID: segmentID, partitionID: line.GetPartitionId(), state: state}
}
}
version := NewShardClusterVersion(sc.nextVersionID.Inc(), allocations, sc.currentVersion)
sc.versions.Store(version.versionID, version)
sc.currentVersion = version
@ -597,86 +623,6 @@ func (sc *ShardCluster) finishUsage(versionID int64) {
}
}
// HandoffSegments processes the handoff/load balance segments update procedure.
func (sc *ShardCluster) HandoffSegments(info *querypb.SegmentChangeInfo) error {
log := sc.getLogger()
// wait for all OnlineSegment is loaded
onlineSegmentIDs := make([]int64, 0, len(info.OnlineSegments))
onlineSegments := make([]shardSegmentInfo, 0, len(info.OnlineSegments))
for _, seg := range info.OnlineSegments {
// filter out segments not maintained in this cluster
if seg.GetCollectionID() != sc.collectionID || seg.GetDmChannel() != sc.vchannelName {
continue
}
nodeID, has := sc.selectNodeInReplica(seg.NodeIds)
if !has {
continue
}
onlineSegments = append(onlineSegments, shardSegmentInfo{
nodeID: nodeID,
segmentID: seg.GetSegmentID(),
})
onlineSegmentIDs = append(onlineSegmentIDs, seg.GetSegmentID())
}
sc.waitSegmentsOnline(onlineSegments)
// now online segment can provide service, generate a new version
// add segmentChangeInfo to pending list
versionID := sc.applySegmentChange(info, onlineSegmentIDs)
removes := make(map[int64][]int64) // nodeID => []segmentIDs
// remove offline segments record
for _, seg := range info.OfflineSegments {
// filter out segments not maintained in this cluster
if seg.GetCollectionID() != sc.collectionID || seg.GetDmChannel() != sc.vchannelName {
continue
}
nodeID, has := sc.selectNodeInReplica(seg.NodeIds)
if !has {
// remove segment placeholder
nodeID = common.InvalidNodeID
}
sc.removeSegment(shardSegmentInfo{segmentID: seg.GetSegmentID(), nodeID: nodeID})
// only add remove operations when node is valid
if nodeID != common.InvalidNodeID {
removes[nodeID] = append(removes[nodeID], seg.SegmentID)
}
}
var errs errorutil.ErrorList
// notify querynode(s) to release segments
for nodeID, segmentIDs := range removes {
node, ok := sc.getNode(nodeID)
if !ok {
log.Warn("node not in cluster", zap.Int64("nodeID", nodeID))
errs = append(errs, fmt.Errorf("node not in cluster nodeID %d", nodeID))
continue
}
state, err := node.client.ReleaseSegments(context.Background(), &querypb.ReleaseSegmentsRequest{
CollectionID: sc.collectionID,
SegmentIDs: segmentIDs,
Scope: querypb.DataScope_Historical,
})
if err != nil {
errs = append(errs, err)
continue
}
if state.GetErrorCode() != commonpb.ErrorCode_Success {
errs = append(errs, fmt.Errorf("Release segments failed with reason: %s", state.GetReason()))
}
}
sc.cleanupVersion(versionID)
// return err if release fail, however the whole segment change is completed
if len(errs) > 0 {
return errs
}
return nil
}
// LoadSegments loads segments with shardCluster.
// shard cluster shall try to loadSegments in the follower then update the allocation.
func (sc *ShardCluster) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) error {
@ -837,80 +783,6 @@ func (sc *ShardCluster) ReleaseSegments(ctx context.Context, req *querypb.Releas
return err
}
// appendHandoff adds the change info into pending list and returns the token.
func (sc *ShardCluster) applySegmentChange(info *querypb.SegmentChangeInfo, onlineSegmentIDs []UniqueID) int64 {
// the suspects growing segment ids
// first all online segment shall be tried, for flush-handoff only puts segment in onlineSegments
// and we need to try all offlineSegments in case flush-compact-handoff case
possibleGrowingToRemove := make([]UniqueID, 0, len(info.OfflineSegments)+len(onlineSegmentIDs))
offlineNodes := make(map[int64]int64)
for _, offline := range info.OfflineSegments {
offlineNodes[offline.GetSegmentID()] = offline.GetNodeID()
possibleGrowingToRemove = append(possibleGrowingToRemove, offline.GetSegmentID())
}
// add online segment ids to suspect list
possibleGrowingToRemove = append(possibleGrowingToRemove, onlineSegmentIDs...)
// generate next version allocation
sc.mut.RLock()
allocations := sc.segments.Clone(func(segmentID int64, nodeID int64) bool {
offlineNodeID, ok := offlineNodes[segmentID]
return ok && offlineNodeID == nodeID
})
sc.mut.RUnlock()
sc.mutVersion.Lock()
defer sc.mutVersion.Unlock()
// generate a new version
versionID := sc.nextVersionID.Inc()
// remove offline segments in next version
// so incoming request will not have allocation of these segments
version := NewShardClusterVersion(versionID, allocations, sc.currentVersion)
sc.versions.Store(versionID, version)
var lastVersionID int64
/*
----------------------------------------------------------------------------
T0 |T1(S2 online)| T2(change version)|T3(remove G2)|
----------------------------------------------------------------------------
G2, G3 |G2, G3 | G2, G3 | G3
----------------------------------------------------------------------------
S1 |S1, S2 | S1, S2 | S1,S2
----------------------------------------------------------------------------
v0=[S1] |v0=[S1] | v1=[S1,S2] | v1=[S1,S2]
There is no method to ensure search after T2 does not search G2 so that it
could be removed safely
Currently, the only safe method is to block incoming allocation, so there is no
search will be dispatch to G2.
After shard cluster is able to maintain growing semgents, this version change could
reduce the lock range
*/
// currentVersion shall be not nil
if sc.currentVersion != nil {
// wait for last version search done
<-sc.currentVersion.Expire()
lastVersionID = sc.currentVersion.versionID
// remove growing segments if any
// handles the case for Growing to Sealed Handoff(which does not has offline segment info)
if sc.leader != nil {
// error ignored here
sc.leader.client.ReleaseSegments(context.Background(), &querypb.ReleaseSegmentsRequest{
CollectionID: sc.collectionID,
SegmentIDs: possibleGrowingToRemove,
Scope: querypb.DataScope_Streaming,
})
}
}
// set current version to new one
sc.currentVersion = version
return lastVersionID
}
// cleanupVersion clean up version from map
func (sc *ShardCluster) cleanupVersion(versionID int64) {
sc.mutVersion.RLock()

View File

@ -113,24 +113,6 @@ func (s *ShardClusterService) releaseCollection(collectionID int64) {
log.Info("successfully release collection", zap.Int64("collectionID", collectionID))
}
// HandoffSegments dispatch segmentChangeInfo to related shardClusters
func (s *ShardClusterService) HandoffSegments(collectionID int64, info *querypb.SegmentChangeInfo) {
var wg sync.WaitGroup
s.clusters.Range(func(k, v interface{}) bool {
cs := v.(*ShardCluster)
if cs.collectionID == collectionID {
wg.Add(1)
go func() {
defer wg.Done()
cs.HandoffSegments(info)
}()
}
return true
})
wg.Wait()
log.Info("successfully handoff segments", zap.Int64("collectionID", collectionID))
}
// SyncReplicaSegments dispatches nodeID segments distribution to ShardCluster.
func (s *ShardClusterService) SyncReplicaSegments(vchannelName string, distribution []*querypb.ReplicaSegmentsInfo) error {
sc, ok := s.getShardCluster(vchannelName)
@ -143,23 +125,6 @@ func (s *ShardClusterService) SyncReplicaSegments(vchannelName string, distribut
return nil
}
// HandoffVChannelSegments dispatches SegmentChangeInfo to related ShardCluster with VChannel
func (s *ShardClusterService) HandoffVChannelSegments(vchannel string, info *querypb.SegmentChangeInfo) error {
raw, ok := s.clusters.Load(vchannel)
if !ok {
// not leader for this channel, ignore without error
return nil
}
sc := raw.(*ShardCluster)
err := sc.HandoffSegments(info)
if err == nil {
log.Info("successfully handoff", zap.String("channel", vchannel), zap.Any("segment", info))
} else {
log.Warn("failed to handoff", zap.String("channel", vchannel), zap.Any("segment", info), zap.Error(err))
}
return err
}
func (s *ShardClusterService) GetShardClusters() []*ShardCluster {
ret := make([]*ShardCluster, 0)
s.clusters.Range(func(key, value any) bool {

View File

@ -2,7 +2,6 @@ package querynode
import (
"context"
"errors"
"testing"
"github.com/milvus-io/milvus/internal/common"
@ -37,23 +36,6 @@ func TestShardClusterService(t *testing.T) {
assert.Error(t, err)
}
func TestShardClusterService_HandoffSegments(t *testing.T) {
qn, err := genSimpleQueryNode(context.Background())
require.NoError(t, err)
client := v3client.New(embedetcdServer.Server)
defer client.Close()
session := sessionutil.NewSession(context.Background(), "/by-dev/sessions/unittest/querynode/", client)
clusterService := newShardClusterService(client, session, qn)
clusterService.addShardCluster(defaultCollectionID, defaultReplicaID, defaultDMLChannel)
//TODO change shardCluster to interface to mock test behavior
assert.NotPanics(t, func() {
clusterService.HandoffSegments(defaultCollectionID, &querypb.SegmentChangeInfo{})
})
clusterService.releaseShardCluster(defaultDMLChannel)
}
func TestShardClusterService_SyncReplicaSegments(t *testing.T) {
qn, err := genSimpleQueryNode(context.Background())
require.NoError(t, err)
@ -68,9 +50,34 @@ func TestShardClusterService_SyncReplicaSegments(t *testing.T) {
assert.Error(t, err)
})
t.Run("sync initailizing shard cluster", func(t *testing.T) {
clusterService.addShardCluster(defaultCollectionID, defaultReplicaID, defaultDMLChannel)
sc, ok := clusterService.getShardCluster(defaultDMLChannel)
require.True(t, ok)
assert.NotPanics(t, func() {
err := clusterService.SyncReplicaSegments(defaultDMLChannel, []*querypb.ReplicaSegmentsInfo{
{
NodeId: 1,
PartitionId: defaultPartitionID,
SegmentIds: []int64{1},
Versions: []int64{1},
},
})
assert.NoError(t, err)
assert.Nil(t, sc.currentVersion)
})
})
t.Run("sync shard cluster", func(t *testing.T) {
clusterService.addShardCluster(defaultCollectionID, defaultReplicaID, defaultDMLChannel)
sc, ok := clusterService.getShardCluster(defaultDMLChannel)
require.True(t, ok)
sc.SetupFirstVersion()
err := clusterService.SyncReplicaSegments(defaultDMLChannel, []*querypb.ReplicaSegmentsInfo{
{
NodeId: 1,
@ -92,57 +99,3 @@ func TestShardClusterService_SyncReplicaSegments(t *testing.T) {
assert.Equal(t, segmentStateLoaded, segment.state)
})
}
func TestShardClusterService_HandoffVChannelSegments(t *testing.T) {
qn, err := genSimpleQueryNode(context.Background())
require.NoError(t, err)
client := v3client.New(embedetcdServer.Server)
defer client.Close()
session := sessionutil.NewSession(context.Background(), "/by-dev/sessions/unittest/querynode/", client)
clusterService := newShardClusterService(client, session, qn)
err = clusterService.HandoffVChannelSegments(defaultDMLChannel, &querypb.SegmentChangeInfo{})
assert.NoError(t, err)
clusterService.addShardCluster(defaultCollectionID, defaultReplicaID, defaultDMLChannel)
//TODO change shardCluster to interface to mock test behavior
t.Run("normal case", func(t *testing.T) {
assert.NotPanics(t, func() {
err = clusterService.HandoffVChannelSegments(defaultDMLChannel, &querypb.SegmentChangeInfo{})
assert.NoError(t, err)
})
})
t.Run("error case", func(t *testing.T) {
mqn := &mockShardQueryNode{}
nodeEvents := []nodeEvent{
{
nodeID: 3,
nodeAddr: "addr_3",
},
}
sc := NewShardCluster(defaultCollectionID, defaultReplicaID, defaultDMLChannel,
&mockNodeDetector{initNodes: nodeEvents}, &mockSegmentDetector{}, func(nodeID int64, addr string) shardQueryNode {
return mqn
})
defer sc.Close()
mqn.releaseSegmentsErr = errors.New("mocked error")
// set mocked shard cluster
clusterService.clusters.Store(defaultDMLChannel, sc)
assert.NotPanics(t, func() {
err = clusterService.HandoffVChannelSegments(defaultDMLChannel, &querypb.SegmentChangeInfo{
OfflineSegments: []*querypb.SegmentInfo{
{SegmentID: 1, NodeID: 3, CollectionID: defaultCollectionID, DmChannel: defaultDMLChannel, NodeIds: []UniqueID{3}},
},
})
assert.Error(t, err)
})
})
}

View File

@ -891,6 +891,7 @@ func TestShardCluster_SyncSegments(t *testing.T) {
evtCh: evtCh,
}, buildMockQueryNode)
defer sc.Close()
sc.SetupFirstVersion()
sc.SyncSegments([]*querypb.ReplicaSegmentsInfo{
{
@ -965,6 +966,7 @@ func TestShardCluster_SyncSegments(t *testing.T) {
evtCh: evtCh,
}, buildMockQueryNode)
defer sc.Close()
sc.SetupFirstVersion()
sc.SyncSegments([]*querypb.ReplicaSegmentsInfo{
{
@ -1009,6 +1011,7 @@ func TestShardCluster_SyncSegments(t *testing.T) {
evtCh: evtCh,
}, buildMockQueryNode)
defer sc.Close()
sc.SetupFirstVersion()
sc.SyncSegments([]*querypb.ReplicaSegmentsInfo{
{
@ -1150,7 +1153,8 @@ func TestShardCluster_Search(t *testing.T) {
defer sc.Close()
// setup first version
sc.SyncSegments(nil, segmentStateLoaded)
sc.SetupFirstVersion()
setupSegmentForShardCluster(sc, segmentEvents)
require.EqualValues(t, available, sc.state.Load())
@ -1200,7 +1204,8 @@ func TestShardCluster_Search(t *testing.T) {
defer sc.Close()
// setup first version
sc.SyncSegments(nil, segmentStateLoaded)
sc.SetupFirstVersion()
setupSegmentForShardCluster(sc, segmentEvents)
require.EqualValues(t, available, sc.state.Load())
@ -1257,7 +1262,8 @@ func TestShardCluster_Search(t *testing.T) {
defer sc.Close()
// setup first version
sc.SyncSegments(nil, segmentStateLoaded)
sc.SetupFirstVersion()
setupSegmentForShardCluster(sc, segmentEvents)
require.EqualValues(t, available, sc.state.Load())
@ -1297,19 +1303,19 @@ func TestShardCluster_Search(t *testing.T) {
}, &mockSegmentDetector{
initSegments: segmentEvents,
}, buildMockQueryNode)
// setup first version
sc.SetupFirstVersion()
setupSegmentForShardCluster(sc, segmentEvents)
//mock meta error
sc.mut.Lock()
sc.segments[3] = shardSegmentInfo{
sc.mutVersion.Lock()
sc.currentVersion.segments[3] = shardSegmentInfo{
segmentID: 3,
nodeID: 3, // node does not exist
state: segmentStateLoaded,
}
sc.mut.Unlock()
sc.mutVersion.Unlock()
defer sc.Close()
// setup first version
sc.SyncSegments(nil, segmentStateLoaded)
require.EqualValues(t, available, sc.state.Load())
@ -1367,7 +1373,8 @@ func TestShardCluster_Query(t *testing.T) {
defer sc.Close()
// setup first version
sc.SyncSegments(nil, segmentStateLoaded)
sc.SetupFirstVersion()
setupSegmentForShardCluster(sc, segmentEvents)
require.EqualValues(t, unavailable, sc.state.Load())
@ -1382,7 +1389,7 @@ func TestShardCluster_Query(t *testing.T) {
defer sc.Close()
// setup first version
sc.SyncSegments(nil, segmentStateLoaded)
sc.SetupFirstVersion()
_, err := sc.Query(ctx, &querypb.QueryRequest{
DmlChannels: []string{vchannelName + "_suffix"},
@ -1428,7 +1435,8 @@ func TestShardCluster_Query(t *testing.T) {
defer sc.Close()
// setup first version
sc.SyncSegments(nil, segmentStateLoaded)
sc.SetupFirstVersion()
setupSegmentForShardCluster(sc, segmentEvents)
require.EqualValues(t, available, sc.state.Load())
@ -1477,7 +1485,8 @@ func TestShardCluster_Query(t *testing.T) {
defer sc.Close()
// setup first version
sc.SyncSegments(nil, segmentStateLoaded)
sc.SetupFirstVersion()
setupSegmentForShardCluster(sc, segmentEvents)
require.EqualValues(t, available, sc.state.Load())
@ -1534,7 +1543,8 @@ func TestShardCluster_Query(t *testing.T) {
defer sc.Close()
// setup first version
sc.SyncSegments(nil, segmentStateLoaded)
sc.SetupFirstVersion()
setupSegmentForShardCluster(sc, segmentEvents)
require.EqualValues(t, available, sc.state.Load())
@ -1575,18 +1585,19 @@ func TestShardCluster_Query(t *testing.T) {
initSegments: segmentEvents,
}, buildMockQueryNode)
// setup first version
sc.SetupFirstVersion()
setupSegmentForShardCluster(sc, segmentEvents)
//mock meta error
sc.mut.Lock()
sc.segments[3] = shardSegmentInfo{
sc.mutVersion.Lock()
sc.currentVersion.segments[3] = shardSegmentInfo{
segmentID: 3,
nodeID: 3, // node does not exist
state: segmentStateLoaded,
}
sc.mut.Unlock()
sc.mutVersion.Unlock()
defer sc.Close()
// setup first version
sc.SyncSegments(nil, segmentStateLoaded)
require.EqualValues(t, available, sc.state.Load())
@ -1703,7 +1714,8 @@ func TestShardCluster_GetStatistics(t *testing.T) {
defer sc.Close()
// setup first version
sc.SyncSegments(nil, segmentStateLoaded)
sc.SetupFirstVersion()
setupSegmentForShardCluster(sc, segmentEvents)
require.EqualValues(t, available, sc.state.Load())
@ -1753,7 +1765,8 @@ func TestShardCluster_GetStatistics(t *testing.T) {
defer sc.Close()
// setup first version
sc.SyncSegments(nil, segmentStateLoaded)
sc.SetupFirstVersion()
setupSegmentForShardCluster(sc, segmentEvents)
require.EqualValues(t, available, sc.state.Load())
@ -1811,7 +1824,8 @@ func TestShardCluster_GetStatistics(t *testing.T) {
defer sc.Close()
// setup first version
sc.SyncSegments(nil, segmentStateLoaded)
sc.SetupFirstVersion()
setupSegmentForShardCluster(sc, segmentEvents)
require.EqualValues(t, available, sc.state.Load())
@ -1851,19 +1865,18 @@ func TestShardCluster_GetStatistics(t *testing.T) {
}, &mockSegmentDetector{
initSegments: segmentEvents,
}, buildMockQueryNode)
// setup first version
sc.SetupFirstVersion()
setupSegmentForShardCluster(sc, segmentEvents)
//mock meta error
sc.mut.Lock()
sc.segments[3] = shardSegmentInfo{
sc.mutVersion.Lock()
sc.currentVersion.segments[3] = shardSegmentInfo{
segmentID: 3,
nodeID: 3, // node does not exist
state: segmentStateLoaded,
}
sc.mut.Unlock()
sc.mutVersion.Unlock()
defer sc.Close()
// setup first version
sc.SyncSegments(nil, segmentStateLoaded)
require.EqualValues(t, available, sc.state.Load())
@ -1920,7 +1933,7 @@ func TestShardCluster_Version(t *testing.T) {
}, buildMockQueryNode)
defer sc.Close()
sc.SyncSegments(nil, segmentStateLoaded)
sc.SetupFirstVersion()
_, version := sc.segmentAllocations(nil)
sc.mut.RLock()
@ -1969,7 +1982,7 @@ func TestShardCluster_Version(t *testing.T) {
}, buildMockQueryNode)
defer sc.Close()
sc.SyncSegments(nil, segmentStateLoaded)
sc.SetupFirstVersion()
assert.True(t, sc.segmentsOnline([]shardSegmentInfo{{nodeID: 1, segmentID: 1}, {nodeID: 2, segmentID: 2}}))
assert.False(t, sc.segmentsOnline([]shardSegmentInfo{{nodeID: 1, segmentID: 1}, {nodeID: 2, segmentID: 2}, {nodeID: 1, segmentID: 3}}))
@ -1992,417 +2005,17 @@ func TestShardCluster_Version(t *testing.T) {
})
}
func TestShardCluster_HandoffSegments(t *testing.T) {
collectionID := int64(1)
otherCollectionID := int64(2)
vchannelName := "dml_1_1_v0"
otherVchannelName := "dml_1_2_v0"
replicaID := int64(0)
t.Run("handoff without using segments", func(t *testing.T) {
nodeEvents := []nodeEvent{
func setupSegmentForShardCluster(sc *ShardCluster, segmentEvents []segmentEvent) {
for _, evt := range segmentEvents {
sc.SyncSegments([]*querypb.ReplicaSegmentsInfo{
{
nodeID: 1,
nodeAddr: "addr_1",
isLeader: true,
NodeId: evt.nodeIDs[0],
PartitionId: evt.partitionID,
SegmentIds: []int64{evt.segmentID},
Versions: []int64{0},
},
{
nodeID: 2,
nodeAddr: "addr_2",
},
}
segmentEvents := []segmentEvent{
{
segmentID: 1,
nodeIDs: []int64{1},
state: segmentStateLoaded,
},
{
segmentID: 2,
nodeIDs: []int64{2},
state: segmentStateLoaded,
},
}
sc := NewShardCluster(collectionID, replicaID, vchannelName,
&mockNodeDetector{
initNodes: nodeEvents,
}, &mockSegmentDetector{
initSegments: segmentEvents,
}, buildMockQueryNode)
defer sc.Close()
sc.SyncSegments(segmentEventsToSyncInfo(nil), segmentStateLoaded)
err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName, NodeIds: []UniqueID{2}},
},
OfflineSegments: []*querypb.SegmentInfo{
{SegmentID: 1, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName, NodeIds: []UniqueID{1}},
},
})
if err != nil {
t.Log(err.Error())
}
assert.NoError(t, err)
sc.mut.RLock()
_, has := sc.segments[1]
sc.mut.RUnlock()
assert.False(t, has)
})
t.Run("handoff with growing segment(segment not recorded)", func(t *testing.T) {
nodeEvents := []nodeEvent{
{
nodeID: 1,
nodeAddr: "addr_1",
isLeader: true,
},
{
nodeID: 2,
nodeAddr: "addr_2",
},
}
segmentEvents := []segmentEvent{
{
segmentID: 1,
nodeIDs: []int64{1},
state: segmentStateLoaded,
},
{
segmentID: 2,
nodeIDs: []int64{2},
state: segmentStateLoaded,
},
}
sc := NewShardCluster(collectionID, replicaID, vchannelName,
&mockNodeDetector{
initNodes: nodeEvents,
}, &mockSegmentDetector{
initSegments: segmentEvents,
}, buildMockQueryNode)
defer sc.Close()
err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName, NodeIds: []UniqueID{2}},
{SegmentID: 4, NodeID: 2, CollectionID: otherCollectionID, DmChannel: otherVchannelName, NodeIds: []UniqueID{2}},
},
OfflineSegments: []*querypb.SegmentInfo{
{SegmentID: 3, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName, NodeIds: []UniqueID{1}},
{SegmentID: 5, NodeID: 2, CollectionID: otherCollectionID, DmChannel: otherVchannelName, NodeIds: []UniqueID{2}},
},
})
assert.NoError(t, err)
sc.mut.RLock()
_, has := sc.segments[3]
sc.mut.RUnlock()
assert.False(t, has)
})
t.Run("handoff wait online and usage", func(t *testing.T) {
nodeEvents := []nodeEvent{
{
nodeID: 1,
nodeAddr: "addr_1",
isLeader: true,
},
{
nodeID: 2,
nodeAddr: "addr_2",
},
}
segmentEvents := []segmentEvent{
{
segmentID: 1,
nodeIDs: []int64{1},
state: segmentStateLoaded,
},
{
segmentID: 2,
nodeIDs: []int64{2},
state: segmentStateLoaded,
},
}
evtCh := make(chan segmentEvent, 10)
sc := NewShardCluster(collectionID, replicaID, vchannelName,
&mockNodeDetector{
initNodes: nodeEvents,
}, &mockSegmentDetector{
initSegments: segmentEvents,
evtCh: evtCh,
}, buildMockQueryNode)
defer sc.Close()
sc.SyncSegments(segmentEventsToSyncInfo(nil), segmentStateLoaded)
//add in-use count
_, versionID := sc.segmentAllocations(nil)
sig := make(chan struct{})
go func() {
err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{
{SegmentID: 3, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName, NodeIds: []UniqueID{1}},
},
OfflineSegments: []*querypb.SegmentInfo{
{SegmentID: 1, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName, NodeIds: []UniqueID{1}},
},
})
assert.NoError(t, err)
close(sig)
}()
sc.finishUsage(versionID)
evtCh <- segmentEvent{
eventType: segmentAdd,
segmentID: 3,
nodeIDs: []int64{1},
state: segmentStateLoaded,
}
// wait for handoff appended into list
assert.Eventually(t, func() bool {
sc.mut.RLock()
defer sc.mut.RUnlock()
return sc.currentVersion.versionID != versionID
}, time.Second, time.Millisecond*10)
tmpAllocs, nVersionID := sc.segmentAllocations(nil)
found := false
for _, segments := range tmpAllocs {
if inList(segments, int64(1)) {
found = true
break
}
}
// segment 1 shall not be allocated again!
assert.False(t, found)
sc.finishUsage(nVersionID)
// rc shall be 0 now
sc.finishUsage(versionID)
// wait handoff finished
<-sig
sc.mut.RLock()
_, has := sc.segments[1]
sc.mut.RUnlock()
assert.False(t, has)
})
t.Run("load balance wait online and usage", func(t *testing.T) {
nodeEvents := []nodeEvent{
{
nodeID: 1,
nodeAddr: "addr_1",
isLeader: true,
},
{
nodeID: 2,
nodeAddr: "addr_2",
},
}
segmentEvents := []segmentEvent{
{
segmentID: 1,
nodeIDs: []int64{1},
state: segmentStateLoaded,
},
{
segmentID: 2,
nodeIDs: []int64{2},
state: segmentStateLoaded,
},
}
evtCh := make(chan segmentEvent, 10)
sc := NewShardCluster(collectionID, replicaID, vchannelName,
&mockNodeDetector{
initNodes: nodeEvents,
}, &mockSegmentDetector{
initSegments: segmentEvents,
evtCh: evtCh,
}, buildMockQueryNode)
defer sc.Close()
// init first version
sc.SyncSegments(segmentEventsToSyncInfo(nil), segmentStateLoaded)
// add rc to current version
_, versionID := sc.segmentAllocations(nil)
sig := make(chan struct{})
go func() {
err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{
{SegmentID: 1, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName, NodeIds: []UniqueID{2}},
},
OfflineSegments: []*querypb.SegmentInfo{
{SegmentID: 1, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName, NodeIds: []UniqueID{1}},
},
})
assert.NoError(t, err)
close(sig)
}()
evtCh <- segmentEvent{
eventType: segmentAdd,
segmentID: 1,
nodeIDs: []int64{2},
state: segmentStateLoaded,
}
sc.finishUsage(versionID)
// after handoff, the version id shall be changed
assert.Eventually(t, func() bool {
sc.mut.RLock()
defer sc.mut.RUnlock()
return sc.currentVersion.versionID != versionID
}, time.Second, time.Millisecond*10)
tmpAllocs, tmpVersionID := sc.segmentAllocations(nil)
for nodeID, segments := range tmpAllocs {
for _, segment := range segments {
if segment == int64(1) {
assert.Equal(t, int64(2), nodeID)
}
}
}
sc.finishUsage(tmpVersionID)
// wait handoff finished
<-sig
sc.mut.RLock()
info, has := sc.segments[1]
sc.mut.RUnlock()
assert.True(t, has)
assert.Equal(t, int64(2), info.nodeID)
})
t.Run("handoff from non-exist node", func(t *testing.T) {
nodeEvents := []nodeEvent{
{
nodeID: 1,
nodeAddr: "addr_1",
isLeader: true,
},
{
nodeID: 2,
nodeAddr: "addr_2",
},
}
segmentEvents := []segmentEvent{
{
segmentID: 1,
nodeIDs: []int64{1},
state: segmentStateLoaded,
},
{
segmentID: 2,
nodeIDs: []int64{2},
state: segmentStateLoaded,
},
}
evtCh := make(chan segmentEvent, 10)
sc := NewShardCluster(collectionID, replicaID, vchannelName,
&mockNodeDetector{
initNodes: nodeEvents,
}, &mockSegmentDetector{
initSegments: segmentEvents,
evtCh: evtCh,
}, buildMockQueryNode)
defer sc.Close()
sc.SyncSegments(segmentEventsToSyncInfo(nil), segmentStateLoaded)
err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
OnlineSegments: []*querypb.SegmentInfo{
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName, NodeIds: []UniqueID{2}},
},
OfflineSegments: []*querypb.SegmentInfo{
{SegmentID: 1, NodeID: 3, CollectionID: collectionID, DmChannel: vchannelName, NodeIds: []UniqueID{3}},
},
})
assert.NoError(t, err)
})
t.Run("release failed", func(t *testing.T) {
nodeEvents := []nodeEvent{
{
nodeID: 1,
nodeAddr: "addr_1",
isLeader: true,
},
{
nodeID: 2,
nodeAddr: "addr_2",
},
}
segmentEvents := []segmentEvent{
{
segmentID: 1,
nodeIDs: []int64{1},
state: segmentStateLoaded,
},
{
segmentID: 2,
nodeIDs: []int64{2},
state: segmentStateLoaded,
},
}
evtCh := make(chan segmentEvent, 10)
mqn := &mockShardQueryNode{}
sc := NewShardCluster(collectionID, replicaID, vchannelName,
&mockNodeDetector{
initNodes: nodeEvents,
}, &mockSegmentDetector{
initSegments: segmentEvents,
evtCh: evtCh,
}, func(nodeID int64, addr string) shardQueryNode {
return mqn
})
defer sc.Close()
mqn.releaseSegmentsErr = errors.New("mocked error")
err := sc.HandoffSegments(&querypb.SegmentChangeInfo{
OfflineSegments: []*querypb.SegmentInfo{
{SegmentID: 1, NodeID: 1, CollectionID: collectionID, DmChannel: vchannelName, NodeIds: []UniqueID{1}},
},
})
assert.Error(t, err)
mqn.releaseSegmentsErr = nil
mqn.releaseSegmentsResult = &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "mocked error",
}
err = sc.HandoffSegments(&querypb.SegmentChangeInfo{
OfflineSegments: []*querypb.SegmentInfo{
{SegmentID: 2, NodeID: 2, CollectionID: collectionID, DmChannel: vchannelName, NodeIds: []UniqueID{2}},
},
})
assert.Error(t, err)
})
}, evt.state)
}
}
type ShardClusterSuite struct {
@ -2457,6 +2070,17 @@ func (suite *ShardClusterSuite) SetupTest() {
}, &mockSegmentDetector{
initSegments: segmentEvents,
}, buildMockQueryNode)
suite.sc.SetupFirstVersion()
for _, evt := range segmentEvents {
suite.sc.SyncSegments([]*querypb.ReplicaSegmentsInfo{
{
NodeId: evt.nodeIDs[0],
PartitionId: evt.partitionID,
SegmentIds: []int64{evt.segmentID},
Versions: []int64{0},
},
}, segmentStateLoaded)
}
}
func (suite *ShardClusterSuite) TearDownTest() {
@ -2471,16 +2095,20 @@ func (suite *ShardClusterSuite) TestReleaseSegments() {
nodeID int64
scope querypb.DataScope
expectAlloc map[int64][]int64
expectError bool
force bool
}
cases := []TestCase{
{
tag: "normal release",
segmentIDs: []int64{2},
nodeID: 2,
scope: querypb.DataScope_All,
tag: "normal release",
segmentIDs: []int64{2},
nodeID: 2,
scope: querypb.DataScope_All,
expectAlloc: map[int64][]int64{
1: {1},
},
expectError: false,
force: false,
},
@ -2501,6 +2129,9 @@ func (suite *ShardClusterSuite) TestReleaseSegments() {
suite.Error(err)
} else {
suite.NoError(err)
alloc, vid := suite.sc.segmentAllocations(nil)
suite.sc.finishUsage(vid)
suite.Equal(test.expectAlloc, alloc)
}
})
}