diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index 02b5279d02..94e4965919 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -20,6 +20,7 @@ import ( "context" "errors" "fmt" + "math/rand" "strconv" "sync" "time" @@ -120,12 +121,22 @@ type shardLeadersReader struct { // Shuffle returns the shuffled shard leader list. func (it shardLeadersReader) Shuffle() map[string][]nodeInfo { result := make(map[string][]nodeInfo) + rand.Seed(time.Now().UnixNano()) for channel, leaders := range it.leaders.shardLeaders { l := len(leaders) - shuffled := make([]nodeInfo, 0, len(leaders)) - for i := 0; i < l; i++ { - shuffled = append(shuffled, leaders[(i+int(it.idx))%l]) + // shuffle all replica at random order + shuffled := make([]nodeInfo, l) + for i, randIndex := range rand.Perm(l) { + shuffled[i] = leaders[randIndex] } + + // make each copy has same probability to be first replica + for index, leader := range shuffled { + if leader == leaders[int(it.idx)%l] { + shuffled[0], shuffled[index] = shuffled[index], shuffled[0] + } + } + result[channel] = shuffled } return result diff --git a/internal/proxy/meta_cache_test.go b/internal/proxy/meta_cache_test.go index 8b92f98d2e..bbe5ab4ba7 100644 --- a/internal/proxy/meta_cache_test.go +++ b/internal/proxy/meta_cache_test.go @@ -26,6 +26,7 @@ import ( "time" "github.com/milvus-io/milvus/internal/util/funcutil" + uatomic "go.uber.org/atomic" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -842,3 +843,42 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) { return len(nodeInfos["channel-1"]) == 3 && len(nodeInfos["channel-2"]) == 3 }, 3*time.Second, 1*time.Second) } + +func TestGlobalMetaCache_ShuffleShardLeaders(t *testing.T) { + shards := map[string][]nodeInfo{ + "channel-1": { + { + nodeID: 1, + address: "localhost:9000", + }, + { + nodeID: 2, + address: "localhost:9000", + }, + { + nodeID: 3, + address: "localhost:9000", + }, + }, + } + sl := &shardLeaders{ + deprecated: uatomic.NewBool(false), + idx: uatomic.NewInt64(5), + shardLeaders: shards, + } + + reader := sl.GetReader() + result := reader.Shuffle() + assert.Len(t, result["channel-1"], 3) + assert.Equal(t, int64(1), result["channel-1"][0].nodeID) + + reader = sl.GetReader() + result = reader.Shuffle() + assert.Len(t, result["channel-1"], 3) + assert.Equal(t, int64(2), result["channel-1"][0].nodeID) + + reader = sl.GetReader() + result = reader.Shuffle() + assert.Len(t, result["channel-1"], 3) + assert.Equal(t, int64(3), result["channel-1"][0].nodeID) +}